Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ steps:
dirs:
- src
- lib
- examples
agents:
queue: "juliaecosystem"
os: "macos"
Expand Down Expand Up @@ -84,7 +83,6 @@ steps:
dirs:
- src
- lib
- examples
env:
MTL_DEBUG_LAYER: '1'
MTL_SHADER_VALIDATION: '1'
Expand Down Expand Up @@ -113,7 +111,6 @@ steps:
dirs:
- src
- lib
- examples
env:
JULIA_LLVM_ARGS: '--opaque-pointers'
agents:
Expand Down
7 changes: 7 additions & 0 deletions lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ function MPSMatrix(arr::MtlArray{T,3}) where T
return MPSMatrix(arr, desc, offset)
end

function Base.size(mat::MPS.MPSMatrix)
if mat.matrices > 1
return Int64.((mat.matrices, mat.rows, mat.columns))
else
return Int64.((mat.rows, mat.columns))
end
end

## matrix multiplication

Expand Down
2 changes: 1 addition & 1 deletion lib/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
arrsize = size(arr)
@assert arrsize[1] * sizeof(T) % 16 == 0 "First dimension of input MtlArray must have a byte size divisible by 16"
desc = MPSNDArrayDescriptor(T, arrsize)
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
return MPSNDArray(arr.data[], UInt(arr.offset) * sizeof(T), desc)
end

function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)
Expand Down
2 changes: 2 additions & 0 deletions src/MetalKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ end

## indexing

## COV_EXCL_START
@device_override @inline function KA.__index_Local_Linear(ctx)
return thread_position_in_threadgroup_1d()
end
Expand Down Expand Up @@ -191,5 +192,6 @@ end
@device_override @inline function KA.__print(args...)
# TODO
end
## COV_EXCL_STOP

end
2 changes: 2 additions & 0 deletions src/accumulate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
## COV_EXCL_START
function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArray,
Rdim, Rpre, Rpost, Rother, neutral, init,
::Val{maxthreads}, ::Val{inclusive}=Val(true)) where {T, maxthreads, inclusive}
Expand Down Expand Up @@ -100,6 +101,7 @@ function aggregate_partial_scan(op::Function, output::AbstractArray, aggregates:

return
end
## COV_EXCL_STOP

function scan!(f::Function, output::WrappedMtlArray{T}, input::WrappedMtlArray;
dims::Integer, init=nothing, neutral=GPUArrays.neutral_element(f, T)) where {T}
Expand Down
10 changes: 10 additions & 0 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ end
_broadcast_shapes[Is] += 1
end
if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
## COV_EXCL_START
function broadcast_cartesian_static(dest, bc, Is)
i = thread_position_in_grid_1d()
stride = threads_per_grid_1d()
Expand All @@ -69,6 +70,7 @@ end
end
return
end
## COV_EXCL_STOP

Is = StaticCartesianIndices(Is)
kernel = @metal launch=false broadcast_cartesian_static(dest, bc, Is)
Expand All @@ -82,6 +84,7 @@ end
# try to use the most appropriate hardware index to avoid integer division
if ndims(dest) == 1 ||
(isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
## COV_EXCL_START
function broadcast_linear(dest, bc)
i = thread_position_in_grid_1d()
stride = threads_per_grid_1d()
Expand All @@ -91,12 +94,14 @@ end
end
return
end
## COV_EXCL_STOP

kernel = @metal launch=false broadcast_linear(dest, bc)
elements = cld(length(dest), 4)
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
groups = cld(elements, threads)
elseif ndims(dest) == 2
## COV_EXCL_START
function broadcast_2d(dest, bc)
is = Tuple(thread_position_in_grid_2d())
stride = threads_per_grid_2d()
Expand All @@ -107,13 +112,15 @@ end
end
return
end
## COV_EXCL_STOP

kernel = @metal launch=false broadcast_2d(dest, bc)
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
h = min(size(dest, 2), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w)
threads = (w, h)
groups = cld.(size(dest), threads)
elseif ndims(dest) == 3
## COV_EXCL_START
function broadcast_3d(dest, bc)
is = Tuple(thread_position_in_grid_3d())
stride = threads_per_grid_3d()
Expand All @@ -126,6 +133,7 @@ end
end
return
end
## COV_EXCL_STOP

kernel = @metal launch=false broadcast_3d(dest, bc)
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
Expand All @@ -135,6 +143,7 @@ end
threads = (w, h, d)
groups = cld.(size(dest), threads)
else
## COV_EXCL_START
function broadcast_cartesian(dest, bc)
i = thread_position_in_grid_1d()
stride = threads_per_grid_1d()
Expand All @@ -145,6 +154,7 @@ end
end
return
end
## COV_EXCL_STOP

kernel = @metal launch=false broadcast_cartesian(dest, bc)
elements = cld(length(dest), 4)
Expand Down
4 changes: 3 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ end


## profile macro

## COV_EXCL_START
function profile_dir()
root = pwd()
i = 1
Expand Down Expand Up @@ -239,3 +239,5 @@ macro profile(ex...)
end
end
end
## COV_EXCL_START

148 changes: 77 additions & 71 deletions test/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,77 +45,83 @@ using .MPS: MPSMatrix
rowBytes = sizeof(T) * cols
mats = 4

desc = MPSMatrixDescriptor(rows, cols, rowBytes, T)
devmat = MPSMatrix(dev, desc)
@test devmat isa MPSMatrix
@test devmat.device == dev
@test devmat.rows == rows
@test devmat.columns == cols
@test devmat.rowBytes == rowBytes
@test devmat.matrices == 1
@test devmat.dataType == DT
@test devmat.matrixBytes == rowBytes * rows
@test devmat.offset == 0

mat = MtlMatrix{T}(undef, rows, cols)
acols, arows = size(mat)
arowBytes = sizeof(T) * acols
abufmat = MPSMatrix(mat)
@test abufmat isa MPSMatrix
@test abufmat.device == dev
@test abufmat.rows == arows
@test abufmat.columns == acols
@test abufmat.rowBytes == arowBytes
@test abufmat.matrices == 1
@test abufmat.dataType == DT
@test abufmat.matrixBytes == arowBytes * arows
@test abufmat.offset == 0
@test abufmat.data == mat.data[]

vmat = @view mat[:, 2:3]
vcols, vrows = size(vmat)
vrowBytes = sizeof(T) * vcols
vbufmat = MPSMatrix(vmat)
@test vbufmat isa MPSMatrix
@test vbufmat.device == dev
@test vbufmat.rows == vrows
@test vbufmat.columns == vcols
@test vbufmat.rowBytes == vrowBytes
@test vbufmat.matrices == 1
@test vbufmat.dataType == DT
@test vbufmat.matrixBytes == vrowBytes * vrows
@test vbufmat.offset == vmat.offset * sizeof(T)
@test vbufmat.data == vmat.data[]

arr = MtlArray{T,3}(undef, rows, cols, mats)
mcols, mrows, mmats = size(arr)
mrowBytes = sizeof(T) * mcols
mpsmat = MPSMatrix(mat)
@test mpsmat isa MPSMatrix
@test mpsmat.device == dev
@test mpsmat.rows == mrows
@test mpsmat.columns == mcols
@test mpsmat.rowBytes == mrowBytes
@test mpsmat.matrices == 1
@test mpsmat.dataType == DT
@test mpsmat.matrixBytes == mrowBytes * mrows
@test mpsmat.offset == 0
@test mpsmat.data == mat.data[]

vec = MtlVector{T}(undef, rows)
veccols, vecrows = length(vec), 1
vecrowBytes = sizeof(T)*veccols
vmpsmat = MPSMatrix(vec)
@test vmpsmat isa MPSMatrix
@test vmpsmat.device == dev
@test vmpsmat.rows == vecrows
@test vmpsmat.columns == veccols
@test vmpsmat.rowBytes == vecrowBytes
@test vmpsmat.matrices == 1
@test vmpsmat.dataType == DT
@test vmpsmat.matrixBytes == vecrowBytes*vecrows
@test vmpsmat.offset == 0
@test vmpsmat.data == vec.data[]
let desc = MPSMatrixDescriptor(rows, cols, rowBytes, T)
devmat = MPSMatrix(dev, desc)
@test devmat isa MPSMatrix
@test devmat.device == dev
@test devmat.rows == rows
@test devmat.columns == cols
@test devmat.rowBytes == rowBytes
@test devmat.matrices == 1
@test devmat.dataType == DT
@test devmat.matrixBytes == rowBytes * rows
@test devmat.offset == 0
@test size(devmat) == (rows, cols)
end

let mat = MtlMatrix{T}(undef, rows, cols)
acols, arows = size(mat)
arowBytes = sizeof(T) * acols
abufmat = MPSMatrix(mat)
@test abufmat isa MPSMatrix
@test abufmat.device == dev
@test abufmat.rows == arows
@test abufmat.columns == acols
@test abufmat.rowBytes == arowBytes
@test abufmat.matrices == 1
@test abufmat.dataType == DT
@test abufmat.matrixBytes == arowBytes * arows
@test abufmat.offset == 0
@test abufmat.data == mat.data[]

vmat = @view mat[:, 2:3]
vcols, vrows = size(vmat)
vrowBytes = sizeof(T) * vcols
vbufmat = MPSMatrix(vmat)
@test vbufmat isa MPSMatrix
@test vbufmat.device == dev
@test vbufmat.rows == vrows
@test vbufmat.columns == vcols
@test vbufmat.rowBytes == vrowBytes
@test vbufmat.matrices == 1
@test vbufmat.dataType == DT
@test vbufmat.matrixBytes == vrowBytes * vrows
@test vbufmat.offset == vmat.offset * sizeof(T)
@test vbufmat.data == vmat.data[]
end

let arr = MtlArray{T, 3}(undef, rows, cols, mats)
mcols, mrows, mmats = size(arr)
mrowBytes = sizeof(T) * mcols
mpsmat = MPSMatrix(arr)
@test mpsmat isa MPSMatrix
@test mpsmat.device == dev
@test mpsmat.rows == mrows
@test mpsmat.columns == mcols
@test mpsmat.rowBytes == mrowBytes
@test mpsmat.matrices == mmats
@test mpsmat.dataType == DT
@test mpsmat.matrixBytes == mrowBytes * mrows
@test mpsmat.offset == 0
@test mpsmat.data == arr.data[]
@test size(mpsmat) == (mmats, mrows, mcols)
end

let vec = MtlVector{T}(undef, rows)
veccols, vecrows = length(vec), 1
vecrowBytes = sizeof(T) * veccols
vmpsmat = MPSMatrix(vec)
@test vmpsmat isa MPSMatrix
@test vmpsmat.device == dev
@test vmpsmat.rows == vecrows
@test vmpsmat.columns == veccols
@test vmpsmat.rowBytes == vecrowBytes
@test vmpsmat.matrices == 1
@test vmpsmat.dataType == DT
@test vmpsmat.matrixBytes == vecrowBytes * vecrows
@test vmpsmat.offset == 0
@test vmpsmat.data == vec.data[]
end
end


Expand Down