From 917c7fa8ce098323f3fd085753051c451573d65d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 31 May 2024 13:51:22 +1200 Subject: [PATCH 1/3] bump 0.4.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8d2d1ec..a069e8f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJEnsembles" uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" authors = ["Anthony D. Blaom "] -version = "0.4.2" +version = "0.4.3" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" From 4b91b754a9386fd7bd6449f284e3ea4582fdc56f Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 31 May 2024 13:52:01 +1200 Subject: [PATCH 2/3] bump [compat] MLJModelInterface = "1.10" --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a069e8f..704def5 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,7 @@ CategoricalArrays = "0.8, 0.9, 0.10" CategoricalDistributions = "0.1.2" ComputationalResources = "0.3" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" -MLJModelInterface = "0.4.1, 1.1" +MLJModelInterface = "1.10" ProgressMeter = "1.1" ScientificTypesBase = "2,3" StatisticalMeasuresBase = "0.1" From 615b95ae86efd5bbd697cc127c7f3c14a6d7024e Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 3 Jun 2024 19:50:08 +1200 Subject: [PATCH 3/3] overload constructor trait for EnsembleModel types --- src/ensembles.jl | 8 ++------ test/ensembles.jl | 3 +++ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/ensembles.jl b/src/ensembles.jl index 122a87b..28cb2e0 100644 --- a/src/ensembles.jl +++ b/src/ensembles.jl @@ -200,7 +200,6 @@ _reducer(p, q) = vcat(p, q) _reducer(p::Tuple, q::Tuple) = (vcat(p[1], q[1]), vcat(p[2], q[2])) - # # ENSEMBLE MODEL TYPES mutable struct DeterministicEnsembleModel{Atom<:Deterministic} <: Deterministic @@ -638,11 +637,8 @@ end # Note: input and target traits are inherited from atom -MMI.load_path(::Type{<:ProbabilisticEnsembleModel}) = - "MLJ.ProbabilisticEnsembleModel" -MMI.load_path(::Type{<:DeterministicEnsembleModel}) = - "MLJ.DeterministicEnsembleModel" - +MMI.load_path(::Type{<:EitherEnsembleModel}) = "MLJEnsembles.EnsembleModel" +MMI.constructor(::Type{<:EitherEnsembleModel}) = EnsembleModel MMI.is_wrapper(::Type{<:EitherEnsembleModel}) = true MMI.supports_weights(::Type{<:EitherEnsembleModel{Atom}}) where Atom = MMI.supports_weights(Atom) diff --git a/test/ensembles.jl b/test/ensembles.jl index f371dde..1cb1dfc 100644 --- a/test/ensembles.jl +++ b/test/ensembles.jl @@ -63,6 +63,9 @@ X = MLJEnsembles.table(ones(5,3)) y = categorical(collect("asdfa")) train, test = partition(1:length(y), 0.8); ensemble_model = EnsembleModel(model=atom) +@test constructor(ensemble_model) == EnsembleModel +@test load_path(ensemble_model) == "MLJEnsembles.EnsembleModel" +@test package_name(ensemble_model) == "MLJEnsembles" ensemble_model.n = 10 fitresult, cache, report = MLJEnsembles.fit(ensemble_model, 0, X, y) predict(ensemble_model, fitresult, MLJEnsembles.selectrows(X, test))