@@ -117,6 +117,7 @@ function build_stump(
117117 labels :: AbstractVector{T} ,
118118 features :: AbstractMatrix{S} ,
119119 weights = nothing ;
120+ n_classes :: Int = length (unique (labels)),
120121 rng = Random. GLOBAL_RNG,
121122 impurity_importance :: Bool = true ) where {S, T}
122123
@@ -133,7 +134,7 @@ function build_stump(
133134 min_purity_increase = 0.0 ;
134135 rng = rng)
135136
136- return _build_tree (t, labels, size (features, 2 ), size (features, 1 ), impurity_importance)
137+ return _build_tree (t, labels, n_classes, size (features, 2 ), size (features, 1 ), impurity_importance)
137138end
138139
139140function build_tree (
@@ -144,6 +145,7 @@ function build_tree(
144145 min_samples_leaf = 1 ,
145146 min_samples_split = 2 ,
146147 min_purity_increase = 0.0 ;
148+ n_classes :: Int = length (unique (labels)),
147149 loss = util. entropy :: Function ,
148150 rng = Random. GLOBAL_RNG,
149151 impurity_importance :: Bool = true ) where {S, T}
@@ -168,18 +170,18 @@ function build_tree(
168170 min_purity_increase = Float64 (min_purity_increase),
169171 rng = rng)
170172
171- return _build_tree (t, labels, size (features, 2 ), size (features, 1 ), impurity_importance)
173+ return _build_tree (t, labels, n_classes, size (features, 2 ), size (features, 1 ), impurity_importance)
172174end
173175
174176function _build_tree (
175177 tree:: treeclassifier.Tree{S, T} ,
176178 labels:: AbstractVector{T} ,
179+ n_classes:: Int ,
177180 n_features,
178181 n_samples,
179182 impurity_importance:: Bool
180183) where {S, T}
181184 node = _convert (tree. root, tree. list, labels[tree. labels])
182- n_classes = unique (labels) |> length
183185 if ! impurity_importance
184186 return Root {S, T, n_classes} (node, n_features, Float64[])
185187 else
@@ -237,15 +239,15 @@ function prune_tree(
237239 if ! isempty (fi)
238240 update_pruned_impurity! (tree, fi, ntt, loss)
239241 end
240- return Leaf {T, N} (tree. left. features , majority, combined, total)
242+ return Leaf {T, N} (tree. left. classes , majority, combined, total)
241243 else
242244 return tree
243245 end
244246 end
245247 function _prune_run (tree:: Root{S, T, N} , purity_thresh:: Real ) where {S, T, N}
246248 fi = deepcopy (tree. featim) # # recalculate feature importances
247249 node = _prune_run (tree. node, purity_thresh, fi)
248- return Root {S, T, N} (node, fi)
250+ return Root {S, T, N} (node, tree . n_feat, fi)
249251 end
250252 function _prune_run (
251253 tree:: LeafOrNode{S, T, N} ,
@@ -273,7 +275,7 @@ function prune_tree(
273275end
274276
275277
276- apply_tree (leaf:: Leaf , feature:: AbstractVector ) = leaf. features [leaf. majority]
278+ apply_tree (leaf:: Leaf , feature:: AbstractVector ) = leaf. classes [leaf. majority]
277279apply_tree (
278280 tree:: Root{S, T} ,
279281 features:: AbstractVector{S}
@@ -369,10 +371,11 @@ function build_forest(
369371
370372 t_samples = length (labels)
371373 n_samples = floor (Int, partial_sampling * t_samples)
374+ n_classes = length (unique (labels))
372375
373376 forest = impurity_importance ?
374- Vector {Root{S, T}} (undef, n_trees) :
375- Vector {LeafOrNode{S, T}} (undef, n_trees)
377+ Vector {Root{S, T, n_classes }} (undef, n_trees) :
378+ Vector {LeafOrNode{S, T, n_classes }} (undef, n_trees)
376379
377380 entropy_terms = util. compute_entropy_terms (n_samples)
378381 loss = (ns, n) -> util. entropy (ns, n, entropy_terms)
@@ -392,7 +395,8 @@ function build_forest(
392395 max_depth,
393396 min_samples_leaf,
394397 min_samples_split,
395- min_purity_increase,
398+ min_purity_increase;
399+ n_classes,
396400 loss = loss,
397401 rng = _rng,
398402 impurity_importance = impurity_importance)
@@ -408,7 +412,8 @@ function build_forest(
408412 max_depth,
409413 min_samples_leaf,
410414 min_samples_split,
411- min_purity_increase,
415+ min_purity_increase;
416+ n_classes,
412417 loss = loss,
413418 impurity_importance = impurity_importance)
414419 end
@@ -418,13 +423,13 @@ function build_forest(
418423end
419424
420425function _build_forest (
421- forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}} ,
426+ forest :: Vector{<: Union{Root{S, T, N }, LeafOrNode{S, T, N }}} ,
422427 n_features ,
423428 n_trees ,
424- impurity_importance :: Bool ) where {S, T}
429+ impurity_importance :: Bool ) where {S, T, N }
425430
426431 if ! impurity_importance
427- return Ensemble {S, T} (forest, n_features, Float64[])
432+ return Ensemble {S, T, N } (forest, n_features, Float64[])
428433 else
429434 fi = zeros (Float64, n_features)
430435 for tree in forest
@@ -434,12 +439,12 @@ function _build_forest(
434439 end
435440 end
436441
437- forest_new = Vector {LeafOrNode{S, T}} (undef, n_trees)
442+ forest_new = Vector {LeafOrNode{S, T, N }} (undef, n_trees)
438443 Threads. @threads for i in 1 : n_trees
439444 forest_new[i] = forest[i]. node
440445 end
441446
442- return Ensemble {S, T} (forest_new, n_features, fi ./ n_trees)
447+ return Ensemble {S, T, N } (forest_new, n_features, fi ./ n_trees)
443448 end
444449end
445450
@@ -516,11 +521,13 @@ function build_adaboost_stumps(
516521 stumps = Node{S, T}[]
517522 coeffs = Float64[]
518523 n_features = size (features, 2 )
524+ n_classes = length (unique (labels))
519525 for i in 1 : n_iterations
520526 new_stump = build_stump (
521527 labels,
522528 features,
523529 weights;
530+ n_classes,
524531 rng= mk_rng (rng),
525532 impurity_importance= false
526533 )
0 commit comments