diff --git a/dtreeviz/models/shadow_decision_tree.py b/dtreeviz/models/shadow_decision_tree.py index c7522c94..c5761f12 100644 --- a/dtreeviz/models/shadow_decision_tree.py +++ b/dtreeviz/models/shadow_decision_tree.py @@ -464,7 +464,7 @@ def get_shadow_tree(tree_model, x_data, y_data, feature_names, target_name, clas elif (str(type(tree_model)).endswith("pyspark.ml.classification.DecisionTreeClassificationModel'>") or str(type(tree_model)).endswith("pyspark.ml.classification.DecisionTreeClassificationModel'>")): from dtreeviz.models import spark_decision_tree - return spark_decision_tree.ShadowSparkTree(tree_model, tree_index, x_data, y_data, + return spark_decision_tree.ShadowSparkTree(tree_model, x_data, y_data, feature_names, target_name, class_names) elif "lightgbm.basic.Booster" in str(type(tree_model)): from dtreeviz.models import lightgbm_decision_tree