Skip to content

Commit 071d10a

Browse files
authored
Add tests for StatsAPI modeling functions (#767)
Ensure that we don't accidentally stop exporting some functions and that fallbacks defined in StatsAPI work. Stop importing `params` and `params!` as they are neither used nor reexported
1 parent d9d20f0 commit 071d10a

File tree

1 file changed

+61
-1
lines changed

1 file changed

+61
-1
lines changed

test/statmodels.jl

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using StatsBase
22
using StatsBase: PValue, TestStat
3-
using Test, Random
3+
using Test, Random, StatsAPI, LinearAlgebra
44

55
v1 = [1.45666, -23.14, 1.56734e-13]
66
v2 = ["Good", "Great", "Bad"]
@@ -131,3 +131,63 @@ end
131131

132132
err = @test_throws ArgumentError ConvergenceException(10,.1,.2)
133133
@test err.value.msg == "Change must be greater than tol."
134+
135+
struct MyStatisticalModel <: StatisticalModel
136+
end
137+
138+
StatsAPI.vcov(::MyStatisticalModel) = [1 2; 3 4]
139+
StatsAPI.loglikelihood(::MyStatisticalModel) = 3
140+
StatsAPI.nullloglikelihood(::MyStatisticalModel) = 4
141+
StatsAPI.deviance(::MyStatisticalModel) = 25
142+
StatsAPI.nulldeviance(::MyStatisticalModel) = 40
143+
StatsAPI.dof(::MyStatisticalModel) = 5
144+
StatsAPI.nobs(::MyStatisticalModel) = 100
145+
146+
@testset "StatisticalModel" begin
147+
m = MyStatisticalModel()
148+
149+
@test stderror(m) == [1, 2]
150+
@test aic(m) == 4
151+
@test aicc(m) 4.638297872340425
152+
@test bic(m) 17.02585092994046
153+
@test r2(m, :McFadden) 0.25
154+
@test r2(m, :CoxSnell) -0.020201340026755776
155+
@test r2(m, :Nagelkerke) 0.24255074155803877
156+
@test r2(m, :devianceratio) 0.375
157+
158+
@test_throws Union{ErrorException, ArgumentError} r2(m, :err)
159+
@test_throws MethodError r2(m)
160+
@test adjr2(m, :McFadden) 1.5
161+
@test adjr2(m, :devianceratio) 0.3486842105263158
162+
@test_throws Union{ErrorException, ArgumentError} adjr2(m, :err)
163+
164+
@test r2 ===
165+
@test adjr2 === adjr²
166+
end
167+
168+
struct MyRegressionModel <: RegressionModel
169+
end
170+
171+
StatsAPI.modelmatrix(::MyRegressionModel) = [1 2; 3 4]
172+
173+
@testset "TestRegressionModel" begin
174+
m = MyRegressionModel()
175+
176+
@test crossmodelmatrix(m) == [10 14; 14 20]
177+
@test crossmodelmatrix(m) isa Symmetric
178+
end
179+
180+
@testset "StatsAPI model reexports" begin
181+
for f in (fitted, response, responsename, meanresponse,
182+
modelmatrix, crossmodelmatrix, leverage, cooksdistance, residuals,
183+
predict, predict!, dof_residual, coef, coefnames, coeftable, confint,
184+
deviance, islinear, nulldeviance, loglikelihood, nullloglikelihood,
185+
loglikelihood, loglikelihood, score, nobs, dof, mss, rss,
186+
informationmatrix, stderror, vcov, weights, isfitted, fit, fit!,
187+
aic, aicc, bic, r2, r², adjr2, adjr²)
188+
@test f isa Function
189+
end
190+
# Defined but not reexported
191+
@test StatsBase.params isa Function
192+
@test StatsBase.params! isa Function
193+
end

0 commit comments

Comments
 (0)