diff --git a/src/ranges.jl b/src/ranges.jl index b2b799d0b..303601678 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -1,42 +1,53 @@ """ - OptionallyStaticUnitRange(start, stop) <: AbstractUnitRange{Int} + OptionallyStaticUnitRange(start, stop, check_lower_bound=True(), check_upper_bound=True()) <: AbstractUnitRange{Int} -Similar to `UnitRange` except each field may be an `Int` or `StaticInt`. An -`OptionallyStaticUnitRange` is intended to be constructed internally from other valid -indices. Therefore, users should not expect the same checks are used to ensure construction -of a valid `OptionallyStaticUnitRange` as a `UnitRange`. +Similar to `UnitRange` except each field may be an `Int` or `StaticInt`. `check_lower_bound` +and `check_upper_bound` determine whether `start` and `stop` are bounds checked when +`OptionallyStaticUnitRange` is used as an index. This type is intended to be constructed +internally from other valid indices. Therefore, users should not expect the same checks are +used to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`. + +!!! warning + + Manually setting `check_lower_bound` and `check_upper_bound` to `False()` has similar + behavior as `@inbounds` and may result in incorrect results/crashes/corruption. """ -struct OptionallyStaticUnitRange{F<:CanonicalInt,L<:CanonicalInt} <: AbstractUnitRange{Int} +struct OptionallyStaticUnitRange{F<:CanonicalInt,L<:CanonicalInt,CLB<:Union{False,True},CUB<:Union{False,True}} <: AbstractUnitRange{Int} start::F stop::L + check_lower_bound::CLB + check_upper_bound::CUB - function OptionallyStaticUnitRange(start::CanonicalInt, stop::CanonicalInt) - new{typeof(start),typeof(stop)}(start, stop) + function OptionallyStaticUnitRange(start::CanonicalInt, stop::CanonicalInt, check_lower_bound=True(), check_upper_bound=True()) + new{typeof(start),typeof(stop),typeof(check_lower_bound),typeof(check_upper_bound)}(start, stop, check_lower_bound, check_upper_bound) end - function OptionallyStaticUnitRange(start, stop) - OptionallyStaticUnitRange(canonicalize(start), canonicalize(stop)) + function OptionallyStaticUnitRange(start, stop, check_lower_bound=True(), check_upper_bound=True()) + OptionallyStaticUnitRange(canonicalize(start), canonicalize(stop), check_lower_bound, check_upper_bound) end - function OptionallyStaticUnitRange(x::AbstractRange) - step(x) == 1 && return OptionallyStaticUnitRange(static_first(x), static_last(x)) + function OptionallyStaticUnitRange(x::AbstractRange, check_lower_bound=check_lower_bound(x), check_upper_bound=check_upper_bound(x)) + step(x) == 1 && return OptionallyStaticUnitRange(static_first(x), static_last(x), check_lower_bound, check_upper_bound) errmsg(x) = throw(ArgumentError("step must be 1, got $(step(x))")) # avoid GC frame errmsg(x) end - OptionallyStaticUnitRange{F,L}(x::AbstractRange) where {F,L} = OptionallyStaticUnitRange(x) - function OptionallyStaticUnitRange{StaticInt{F},StaticInt{L}}() where {F,L} - new{StaticInt{F},StaticInt{L}}() + function OptionallyStaticUnitRange{F,L}(x::AbstractRange, check_lower_bound=check_lower_bound(x), check_upper_bound=check_upper_bound(x)) where {F,L} + OptionallyStaticUnitRange(x) + end + function OptionallyStaticUnitRange{StaticInt{F},StaticInt{L}}(check_lower_bound=True(), check_upper_bound=True()) where {F,L} + new{StaticInt{F},StaticInt{L},typeof(check_lower_bound),typeof(check_upper_bound)}(StaticInt(F), StaticInt(L), check_lower_bound, check_upper_bound) end end """ - OptionallyStaticStepRange(start, step, stop) <: OrdinalRange{Int,Int} + OptionallyStaticStepRange(start, step, stop, check_lower_bound=True(), check_upper_bound=True(), check_lower_bound=True(), check_upper_bound=True()) <: OrdinalRange{Int,Int} Similarly to [`OptionallyStaticUnitRange`](@ref), `OptionallyStaticStepRange` permits a combination of static and standard primitive `Int`s to construct a range. It specifically enables the use of ranges without a step size of 1. It may be constructed -through the use of `OptionallyStaticStepRange` directly or using static integers with -the range operator (i.e., `:`). +through the use of `OptionallyStaticStepRange` directly or using static integers with the +range operator (i.e., `:`). `check_lower_bound` and `check_upper_bound` determine whether +`start` and `stop` are bounds checked when `OptionallyStaticStepRange` is used as an index. ```julia julia> using ArrayInterface @@ -51,20 +62,22 @@ static(2):static(2):10 ``` """ -struct OptionallyStaticStepRange{F<:CanonicalInt,S<:CanonicalInt,L<:CanonicalInt} <: OrdinalRange{Int,Int} +struct OptionallyStaticStepRange{F<:CanonicalInt,S<:CanonicalInt,L<:CanonicalInt,CLB<:Union{False,True},CUB<:Union{False,True}} <: OrdinalRange{Int,Int} start::F step::S stop::L + check_lower_bound::CLB + check_upper_bound::CUB - function OptionallyStaticStepRange(start::CanonicalInt, step::CanonicalInt, stop::CanonicalInt) + function OptionallyStaticStepRange(start::CanonicalInt, step::CanonicalInt, stop::CanonicalInt, check_lower_bound=True(), check_upper_bound=True()) lst = _steprange_last(start, step, stop) - new{typeof(start),typeof(step),typeof(lst)}(start, step, lst) + new{typeof(start),typeof(step),typeof(lst),typeof(check_lower_bound),typeof(check_upper_bound)}(start, step, lst, check_lower_bound, check_upper_bound) end - function OptionallyStaticStepRange(start, step, stop) - OptionallyStaticStepRange(canonicalize(start), canonicalize(step), canonicalize(stop)) + function OptionallyStaticStepRange(start, step, stop, check_lower_bound=True(), check_upper_bound=True()) + OptionallyStaticStepRange(canonicalize(start), canonicalize(step), canonicalize(stop), check_lower_bound, check_upper_bound) end - function OptionallyStaticStepRange(x::AbstractRange) - return OptionallyStaticStepRange(static_first(x), static_step(x), static_last(x)) + function OptionallyStaticStepRange(x::AbstractRange, check_lower_bound=check_lower_bound(x), check_upper_bound=check_upper_bound(x)) + OptionallyStaticStepRange(static_first(x), static_step(x), static_last(x), check_lower_bound, check_upper_bound) end end @@ -123,6 +136,15 @@ SOneTo(n::Int) = SOneTo{n}() const OptionallyStaticRange = Union{<:OptionallyStaticUnitRange,<:OptionallyStaticStepRange} +check_lower_bound(x) = True() +check_lower_bound(x::OptionallyStaticRange) = getfield(x, :check_lower_bound) +check_lower_bound(x::Base.IdentityUnitRange) = check_lower_bound(getfield(x, :indices)) +check_lower_bound(::Base.Slice) = False() + +check_upper_bound(x) = True() +check_upper_bound(x::OptionallyStaticRange) = getfield(x, :check_upper_bound) +check_upper_bound(x::Base.IdentityUnitRange) = check_upper_bound(getfield(x, :indices)) +check_upper_bound(::Base.Slice) = False() ArrayInterfaceCore.known_first(::Type{<:OptionallyStaticUnitRange{StaticInt{F}}}) where {F} = F::Int ArrayInterfaceCore.known_first(::Type{<:OptionallyStaticStepRange{StaticInt{F}}}) where {F} = F::Int @@ -139,7 +161,7 @@ ArrayInterfaceCore.known_last(::Type{<:OptionallyStaticStepRange{<:Any,<:Any,Sta return known_first(r) end end -function Base.step(r::OptionallyStaticStepRange)::Int +@inline function Base.step(r::OptionallyStaticStepRange)::Int if known_step(r) === nothing return getfield(r, :step) else @@ -193,25 +215,38 @@ function Base.isempty(r::OptionallyStaticStepRange) (r.start != r.stop) & ((r.step > 0) != (r.stop > r.start)) end -function Base.checkindex( - ::Type{Bool}, - ::SUnitRange{F1,L1}, - ::SUnitRange{F2,L2} -) where {F1,L1,F2,L2} +const CheckBoundsRange{CLB,CUB} = Union{OptionallyStaticUnitRange{<:CanonicalInt,<:CanonicalInt,CLB,CUB},OptionallyStaticStepRange{<:CanonicalInt,<:CanonicalInt,<:CanonicalInt,CLB,CUB}} - (F1::Int <= F2::Int) && (L1::Int >= L2::Int) +Base.checkindex(::Type{Bool}, x::SUnitRange{F,L}, ::StaticInt{I}) where {F,L,I} = F <= I <= L +Base.checkindex(::Type{Bool}, x::AbstractUnitRange, i::CheckBoundsRange{False,False}) = true +@inline function Base.checkindex(::Type{Bool}, x::AbstractUnitRange, i::CheckBoundsRange{False,True}) + checkindex(Bool, x, getfield(i, :stop)) || isempty(i) +end +@inline function Base.checkindex(::Type{Bool}, x::AbstractUnitRange, i::CheckBoundsRange{True,False}) + checkindex(Bool, x, getfield(i, :start)) || isempty(i) +end +@inline function Base.checkindex(::Type{Bool}, x::AbstractUnitRange, i::CheckBoundsRange{True,True}) + (checkindex(Bool, x, getfield(i, :stop)) && checkindex(Bool, x, getfield(i, :start))) || isempty(i) end -@propagate_inbounds function Base.getindex( - r::OptionallyStaticUnitRange, - s::AbstractUnitRange{<:Integer}, -) +@inline function Base.getindex(r::OptionallyStaticUnitRange, s::AbstractUnitRange{<:Integer}) @boundscheck checkbounds(r, s) f = static_first(r) fnew = f - one(f) - return (fnew+static_first(s)):(fnew+static_last(s)) + # propagate bounds checking directives in case this is a subset of a known inbounds range + return OptionallyStaticUnitRange((fnew+static_first(s)), (fnew+static_last(s)), getfield(r, :check_lower_bound), getfield(r, :check_upper_bound)) end +@inline function Base.getindex(x::OptionallyStaticRange, i::AbstractRange{T}) where {T<:Integer} + @boundscheck checkbounds(x, i) + fi = static_first(i) + sx = static_step(x) + si = static_step(i) + start = static_first(x) + (fi - one(fi)) * sx + st = sx * si + len = static_length(i) + return OptionallyStaticStepRange(start, st, (start + (len - one(len)) * st), getfield(r, :check_lower_bound), getfield(r, :check_upper_bound)) +end @propagate_inbounds function Base.getindex(x::OptionallyStaticUnitRange{StaticInt{1}}, i::Int) @boundscheck checkbounds(x, i) i diff --git a/test/ranges.jl b/test/ranges.jl index 098bfc435..c9f0377f6 100644 --- a/test/ranges.jl +++ b/test/ranges.jl @@ -38,12 +38,19 @@ @test AbstractUnitRange{UInt}(ArrayInterface.OptionallyStaticUnitRange(static(1), static(10))) isa Base.OneTo @test AbstractUnitRange{UInt}(ArrayInterface.OptionallyStaticUnitRange(static(2), static(10))) isa UnitRange - @test @inferred((static(1):static(10))[static(2):static(3)]) === static(2):static(3) - @test @inferred((static(1):static(10))[static(2):3]) === static(2):3 - @test @inferred((static(1):static(10))[2:3]) === 2:3 - @test @inferred((1:static(10))[static(2):static(3)]) === 2:3 - - @test Base.checkindex(Bool, static(1):static(10), static(1):static(5)) + @test @inferred((static(1):static(10))[static(2):static(3)]) == static(2):static(3) + @test @inferred((static(1):static(10))[static(2):3]) == static(2):3 + @test @inferred((static(1):static(10))[2:3]) == 2:3 + @test @inferred((1:static(10))[static(2):static(3)]) == 2:3 + + @test !Base.checkindex(Bool, 1:5, ArrayInterface.OptionallyStaticUnitRange(1, 10, True(), True())) + @test !Base.checkindex(Bool, 1:5, ArrayInterface.OptionallyStaticUnitRange(1, 10, False(), True())) + @test Base.checkindex(Bool, 1:10, ArrayInterface.OptionallyStaticUnitRange(1, 5, False(), True())) + @test Base.checkindex(Bool, 1:10, ArrayInterface.OptionallyStaticUnitRange(1, 5, True(), False())) + @test Base.checkindex(Bool, 1:10, ArrayInterface.OptionallyStaticUnitRange(1, 5, True(), True())) + # these are actually out of bounds but we want to ensure that we can actually elide bounds checking + @test Base.checkindex(Bool, 1:5, ArrayInterface.OptionallyStaticUnitRange(1, 10, True(), False())) + @test Base.checkindex(Bool, 1:5, ArrayInterface.OptionallyStaticUnitRange(1, 10, False(), False())) @test -(static(1):static(10)) === static(-1):static(-1):static(-10) @test reverse(static(1):static(10)) === static(10):static(-1):static(1) @@ -149,4 +156,3 @@ end @test ArrayInterface.indices((x',y'),StaticInt(1)) === Base.Slice(StaticInt(1):StaticInt(1)) @test ArrayInterface.indices((x,y), StaticInt(2)) === Base.Slice(StaticInt(1):StaticInt(1)) end -