@@ -28,16 +28,18 @@ export InfoNode, InfoLeaf, wrap
2828# ##########################
2929# ######### Types ##########
3030
31- struct Leaf{T}
32- majority :: T
33- values :: Vector{T}
31+ struct Leaf{T, N}
32+ features :: NTuple{N, T}
33+ majority :: Int
34+ values :: NTuple{N, Int}
35+ total :: Int
3436end
3537
36- struct Node{S, T}
38+ struct Node{S, T, N }
3739 featid :: Int
3840 featval :: S
39- left :: Union{Leaf{T}, Node{S, T}}
40- right :: Union{Leaf{T}, Node{S, T}}
41+ left :: Union{Leaf{T, N }, Node{S, T, N }}
42+ right :: Union{Leaf{T, N }, Node{S, T, N }}
4143end
4244
4345const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
@@ -46,11 +48,15 @@ struct Ensemble{S, T}
4648 trees :: Vector{LeafOrNode{S, T}}
4749end
4850
51+ Leaf (features:: NTuple{T, N} ) where {T, N} =
52+ Leaf (features, 0 , Tuple (zeros (T, N)), 0 )
53+
4954is_leaf (l:: Leaf ) = true
5055is_leaf (n:: Node ) = false
5156
5257zero (String) = " "
53- convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} = Node (0 , zero (S), lf, Leaf (zero (T), [zero (T)]))
58+ convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} =
59+ Node (0 , zero (S), lf, Leaf (lf. features))
5460promote_rule (:: Type{Node{S, T}} , :: Type{Leaf{T}} ) where {S, T} = Node{S, T}
5561promote_rule (:: Type{Leaf{T}} , :: Type{Node{S, T}} ) where {S, T} = Node{S, T}
5662
@@ -81,9 +87,8 @@ depth(leaf::Leaf) = 0
8187depth (tree:: Node ) = 1 + max (depth (tree. left), depth (tree. right))
8288
8389function print_tree (io:: IO , leaf:: Leaf , depth= - 1 , indent= 0 ; feature_names= nothing )
84- n_matches = count (leaf. values .== leaf. majority)
85- ratio = string (n_matches, " /" , length (leaf. values))
86- println (io, " $(leaf. majority) : $(ratio) " )
90+ println (io, " $(leaf. features[leaf. majority]) : " ,
91+ leaf. values[leaf. majority], ' /' , leaf. total)
8792end
8893function print_tree (leaf:: Leaf , depth= - 1 , indent= 0 ; feature_names= nothing )
8994 return print_tree (stdout , leaf, depth, indent; feature_names= feature_names)
139144
140145function show (io:: IO , leaf:: Leaf )
141146 println (io, " Decision Leaf" )
142- println (io, " Majority: $( leaf. majority) " )
143- print (io, " Samples: $( length ( leaf. values)) " )
147+ println (io, " Majority: " , leaf. features[leaf . majority] )
148+ print (io, " Samples: " , leaf. total )
144149end
145150
146151function show (io:: IO , tree:: Node )
0 commit comments