Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions perf/hand_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ const OUT_DIM = 2
struct Buffers{M<:AbstractMatrix}
y_1::M # h × n = tanh.(W1 * X)
J_1::M # h × n = 1 .- y_1.^2
J_2::M # m × n = 2 .* (W2*y_1 .- y)
J_2::M # m × n = W2*y_1 - y, then scaled in place by 2
W2T_J2::M # h × n = W2' * J_2, then ⊙= J_1
grad::M # h × d = W2T_J2 * X'
loss::M # 1 × 1 = sum((W2*y_1 - y).^2), kept on-device to match
# `arraydiff.jl`'s `forward_storage[1]`
end

function Buffers{M}(h::Int, d::Int, n::Int) where {M}
Expand All @@ -43,6 +45,7 @@ function Buffers{M}(h::Int, d::Int, n::Int) where {M}
M(undef, OUT_DIM, n),
M(undef, h, n),
M(undef, h, d),
M(undef, 1, 1),
)
end

Expand All @@ -51,11 +54,16 @@ function gradient!(buf::Buffers, W1, W2, X, y)
buf.y_1 .= tanh.(buf.y_1) # y_1 = tanh.(y_1)
buf.J_1 .= 1 .- buf.y_1 .^ 2
LinearAlgebra.mul!(buf.J_2, W2, buf.y_1)
buf.J_2 .= 2 .* (buf.J_2 .- y)
buf.J_2 .-= y # J_2 = residual (W2*y_1 - y)
# Fold the squaring into the reduction kernel so we don't materialise an
# extra (m × n) temporary. Result stays on device — ArrayDiff's tape leaves
# `forward_storage[1]` untouched on the device too.
Base.sum!(buf.loss, buf.J_2 .^ 2)
buf.J_2 .*= 2 # J_2 = 2 * residual = dL/dY
LinearAlgebra.mul!(buf.W2T_J2, W2', buf.J_2)
buf.W2T_J2 .= buf.J_1 .* buf.W2T_J2
LinearAlgebra.mul!(buf.grad, buf.W2T_J2, X')
return buf.grad
return buf.loss, buf.grad
end

# Allocating reverse pass — same arithmetic, but every intermediate is a
Expand All @@ -64,8 +72,10 @@ end
function gradient_alloc(W1, W2, X, y)
y_1 = tanh.(W1 * X)
J_1 = 1 .- y_1 .^ 2
J_2 = 2 .* (W2 * y_1 .- y)
return (J_1 .* (W2' * J_2)) * X'
residual = W2 * y_1 .- y
loss = sum(abs2, residual)
J_2 = 2 .* residual
return loss, (J_1 .* (W2' * J_2)) * X'
end

"""
Expand Down
Loading