diff --git a/CHANGELOG.md b/CHANGELOG.md index 0122db9b2..e7bb0e414 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ### Changed - Slight change to `FlattenedCE` and `unflatten` to ensure that basic functionality remains intact. [#505] +- Fixed small issue in `benchmark` function. ## Version [1.4.1] - 2024-12-19 diff --git a/src/evaluation/benchmark.jl b/src/evaluation/benchmark.jl index 4ab964773..e78a4f802 100644 --- a/src/evaluation/benchmark.jl +++ b/src/evaluation/benchmark.jl @@ -176,7 +176,8 @@ function benchmark( ) # Unflatten for evaluation: - @assert typeof(ces) == Vector{FlattenedCE} "Expecting a vector of `FlattenedCE`. Did you accidentally set `return_flattened=false`?" + @assert all(typeof.(ces) .== FlattenedCE) "Expecting a vector of `FlattenedCE`. Did you accidentally set `return_flattened=false`?" + ces = convert(Vector{FlattenedCE}, ces) ces = unflatten_for_eval(ces, data, Ms, gens, kwrgs) # Meta Data: @@ -432,7 +433,8 @@ function benchmark( ) # Unflatten for evaluation: - @assert typeof(ces) == Vector{FlattenedCE} "Expecting a vector of `FlattenedCE`. Did you accidentally set `return_flattened=false`?" + @assert all(typeof.(ces) .== FlattenedCE) "Expecting a vector of `FlattenedCE`. Did you accidentally set `return_flattened=false`?" + ces = convert(Vector{FlattenedCE}, ces) ces = unflatten_for_eval(ces, data, Ms, gens, kwrgs) # Free up memory: diff --git a/test/Project.toml b/test/Project.toml index 76e134799..ac5f071d7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -27,6 +27,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TaijaData = "9d524318-b4e6-4a65-86d2-b2b72d07866c" +TaijaParallel = "bf1c2c22-5e42-4e78-8b6b-92e6c673eeb0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/other/evaluation.jl b/test/other/evaluation.jl index 5413376cc..042a75cb9 100644 --- a/test/other/evaluation.jl +++ b/test/other/evaluation.jl @@ -3,6 +3,7 @@ using CounterfactualExplanations.Evaluation: using CounterfactualExplanations.Objectives: distance using Serialization: serialize using TaijaData: load_moons, load_circles +using TaijaParallel: ThreadsParallelizer # Dataset data = TaijaData.load_overlapping() @@ -70,6 +71,17 @@ end @testset "Benchmarking" begin bmk = Evaluation.benchmark(counterfactual_data; convergence=:generator_conditions) + @testset "Parallelization" begin + @testset "Threads" begin + parallelizer = ThreadsParallelizer() + bmk = benchmark( + counterfactual_data; + convergence=:generator_conditions, + parallelizer=parallelizer, + ) + end + end + @testset "Basics" begin @test typeof(bmk()) <: DataFrame @test typeof(bmk(; agg=nothing)) <: DataFrame