Skip to content

Commit 52d8d29

Browse files
authored
Avoid overflow in indexing and sum (#217)
* Avoid overflow in indexing and sum * Change _third_prod to _onethird_prod * Define _half_prod only for integers * Remove unused _half methods
1 parent 1e6542b commit 52d8d29

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

src/cumsum.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@ Base.parent(r::RangeCumsum) = r.range
1515
==(a::RangeCumsum, b::RangeCumsum) = a.range == b.range
1616
BroadcastStyle(::Type{<:RangeCumsum{<:Any,RR}}) where RR = BroadcastStyle(RR)
1717

18-
_half(x::Integer) = x ÷ 2
19-
_half(x) = x / 2
18+
function _half_prod(a::Integer, b::Integer)
19+
iseven(a) ? (a÷2) * b : a * (b÷2)
20+
end
21+
function _onethird_prod(a::Integer, b::Integer)
22+
mod(a, 3) == 0 ? (a÷3) * b : a * (b÷3)
23+
end
2024

2125
function _getindex(r::AbstractRange{<:Real}, k)
2226
v = first(r)
2327
s = step(r)
24-
_half(k * (2v - s + s*k))
28+
# avoid overflow, if possible
29+
k * v + s * _half_prod(k, k-1)
2530
end
2631
Base.@propagate_inbounds _getindex(r::AbstractRange, k) = sum(r[range(firstindex(r), length=k)])
2732

@@ -44,7 +49,9 @@ function Base.sum(r::RangeCumsum{<:Real})
4449
N = length(r)
4550
v = first(r)
4651
s = step(r.range)
47-
_half((2v-s)*(N*(N+1)÷2) + s*(N*(N+1)*(2N+1)÷6))
52+
# avoid overflow, if possible
53+
halfnnp1 = _half_prod(N, N+1)
54+
v * halfnnp1 + s * _onethird_prod(halfnnp1, N-1)
4855
end
4956

5057
union(a::RangeCumsum{<:Any,<:OneTo}, b::RangeCumsum{<:Any,<:OneTo}) =

test/test_cumsum.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ cmpop(p) = isinteger(real(first(p))) && isinteger(real(step(p))) ? (==) : (≈)
5151
r = RangeCumsum(InfiniteArrays.OneToInf())
5252
@test axes(r, 1) == InfiniteArrays.OneToInf()
5353

54+
@testset "overflow" begin
55+
r = RangeCumsum(typemax(Int)÷2 .+ (0:1))
56+
@test last(r) == typemax(Int)
57+
r = RangeCumsum(typemin(Int)÷2 .- (1:1))
58+
@test first(r) == typemin(Int)÷2 - 1
59+
r = RangeCumsum(typemax(Int) .+ (0:0))
60+
@test sum(r) == typemax(Int)
61+
r = RangeCumsum(typemin(Int) .+ (0:0))
62+
@test sum(r) == typemin(Int)
63+
end
64+
5465
@testset "multiplication by a number" begin
5566
function test_broadcast(n, r)
5667
w = Vector(r)

0 commit comments

Comments
 (0)