diff --git a/base/abstractarray.jl b/base/abstractarray.jl index bf2a6ebabecba..7523dc10f036f 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -2255,6 +2255,72 @@ end typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as...) where T = typed_hvncat(T, rows_to_dimshape(rows), true, as...) +# A fast version of hvcat for the case where we have static size information of xs +# and the number of rows is known at compile time -- we can eliminate all the runtime +# size checks. For cases that static size information is not beneficial, we fall back to +# the general hvcat/typed_hvcat methods. +@generated function typed_hvcat_static(::Type{T}, ::Val{rows}, xs::Number...) where {T<:Number, rows} + nr = length(rows) + nc = rows[1] + for i = 2:nr + if nc != rows[i] + return quote + msg = "row " * string($i) * " has mismatched number of columns (expected " * string($nc) * ", got " * string($rows[$i]) * ")" + throw(DimensionMismatch(msg)) + end + end + end + + len = length(xs) + if nr*nc != len + return quote + msg = "argument count " * string($len) * " does not match specified shape " * string(($nr, $nc)) + throw(ArgumentError(msg)) + end + end + + if len <= 16 + # For small array construction, manually unroll the loop for better performance + assigns = Expr[] + k = 1 + for i in 1:nr + for j in 1:nc + ex = :(a[$i, $j] = xs[$k]) + push!(assigns, ex) + k += 1 + end + end + + return quote + a = Matrix{$T}(undef, $nr, $nc) + $(assigns...) + return a + end + end + + # For generic fallback case, directly call a normal `hvcat_fill!` function + # to avoid the overhead of generating a large number of duplicated expressions. + quote + a = Matrix{$T}(undef, $nr, $nc) + hvcat_fill!(a, xs) + end +end +@inline function hvcat_static(::Val{rows}, x::T, xs::Vararg{T}) where {rows, T<:Number} + typed_hvcat_static(T, Val{rows}(), x, xs...) +end +@inline function hvcat_static(::Val{rows}, xs::Number...) where {rows} + typed_hvcat_static(promote_typeof(xs...), Val{rows}(), xs...) +end +@inline function typed_hvcat_static(::Type{T}, ::Val{rows}, xs...) where {T, rows} + # fallback to the general case + typed_hvcat(T, rows, xs...) +end +@inline function hvcat_static(::Val{rows}, xs...) where {rows} + # fallback to the general case + hvcat(rows, xs...) +end + + ## N-dimensional concatenation ## """ @@ -2750,6 +2816,94 @@ end Ai end +# Static version of hvncat for better performance with scalar numbers +# See the comments for hvcat_static for more details. +@generated function typed_hvncat_static(::Type{T}, ::Val{dims}, ::Val{row_first}, xs::Number...) where {T<:Number, dims, row_first} + for d in dims + if d <= 0 + return quote + throw(ArgumentError("`dims` argument must contain positive integers")) + end + end + end + + N = length(dims) + lengtha = prod(dims) + lengthx = length(xs) + if lengtha != lengthx + return quote + msg = "argument count does not match specified shape (expected " * string($lengtha) * ", got " * string($lengthx) * ")" + throw(ArgumentError(msg)) + end + end + + if lengthx <= 16 + # For small array construction, manually unroll the loop + assigns = Expr[] + nr, nc = dims[1], dims[2] + na = if N > 2 + n = 1 + for d in 3:N + n *= dims[d] + end + n + else + 1 + end + nrc = nr * nc + + if row_first + k = 1 + for d in 1:na + dd = nrc * (d - 1) + for i in 1:nr + Ai = dd + i + for j in 1:nc + ex = :(A[$Ai] = xs[$k]) + push!(assigns, ex) + k += 1 + Ai += nr + end + end + end + else + k = 1 + for i in 1:lengtha + ex = :(A[$i] = xs[$k]) + push!(assigns, ex) + k += 1 + end + end + + return quote + A = Array{$T, $N}(undef, $dims...) + $(assigns...) + return A + end + end + + # For larger arrays, use the regular loop + quote + A = Array{$T, $N}(undef, $dims...) + hvncat_fill!(A, $row_first, xs) + return A + end +end +@inline function hvncat_static(::Val{dims}, ::Val{row_first}, x::T, xs::Vararg{T}) where {dims, row_first, T<:Number} + typed_hvncat_static(T, Val{dims}(), Val{row_first}(), x, xs...) +end +@inline function hvncat_static(::Val{dims}, ::Val{row_first}, xs::Number...) where {dims, row_first} + typed_hvncat_static(promote_typeof(xs...), Val{dims}(), Val{row_first}(), xs...) +end +@inline function typed_hvncat_static(::Type{T}, ::Val{dims}, ::Val{row_first}, xs...) where {T, dims, row_first} + # fallback to the general case + typed_hvncat(T, dims, row_first, xs...) +end +@inline function hvncat_static(::Val{dims}, ::Val{row_first}, xs...) where {dims, row_first} + # fallback to the general case + hvncat(dims, row_first, xs...) +end + """ stack(iter; [dims]) diff --git a/base/expr.jl b/base/expr.jl index bf9e2ef2bf92c..56c035a2524f4 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -1684,15 +1684,23 @@ function generated_body_to_codeinfo(ex::Expr, defmod::Module, isva::Bool) end ci.isva = isva code = ci.code - bindings = IdSet{Core.Binding}() + bindings = Core.svec() for i = 1:length(code) stmt = code[i] if isa(stmt, GlobalRef) - push!(bindings, convert(Core.Binding, stmt)) + # plain loop is used rather than fancy `any` or `IdSet` to avoid world-age issues + # if we want to use generated function during the early bootstrap stage when + # these functions are not available. (e.g., using `@generated` in `abstractarray.jl`) + for x in bindings + if x === stmt + continue + end + end + bindings = Core.svec(bindings..., convert(Core.Binding, stmt)) end end if !isempty(bindings) - ci.edges = Core.svec(bindings...) + ci.edges = bindings end return ci end diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 8f0bcd55ac194..3d5c41a15c0d8 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -2254,7 +2254,7 @@ (define (expand-vcat e (vcat '((top vcat))) - (hvcat '((top hvcat))) + (hvcat '((top hvcat_static))) (hvcat_rows '((top hvcat_rows)))) (let ((a (cdr e))) (if (any assignment? a) @@ -2276,11 +2276,13 @@ (if (any (lambda (row) (any vararg? row)) rows) `(call ,@hvcat_rows ,@(map (lambda (x) `(tuple ,@x)) rows)) `(call ,@hvcat - (tuple ,@(map length rows)) + (new (curly (top Val) (call (core tuple) ,@(map length rows)))) ,@(apply append rows)))) `(call ,@vcat ,@a)))))) -(define (expand-ncat e (hvncat '((top hvncat)))) +(define (expand-ncat e + (hvncat '((top hvncat))) + (hvncat_static '((top hvncat_static)))) (define (is-row a) (and (pair? a) (or (eq? (car a) 'row) (eq? (car a) 'nrow)))) @@ -2384,7 +2386,7 @@ (let ((shape (get-shape a is-row-first d))) (if (is-balanced shape) (let ((dims `(tuple ,@(reverse (get-dims a is-row-first d))))) - `(call ,@hvncat ,dims ,(tf is-row-first) ,@aflat)) + `(call ,@hvncat_static (new (curly (top Val) ,dims)) (new (curly (top Val) ,(tf is-row-first))) ,@aflat)) `(call ,@hvncat ,(tuplize shape) ,(tf is-row-first) ,@aflat)))))))) (define (maybe-ssavalue lhss x in-lhs?) @@ -2899,13 +2901,13 @@ (lambda (e) (let ((t (cadr e)) (e (cdr e))) - (expand-vcat e `((top typed_vcat) ,t) `((top typed_hvcat) ,t) `((top typed_hvcat_rows) ,t)))) + (expand-vcat e `((top typed_vcat) ,t) `((top typed_hvcat_static) ,t) `((top typed_hvcat_rows) ,t)))) 'typed_ncat (lambda (e) (let ((t (cadr e)) (e (cdr e))) - (expand-ncat e `((top typed_hvncat) ,t)))) + (expand-ncat e `((top typed_hvncat) ,t) `((top typed_hvncat_static) ,t)))) '|'| (lambda (e) (expand-forms `(call |'| ,(cadr e)))) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 01e13f17460b5..fdf4600494a77 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1745,6 +1745,153 @@ using Base: typed_hvncat @test [["A";"B"];;"C";"D"] == ["A" "C"; "B" "D"] end +@testset "array construction using numbers" begin + # test array construction using hvcat [x x; y y] and hvncat [x x;;; y y] + + @testset "hvcat array construction" begin + function test_hvcat(x1, x2, x3, x4) + # small arrays are constructed differently (manually unrolled) + A = [x1 x2; x3 x4] + @test A[1,1] == x1 + @test A[1,2] == x2 + @test A[2,1] == x3 + @test A[2,2] == x4 + + AT = Float64[x1 x2; x3 x4] + @test AT == A + + # large arrays are constructed using a loop + A = [x1 x2 x3 x4; x2 x3 x4 x1; x3 x4 x1 x2; x4 x1 x2 x3; x1 x2 x3 x4] + @test A[1,1] == x1 + @test A[1,2] == x2 + @test A[1,3] == x3 + @test A[1,4] == x4 + @test A[2,1] == x2 + @test A[2,2] == x3 + @test A[2,3] == x4 + @test A[2,4] == x1 + @test A[3,1] == x3 + @test A[3,2] == x4 + @test A[3,3] == x1 + @test A[3,4] == x2 + @test A[4,1] == x4 + @test A[4,2] == x1 + @test A[4,3] == x2 + @test A[4,4] == x3 + @test A[5,1] == x1 + @test A[5,2] == x2 + @test A[5,3] == x3 + @test A[5,4] == x4 + AT = Float64[x1 x2 x3 x4; x2 x3 x4 x1; x3 x4 x1 x2; x4 x1 x2 x3; x1 x2 x3 x4] + @test AT == A + end + + test_hvcat(1, 2, 3, 4) + test_hvcat(1.0, 2.0, 3.0, 4.0) + test_hvcat(1.0, 2, 3.0, 4) + end + + @testset "hvncat array construction" begin + # Test hvncat with dims as Int and Tuple, row_first true/false + # Testing 3D and 4D outputs, same and mixed element types, and different sizes + + function test_hvncat(x1, x2, x3, x4) + # 3D arrays with dims as Int (row-first by default) + A = [x1 x2;;; x3 x4] # 1x2x2 Array (row-first) + @test size(A) == (1, 2, 2) + @test A[1,1,1] == x1 + @test A[1,1,2] == x3 + @test A[1,2,1] == x2 + @test A[1,2,2] == x4 + + AT = Float64[x1 x2;;; x3 x4] + @test AT == A + + A = [x1 x2; x3 x4;;; x2 x3; x4 x1;;; x3 x4; x1 x2;;; x4 x1; x2 x3;;; x1 x2; x3 x4] + @test size(A) == (2, 2, 5) + + @test A[:, :, 1] == [x1 x2; x3 x4] + @test A[:, :, 2] == [x2 x3; x4 x1] + @test A[:, :, 3] == [x3 x4; x1 x2] + @test A[:, :, 4] == [x4 x1; x2 x3] + @test A[:, :, 5] == [x1 x2; x3 x4] + + AT = Float64[x1 x2; x3 x4;;; x2 x3; x4 x1;;; x3 x4; x1 x2;;; x4 x1; x2 x3;;; x1 x2; x3 x4] + @test AT == A + end + + test_hvncat(1, 2, 3, 4) + test_hvncat(1.0, 2.0, 3.0, 4.0) + test_hvncat(1.0, 2, 3.0, 4) + end + + @testset "hvcat vs hvcat_static" begin + # number cases generate the same result + @test Base.hvcat_static(Val{(2,2)}(), 1, 2, 3, 4) == hvcat((2, 2), 1, 2, 3, 4) + @test Base.hvcat_static(Val{(2,2)}(), 1, 2, 3.0, 4.0) == hvcat((2, 2), 1, 2, 3.0, 4.0) + @test Base.typed_hvcat_static(Float64, Val{(2,2)}(), 1, 2, 3, 4) == Base.typed_hvcat(Float64, (2, 2), 1, 2, 3, 4) + @test Base.typed_hvcat_static(Float64, Val{(2,2)}(), 1, 2, 3.0, 4.0) == Base.typed_hvcat(Float64, (2, 2), 1, 2, 3.0, 4.0) + + # non-number cases will be fallbacks to hvcat + @test Base.hvcat_static(Val{(2,2)}(), "a", "b", "c", "d") == hvcat((2, 2), "a", "b", "c", "d") + @test Base.typed_hvcat_static(String, Val{(2,2)}(), "a", "b", "c", "d") == Base.typed_hvcat(String, (2, 2), "a", "b", "c", "d") + + # non-scalar cases will be fallbacks to hvcat + @test Base.hvcat_static(Val{(2,2)}(), [1 2], [2 2], [3 3], [4 4]) == hvcat((2, 2), [1 2], [2 2], [3 3], [4 4]) + @test Base.typed_hvcat_static(Float64, Val{(2,2)}(), [1 2], [2 2], [3 3], [4 4]) == Base.typed_hvcat(Float64, (2, 2), [1 2], [2 2], [3 3], [4 4]) + + @test_throws DimensionMismatch hvcat((2,4), 2, 3, 4, 5) + @test_throws DimensionMismatch Base.hvcat_static(Val{(2,4)}(), 2, 3, 4, 5) + end + + @testset "hvncat vs hvncat_static" begin + # basic test + x = rand(8) + + A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{true}(), x...) + B = hvncat((2, 2, 2), true, x...) + @test A == B + + A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{false}(), x...) + B = hvncat((2, 2, 2), false, x...) + @test A == B + + # test different eltypes + x, y = rand(4), rand(1:10, 4) + A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{true}(), x..., y...) + B = hvncat((2, 2, 2), true, x..., y...) + @test A == B + + A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{false}(), x..., y...) + B = hvncat((2, 2, 2), false, x..., y...) + @test A == B + + # test large array + x = rand(24) + A = Base.hvncat_static(Val{(2, 3, 4)}(), Val{true}(), x...) + B = hvncat((2, 3, 4), true, x...) + @test A == B + + A = Base.hvncat_static(Val{(2, 3, 4)}(), Val{false}(), x...) + B = hvncat((2, 3, 4), false, x...) + @test A == B + end + + @testset "Static method doesn't apply to non-number types" begin + # StaticArrays "abused" the syntax SA[1 2; 3 4] to create a static array + # but unfortunately SA isn't a number-like type. + # We need to ensure that our typed_hvcat_static specialization doesn't affect its usage + # and returns the same result as the non-static version. + + struct FOO end + Base.typed_hvcat(::Type{FOO}, dims::Dims, xs::Number...) = "Foo typed_hvcat" + Base.typed_hvncat(::Type{FOO}, dims::Dims, row_first::Bool, xs::Number...) = "Foo typed_hvncat" + + @test FOO[1 2; 3 4] == "Foo typed_hvcat" + @test FOO[1 2;;; 3 4] == "Foo typed_hvncat" + end +end + @testset "stack" begin # Basics for args in ([[1, 2]], [1:2, 3:4], [[1 2; 3 4], [5 6; 7 8]],