@@ -224,22 +224,19 @@ function prune_tree(
224224 end
225225 ntt = nsample (tree)
226226 function _prune_run_stump (
227- tree:: LeafOrNode{S, T} ,
227+ tree:: LeafOrNode{S, T, N } ,
228228 purity_thresh:: Real ,
229229 fi:: Vector{Float64} = Float64[]
230- ) where {S, T}
231- all_labels = [ tree. left. values; tree. right. values]
232- majority = majority_vote (all_labels)
233- matches = findall (all_labels .== majority )
234- purity = length (matches) / length (all_labels)
230+ ) where {S, T, N }
231+ combined = tree. left. values .+ tree. right. values
232+ total = tree . left . total + tree . right . total
233+ majority = argmax (combined )
234+ purity = combined[majority] / total
235235 if purity >= purity_thresh
236236 if ! isempty (fi)
237237 update_pruned_impurity! (tree, fi, ntt, loss)
238238 end
239- features = Tuple (unique (all_labels))
240- featfreq = Tuple (sum (all_labels .== f) for f in features)
241- return Leaf {T} (features, argmax (featfreq),
242- featfreq, length (all_labels))
239+ return Leaf {T, N} (tree. left. features, majority, combined, total)
243240 else
244241 return tree
245242 end
@@ -250,19 +247,20 @@ function prune_tree(
250247 return Root {S, T} (node, tree. n_feat, fi)
251248 end
252249 function _prune_run (
253- tree:: LeafOrNode{S, T} ,
250+ tree:: LeafOrNode{S, T, N } ,
254251 purity_thresh:: Real ,
255252 fi:: Vector{Float64} = Float64[]
256- ) where {S, T}
257- N = length (tree)
258- if N == 1 # # a Leaf
253+ ) where {S, T, N }
254+ L = length (tree)
255+ if L == 1 # # a Leaf
259256 return tree
260- elseif N == 2 # # a stump
257+ elseif L == 2 # # a stump
261258 return _prune_run_stump (tree, purity_thresh, fi)
262259 else
263- left = _prune_run (tree. left, purity_thresh, fi)
264- right = _prune_run (tree. right, purity_thresh, fi)
265- return Node {S, T} (tree. featid, tree. featval, left, right)
260+ return Node {S, T, N} (
261+ tree. featid, tree. featval,
262+ _prune_run (tree. left, purity_thresh),
263+ _prune_run (tree. right, purity_thresh))
266264 end
267265 end
268266 pruned = _prune_run (tree, purity_thresh)
0 commit comments