You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
a function export_graphviz() returns ValueError: cannot convert float NaN to integer error on ExtraTrees and RandomForest classifier algorithms after Intelex patching.
python3.10/site-packages/sklearn/tree/_export.py:258, in <listcomp>(.0)
254 alpha = (value - self.colors["bounds"][0]) / (
255 self.colors["bounds"][1] - self.colors["bounds"][0]
256 )
257 # compute the color as alpha against white
--> 258 color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color]
259 # Return html color code in #RRGGBB format
260 return "#%2x%2x%2x" % tuple(color)
ValueError: cannot convert float NaN to integer
The original scikit-learn returns [0.0 1.0] for the classifier.estimators_[0].classes_ but after patching Intelex, it returns 0. Maybe this is linked to the following code?
from sklearnex import patch_sklearn
patch_sklearn()
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
# Load the dataset
data = load_breast_cancer()
X = data['data']
y = data['target']
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.3, random_state=1)
# Train the model
clf = ExtraTreesClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
from sklearn.tree import export_graphviz
import graphviz
# Export a single tree from the forest
tree = clf.estimators_[0]
for each in clf.estimators_:
print(each.classes_)
dot_data = export_graphviz(tree, out_file=None,
feature_names=data.feature_names,
class_names=data.target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("extratree") # Saves the tree as a .pdf file
# Display the tree
graph
Expected behavior
Print [0 1] and show a tree plot.
YoochanMyung
changed the title
Cannot visualize a tree plot with ExtraTrees and Randomforest
Cannot visualize a tree plot with ExtraTrees and Randomforest classifiers
Jul 5, 2024
Describe the bug
a function export_graphviz() returns
ValueError: cannot convert float NaN to integer
error on ExtraTrees and RandomForest classifier algorithms after Intelex patching.The original scikit-learn returns
[0.0 1.0]
for theclassifier.estimators_[0].classes_
but after patching Intelex, it returns0
. Maybe this is linked to the following code?https://github.com/intel/scikit-learn-intelex/blob/01def265ba59d7d4e1eb2e5944d938e274d1bde8/sklearnex/ensemble/_forest.py#L478
To Reproduce
Expected behavior
Print [0 1] and show a tree plot.
Output/Screenshots
Environment:
The text was updated successfully, but these errors were encountered: