Skip to content

Commit 5af63ff

Browse files
authored
faster rand(TaskLocalRNG(), 1:n) by outlining throw (#58306)
In #58089, this method took a small performance hit in some contexts. It turns out that by outlining the unlikely branch which throws on empty ranges, this hit can be recovered. In #50509 (comment), a graph of the performance improvement of the "speed-up randperm by using our current rand(1:n)" was posted, but I realized it was only true when calls to `rand(1:n)` were prefixed by `@inline`; without `@inline` it was overall slower for `TaskLocalRNG()` for very big arrays (but still faster otherwise). An alternative to these `@inline` annotation is to outline `throw` like here, for equivalent benefits as `@inline` in that `randperm` PR. Assuming that PR is merged, this PR improves roughly performance by 2x for `TaskLocalRNG()` (no change for other RNGs): ![new-shuffle-outlinethrow](https://github.com/user-attachments/assets/8c0d4740-3bb4-4bcf-a49d-9af09426bec7) While at it, I outlined a bunch of other unliky throwing branches. After that, #50509 can probably be merged, finally!
1 parent 62d3371 commit 5af63ff

File tree

4 files changed

+15
-10
lines changed

4 files changed

+15
-10
lines changed

stdlib/Random/src/Random.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ using Base.GMP.MPZ
1515
using Base.GMP: Limb
1616
using SHA: SHA, SHA2_256_CTX, SHA2_512_CTX, SHA_CTX
1717

18-
using Base: BitInteger, BitInteger_types, BitUnsigned, require_one_based_indexing
18+
using Base: BitInteger, BitInteger_types, BitUnsigned, require_one_based_indexing,
19+
_throw_argerror
20+
1921
import Base: copymutable, copy, copy!, ==, hash, convert,
2022
rand, randn, show
2123

stdlib/Random/src/Xoshiro.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ rng_native_52(::TaskLocalRNG) = UInt64
232232
# this variant of setstate! initializes the internal splitmix state, a.k.a. `s4`
233233
@inline function initstate!(x::Union{TaskLocalRNG, Xoshiro}, state)
234234
length(state) == 4 && eltype(state) == UInt64 ||
235-
throw(ArgumentError("initstate! expects a list of 4 `UInt64` values"))
235+
_throw_argerror("initstate! expects a list of 4 `UInt64` values")
236236
s0, s1, s2, s3 = state
237237
setstate!(x, (s0, s1, s2, s3, 1s0 + 3s1 + 5s2 + 7s3))
238238
end

stdlib/Random/src/generation.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Sampler(::Type{<:AbstractRNG}, I::FloatInterval{BigFloat}, ::Repetition) =
5757
SamplerBigFloat{typeof(I)}(precision(BigFloat))
5858

5959
function _rand!(rng::AbstractRNG, z::BigFloat, sp::SamplerBigFloat)
60-
precision(z) == sp.prec || throw(ArgumentError("incompatible BigFloat precision"))
60+
precision(z) == sp.prec || _throw_argerror("incompatible BigFloat precision")
6161
limbs = sp.limbs
6262
rand!(rng, limbs)
6363
@inbounds begin
@@ -229,6 +229,9 @@ uint_sup(::Type{<:Base.BitInteger32}) = UInt32
229229
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
230230
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
231231

232+
@noinline empty_collection_error() = throw(ArgumentError("collection must be non-empty"))
233+
234+
232235
#### Fast
233236

234237
struct SamplerRangeFast{U<:BitUnsigned,T<:BitInteger} <: Sampler{T}
@@ -242,7 +245,7 @@ SamplerRangeFast(r::AbstractUnitRange{T}) where T<:BitInteger =
242245
SamplerRangeFast(r, uint_sup(T))
243246

244247
function SamplerRangeFast(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
245-
isempty(r) && throw(ArgumentError("collection must be non-empty"))
248+
isempty(r) && empty_collection_error()
246249
m = (last(r) - first(r)) % unsigned(T) % U # % unsigned(T) to not propagate sign bit
247250
bw = (Base.top_set_bit(m)) % UInt # bit-width
248251
mask = ((1 % U) << bw) - (1 % U)
@@ -316,7 +319,7 @@ SamplerRangeInt(r::AbstractUnitRange{T}) where T<:BitInteger =
316319
SamplerRangeInt(r, uint_sup(T))
317320

318321
function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
319-
isempty(r) && throw(ArgumentError("collection must be non-empty"))
322+
isempty(r) && empty_collection_error()
320323
a = first(r)
321324
m = (last(r) - first(r)) % unsigned(T) % U
322325
k = m + one(U)
@@ -362,7 +365,7 @@ struct SamplerRangeNDL{U<:Unsigned,T} <: Sampler{T}
362365
end
363366

364367
function SamplerRangeNDL(r::AbstractUnitRange{T}) where {T}
365-
isempty(r) && throw(ArgumentError("collection must be non-empty"))
368+
isempty(r) && empty_collection_error()
366369
a = first(r)
367370
U = uint_sup(T)
368371
s = (last(r) - first(r)) % unsigned(T) % U + one(U) # overflow ok
@@ -405,7 +408,7 @@ end
405408
function SamplerBigInt(::Type{RNG}, r::AbstractUnitRange{BigInt}, N::Repetition=Val(Inf)
406409
) where {RNG<:AbstractRNG}
407410
m = last(r) - first(r)
408-
m.size < 0 && throw(ArgumentError("collection must be non-empty"))
411+
m.size < 0 && empty_collection_error()
409412
nlimbs = Int(m.size)
410413
hm = nlimbs == 0 ? Limb(0) : GC.@preserve m unsafe_load(m.d, nlimbs)
411414
highsp = Sampler(RNG, Limb(0):hm, N)
@@ -461,7 +464,7 @@ rand(rng::AbstractRNG, sp::SamplerSimple{<:AbstractArray,<:Sampler}) =
461464
## random values from Dict
462465

463466
function Sampler(::Type{RNG}, t::Dict, ::Repetition) where RNG<:AbstractRNG
464-
isempty(t) && throw(ArgumentError("collection must be non-empty"))
467+
isempty(t) && empty_collection_error()
465468
# we use Val(Inf) below as rand is called repeatedly internally
466469
# even for generating only one random value from t
467470
SamplerSimple(t, Sampler(RNG, LinearIndices(t.slots), Val(Inf)))
@@ -490,7 +493,7 @@ rand(rng::AbstractRNG, sp::SamplerTag{<:Set,<:Sampler}) = rand(rng, sp.data).fir
490493
## random values from BitSet
491494

492495
function Sampler(RNG::Type{<:AbstractRNG}, t::BitSet, n::Repetition)
493-
isempty(t) && throw(ArgumentError("collection must be non-empty"))
496+
isempty(t) && empty_collection_error()
494497
SamplerSimple(t, Sampler(RNG, minimum(t):maximum(t), Val(Inf)))
495498
end
496499

stdlib/Random/src/misc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ end
9797
# size-m subset of A where m is fixed!)
9898
function randsubseq!(r::AbstractRNG, S::AbstractArray, A::AbstractArray, p::Real)
9999
require_one_based_indexing(S, A)
100-
0 <= p <= 1 || throw(ArgumentError("probability $p not in [0,1]"))
100+
0 <= p <= 1 || _throw_argerror(LazyString("probability ", p, " not in [0,1]"))
101101
n = length(A)
102102
p == 1 && return copyto!(resize!(S, n), A)
103103
empty!(S)

0 commit comments

Comments
 (0)