diff --git a/perf/hand_cuda.jl b/perf/hand_cuda.jl index b69de98..7c032b9 100644 --- a/perf/hand_cuda.jl +++ b/perf/hand_cuda.jl @@ -31,9 +31,11 @@ const OUT_DIM = 2 struct Buffers{M<:AbstractMatrix} y_1::M # h × n = tanh.(W1 * X) J_1::M # h × n = 1 .- y_1.^2 - J_2::M # m × n = 2 .* (W2*y_1 .- y) + J_2::M # m × n = W2*y_1 - y, then scaled in place by 2 W2T_J2::M # h × n = W2' * J_2, then ⊙= J_1 grad::M # h × d = W2T_J2 * X' + loss::M # 1 × 1 = sum((W2*y_1 - y).^2), kept on-device to match + # `arraydiff.jl`'s `forward_storage[1]` end function Buffers{M}(h::Int, d::Int, n::Int) where {M} @@ -43,6 +45,7 @@ function Buffers{M}(h::Int, d::Int, n::Int) where {M} M(undef, OUT_DIM, n), M(undef, h, n), M(undef, h, d), + M(undef, 1, 1), ) end @@ -51,11 +54,16 @@ function gradient!(buf::Buffers, W1, W2, X, y) buf.y_1 .= tanh.(buf.y_1) # y_1 = tanh.(y_1) buf.J_1 .= 1 .- buf.y_1 .^ 2 LinearAlgebra.mul!(buf.J_2, W2, buf.y_1) - buf.J_2 .= 2 .* (buf.J_2 .- y) + buf.J_2 .-= y # J_2 = residual (W2*y_1 - y) + # Fold the squaring into the reduction kernel so we don't materialise an + # extra (m × n) temporary. Result stays on device — ArrayDiff's tape leaves + # `forward_storage[1]` untouched on the device too. + Base.sum!(buf.loss, buf.J_2 .^ 2) + buf.J_2 .*= 2 # J_2 = 2 * residual = dL/dY LinearAlgebra.mul!(buf.W2T_J2, W2', buf.J_2) buf.W2T_J2 .= buf.J_1 .* buf.W2T_J2 LinearAlgebra.mul!(buf.grad, buf.W2T_J2, X') - return buf.grad + return buf.loss, buf.grad end # Allocating reverse pass — same arithmetic, but every intermediate is a @@ -64,8 +72,10 @@ end function gradient_alloc(W1, W2, X, y) y_1 = tanh.(W1 * X) J_1 = 1 .- y_1 .^ 2 - J_2 = 2 .* (W2 * y_1 .- y) - return (J_1 .* (W2' * J_2)) * X' + residual = W2 * y_1 .- y + loss = sum(abs2, residual) + J_2 = 2 .* residual + return loss, (J_1 .* (W2' * J_2)) * X' end """