From 615b95ae86efd5bbd697cc127c7f3c14a6d7024e Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 3 Jun 2024 19:50:08 +1200 Subject: [PATCH] overload constructor trait for EnsembleModel types --- src/ensembles.jl | 8 ++------ test/ensembles.jl | 3 +++ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/ensembles.jl b/src/ensembles.jl index 122a87b..28cb2e0 100644 --- a/src/ensembles.jl +++ b/src/ensembles.jl @@ -200,7 +200,6 @@ _reducer(p, q) = vcat(p, q) _reducer(p::Tuple, q::Tuple) = (vcat(p[1], q[1]), vcat(p[2], q[2])) - # # ENSEMBLE MODEL TYPES mutable struct DeterministicEnsembleModel{Atom<:Deterministic} <: Deterministic @@ -638,11 +637,8 @@ end # Note: input and target traits are inherited from atom -MMI.load_path(::Type{<:ProbabilisticEnsembleModel}) = - "MLJ.ProbabilisticEnsembleModel" -MMI.load_path(::Type{<:DeterministicEnsembleModel}) = - "MLJ.DeterministicEnsembleModel" - +MMI.load_path(::Type{<:EitherEnsembleModel}) = "MLJEnsembles.EnsembleModel" +MMI.constructor(::Type{<:EitherEnsembleModel}) = EnsembleModel MMI.is_wrapper(::Type{<:EitherEnsembleModel}) = true MMI.supports_weights(::Type{<:EitherEnsembleModel{Atom}}) where Atom = MMI.supports_weights(Atom) diff --git a/test/ensembles.jl b/test/ensembles.jl index f371dde..1cb1dfc 100644 --- a/test/ensembles.jl +++ b/test/ensembles.jl @@ -63,6 +63,9 @@ X = MLJEnsembles.table(ones(5,3)) y = categorical(collect("asdfa")) train, test = partition(1:length(y), 0.8); ensemble_model = EnsembleModel(model=atom) +@test constructor(ensemble_model) == EnsembleModel +@test load_path(ensemble_model) == "MLJEnsembles.EnsembleModel" +@test package_name(ensemble_model) == "MLJEnsembles" ensemble_model.n = 10 fitresult, cache, report = MLJEnsembles.fit(ensemble_model, 0, X, y) predict(ensemble_model, fitresult, MLJEnsembles.selectrows(X, test))