We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a4e8216 commit a86a55dCopy full SHA for a86a55d
lib/mpsgraphs/matmul.jl
@@ -10,9 +10,9 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab}, b::MtlArray{Tab}, alpha::Nu
10
placeB => MPSGraphTensorData(b)
11
)
12
13
- castA, castB = if Tc != Tab
14
- castTensor(graph, placeA, Tc, "castA"),
15
- castTensor(graph, placeB, Tc, "castB")
+ castA, castB = if Tab != Float32
+ castTensor(graph, placeA, Float32, "castA"),
+ castTensor(graph, placeB, Float32, "castB")
16
else
17
placeA, placeB
18
end
@@ -48,8 +48,14 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab}, b::MtlArray{Tab}, alpha::Nu
48
additionWithPrimaryTensor(graph, afteralpha, betaC)
49
50
51
+ castC = if Tc != Float32
52
+ castTensor(graph, afterbeta, Tc, "castC")
53
+ else
54
+ afterbeta
55
+ end
56
+
57
resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}(
- afterbeta => outputTensorData
58
+ castC => outputTensorData
59
60
61
cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
0 commit comments