Skip to content

Commit

Permalink
Merge pull request #37 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.4.1 release
  • Loading branch information
ablaom authored Apr 18, 2024
2 parents 6ab6c66 + 794335c commit af03f16
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 64 deletions.
9 changes: 9 additions & 0 deletions .github/codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
coverage:
status:
project:
default:
threshold: 0.5%
removed_code_behavior: fully_covered_patch
patch:
default:
target: 80%
48 changes: 0 additions & 48 deletions .github/workflows/CI-nightly.yml

This file was deleted.

3 changes: 1 addition & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,5 @@ jobs:
- uses: julia-actions/julia-runtest@latest
- uses: julia-actions/[email protected]
continue-on-error: true
- uses: julia-actions/[email protected]
continue-on-error: true


2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJEnsembles"
uuid = "50ed68f4-41fd-4504-931a-ed422449fee0"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.4.0"
version = "0.4.1"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
51 changes: 38 additions & 13 deletions src/ensembles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,31 +408,56 @@ function _fit(res::CPUProcesses, func, verbosity, stuff)
end
end

# Create thread safe version of RNGs.
# Random._GLOBAL_RNG() and Random.default_rng() are threadsafe by default_rng
# as they have thread local state from julia >=1.3<=1.6 and task local state Julia >=1.7
threadsafe_rng(rng::typeof(Random.default_rng())) = rng
threadsafe_rng(rng::Random._GLOBAL_RNG) = rng
threadsafe_rng(rng) = deepcopy(rng)

function _fit(res::CPUThreads, func, verbosity, stuff)
atom, n, n_patterns, n_train, rng, progress_meter, args = stuff
if verbosity > 0
println("Ensemble-building in parallel on $(Threads.nthreads()) threads.")
end

nthreads = Threads.nthreads()

if nthreads == 1
return _fit(CPU1(), func, verbosity, stuff)
end

chunk_size = div(n, nthreads)
left_over = mod(n, nthreads)
resvec = Vector(undef, nthreads) # FIXME: Make this type-stable?

Threads.@threads for i = 1:nthreads
resvec[i] = if i != nworkers()
func(atom, 0, chunk_size, n_patterns, n_train, rng, progress_meter, args...)
else
func(
atom,
0,
chunk_size + left_over,
n_patterns,
n_train,
rng,
progress_meter,
args...,
@sync begin
for i in 1:nthreads-1
Threads.@spawn(
resvec[i] = func(
atom,
0,
chunk_size,
n_patterns,
n_train,
threadsafe_rng(rng),
progress_meter,
args...
)
)
end
Threads.@spawn(
resvec[nthreads] = func(
atom,
0,
chunk_size + left_over,
n_patterns,
n_train,
threadsafe_rng(rng),
progress_meter,
args...
)
)
end

return reduce(_reducer, resvec)
Expand Down
15 changes: 15 additions & 0 deletions test/ensembles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,23 @@ end
@test length(ensemble.fitresult.ensemble) == 5

@test !isnan(predict(ensemble, MLJEnsembles.selectrows(X, test))[1])

# tests using integer rngs (see issue 27)
X_, y_ = @load_iris
atom = KNNClassifier(K = 7)
ensemble_model = EnsembleModel(
atom;
bagging_fraction=0.6,
rng=123,
out_of_bag_measure = [log_loss, brier_score]
)
ensemble = machine(ensemble_model, X_, y_)
fit!(ensemble)
@test length(ensemble.fitresult.ensemble) == ensemble_model.n

end


end

true
43 changes: 43 additions & 0 deletions test/serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,50 @@ end
@test predict(smach, X) == predict(mach, X)

rm(filename)
end

# define a supervised model with ephemeral `fitresult`, but which overcomes this by
# overloading `save`/`restore`:
thing = []
struct EphemeralRegressor <: Deterministic end
function MLJBase.fit(::EphemeralRegressor, verbosity, X, y)
# if I serialize/deserialized `thing` then `id` below changes:
id = objectid(thing)
fitresult = (thing, id, mean(y))
return fitresult, nothing, NamedTuple()
end
function MLJBase.predict(::EphemeralRegressor, fitresult, X)
thing, id, μ = fitresult
return id == objectid(thing) ? fill(μ, nrows(X)) :
throw(ErrorException("dead fitresult"))
end
MLJBase.target_scitype(::Type{<:EphemeralRegressor}) = AbstractVector{Continuous}
function MLJBase.save(::EphemeralRegressor, fitresult)
thing, _, μ = fitresult
return (thing, μ)
end
function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
thing, μ = serialized_fitresult
id = objectid(thing)
return (thing, id, μ)
end

@testset "serialization for atomic models with non-persistent fitresults" begin
# https://github.com/alan-turing-institute/MLJ.jl/issues/1099
X, y = (; x = rand(10)), fill(42.0, 3)
ensemble = EnsembleModel(
EphemeralRegressor(),
bagging_fraction=0.7,
n=2,
)
mach = machine(ensemble, X, y)
fit!(mach, verbosity=0)
io = IOBuffer()
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
close(io)
@test MLJBase.predict(mach2, (; x = rand(2))) fill(42.0, 2)
end

end
Expand Down

0 comments on commit af03f16

Please sign in to comment.