@@ -26,16 +26,18 @@ export InfoNode, InfoLeaf, wrap
2626# ##########################
2727# ######### Types ##########
2828
29- struct Leaf{T}
30- majority :: T
31- values :: Vector{T}
29+ struct Leaf{T, N}
30+ features :: NTuple{N, T}
31+ majority :: Int
32+ values :: NTuple{N, Int}
33+ total :: Int
3234end
3335
34- struct Node{S, T}
36+ struct Node{S, T, N }
3537 featid :: Int
3638 featval :: S
37- left :: Union{Leaf{T}, Node{S, T}}
38- right :: Union{Leaf{T}, Node{S, T}}
39+ left :: Union{Leaf{T, N }, Node{S, T, N }}
40+ right :: Union{Leaf{T, N }, Node{S, T, N }}
3941end
4042
4143const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
@@ -52,13 +54,15 @@ struct Ensemble{S, T}
5254 featim :: Vector{Float64}
5355end
5456
57+ Leaf (features:: NTuple{T, N} ) where {T, N} =
58+ Leaf (features, 0 , Tuple (zeros (T, N)), 0 )
5559
5660is_leaf (l:: Leaf ) = true
5761is_leaf (n:: Node ) = false
5862
5963_zero (:: Type{String} ) = " "
6064_zero (x:: Any ) = zero (x)
61- convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} = Node (0 , _zero (S), lf, Leaf (_zero (T), [ _zero (T)] ))
65+ convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} = Node (0 , _zero (S), lf, Leaf (lf . features ))
6266convert (:: Type{Root{S, T}} , node:: LeafOrNode{S, T} ) where {S, T} = Root {S, T} (node, 0 , Float64[])
6367convert (:: Type{LeafOrNode{S, T}} , tree:: Root{S, T} ) where {S, T} = tree. node
6468promote_rule (:: Type{Node{S, T}} , :: Type{Leaf{T}} ) where {S, T} = Node{S, T}
@@ -97,9 +101,8 @@ depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
97101depth (tree:: Root ) = depth (tree. node)
98102
99103function print_tree (io:: IO , leaf:: Leaf , depth= - 1 , indent= 0 ; sigdigits= 4 , feature_names= nothing )
100- n_matches = count (leaf. values .== leaf. majority)
101- ratio = string (n_matches, " /" , length (leaf. values))
102- println (io, " $(leaf. majority) : $(ratio) " )
104+ println (io, leaf. features[leaf. majority], " : " ,
105+ leaf. values[leaf. majority], ' /' , leaf. total)
103106end
104107function print_tree (leaf:: Leaf , depth= - 1 , indent= 0 ; sigdigits= 4 , feature_names= nothing )
105108 return print_tree (stdout , leaf, depth, indent; sigdigits, feature_names)
162165
163166function show (io:: IO , leaf:: Leaf )
164167 println (io, " Decision Leaf" )
165- println (io, " Majority: $( leaf. majority) " )
166- print (io, " Samples: $( length ( leaf. values)) " )
168+ println (io, " Majority: " , leaf. features[leaf . majority] )
169+ print (io, " Samples: " , leaf. total )
167170end
168171
169172function show (io:: IO , tree:: Node )
0 commit comments