diff --git a/src/SDiagonal.jl b/src/SDiagonal.jl index d1e92caa..d726eaba 100644 --- a/src/SDiagonal.jl +++ b/src/SDiagonal.jl @@ -17,6 +17,10 @@ SDiagonal(a::StaticMatrix{N,N,T}) where {N,T} = Diagonal(diag(a)) size(::Type{<:SDiagonal{N}}) where {N} = (N,N) size(::Type{<:SDiagonal{N}}, d::Int) where {N} = d > 2 ? 1 : N +Base.axes(D::SDiagonal, d) = d <= 2 ? axes(D)[d] : SOneTo(1) + +Base.reshape(a::SDiagonal, s::Tuple{SOneTo,Vararg{SOneTo}}) = reshape(a, homogenize_shape(s)) + # define specific methods to avoid allocating mutable arrays \(D::SDiagonal, b::AbstractVector) = D.diag .\ b \(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity @@ -56,3 +60,5 @@ function inv(D::SDiagonal) check_singular(D) SDiagonal(inv.(D.diag)) end + +Base.copy(D::SDiagonal) = Diagonal(copy(diag(D))) diff --git a/src/indexing.jl b/src/indexing.jl index dffb0cef..98315e4d 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -377,3 +377,21 @@ Base.unsafe_view(A::AbstractArray, i1::StaticIndexing, indices::StaticIndexing.. # the tuple indices has to have at least one element to prevent infinite # recursion when viewing a zero-dimensional array (see issue #705) Base.SubArray(A::AbstractArray, indices::Tuple{StaticIndexing, Vararg{StaticIndexing}}) = Base.SubArray(A, map(unwrap, indices)) + +########################################################### +# SDiagonal +########################################################### + +# SDiagonal uses Cartesian indexing, and the canonical indexing methods shadow getindex for Diagonal +# these are needed for ambiguity resolution +@inline function getindex(D::SDiagonal, i::Int, j::Int) + invoke(getindex, Tuple{Diagonal, Int, Int}, D, i, j) +end +@inline function getindex(D::SDiagonal, i::Int...) + invoke(getindex, Tuple{Diagonal, Vararg{Int}}, D, i...) +end +# Ensure that vector indexing with static types lead to SArrays +@propagate_inbounds function getindex(a::SDiagonal, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon}...) + ar = reshape(a, Val(length(inds))) + _getindex(ar, index_sizes(Size(ar), inds...), inds) +end diff --git a/test/SDiagonal.jl b/test/SDiagonal.jl index ff94f849..7d1090c6 100644 --- a/test/SDiagonal.jl +++ b/test/SDiagonal.jl @@ -70,6 +70,15 @@ using StaticArrays, Test, LinearAlgebra @test length(m) === 4*4 + m2 = SMatrix{4,4}(m) + @test (@inferred axes(m)) === axes(m2) + @test axes(m, 1) === axes(m2, 1) + @test axes(m, 3) == SOneTo(1) + + @test m[:, 1] === SVector{4}(m[1,1], 0, 0, 0) + @test m[:, :] === m2 + @test m[2, 2, 1] === m[2, 2] + @test_throws Exception m[1] = 1 b = @SVector [2,-1,2,1] @@ -114,5 +123,7 @@ using StaticArrays, Test, LinearAlgebra @test m + zero(m) == m @test m + zero(typeof(m)) == m + + @test copy(m) === m end end