Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat committed Jul 23, 2024
1 parent 2dd7f07 commit 52fa6ae
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 30 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ComputationalResources = "0.3.2"
Distributions = "0.25.109"
Flux = "0.12, 0.13, 0.14"
LinearAlgebra = "1.7, 1.10"
MLJBase = " 1.6.0"
MLJBase = "< 1.4.0"
MLJFlux = "0.5"
MLJModelInterface = "1.8.0"
MLUtils = "0.4"
Expand Down
2 changes: 1 addition & 1 deletion src/LaplaceRedux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ include("calibration_functions.jl")
export empirical_frequency_binary_classification,
sharpness_classification,
empirical_frequency_regression,
sharpness_regression,
sharpness_regression, extract_mean_and_variance,
sigma_scaling
end
28 changes: 1 addition & 27 deletions src/calibration_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,30 +328,4 @@ function sigma_scaling(
sigma = sqrt(1 / length(y_cal) * sum(norm.(y_cal .- means) ./ variances))

return sigma
end
@doc raw"""
sigma_scaling(la::Laplace, x_cal::Vector{<:AbstractFloat}, y_cal::Vector{<:AbstractFloat})
Compute the value of Σ that maximize the conditional log-likelihood:
```math
m ln(Σ) +1/2 * Σ^{-2} ∑_{i=1}^{i=m} || y_cal_i - ̄y_mean_i ||^2 / σ^2_i
```
where m is the number of elements in the calibration set (x_cal,y_cal). \
Source: [Laves,Ihler,Fast, Kahrs, Ortmaier,2020](https://proceedings.mlr.press/v121/laves20a.html)
Inputs: \
- `la`: the Laplace object \
- `x_cal`: a Vector of inputs. \
- `y_cal`: a Vector of true results.
Outputs: \
- `sigma`: the scalar that maximize the likelihood.
"""
function sigma_scaling(
la::Laplace, x_cal::Vector{<:AbstractFloat}, y_cal::Vector{<:AbstractFloat}
)
distrs, means, variances = glm_predictive_distribution(la, x_cal')

sigma = sigma_scaling(distrs, y_cal)

return sigma
end
end
34 changes: 33 additions & 1 deletion test/calibration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ end
y_int, sampled_distributions; n_bins=20
)

@test isapprox(mean(emp_avg), 0.5; atol=0.01)
@test isapprox(mean(emp_avg), 0.5; atol=0.2)

# Test 3: Invalid Y_cal input
Y_cal = [0, 1, 0, 1.2, 4]
Expand Down Expand Up @@ -325,3 +325,35 @@ end
Y_cal, distributions, n_bins=0
)
end


# Test for `empirical_frequency_binary_classification` function
@testset "sigma scaling" begin
@info "testing sigma scaling technique"
# Test 1: testing function extract_mean_and_variance
# Create 3 different Normal distributions with known means and variances
known_distributions = [Normal(0.0, 1.0), Normal(2.0, 3.0), Normal(-1.0, 0.5)]
expected_means = [0.0, 2.0, -1.0]
expected_variances = [1.0, 9.0, 0.25]
# Execution: Call the function
actual_means, actual_variances = extract_mean_and_variance(known_distributions)
@test actual_means expected_means
@test actual_variances expected_variances
# Test 2: testing sigma_scaling
# Step 1: Define the parameters for the sine wave
start_point = 0.0 # Start of the interval
end_point = 2 * π # End of the interval, 2π for a full sine wave cycle
sample_points = 2000 # Number of sample points between 0 and 2π

# Step 2: Generate the sample points
x = LinRange(start_point, end_point, sample_points)

# Step 3: Generate the sine wave data
y = sin.(x)
distrs = Distributions.Normal.(y, 0.01)
#fake miscalibrated predictions
predicted_elements = rand.(distrs) .+ rand((1,2))

sigma = sigma_scaling( distrs ,predicted_elements)
@test typeof(sigma) <: Number
end

0 comments on commit 52fa6ae

Please sign in to comment.