@@ -179,12 +179,13 @@ function _build_tree(
179179 impurity_importance:: Bool
180180) where {S, T}
181181 node = _convert (tree. root, tree. list, labels[tree. labels])
182+ n_classes = unique (labels) |> length
182183 if ! impurity_importance
183- return Root {S, T} (node, n_features, Float64[])
184+ return Root {S, T, n_classes } (node, n_features, Float64[])
184185 else
185186 fi = zeros (Float64, n_features)
186187 update_using_impurity! (fi, tree. root)
187- return Root {S, T} (node, n_features, fi ./ n_samples)
188+ return Root {S, T, n_classes } (node, n_features, fi ./ n_samples)
188189 end
189190end
190191
@@ -241,10 +242,10 @@ function prune_tree(
241242 return tree
242243 end
243244 end
244- function _prune_run (tree:: Root{S, T} , purity_thresh:: Real ) where {S, T}
245+ function _prune_run (tree:: Root{S, T, N } , purity_thresh:: Real ) where {S, T, N }
245246 fi = deepcopy (tree. featim) # # recalculate feature importances
246247 node = _prune_run (tree. node, purity_thresh, fi)
247- return Root {S, T} (node, tree . n_feat , fi)
248+ return Root {S, T, N } (node, fi)
248249 end
249250 function _prune_run (
250251 tree:: LeafOrNode{S, T, N} ,
0 commit comments