diff --git a/test/utils.jl b/test/utils.jl index 2cef162fe..0d66e9f67 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -6,6 +6,10 @@ v = randn(rng, D) w = randn(rng, N) + @testset "check_args macro" begin + @test_throws ArgumentError GammaExponentialKernel(-1.0, Euclidean()) + end + @testset "VecOfVecs" begin @test vec_of_vecs(X; obsdim=2) == ColVecs(X) @test vec_of_vecs(X; obsdim=1) == RowVecs(X) @@ -42,6 +46,10 @@ KernelFunctions.pairwise!(K, SqEuclidean(), DX, DY) @test K ≈ pairwise(SqEuclidean(), X, Y; dims=2) + y = rand(N, 1) + yv = y[:] + @test RowVecs(y) == RowVecs(yv) + let @test Zygote.pullback(ColVecs, X)[1] == DX DX, back = Zygote.pullback(ColVecs, X) @@ -98,14 +106,40 @@ x_rowvecs = RowVecs(randn(7, 3)) @test isapprox( - pairwise(SqEuclidean(), x_colvecs, x_rowvecs), + KernelFunctions.pairwise(SqEuclidean(), x_colvecs, x_rowvecs), pairwise(SqEuclidean(), collect(x_colvecs), collect(x_rowvecs)), ) @test isapprox( - pairwise(SqEuclidean(), x_rowvecs, x_colvecs), + KernelFunctions.pairwise(SqEuclidean(), x_rowvecs, x_colvecs), pairwise(SqEuclidean(), collect(x_rowvecs), collect(x_colvecs)), ) end + @testset "AbstractVector + RowVecs" begin + x = [randn(3) for _ in 1:5] + x_rowvecs = RowVecs(randn(7, 3)) + + @test isapprox( + KernelFunctions.pairwise(SqEuclidean(), x, x_rowvecs), + pairwise(SqEuclidean(), x, collect(x_rowvecs)), + ) + @test isapprox( + KernelFunctions.pairwise(SqEuclidean(), x_rowvecs, x), + pairwise(SqEuclidean(), collect(x_rowvecs), x), + ) + end + @testset "AbstractVector + ColVecs" begin + x = [randn(3) for _ in 1:5] + x_colvecs = ColVecs(randn(3, 7)) + + @test isapprox( + KernelFunctions.pairwise(SqEuclidean(), x, x_colvecs), + pairwise(SqEuclidean(), x, collect(x_colvecs)), + ) + @test isapprox( + KernelFunctions.pairwise(SqEuclidean(), x_colvecs, x), + pairwise(SqEuclidean(), collect(x_colvecs), x), + ) + end @testset "input checks" begin D = 3 D⁻ = 2