From 5f0404dc7f9f513a39d40f861e6dd0cf230eb5f0 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sun, 16 Mar 2025 18:39:19 -0300 Subject: [PATCH 01/15] NDArray construction bug fix --- lib/mps/ndarray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/mps/ndarray.jl b/lib/mps/ndarray.jl index 34d1a69f9..6fe2418a0 100644 --- a/lib/mps/ndarray.jl +++ b/lib/mps/ndarray.jl @@ -114,7 +114,7 @@ function MPSNDArray(arr::MtlArray{T,N}) where {T,N} arrsize = size(arr) @assert arrsize[1] * sizeof(T) % 16 == 0 "First dimension of input MtlArray must have a byte size divisible by 16" desc = MPSNDArrayDescriptor(T, arrsize) - return MPSNDArray(arr.data[], UInt(arr.offset), desc) + return MPSNDArray(arr.data[], UInt(arr.offset) * sizeof(T), desc) end function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false) From a685b449fde03cd6f03e04366736a550fa9022e6 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sun, 16 Mar 2025 16:16:26 -0300 Subject: [PATCH 02/15] `size` for MPSMatrix --- lib/mps/matrix.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index f1b156cd2..102db9f22 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -119,6 +119,13 @@ function MPSMatrix(arr::MtlArray{T,3}) where T return MPSMatrix(arr, desc, offset) end +function Base.size(mat::MPS.MPSMatrix) + if mat.matrices > 1 + return Int64.((mat.matrices, mat.rows, mat.columns)) + else + return Int64.((mat.rows, mat.columns)) + end +end ## matrix multiplication From 0c1bd0951476486504863f5ba23df50473d7c81c Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 10:54:20 -0300 Subject: [PATCH 03/15] Add and fix MPSMatrix tests --- test/mps/matrix.jl | 148 +++++++++++++++++++++++---------------------- 1 file changed, 77 insertions(+), 71 deletions(-) diff --git a/test/mps/matrix.jl b/test/mps/matrix.jl index 75e4c2fc4..b0c232e16 100644 --- a/test/mps/matrix.jl +++ b/test/mps/matrix.jl @@ -45,77 +45,83 @@ using .MPS: MPSMatrix rowBytes = sizeof(T) * cols mats = 4 - desc = MPSMatrixDescriptor(rows, cols, rowBytes, T) - devmat = MPSMatrix(dev, desc) - @test devmat isa MPSMatrix - @test devmat.device == dev - @test devmat.rows == rows - @test devmat.columns == cols - @test devmat.rowBytes == rowBytes - @test devmat.matrices == 1 - @test devmat.dataType == DT - @test devmat.matrixBytes == rowBytes * rows - @test devmat.offset == 0 - - mat = MtlMatrix{T}(undef, rows, cols) - acols, arows = size(mat) - arowBytes = sizeof(T) * acols - abufmat = MPSMatrix(mat) - @test abufmat isa MPSMatrix - @test abufmat.device == dev - @test abufmat.rows == arows - @test abufmat.columns == acols - @test abufmat.rowBytes == arowBytes - @test abufmat.matrices == 1 - @test abufmat.dataType == DT - @test abufmat.matrixBytes == arowBytes * arows - @test abufmat.offset == 0 - @test abufmat.data == mat.data[] - - vmat = @view mat[:, 2:3] - vcols, vrows = size(vmat) - vrowBytes = sizeof(T) * vcols - vbufmat = MPSMatrix(vmat) - @test vbufmat isa MPSMatrix - @test vbufmat.device == dev - @test vbufmat.rows == vrows - @test vbufmat.columns == vcols - @test vbufmat.rowBytes == vrowBytes - @test vbufmat.matrices == 1 - @test vbufmat.dataType == DT - @test vbufmat.matrixBytes == vrowBytes * vrows - @test vbufmat.offset == vmat.offset * sizeof(T) - @test vbufmat.data == vmat.data[] - - arr = MtlArray{T,3}(undef, rows, cols, mats) - mcols, mrows, mmats = size(arr) - mrowBytes = sizeof(T) * mcols - mpsmat = MPSMatrix(mat) - @test mpsmat isa MPSMatrix - @test mpsmat.device == dev - @test mpsmat.rows == mrows - @test mpsmat.columns == mcols - @test mpsmat.rowBytes == mrowBytes - @test mpsmat.matrices == 1 - @test mpsmat.dataType == DT - @test mpsmat.matrixBytes == mrowBytes * mrows - @test mpsmat.offset == 0 - @test mpsmat.data == mat.data[] - - vec = MtlVector{T}(undef, rows) - veccols, vecrows = length(vec), 1 - vecrowBytes = sizeof(T)*veccols - vmpsmat = MPSMatrix(vec) - @test vmpsmat isa MPSMatrix - @test vmpsmat.device == dev - @test vmpsmat.rows == vecrows - @test vmpsmat.columns == veccols - @test vmpsmat.rowBytes == vecrowBytes - @test vmpsmat.matrices == 1 - @test vmpsmat.dataType == DT - @test vmpsmat.matrixBytes == vecrowBytes*vecrows - @test vmpsmat.offset == 0 - @test vmpsmat.data == vec.data[] + let desc = MPSMatrixDescriptor(rows, cols, rowBytes, T) + devmat = MPSMatrix(dev, desc) + @test devmat isa MPSMatrix + @test devmat.device == dev + @test devmat.rows == rows + @test devmat.columns == cols + @test devmat.rowBytes == rowBytes + @test devmat.matrices == 1 + @test devmat.dataType == DT + @test devmat.matrixBytes == rowBytes * rows + @test devmat.offset == 0 + @test size(devmat) == (rows, cols) + end + + let mat = MtlMatrix{T}(undef, rows, cols) + acols, arows = size(mat) + arowBytes = sizeof(T) * acols + abufmat = MPSMatrix(mat) + @test abufmat isa MPSMatrix + @test abufmat.device == dev + @test abufmat.rows == arows + @test abufmat.columns == acols + @test abufmat.rowBytes == arowBytes + @test abufmat.matrices == 1 + @test abufmat.dataType == DT + @test abufmat.matrixBytes == arowBytes * arows + @test abufmat.offset == 0 + @test abufmat.data == mat.data[] + + vmat = @view mat[:, 2:3] + vcols, vrows = size(vmat) + vrowBytes = sizeof(T) * vcols + vbufmat = MPSMatrix(vmat) + @test vbufmat isa MPSMatrix + @test vbufmat.device == dev + @test vbufmat.rows == vrows + @test vbufmat.columns == vcols + @test vbufmat.rowBytes == vrowBytes + @test vbufmat.matrices == 1 + @test vbufmat.dataType == DT + @test vbufmat.matrixBytes == vrowBytes * vrows + @test vbufmat.offset == vmat.offset * sizeof(T) + @test vbufmat.data == vmat.data[] + end + + let arr = MtlArray{T,3}(undef, rows, cols, mats) + mcols, mrows, mmats = size(arr) + mrowBytes = sizeof(T) * mcols + mpsmat = MPSMatrix(arr) + @test mpsmat isa MPSMatrix + @test mpsmat.device == dev + @test mpsmat.rows == mrows + @test mpsmat.columns == mcols + @test mpsmat.rowBytes == mrowBytes + @test mpsmat.matrices == mmats + @test mpsmat.dataType == DT + @test mpsmat.matrixBytes == mrowBytes * mrows + @test mpsmat.offset == 0 + @test mpsmat.data == arr.data[] + @test size(mpsmat) == (mmats, mrows, mcols) + end + + let vec = MtlVector{T}(undef, rows) + veccols, vecrows = length(vec), 1 + vecrowBytes = sizeof(T)*veccols + vmpsmat = MPSMatrix(vec) + @test vmpsmat isa MPSMatrix + @test vmpsmat.device == dev + @test vmpsmat.rows == vecrows + @test vmpsmat.columns == veccols + @test vmpsmat.rowBytes == vecrowBytes + @test vmpsmat.matrices == 1 + @test vmpsmat.dataType == DT + @test vmpsmat.matrixBytes == vecrowBytes*vecrows + @test vmpsmat.offset == 0 + @test vmpsmat.data == vec.data[] + end end From db39faebd109a11c4e0d9cc93cdee8ce8e3b11f3 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 12:40:04 -0300 Subject: [PATCH 04/15] Don't track coverage for examples --- .buildkite/pipeline.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 1a4e416a8..2d69ec0e5 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -14,7 +14,6 @@ steps: dirs: - src - lib - - examples agents: queue: "juliaecosystem" os: "macos" @@ -84,7 +83,6 @@ steps: dirs: - src - lib - - examples env: MTL_DEBUG_LAYER: '1' MTL_SHADER_VALIDATION: '1' @@ -113,7 +111,6 @@ steps: dirs: - src - lib - - examples env: JULIA_LLVM_ARGS: '--opaque-pointers' agents: From ad60e0752d8aad8075844780ffb13a6dab41a43b Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:02:34 -0300 Subject: [PATCH 05/15] Exclude device-only code from coverage --- src/MetalKernels.jl | 2 ++ src/accumulate.jl | 2 ++ src/broadcast.jl | 10 ++++++++++ 3 files changed, 14 insertions(+) diff --git a/src/MetalKernels.jl b/src/MetalKernels.jl index cdde3d0f2..7fa17e77f 100644 --- a/src/MetalKernels.jl +++ b/src/MetalKernels.jl @@ -131,6 +131,7 @@ end ## indexing +## COV_EXCL_START @device_override @inline function KA.__index_Local_Linear(ctx) return thread_position_in_threadgroup_1d() end @@ -191,5 +192,6 @@ end @device_override @inline function KA.__print(args...) # TODO end +## COV_EXCL_STOP end diff --git a/src/accumulate.jl b/src/accumulate.jl index 1cf4c1b51..31e2dc4fe 100644 --- a/src/accumulate.jl +++ b/src/accumulate.jl @@ -1,3 +1,4 @@ +## COV_EXCL_START function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArray, Rdim, Rpre, Rpost, Rother, neutral, init, ::Val{maxthreads}, ::Val{inclusive}=Val(true)) where {T, maxthreads, inclusive} @@ -100,6 +101,7 @@ function aggregate_partial_scan(op::Function, output::AbstractArray, aggregates: return end +## COV_EXCL_STOP function scan!(f::Function, output::WrappedMtlArray{T}, input::WrappedMtlArray; dims::Integer, init=nothing, neutral=GPUArrays.neutral_element(f, T)) where {T} diff --git a/src/broadcast.jl b/src/broadcast.jl index 9ece5fc1b..05ea7c196 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -59,6 +59,7 @@ end _broadcast_shapes[Is] += 1 end if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD + ## COV_EXCL_START function broadcast_cartesian_static(dest, bc, Is) i = thread_position_in_grid_1d() stride = threads_per_grid_1d() @@ -69,6 +70,7 @@ end end return end + ## COV_EXCL_STOP Is = StaticCartesianIndices(Is) kernel = @metal launch=false broadcast_cartesian_static(dest, bc, Is) @@ -82,6 +84,7 @@ end # try to use the most appropriate hardware index to avoid integer division if ndims(dest) == 1 || (isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear)) + ## COV_EXCL_START function broadcast_linear(dest, bc) i = thread_position_in_grid_1d() stride = threads_per_grid_1d() @@ -91,12 +94,14 @@ end end return end + ## COV_EXCL_STOP kernel = @metal launch=false broadcast_linear(dest, bc) elements = cld(length(dest), 4) threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup) groups = cld(elements, threads) elseif ndims(dest) == 2 + ## COV_EXCL_START function broadcast_2d(dest, bc) is = Tuple(thread_position_in_grid_2d()) stride = threads_per_grid_2d() @@ -107,6 +112,7 @@ end end return end + ## COV_EXCL_STOP kernel = @metal launch=false broadcast_2d(dest, bc) w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth) @@ -114,6 +120,7 @@ end threads = (w, h) groups = cld.(size(dest), threads) elseif ndims(dest) == 3 + ## COV_EXCL_START function broadcast_3d(dest, bc) is = Tuple(thread_position_in_grid_3d()) stride = threads_per_grid_3d() @@ -126,6 +133,7 @@ end end return end + ## COV_EXCL_STOP kernel = @metal launch=false broadcast_3d(dest, bc) w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth) @@ -135,6 +143,7 @@ end threads = (w, h, d) groups = cld.(size(dest), threads) else + ## COV_EXCL_START function broadcast_cartesian(dest, bc) i = thread_position_in_grid_1d() stride = threads_per_grid_1d() @@ -145,6 +154,7 @@ end end return end + ## COV_EXCL_STOP kernel = @metal launch=false broadcast_cartesian(dest, bc) elements = cld(length(dest), 4) From ed75455b266810331895a381d88a1a697c0873d6 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:02:57 -0300 Subject: [PATCH 06/15] Exclude profiler tests from coverage --- src/utilities.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utilities.jl b/src/utilities.jl index a5285a337..a6ee9c042 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -146,7 +146,7 @@ end ## profile macro - +## COV_EXCL_START function profile_dir() root = pwd() i = 1 @@ -239,3 +239,5 @@ macro profile(ex...) end end end +## COV_EXCL_START + From 2b4a002935a19082c8ba7a0e97610f51cb43398d Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:13:00 -0300 Subject: [PATCH 07/15] format --- test/mps/matrix.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/mps/matrix.jl b/test/mps/matrix.jl index b0c232e16..da61417c6 100644 --- a/test/mps/matrix.jl +++ b/test/mps/matrix.jl @@ -90,7 +90,7 @@ using .MPS: MPSMatrix @test vbufmat.data == vmat.data[] end - let arr = MtlArray{T,3}(undef, rows, cols, mats) + let arr = MtlArray{T, 3}(undef, rows, cols, mats) mcols, mrows, mmats = size(arr) mrowBytes = sizeof(T) * mcols mpsmat = MPSMatrix(arr) @@ -109,7 +109,7 @@ using .MPS: MPSMatrix let vec = MtlVector{T}(undef, rows) veccols, vecrows = length(vec), 1 - vecrowBytes = sizeof(T)*veccols + vecrowBytes = sizeof(T) * veccols vmpsmat = MPSMatrix(vec) @test vmpsmat isa MPSMatrix @test vmpsmat.device == dev @@ -118,7 +118,7 @@ using .MPS: MPSMatrix @test vmpsmat.rowBytes == vecrowBytes @test vmpsmat.matrices == 1 @test vmpsmat.dataType == DT - @test vmpsmat.matrixBytes == vecrowBytes*vecrows + @test vmpsmat.matrixBytes == vecrowBytes * vecrows @test vmpsmat.offset == 0 @test vmpsmat.data == vec.data[] end From 5a5bc19226775450da8ab413b41a6c78db09c61c Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:43:22 -0300 Subject: [PATCH 08/15] Struct coverage --- test/mps/size.jl | 65 ++++++++++++++++++++++++++++++++++++++++-------- test/mtl/size.jl | 35 ++++++++++++++++++-------- 2 files changed, 79 insertions(+), 21 deletions(-) diff --git a/test/mps/size.jl b/test/mps/size.jl index 12687dce9..230b27266 100644 --- a/test/mps/size.jl +++ b/test/mps/size.jl @@ -1,13 +1,18 @@ # ## size @testset "size" begin + siz1 = MPS.MPSSize() + @test siz1.width == 0 + @test siz1.height == 0 + @test siz1.depth == 0 + dim1 = rand() dim2 = rand() dim3 = rand() @test MPS.MPSSize(dim1) == MPS.MPSSize((dim1,)) - @test MPS.MPSSize(dim1,dim2) == MPS.MPSSize((dim1,dim2)) - @test MPS.MPSSize(dim1,dim2,dim3) == MPS.MPSSize((dim1,dim2,dim3)) + @test MPS.MPSSize(dim1, dim2) == MPS.MPSSize((dim1, dim2)) + @test MPS.MPSSize(dim1, dim2, dim3) == MPS.MPSSize((dim1, dim2, dim3)) end @testset "origin" begin @@ -15,10 +20,25 @@ end dim2 = rand() dim3 = rand() - orig = MPS.MPSOrigin(dim1,dim2,dim3) - @test orig.x == dim1 - @test orig.y == dim2 - @test orig.z == dim3 + orig1 = MPS.MPSOrigin(dim1, dim2, dim3) + @test orig1.x == dim1 + @test orig1.y == dim2 + @test orig1.z == dim3 + + orig2 = MPS.MPSOrigin(dim1, dim2) + @test orig2.x == dim1 + @test orig2.y == dim2 + @test orig2.z == 0.0 + + orig3 = MPS.MPSOrigin(dim1) + @test orig3.x == dim1 + @test orig3.y == 0.0 + @test orig3.z == 0.0 + + orig4 = MPS.MPSOrigin() + @test orig4.x == 0.0 + @test orig4.y == 0.0 + @test orig4.z == 0.0 end @testset "offset" begin @@ -26,8 +46,33 @@ end dim2 = rand(Int) dim3 = rand(Int) - off = MPS.MPSOffset(dim1,dim2,dim3) - @test off.x == dim1 - @test off.y == dim2 - @test off.z == dim3 + off1 = MPS.MPSOffset(dim1, dim2, dim3) + @test off1.x == dim1 + @test off1.y == dim2 + @test off1.z == dim3 + + off2 = MPS.MPSOffset(dim1, dim2) + @test off2.x == dim1 + @test off2.y == dim2 + @test off2.z == 0 + + off3 = MPS.MPSOffset(dim1) + @test off3.x == dim1 + @test off3.y == 0 + @test off3.z == 0 + + off4 = MPS.MPSOffset() + @test off4.x == 0 + @test off4.y == 0 + @test off4.z == 0 +end + +@testset "region" begin + reg1 = MPS.MPSRegion() + @test reg1.origin isa MPS.MPSOrigin + @test reg1.size isa MPS.MPSSize + + reg2 = MPS.MPSRegion(MPS.MPSOrigin()) + @test reg1.origin isa MPS.MPSOrigin + @test reg1.size isa MPS.MPSSize end diff --git a/test/mtl/size.jl b/test/mtl/size.jl index f75e8cce6..b43fc35c2 100644 --- a/test/mtl/size.jl +++ b/test/mtl/size.jl @@ -1,12 +1,12 @@ -@testset "size.jl" begin + @testset "size" begin dim1 = rand(UInt64) dim2 = rand(UInt64) dim3 = rand(UInt64) @test MTL.MTLSize(dim1) == MTL.MTLSize((dim1,)) - @test MTL.MTLSize(dim1,dim2) == MTL.MTLSize((dim1,dim2)) - @test MTL.MTLSize(dim1,dim2,dim3) == MTL.MTLSize((dim1,dim2,dim3)) + @test MTL.MTLSize(dim1, dim2) == MTL.MTLSize((dim1, dim2)) + @test MTL.MTLSize(dim1, dim2, dim3) == MTL.MTLSize((dim1, dim2, dim3)) end @testset "origin" begin @@ -14,15 +14,28 @@ end dim2 = rand(UInt64) dim3 = rand(UInt64) - orig = MTL.MTLOrigin(dim1,dim2,dim3) - @test orig.x == dim1 - @test orig.y == dim2 - @test orig.z == dim3 + orig1 = MTL.MTLOrigin(dim1, dim2, dim3) + @test orig1.x == dim1 + @test orig1.y == dim2 + @test orig1.z == dim3 + + orig2 = MTL.MTLOrigin(dim1, dim2) + @test orig2.x == dim1 + @test orig2.y == dim2 + @test orig2.z == 0 + + orig3 = MTL.MTLOrigin(dim1) + @test orig3.x == dim1 + @test orig3.y == 0 + @test orig3.z == 0 end @testset "region" begin - reg = MTL.MTLRegion() - @test reg.origin isa MTL.MTLOrigin - @test reg.size isa MTL.MTLSize -end + reg1 = MTL.MTLRegion() + @test reg1.origin isa MTL.MTLOrigin + @test reg1.size isa MTL.MTLSize + + reg2 = MTL.MTLRegion(MTL.MTLOrigin()) + @test reg1.origin isa MTL.MTLOrigin + @test reg1.size isa MTL.MTLSize end From 2a0cb805d8274216cf2abc670dc930ca6a21d487 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 14:01:19 -0300 Subject: [PATCH 09/15] MTLDevice tests --- test/mtl/metal.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/mtl/metal.jl b/test/mtl/metal.jl index cfb93a45d..722735f50 100644 --- a/test/mtl/metal.jl +++ b/test/mtl/metal.jl @@ -9,7 +9,7 @@ using .MTL devs = devices() @test length(devs) > 0 -dev = first(devs) +dev = MTLDevice(1) @test dev == devs[1] if length(devs) > 1 @@ -34,6 +34,13 @@ full_str = sprint(io->show(io, MIME"text/plain"(), dev)) @test dev.currentAllocatedSize isa Integer +@test is_m1(dev) isa Bool +@test is_m2(dev) isa Bool +@test is_m3(dev) isa Bool +@test is_m4(dev) isa Bool + +@test MTL.MTLCreateSystemDefaultDevice() isa MTLDevice + end @testset "compile options" begin From 6c7d98ef425603720cae3703451b9322ecdec1b3 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 14:26:25 -0300 Subject: [PATCH 10/15] Move MPSDataType out of matrix file and test --- lib/mps/MPS.jl | 1 + lib/mps/datatype.jl | 23 +++++++++++++++++++++++ lib/mps/matrix.jl | 25 ------------------------- test/mps/datatype.jl | 8 ++++++++ 4 files changed, 32 insertions(+), 25 deletions(-) create mode 100644 lib/mps/datatype.jl create mode 100644 test/mps/datatype.jl diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl index 9097a3f77..dacc8817e 100644 --- a/lib/mps/MPS.jl +++ b/lib/mps/MPS.jl @@ -43,6 +43,7 @@ is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice} include("libmps.jl") include("size.jl") +include("datatype.jl") # high-level wrappers include("command_buf.jl") diff --git a/lib/mps/datatype.jl b/lib/mps/datatype.jl new file mode 100644 index 000000000..560501480 --- /dev/null +++ b/lib/mps/datatype.jl @@ -0,0 +1,23 @@ +## Some extra definitions for MPSDataType defined in libmps.jl + +## bitwise operations lose type information, so allow conversions +Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x) + +# Conversions for MPSDataTypes with Julia equivalents +const jl_mps_to_typ = Dict{MPSDataType, DataType}() +for type in [ + :Bool, :UInt8, :UInt16, :UInt32, :UInt64, :Int8, :Int16, :Int32, :Int64, + :Float16, :BFloat16, :Float32, (:ComplexF16, :MPSDataTypeComplexFloat16), + (:ComplexF32, :MPSDataTypeComplexFloat32), + ] + jltype, mpstype = if type isa Symbol + type, Symbol(:MPSDataType, type) + else + type + end + @eval Base.convert(::Type{MPSDataType}, ::Type{$jltype}) = $(mpstype) + @eval jl_mps_to_typ[$(mpstype)] = $jltype +end +Base.sizeof(t::MPSDataType) = sizeof(jl_mps_to_typ[t]) + +Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp] diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index 102db9f22..f28927d2f 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -1,28 +1,3 @@ -## Some extra definitions for MPSDataType defined in libmps.jl - -## bitwise operations lose type information, so allow conversions -Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x) - -# Conversions for MPSDataTypes with Julia equivalents -const jl_mps_to_typ = Dict{MPSDataType, DataType}() -for type in [ - :Bool, :UInt8, :UInt16, :UInt32, :UInt64, :Int8, :Int16, :Int32, :Int64, - :Float16, :BFloat16, :Float32, (:ComplexF16, :MPSDataTypeComplexFloat16), - (:ComplexF32, :MPSDataTypeComplexFloat32), - ] - jltype, mpstype = if type isa Symbol - type, Symbol(:MPSDataType, type) - else - type - end - @eval Base.convert(::Type{MPSDataType}, ::Type{$jltype}) = $(mpstype) - @eval jl_mps_to_typ[$(mpstype)] = $jltype -end -Base.sizeof(t::MPSDataType) = sizeof(jl_mps_to_typ[t]) - -Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp] - - ## descriptor export MPSMatrixDescriptor diff --git a/test/mps/datatype.jl b/test/mps/datatype.jl new file mode 100644 index 000000000..81cf8f606 --- /dev/null +++ b/test/mps/datatype.jl @@ -0,0 +1,8 @@ + +@testset "MPSDataType" begin + +@test convert(MPS.MPSDataType, 0x0000000010000020) == MPS.MPSDataTypeFloat32 +@test sizeof(MPS.MPSDataTypeFloat16) == 2 +@test convert(DataType, MPS.MPSDataTypeInt64) == Int64 + +end From 67a19fc124e34aee59268e2ae09f31b349c4a498 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 14:40:58 -0300 Subject: [PATCH 11/15] Storage type tests --- test/mtl/metal.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/mtl/metal.jl b/test/mtl/metal.jl index 722735f50..278ce34e7 100644 --- a/test/mtl/metal.jl +++ b/test/mtl/metal.jl @@ -43,6 +43,28 @@ full_str = sprint(io->show(io, MIME"text/plain"(), dev)) end +@testset "storage_type" begin + @test convert(MTL.MTLStorageMode, MTL.SharedStorage) == MTL.MTLStorageModeShared + @test convert(MTL.MTLStorageMode, MTL.ManagedStorage) == MTL.MTLStorageModeManaged + @test convert(MTL.MTLStorageMode, MTL.PrivateStorage) == MTL.MTLStorageModePrivate + @test convert(MTL.MTLStorageMode, MTL.Memoryless) == MTL.MTLStorageModeMemoryless + + @test convert(MTL.MTLResourceOptions, MTL.SharedStorage) == MTL.MTLResourceStorageModeShared + @test convert(MTL.MTLResourceOptions, MTL.ManagedStorage) == MTL.MTLResourceStorageModeManaged + @test convert(MTL.MTLResourceOptions, MTL.PrivateStorage) == MTL.MTLResourceStorageModePrivate + @test convert(MTL.MTLResourceOptions, MTL.Memoryless) == MTL.MTLResourceStorageModeMemoryless + + @test convert(MTL.MTLResourceOptions, MTL.MTLStorageModeShared) == MTL.MTLResourceStorageModeShared + @test convert(MTL.MTLResourceOptions, MTL.MTLStorageModeManaged) == MTL.MTLResourceStorageModeManaged + @test convert(MTL.MTLResourceOptions, MTL.MTLStorageModePrivate) == MTL.MTLResourceStorageModePrivate + @test convert(MTL.MTLResourceOptions, MTL.MTLStorageModeMemoryless) == MTL.MTLResourceStorageModeMemoryless + + @test MTL.MTLResourceStorageModeShared == MTL.MTLStorageModeShared + @test MTL.MTLStorageModeManaged == MTL.MTLResourceStorageModeManaged + @test MTL.MTLResourceStorageModePrivate == MTL.MTLStorageModePrivate + @test MTL.MTLStorageModeMemoryless == MTL.MTLResourceStorageModeMemoryless +end + @testset "compile options" begin opts = MTLCompileOptions() From 52b200d7e190efea9bf5706e567e5c9eba1db8e1 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 14:48:26 -0300 Subject: [PATCH 12/15] Format --- lib/mps/matrix.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index f28927d2f..99508a48c 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -142,7 +142,7 @@ with any `MtlArray` and it should be accelerated using Metal Performance Shaders """ function matmul!(c::MtlArray{T1,N}, a::MtlArray{T2,N}, b::MtlArray{T3,N}, alpha::Number=true, beta::Number=true, - transpose_a=false, transpose_b=false) where {T1, T2, T3, N} + transpose_a=false, transpose_b=false) where {T1, T2, T3, N} # NOTE: MPS uses row major, while Julia is col-major. Instead of transposing # the inputs (by passing !transpose_[ab]) and afterwards transposing # the output, we use the property that (AB)ᵀ = BᵀAᵀ From 98b406685b94137a2b719b5a5f74c86b01582b54 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 20 Mar 2025 14:59:15 -0300 Subject: [PATCH 13/15] Fix --- test/mps/size.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/mps/size.jl b/test/mps/size.jl index 230b27266..0bd48c7db 100644 --- a/test/mps/size.jl +++ b/test/mps/size.jl @@ -2,9 +2,9 @@ @testset "size" begin siz1 = MPS.MPSSize() - @test siz1.width == 0 - @test siz1.height == 0 - @test siz1.depth == 0 + @test siz1.width == 1.0 + @test siz1.height == 1.0 + @test siz1.depth == 1.0 dim1 = rand() dim2 = rand() From 74bdbf1bda840c844d0637facb76d1001a314528 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 21 Mar 2025 08:42:38 -0300 Subject: [PATCH 14/15] Update lib/mps/matrix.jl Co-authored-by: Tim Besard --- lib/mps/matrix.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index 99508a48c..5953377e5 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -96,9 +96,9 @@ end function Base.size(mat::MPS.MPSMatrix) if mat.matrices > 1 - return Int64.((mat.matrices, mat.rows, mat.columns)) + return Int.((mat.matrices, mat.rows, mat.columns)) else - return Int64.((mat.rows, mat.columns)) + return Int.((mat.rows, mat.columns)) end end From 129392c779b9e044ae5a92e845275d24da646e72 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 21 Mar 2025 10:39:52 -0300 Subject: [PATCH 15/15] Remove unused code and tests --- lib/mps/datatype.jl | 6 ------ test/mps/datatype.jl | 8 -------- 2 files changed, 14 deletions(-) delete mode 100644 test/mps/datatype.jl diff --git a/lib/mps/datatype.jl b/lib/mps/datatype.jl index 560501480..df37928d9 100644 --- a/lib/mps/datatype.jl +++ b/lib/mps/datatype.jl @@ -1,8 +1,5 @@ ## Some extra definitions for MPSDataType defined in libmps.jl -## bitwise operations lose type information, so allow conversions -Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x) - # Conversions for MPSDataTypes with Julia equivalents const jl_mps_to_typ = Dict{MPSDataType, DataType}() for type in [ @@ -18,6 +15,3 @@ for type in [ @eval Base.convert(::Type{MPSDataType}, ::Type{$jltype}) = $(mpstype) @eval jl_mps_to_typ[$(mpstype)] = $jltype end -Base.sizeof(t::MPSDataType) = sizeof(jl_mps_to_typ[t]) - -Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp] diff --git a/test/mps/datatype.jl b/test/mps/datatype.jl deleted file mode 100644 index 81cf8f606..000000000 --- a/test/mps/datatype.jl +++ /dev/null @@ -1,8 +0,0 @@ - -@testset "MPSDataType" begin - -@test convert(MPS.MPSDataType, 0x0000000010000020) == MPS.MPSDataTypeFloat32 -@test sizeof(MPS.MPSDataTypeFloat16) == 2 -@test convert(DataType, MPS.MPSDataTypeInt64) == Int64 - -end