如何使用 Apache Spark 执行组 K-折交叉验证

交叉验证会将训练数据随机拆分为指定数量的折叠。 为了防止相同数据在多个折叠中出现时发生数据泄漏,可以使用组。 scikit-learn 支持组 K-折交叉验证,以确保折叠不同且不重叠。

在 Spark 上,你可以使用 spark-sklearn 存储 spark-sklearn优化的库来利用此方法。 此示例通过 grp 变量在 Spark 上使用组 K-折方法优化 scikit-learn 随机林模型:

from sklearn.ensemble import RandomForestClassifier
from spark_sklearn import GridSearchCV
from sklearn.model_selection import GroupKFold
param_grid = {"max_depth": [8, 12, None],
              "max_features": [1, 3, 10],
              "min_samples_split": [1, 3, 10],
              "min_samples_leaf": [1, 3, 10],
              "bootstrap": [True, False],
              "criterion": ["gini", "entropy"],
              "n_estimators": [20, 40, 80]}
group_kfold = GroupKFold(n_splits=3)
gs = GridSearchCV(sc, estimator = RandomForestClassifier(random_state=42), param_grid=param_grid, cv = group_kfold)
gs.fit(X1, y1 ,grp)

注意

  • 用于运行网格搜索的库称为 spark-sklearn,因此必须首先传入 Spark 上下文(sc 参数)。
  • X1y1 参数必须是 pandas 数据帧。 此网格搜索选项只对适合驱动程序的数据有效。