@@ -113,27 +113,25 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T}
113113    if  purity_thresh >=  1.0 
114114        return  tree
115115    end 
116-     function  _prune_run (tree:: LeafOrNode{S, T} , purity_thresh:: Real ) where  {S, T}
117-         N  =  length (tree)
118-         if  N  ==  1         # # a Leaf
116+     function  _prune_run (tree:: LeafOrNode{S, T, N } , purity_thresh:: Real ) where  {S, T, N }
117+         L  =  length (tree)
118+         if  L  ==  1         # # a Leaf
119119            return  tree
120-         elseif  N  ==  2     # # a stump
121-             all_labels  =  [ tree. left. values;  tree. right. values] 
122-             majority  =  majority_vote (all_labels) 
123-             matches  =  findall (all_labels  .==  majority )
124-             purity =  length (matches)  /  length (all_labels) 
120+         elseif  L  ==  2     # # a stump
121+             combined  =  tree. left. values  .+   tree. right. values
122+             total  =  tree . left . total  +  tree . right . total 
123+             majority  =  argmax (combined )
124+             purity =  combined[majority]  /  total 
125125            if  purity >=  purity_thresh
126-                 features =  Tuple (unique (all_labels))
127-                 featfreq =  Tuple (sum (all_labels .==  f) for  f in  features)
128-                 return  Leaf {T} (features, argmax (featfreq),
129-                                featfreq, length (all_labels))
126+                 return  Leaf {T, N} (tree. left. features, majority, combined, total)
130127            else 
131128                return  tree
132129            end 
133130        else 
134-             return  Node {S, T} (tree. featid, tree. featval,
135-                         _prune_run (tree. left, purity_thresh),
136-                         _prune_run (tree. right, purity_thresh))
131+             return  Node {S, T, N} (
132+                 tree. featid, tree. featval,
133+                 _prune_run (tree. left, purity_thresh),
134+                 _prune_run (tree. right, purity_thresh))
137135        end 
138136    end 
139137    pruned =  _prune_run (tree, purity_thresh)
0 commit comments