How to Extract Feature Information for Tree-based Apache SparkML Pipeline Models
When you are fitting a tree-based model, such as a decision tree, random forest, or gradient boosted tree, it is helpful to be able to review the feature importance levels along with the feature names. Typically models in SparkML are fit as the last stage of the pipeline. To extract the relevant feature information from the pipeline with the tree model, you must extract the correct pipeline stage. You can extract the feature names from the
from pyspark.ml.feature import StringIndexer, VectorAssembler from pyspark.ml.classification import DecisionTreeClassifier from pyspark.ml import Pipeline pipeline = Pipeline(stages=[indexer, assembler, decision_tree) DTmodel = pipeline.fit(train) va = dtModel.stages[-2] tree = DTmodel.stages[-1] display(tree) #visualize the decision tree model print(tree.toDebugString) #print the nodes of the decision tree model list(zip(va.getInputCols(), tree.featureImportances))
You can also tune a tree-based model using a cross validator in the last stage of the pipeline. To visualize the decision tree and print the feature importance levels, you extract the
bestModel from the
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator cv = CrossValidator(estimator=decision_tree, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3) pipelineCV = Pipeline(stages=[indexer, assembler, cv) DTmodelCV = pipelineCV.fit(train) va = DTmodelCV.stages[-2] treeCV = DTmodelCV.stages[-1].bestModel display(treeCV) #visualize the best decision tree model print(treeCV.toDebugString) #print the nodes of the decision tree model list(zip(va.getInputCols(), treeCV.featureImportances))
display function visualizes decision tree models only. See Machine learning visualizations.