Skip to content

add static size information for hvcat and hvncat #58422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2255,6 +2255,72 @@ end

typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as...) where T = typed_hvncat(T, rows_to_dimshape(rows), true, as...)

# A fast version of hvcat for the case where we have static size information of xs
# and the number of rows is known at compile time -- we can eliminate all the runtime
# size checks. For cases that static size information is not beneficial, we fall back to
# the general hvcat/typed_hvcat methods.
@generated function typed_hvcat_static(::Type{T}, ::Val{rows}, xs::Number...) where {T<:Number, rows}
nr = length(rows)
nc = rows[1]
for i = 2:nr
if nc != rows[i]
return quote
msg = "row " * string($i) * " has mismatched number of columns (expected " * string($nc) * ", got " * string($rows[$i]) * ")"
throw(DimensionMismatch(msg))
end
end
end

len = length(xs)
if nr*nc != len
return quote
msg = "argument count " * string($len) * " does not match specified shape " * string(($nr, $nc))
throw(ArgumentError(msg))
end
end

if len <= 16
# For small array construction, manually unroll the loop for better performance
assigns = Expr[]
k = 1
for i in 1:nr
for j in 1:nc
ex = :(a[$i, $j] = xs[$k])
push!(assigns, ex)
k += 1
end
end

return quote
a = Matrix{$T}(undef, $nr, $nc)
$(assigns...)
return a
end
end

# For generic fallback case, directly call a normal `hvcat_fill!` function
# to avoid the overhead of generating a large number of duplicated expressions.
quote
a = Matrix{$T}(undef, $nr, $nc)
hvcat_fill!(a, xs)
end
end
@inline function hvcat_static(::Val{rows}, x::T, xs::Vararg{T}) where {rows, T<:Number}
typed_hvcat_static(T, Val{rows}(), x, xs...)
end
Comment on lines +2308 to +2310
Copy link
Member Author

@johnnychen94 johnnychen94 May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only regression happens here:

using BenchmarkTools

f(x, y) = [
    x y y y y y
    y x y y y y
    y y x y y y
    y y y x y y
    y y y y x y
    y y y y y x
]

@btime f(x, y) setup=(x = rand(); y = rand())

It takes 984.615ns on my local laptop. In comparison, it takes only 128.225ns for Julia 1.12.

But this can be fixed if we duplicate the generated function definition:

@generated function hvcat_static(::Val{rows}, xs::T...) where {T<:Number, rows}
    # just copied everything from the typed_hvcat_static generated function's implementation
end

By applying this fix, we get 24.995ns.

The only problems with this fix are:

  1. This introduces an ambiguity issue caused by xs::T.... That forces we to write hvcat_static(::Val{rows}, x::T, xs::Vararg{T}) rather than hvcat_static(::Val{rows}, xs::Vararg{T}).
  2. This duplicates the code, and I have almost no idea why this duplication works.
  3. Doing this doesn't fix the performance for mixed element types, i.e., @btime f(x, y) setup=(x = rand(); y = rand(1:10)) (add static size information for hvcat and hvncat #58422 (comment))

Thus, I think the "fix" isn't elegant enough and did not commit it for this review version. If we find it necessary, I can add it up later.

@inline function hvcat_static(::Val{rows}, xs::Number...) where {rows}
typed_hvcat_static(promote_typeof(xs...), Val{rows}(), xs...)
end
@inline function typed_hvcat_static(::Type{T}, ::Val{rows}, xs...) where {T, rows}
# fallback to the general case
typed_hvcat(T, rows, xs...)
end
@inline function hvcat_static(::Val{rows}, xs...) where {rows}
# fallback to the general case
hvcat(rows, xs...)
end


## N-dimensional concatenation ##

"""
Expand Down Expand Up @@ -2750,6 +2816,94 @@ end
Ai
end

# Static version of hvncat for better performance with scalar numbers
# See the comments for hvcat_static for more details.
@generated function typed_hvncat_static(::Type{T}, ::Val{dims}, ::Val{row_first}, xs::Number...) where {T<:Number, dims, row_first}
for d in dims
if d <= 0
return quote
throw(ArgumentError("`dims` argument must contain positive integers"))
end
end
end

N = length(dims)
lengtha = prod(dims)
lengthx = length(xs)
if lengtha != lengthx
return quote
msg = "argument count does not match specified shape (expected " * string($lengtha) * ", got " * string($lengthx) * ")"
throw(ArgumentError(msg))
end
end

if lengthx <= 16
# For small array construction, manually unroll the loop
assigns = Expr[]
nr, nc = dims[1], dims[2]
na = if N > 2
n = 1
for d in 3:N
n *= dims[d]
end
n
else
1
end
nrc = nr * nc

if row_first
k = 1
for d in 1:na
dd = nrc * (d - 1)
for i in 1:nr
Ai = dd + i
for j in 1:nc
ex = :(A[$Ai] = xs[$k])
push!(assigns, ex)
k += 1
Ai += nr
end
end
end
else
k = 1
for i in 1:lengtha
ex = :(A[$i] = xs[$k])
push!(assigns, ex)
k += 1
end
end

return quote
A = Array{$T, $N}(undef, $dims...)
$(assigns...)
return A
end
end

# For larger arrays, use the regular loop
quote
A = Array{$T, $N}(undef, $dims...)
hvncat_fill!(A, $row_first, xs)
return A
end
end
@inline function hvncat_static(::Val{dims}, ::Val{row_first}, x::T, xs::Vararg{T}) where {dims, row_first, T<:Number}
typed_hvncat_static(T, Val{dims}(), Val{row_first}(), x, xs...)
end
@inline function hvncat_static(::Val{dims}, ::Val{row_first}, xs::Number...) where {dims, row_first}
typed_hvncat_static(promote_typeof(xs...), Val{dims}(), Val{row_first}(), xs...)
end
@inline function typed_hvncat_static(::Type{T}, ::Val{dims}, ::Val{row_first}, xs...) where {T, dims, row_first}
# fallback to the general case
typed_hvncat(T, dims, row_first, xs...)
end
@inline function hvncat_static(::Val{dims}, ::Val{row_first}, xs...) where {dims, row_first}
# fallback to the general case
hvncat(dims, row_first, xs...)
end

"""
stack(iter; [dims])

Expand Down
14 changes: 11 additions & 3 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1684,15 +1684,23 @@ function generated_body_to_codeinfo(ex::Expr, defmod::Module, isva::Bool)
end
ci.isva = isva
code = ci.code
bindings = IdSet{Core.Binding}()
bindings = Core.svec()
for i = 1:length(code)
stmt = code[i]
if isa(stmt, GlobalRef)
push!(bindings, convert(Core.Binding, stmt))
# plain loop is used rather than fancy `any` or `IdSet` to avoid world-age issues
# if we want to use generated function during the early bootstrap stage when
# these functions are not available. (e.g., using `@generated` in `abstractarray.jl`)
for x in bindings
if x === stmt
continue
end
end
bindings = Core.svec(bindings..., convert(Core.Binding, stmt))
end
end
if !isempty(bindings)
ci.edges = Core.svec(bindings...)
ci.edges = bindings
end
return ci
end
Expand Down
14 changes: 8 additions & 6 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -2254,7 +2254,7 @@

(define (expand-vcat e
(vcat '((top vcat)))
(hvcat '((top hvcat)))
(hvcat '((top hvcat_static)))
(hvcat_rows '((top hvcat_rows))))
(let ((a (cdr e)))
(if (any assignment? a)
Expand All @@ -2276,11 +2276,13 @@
(if (any (lambda (row) (any vararg? row)) rows)
`(call ,@hvcat_rows ,@(map (lambda (x) `(tuple ,@x)) rows))
`(call ,@hvcat
(tuple ,@(map length rows))
(new (curly (top Val) (call (core tuple) ,@(map length rows))))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get some inconsistency checking the result of @which

julia> @which [2 2; 3 3]
hvcat(rows::Tuple{Vararg{Int64}}, xs::T...) where T<:Number
     @ Base abstractarray.jl:2203

julia> Meta.@lower [2 2; 3 3]
:($(Expr(:thunk, CodeInfo(
1 ─ %1 =   builtin Core.tuple(2, 2)
│   %2 =   builtin Core.apply_type(Base.Val, %1)
│   %3 = %new(%2)
│   %4 =   dynamic Base.hvcat_static(%3, 2, 2, 3, 3)
└──      return %4
))))

is it because I need to update JuliaSyntax as well?

,@(apply append rows))))
`(call ,@vcat ,@a))))))

(define (expand-ncat e (hvncat '((top hvncat))))
(define (expand-ncat e
(hvncat '((top hvncat)))
(hvncat_static '((top hvncat_static))))
(define (is-row a) (and (pair? a)
(or (eq? (car a) 'row)
(eq? (car a) 'nrow))))
Expand Down Expand Up @@ -2384,7 +2386,7 @@
(let ((shape (get-shape a is-row-first d)))
(if (is-balanced shape)
(let ((dims `(tuple ,@(reverse (get-dims a is-row-first d)))))
`(call ,@hvncat ,dims ,(tf is-row-first) ,@aflat))
`(call ,@hvncat_static (new (curly (top Val) ,dims)) (new (curly (top Val) ,(tf is-row-first))) ,@aflat))
`(call ,@hvncat ,(tuplize shape) ,(tf is-row-first) ,@aflat))))))))

(define (maybe-ssavalue lhss x in-lhs?)
Expand Down Expand Up @@ -2899,13 +2901,13 @@
(lambda (e)
(let ((t (cadr e))
(e (cdr e)))
(expand-vcat e `((top typed_vcat) ,t) `((top typed_hvcat) ,t) `((top typed_hvcat_rows) ,t))))
(expand-vcat e `((top typed_vcat) ,t) `((top typed_hvcat_static) ,t) `((top typed_hvcat_rows) ,t))))

'typed_ncat
(lambda (e)
(let ((t (cadr e))
(e (cdr e)))
(expand-ncat e `((top typed_hvncat) ,t))))
(expand-ncat e `((top typed_hvncat) ,t) `((top typed_hvncat_static) ,t))))

'|'| (lambda (e) (expand-forms `(call |'| ,(cadr e))))

Expand Down
147 changes: 147 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,153 @@ using Base: typed_hvncat
@test [["A";"B"];;"C";"D"] == ["A" "C"; "B" "D"]
end

@testset "array construction using numbers" begin
# test array construction using hvcat [x x; y y] and hvncat [x x;;; y y]

@testset "hvcat array construction" begin
function test_hvcat(x1, x2, x3, x4)
# small arrays are constructed differently (manually unrolled)
A = [x1 x2; x3 x4]
@test A[1,1] == x1
@test A[1,2] == x2
@test A[2,1] == x3
@test A[2,2] == x4

AT = Float64[x1 x2; x3 x4]
@test AT == A

# large arrays are constructed using a loop
A = [x1 x2 x3 x4; x2 x3 x4 x1; x3 x4 x1 x2; x4 x1 x2 x3; x1 x2 x3 x4]
@test A[1,1] == x1
@test A[1,2] == x2
@test A[1,3] == x3
@test A[1,4] == x4
@test A[2,1] == x2
@test A[2,2] == x3
@test A[2,3] == x4
@test A[2,4] == x1
@test A[3,1] == x3
@test A[3,2] == x4
@test A[3,3] == x1
@test A[3,4] == x2
@test A[4,1] == x4
@test A[4,2] == x1
@test A[4,3] == x2
@test A[4,4] == x3
@test A[5,1] == x1
@test A[5,2] == x2
@test A[5,3] == x3
@test A[5,4] == x4
AT = Float64[x1 x2 x3 x4; x2 x3 x4 x1; x3 x4 x1 x2; x4 x1 x2 x3; x1 x2 x3 x4]
@test AT == A
end

test_hvcat(1, 2, 3, 4)
test_hvcat(1.0, 2.0, 3.0, 4.0)
test_hvcat(1.0, 2, 3.0, 4)
end

@testset "hvncat array construction" begin
# Test hvncat with dims as Int and Tuple, row_first true/false
# Testing 3D and 4D outputs, same and mixed element types, and different sizes

function test_hvncat(x1, x2, x3, x4)
# 3D arrays with dims as Int (row-first by default)
A = [x1 x2;;; x3 x4] # 1x2x2 Array (row-first)
@test size(A) == (1, 2, 2)
@test A[1,1,1] == x1
@test A[1,1,2] == x3
@test A[1,2,1] == x2
@test A[1,2,2] == x4

AT = Float64[x1 x2;;; x3 x4]
@test AT == A

A = [x1 x2; x3 x4;;; x2 x3; x4 x1;;; x3 x4; x1 x2;;; x4 x1; x2 x3;;; x1 x2; x3 x4]
@test size(A) == (2, 2, 5)

@test A[:, :, 1] == [x1 x2; x3 x4]
@test A[:, :, 2] == [x2 x3; x4 x1]
@test A[:, :, 3] == [x3 x4; x1 x2]
@test A[:, :, 4] == [x4 x1; x2 x3]
@test A[:, :, 5] == [x1 x2; x3 x4]

AT = Float64[x1 x2; x3 x4;;; x2 x3; x4 x1;;; x3 x4; x1 x2;;; x4 x1; x2 x3;;; x1 x2; x3 x4]
@test AT == A
end

test_hvncat(1, 2, 3, 4)
test_hvncat(1.0, 2.0, 3.0, 4.0)
test_hvncat(1.0, 2, 3.0, 4)
end

@testset "hvcat vs hvcat_static" begin
# number cases generate the same result
@test Base.hvcat_static(Val{(2,2)}(), 1, 2, 3, 4) == hvcat((2, 2), 1, 2, 3, 4)
@test Base.hvcat_static(Val{(2,2)}(), 1, 2, 3.0, 4.0) == hvcat((2, 2), 1, 2, 3.0, 4.0)
@test Base.typed_hvcat_static(Float64, Val{(2,2)}(), 1, 2, 3, 4) == Base.typed_hvcat(Float64, (2, 2), 1, 2, 3, 4)
@test Base.typed_hvcat_static(Float64, Val{(2,2)}(), 1, 2, 3.0, 4.0) == Base.typed_hvcat(Float64, (2, 2), 1, 2, 3.0, 4.0)

# non-number cases will be fallbacks to hvcat
@test Base.hvcat_static(Val{(2,2)}(), "a", "b", "c", "d") == hvcat((2, 2), "a", "b", "c", "d")
@test Base.typed_hvcat_static(String, Val{(2,2)}(), "a", "b", "c", "d") == Base.typed_hvcat(String, (2, 2), "a", "b", "c", "d")

# non-scalar cases will be fallbacks to hvcat
@test Base.hvcat_static(Val{(2,2)}(), [1 2], [2 2], [3 3], [4 4]) == hvcat((2, 2), [1 2], [2 2], [3 3], [4 4])
@test Base.typed_hvcat_static(Float64, Val{(2,2)}(), [1 2], [2 2], [3 3], [4 4]) == Base.typed_hvcat(Float64, (2, 2), [1 2], [2 2], [3 3], [4 4])

@test_throws DimensionMismatch hvcat((2,4), 2, 3, 4, 5)
@test_throws DimensionMismatch Base.hvcat_static(Val{(2,4)}(), 2, 3, 4, 5)
end

@testset "hvncat vs hvncat_static" begin
# basic test
x = rand(8)

A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{true}(), x...)
B = hvncat((2, 2, 2), true, x...)
@test A == B

A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{false}(), x...)
B = hvncat((2, 2, 2), false, x...)
@test A == B

# test different eltypes
x, y = rand(4), rand(1:10, 4)
A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{true}(), x..., y...)
B = hvncat((2, 2, 2), true, x..., y...)
@test A == B

A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{false}(), x..., y...)
B = hvncat((2, 2, 2), false, x..., y...)
@test A == B

# test large array
x = rand(24)
A = Base.hvncat_static(Val{(2, 3, 4)}(), Val{true}(), x...)
B = hvncat((2, 3, 4), true, x...)
@test A == B

A = Base.hvncat_static(Val{(2, 3, 4)}(), Val{false}(), x...)
B = hvncat((2, 3, 4), false, x...)
@test A == B
end

@testset "Static method doesn't apply to non-number types" begin
# StaticArrays "abused" the syntax SA[1 2; 3 4] to create a static array
# but unfortunately SA isn't a number-like type.
# We need to ensure that our typed_hvcat_static specialization doesn't affect its usage
# and returns the same result as the non-static version.

struct FOO end
Base.typed_hvcat(::Type{FOO}, dims::Dims, xs::Number...) = "Foo typed_hvcat"
Base.typed_hvncat(::Type{FOO}, dims::Dims, row_first::Bool, xs::Number...) = "Foo typed_hvncat"

@test FOO[1 2; 3 4] == "Foo typed_hvcat"
@test FOO[1 2;;; 3 4] == "Foo typed_hvncat"
end
end

@testset "stack" begin
# Basics
for args in ([[1, 2]], [1:2, 3:4], [[1 2; 3 4], [5 6; 7 8]],
Expand Down