Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ steps:
dirs:
- src
- lib
- examples
agents:
queue: "juliaecosystem"
os: "macos"
Expand Down Expand Up @@ -84,7 +83,6 @@ steps:
dirs:
- src
- lib
- examples
env:
MTL_DEBUG_LAYER: '1'
MTL_SHADER_VALIDATION: '1'
Expand Down Expand Up @@ -113,7 +111,6 @@ steps:
dirs:
- src
- lib
- examples
env:
JULIA_LLVM_ARGS: '--opaque-pointers'
agents:
Expand Down
1 change: 1 addition & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice}
include("libmps.jl")

include("size.jl")
include("datatype.jl")

# high-level wrappers
include("command_buf.jl")
Expand Down
23 changes: 23 additions & 0 deletions lib/mps/datatype.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
## Some extra definitions for MPSDataType defined in libmps.jl

## bitwise operations lose type information, so allow conversions
Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)

# Conversions for MPSDataTypes with Julia equivalents
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
for type in [
:Bool, :UInt8, :UInt16, :UInt32, :UInt64, :Int8, :Int16, :Int32, :Int64,
:Float16, :BFloat16, :Float32, (:ComplexF16, :MPSDataTypeComplexFloat16),
(:ComplexF32, :MPSDataTypeComplexFloat32),
]
jltype, mpstype = if type isa Symbol
type, Symbol(:MPSDataType, type)
else
type
end
@eval Base.convert(::Type{MPSDataType}, ::Type{$jltype}) = $(mpstype)
@eval jl_mps_to_typ[$(mpstype)] = $jltype
end
Base.sizeof(t::MPSDataType) = sizeof(jl_mps_to_typ[t])

Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp]
34 changes: 8 additions & 26 deletions lib/mps/matrix.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,3 @@
## Some extra definitions for MPSDataType defined in libmps.jl

## bitwise operations lose type information, so allow conversions
Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)

# Conversions for MPSDataTypes with Julia equivalents
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
for type in [
:Bool, :UInt8, :UInt16, :UInt32, :UInt64, :Int8, :Int16, :Int32, :Int64,
:Float16, :BFloat16, :Float32, (:ComplexF16, :MPSDataTypeComplexFloat16),
(:ComplexF32, :MPSDataTypeComplexFloat32),
]
jltype, mpstype = if type isa Symbol
type, Symbol(:MPSDataType, type)
else
type
end
@eval Base.convert(::Type{MPSDataType}, ::Type{$jltype}) = $(mpstype)
@eval jl_mps_to_typ[$(mpstype)] = $jltype
end
Base.sizeof(t::MPSDataType) = sizeof(jl_mps_to_typ[t])

Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp]


## descriptor

export MPSMatrixDescriptor
Expand Down Expand Up @@ -119,6 +94,13 @@ function MPSMatrix(arr::MtlArray{T,3}) where T
return MPSMatrix(arr, desc, offset)
end

function Base.size(mat::MPS.MPSMatrix)
if mat.matrices > 1
return Int64.((mat.matrices, mat.rows, mat.columns))
else
return Int64.((mat.rows, mat.columns))
end
end

## matrix multiplication

Expand Down Expand Up @@ -160,7 +142,7 @@ with any `MtlArray` and it should be accelerated using Metal Performance Shaders
"""
function matmul!(c::MtlArray{T1,N}, a::MtlArray{T2,N}, b::MtlArray{T3,N},
alpha::Number=true, beta::Number=true,
transpose_a=false, transpose_b=false) where {T1, T2, T3, N}
transpose_a=false, transpose_b=false) where {T1, T2, T3, N}
# NOTE: MPS uses row major, while Julia is col-major. Instead of transposing
# the inputs (by passing !transpose_[ab]) and afterwards transposing
# the output, we use the property that (AB)ᵀ = BᵀAᵀ
Expand Down
2 changes: 1 addition & 1 deletion lib/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
arrsize = size(arr)
@assert arrsize[1] * sizeof(T) % 16 == 0 "First dimension of input MtlArray must have a byte size divisible by 16"
desc = MPSNDArrayDescriptor(T, arrsize)
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
return MPSNDArray(arr.data[], UInt(arr.offset) * sizeof(T), desc)
end

function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)
Expand Down
2 changes: 2 additions & 0 deletions src/MetalKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ end

## indexing

## COV_EXCL_START
@device_override @inline function KA.__index_Local_Linear(ctx)
return thread_position_in_threadgroup_1d()
end
Expand Down Expand Up @@ -191,5 +192,6 @@ end
@device_override @inline function KA.__print(args...)
# TODO
end
## COV_EXCL_STOP

end
2 changes: 2 additions & 0 deletions src/accumulate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
## COV_EXCL_START
function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArray,
Rdim, Rpre, Rpost, Rother, neutral, init,
::Val{maxthreads}, ::Val{inclusive}=Val(true)) where {T, maxthreads, inclusive}
Expand Down Expand Up @@ -100,6 +101,7 @@ function aggregate_partial_scan(op::Function, output::AbstractArray, aggregates:

return
end
## COV_EXCL_STOP

function scan!(f::Function, output::WrappedMtlArray{T}, input::WrappedMtlArray;
dims::Integer, init=nothing, neutral=GPUArrays.neutral_element(f, T)) where {T}
Expand Down
10 changes: 10 additions & 0 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ end
_broadcast_shapes[Is] += 1
end
if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
## COV_EXCL_START
function broadcast_cartesian_static(dest, bc, Is)
i = thread_position_in_grid_1d()
stride = threads_per_grid_1d()
Expand All @@ -69,6 +70,7 @@ end
end
return
end
## COV_EXCL_STOP

Is = StaticCartesianIndices(Is)
kernel = @metal launch=false broadcast_cartesian_static(dest, bc, Is)
Expand All @@ -82,6 +84,7 @@ end
# try to use the most appropriate hardware index to avoid integer division
if ndims(dest) == 1 ||
(isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
## COV_EXCL_START
function broadcast_linear(dest, bc)
i = thread_position_in_grid_1d()
stride = threads_per_grid_1d()
Expand All @@ -91,12 +94,14 @@ end
end
return
end
## COV_EXCL_STOP

kernel = @metal launch=false broadcast_linear(dest, bc)
elements = cld(length(dest), 4)
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
groups = cld(elements, threads)
elseif ndims(dest) == 2
## COV_EXCL_START
function broadcast_2d(dest, bc)
is = Tuple(thread_position_in_grid_2d())
stride = threads_per_grid_2d()
Expand All @@ -107,13 +112,15 @@ end
end
return
end
## COV_EXCL_STOP

kernel = @metal launch=false broadcast_2d(dest, bc)
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
h = min(size(dest, 2), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w)
threads = (w, h)
groups = cld.(size(dest), threads)
elseif ndims(dest) == 3
## COV_EXCL_START
function broadcast_3d(dest, bc)
is = Tuple(thread_position_in_grid_3d())
stride = threads_per_grid_3d()
Expand All @@ -126,6 +133,7 @@ end
end
return
end
## COV_EXCL_STOP

kernel = @metal launch=false broadcast_3d(dest, bc)
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
Expand All @@ -135,6 +143,7 @@ end
threads = (w, h, d)
groups = cld.(size(dest), threads)
else
## COV_EXCL_START
function broadcast_cartesian(dest, bc)
i = thread_position_in_grid_1d()
stride = threads_per_grid_1d()
Expand All @@ -145,6 +154,7 @@ end
end
return
end
## COV_EXCL_STOP

kernel = @metal launch=false broadcast_cartesian(dest, bc)
elements = cld(length(dest), 4)
Expand Down
4 changes: 3 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ end


## profile macro

## COV_EXCL_START
function profile_dir()
root = pwd()
i = 1
Expand Down Expand Up @@ -239,3 +239,5 @@ macro profile(ex...)
end
end
end
## COV_EXCL_START

8 changes: 8 additions & 0 deletions test/mps/datatype.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

@testset "MPSDataType" begin

@test convert(MPS.MPSDataType, 0x0000000010000020) == MPS.MPSDataTypeFloat32
@test sizeof(MPS.MPSDataTypeFloat16) == 2
@test convert(DataType, MPS.MPSDataTypeInt64) == Int64

end
Loading