Skip to content

Index set ranges #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 72 additions & 37 deletions src/ranges.jl
Original file line number Diff line number Diff line change
@@ -1,42 +1,53 @@

"""
OptionallyStaticUnitRange(start, stop) <: AbstractUnitRange{Int}
OptionallyStaticUnitRange(start, stop, check_lower_bound=True(), check_upper_bound=True()) <: AbstractUnitRange{Int}

Similar to `UnitRange` except each field may be an `Int` or `StaticInt`. An
`OptionallyStaticUnitRange` is intended to be constructed internally from other valid
indices. Therefore, users should not expect the same checks are used to ensure construction
of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
Similar to `UnitRange` except each field may be an `Int` or `StaticInt`. `check_lower_bound`
and `check_upper_bound` determine whether `start` and `stop` are bounds checked when
`OptionallyStaticUnitRange` is used as an index. This type is intended to be constructed
internally from other valid indices. Therefore, users should not expect the same checks are
used to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.

!!! warning

Manually setting `check_lower_bound` and `check_upper_bound` to `False()` has similar
behavior as `@inbounds` and may result in incorrect results/crashes/corruption.
"""
struct OptionallyStaticUnitRange{F<:CanonicalInt,L<:CanonicalInt} <: AbstractUnitRange{Int}
struct OptionallyStaticUnitRange{F<:CanonicalInt,L<:CanonicalInt,CLB<:Union{False,True},CUB<:Union{False,True}} <: AbstractUnitRange{Int}
start::F
stop::L
check_lower_bound::CLB
check_upper_bound::CUB

function OptionallyStaticUnitRange(start::CanonicalInt, stop::CanonicalInt)
new{typeof(start),typeof(stop)}(start, stop)
function OptionallyStaticUnitRange(start::CanonicalInt, stop::CanonicalInt, check_lower_bound=True(), check_upper_bound=True())
new{typeof(start),typeof(stop),typeof(check_lower_bound),typeof(check_upper_bound)}(start, stop, check_lower_bound, check_upper_bound)
end
function OptionallyStaticUnitRange(start, stop)
OptionallyStaticUnitRange(canonicalize(start), canonicalize(stop))
function OptionallyStaticUnitRange(start, stop, check_lower_bound=True(), check_upper_bound=True())
OptionallyStaticUnitRange(canonicalize(start), canonicalize(stop), check_lower_bound, check_upper_bound)
end
function OptionallyStaticUnitRange(x::AbstractRange)
step(x) == 1 && return OptionallyStaticUnitRange(static_first(x), static_last(x))
function OptionallyStaticUnitRange(x::AbstractRange, check_lower_bound=check_lower_bound(x), check_upper_bound=check_upper_bound(x))
step(x) == 1 && return OptionallyStaticUnitRange(static_first(x), static_last(x), check_lower_bound, check_upper_bound)

errmsg(x) = throw(ArgumentError("step must be 1, got $(step(x))")) # avoid GC frame
errmsg(x)
end
OptionallyStaticUnitRange{F,L}(x::AbstractRange) where {F,L} = OptionallyStaticUnitRange(x)
function OptionallyStaticUnitRange{StaticInt{F},StaticInt{L}}() where {F,L}
new{StaticInt{F},StaticInt{L}}()
function OptionallyStaticUnitRange{F,L}(x::AbstractRange, check_lower_bound=check_lower_bound(x), check_upper_bound=check_upper_bound(x)) where {F,L}
OptionallyStaticUnitRange(x)
end
function OptionallyStaticUnitRange{StaticInt{F},StaticInt{L}}(check_lower_bound=True(), check_upper_bound=True()) where {F,L}
new{StaticInt{F},StaticInt{L},typeof(check_lower_bound),typeof(check_upper_bound)}(StaticInt(F), StaticInt(L), check_lower_bound, check_upper_bound)
end
end

"""
OptionallyStaticStepRange(start, step, stop) <: OrdinalRange{Int,Int}
OptionallyStaticStepRange(start, step, stop, check_lower_bound=True(), check_upper_bound=True(), check_lower_bound=True(), check_upper_bound=True()) <: OrdinalRange{Int,Int}

Similarly to [`OptionallyStaticUnitRange`](@ref), `OptionallyStaticStepRange` permits
a combination of static and standard primitive `Int`s to construct a range. It
specifically enables the use of ranges without a step size of 1. It may be constructed
through the use of `OptionallyStaticStepRange` directly or using static integers with
the range operator (i.e., `:`).
through the use of `OptionallyStaticStepRange` directly or using static integers with the
range operator (i.e., `:`). `check_lower_bound` and `check_upper_bound` determine whether
`start` and `stop` are bounds checked when `OptionallyStaticStepRange` is used as an index.

```julia
julia> using ArrayInterface
Expand All @@ -51,20 +62,22 @@ static(2):static(2):10

```
"""
struct OptionallyStaticStepRange{F<:CanonicalInt,S<:CanonicalInt,L<:CanonicalInt} <: OrdinalRange{Int,Int}
struct OptionallyStaticStepRange{F<:CanonicalInt,S<:CanonicalInt,L<:CanonicalInt,CLB<:Union{False,True},CUB<:Union{False,True}} <: OrdinalRange{Int,Int}
start::F
step::S
stop::L
check_lower_bound::CLB
check_upper_bound::CUB

function OptionallyStaticStepRange(start::CanonicalInt, step::CanonicalInt, stop::CanonicalInt)
function OptionallyStaticStepRange(start::CanonicalInt, step::CanonicalInt, stop::CanonicalInt, check_lower_bound=True(), check_upper_bound=True())
lst = _steprange_last(start, step, stop)
new{typeof(start),typeof(step),typeof(lst)}(start, step, lst)
new{typeof(start),typeof(step),typeof(lst),typeof(check_lower_bound),typeof(check_upper_bound)}(start, step, lst, check_lower_bound, check_upper_bound)
end
function OptionallyStaticStepRange(start, step, stop)
OptionallyStaticStepRange(canonicalize(start), canonicalize(step), canonicalize(stop))
function OptionallyStaticStepRange(start, step, stop, check_lower_bound=True(), check_upper_bound=True())
OptionallyStaticStepRange(canonicalize(start), canonicalize(step), canonicalize(stop), check_lower_bound, check_upper_bound)
end
function OptionallyStaticStepRange(x::AbstractRange)
return OptionallyStaticStepRange(static_first(x), static_step(x), static_last(x))
function OptionallyStaticStepRange(x::AbstractRange, check_lower_bound=check_lower_bound(x), check_upper_bound=check_upper_bound(x))
OptionallyStaticStepRange(static_first(x), static_step(x), static_last(x), check_lower_bound, check_upper_bound)
end
end

Expand Down Expand Up @@ -123,6 +136,15 @@ SOneTo(n::Int) = SOneTo{n}()

const OptionallyStaticRange = Union{<:OptionallyStaticUnitRange,<:OptionallyStaticStepRange}

check_lower_bound(x) = True()
check_lower_bound(x::OptionallyStaticRange) = getfield(x, :check_lower_bound)
check_lower_bound(x::Base.IdentityUnitRange) = check_lower_bound(getfield(x, :indices))
check_lower_bound(::Base.Slice) = False()

check_upper_bound(x) = True()
check_upper_bound(x::OptionallyStaticRange) = getfield(x, :check_upper_bound)
check_upper_bound(x::Base.IdentityUnitRange) = check_upper_bound(getfield(x, :indices))
check_upper_bound(::Base.Slice) = False()

ArrayInterfaceCore.known_first(::Type{<:OptionallyStaticUnitRange{StaticInt{F}}}) where {F} = F::Int
ArrayInterfaceCore.known_first(::Type{<:OptionallyStaticStepRange{StaticInt{F}}}) where {F} = F::Int
Expand All @@ -139,7 +161,7 @@ ArrayInterfaceCore.known_last(::Type{<:OptionallyStaticStepRange{<:Any,<:Any,Sta
return known_first(r)
end
end
function Base.step(r::OptionallyStaticStepRange)::Int
@inline function Base.step(r::OptionallyStaticStepRange)::Int
if known_step(r) === nothing
return getfield(r, :step)
else
Expand Down Expand Up @@ -193,25 +215,38 @@ function Base.isempty(r::OptionallyStaticStepRange)
(r.start != r.stop) & ((r.step > 0) != (r.stop > r.start))
end

function Base.checkindex(
::Type{Bool},
::SUnitRange{F1,L1},
::SUnitRange{F2,L2}
) where {F1,L1,F2,L2}
const CheckBoundsRange{CLB,CUB} = Union{OptionallyStaticUnitRange{<:CanonicalInt,<:CanonicalInt,CLB,CUB},OptionallyStaticStepRange{<:CanonicalInt,<:CanonicalInt,<:CanonicalInt,CLB,CUB}}

(F1::Int <= F2::Int) && (L1::Int >= L2::Int)
Base.checkindex(::Type{Bool}, x::SUnitRange{F,L}, ::StaticInt{I}) where {F,L,I} = F <= I <= L
Base.checkindex(::Type{Bool}, x::AbstractUnitRange, i::CheckBoundsRange{False,False}) = true
@inline function Base.checkindex(::Type{Bool}, x::AbstractUnitRange, i::CheckBoundsRange{False,True})
checkindex(Bool, x, getfield(i, :stop)) || isempty(i)
end
@inline function Base.checkindex(::Type{Bool}, x::AbstractUnitRange, i::CheckBoundsRange{True,False})
checkindex(Bool, x, getfield(i, :start)) || isempty(i)
end
@inline function Base.checkindex(::Type{Bool}, x::AbstractUnitRange, i::CheckBoundsRange{True,True})
(checkindex(Bool, x, getfield(i, :stop)) && checkindex(Bool, x, getfield(i, :start))) || isempty(i)
end

@propagate_inbounds function Base.getindex(
r::OptionallyStaticUnitRange,
s::AbstractUnitRange{<:Integer},
)
@inline function Base.getindex(r::OptionallyStaticUnitRange, s::AbstractUnitRange{<:Integer})
@boundscheck checkbounds(r, s)
f = static_first(r)
fnew = f - one(f)
return (fnew+static_first(s)):(fnew+static_last(s))
# propagate bounds checking directives in case this is a subset of a known inbounds range
return OptionallyStaticUnitRange((fnew+static_first(s)), (fnew+static_last(s)), getfield(r, :check_lower_bound), getfield(r, :check_upper_bound))
end

@inline function Base.getindex(x::OptionallyStaticRange, i::AbstractRange{T}) where {T<:Integer}
@boundscheck checkbounds(x, i)
fi = static_first(i)
sx = static_step(x)
si = static_step(i)
start = static_first(x) + (fi - one(fi)) * sx
st = sx * si
len = static_length(i)
return OptionallyStaticStepRange(start, st, (start + (len - one(len)) * st), getfield(r, :check_lower_bound), getfield(r, :check_upper_bound))
end
@propagate_inbounds function Base.getindex(x::OptionallyStaticUnitRange{StaticInt{1}}, i::Int)
@boundscheck checkbounds(x, i)
i
Expand Down
20 changes: 13 additions & 7 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,19 @@
@test AbstractUnitRange{UInt}(ArrayInterface.OptionallyStaticUnitRange(static(1), static(10))) isa Base.OneTo
@test AbstractUnitRange{UInt}(ArrayInterface.OptionallyStaticUnitRange(static(2), static(10))) isa UnitRange

@test @inferred((static(1):static(10))[static(2):static(3)]) === static(2):static(3)
@test @inferred((static(1):static(10))[static(2):3]) === static(2):3
@test @inferred((static(1):static(10))[2:3]) === 2:3
@test @inferred((1:static(10))[static(2):static(3)]) === 2:3

@test Base.checkindex(Bool, static(1):static(10), static(1):static(5))
@test @inferred((static(1):static(10))[static(2):static(3)]) == static(2):static(3)
@test @inferred((static(1):static(10))[static(2):3]) == static(2):3
@test @inferred((static(1):static(10))[2:3]) == 2:3
@test @inferred((1:static(10))[static(2):static(3)]) == 2:3

@test !Base.checkindex(Bool, 1:5, ArrayInterface.OptionallyStaticUnitRange(1, 10, True(), True()))
@test !Base.checkindex(Bool, 1:5, ArrayInterface.OptionallyStaticUnitRange(1, 10, False(), True()))
@test Base.checkindex(Bool, 1:10, ArrayInterface.OptionallyStaticUnitRange(1, 5, False(), True()))
@test Base.checkindex(Bool, 1:10, ArrayInterface.OptionallyStaticUnitRange(1, 5, True(), False()))
@test Base.checkindex(Bool, 1:10, ArrayInterface.OptionallyStaticUnitRange(1, 5, True(), True()))
# these are actually out of bounds but we want to ensure that we can actually elide bounds checking
@test Base.checkindex(Bool, 1:5, ArrayInterface.OptionallyStaticUnitRange(1, 10, True(), False()))
@test Base.checkindex(Bool, 1:5, ArrayInterface.OptionallyStaticUnitRange(1, 10, False(), False()))
@test -(static(1):static(10)) === static(-1):static(-1):static(-10)

@test reverse(static(1):static(10)) === static(10):static(-1):static(1)
Expand Down Expand Up @@ -149,4 +156,3 @@ end
@test ArrayInterface.indices((x',y'),StaticInt(1)) === Base.Slice(StaticInt(1):StaticInt(1))
@test ArrayInterface.indices((x,y), StaticInt(2)) === Base.Slice(StaticInt(1):StaticInt(1))
end