diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 7a9484e..86c158e 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -9,38 +9,51 @@ # the right view type for each node and `LinearAlgebra.mul!` covers both # shape combinations. function _matmul_reverse!(f, k::Int, ix1::Int, ix2::Int) - _reshape_call( - f.forward_storage, - f.sizes, - (ix1, ix2), - _matmul_reverse_outer, - (f.reverse_storage, f.sizes, ix1, ix2, k), - ) + if f.nodes[ix1].type != NODE_VALUE_BLOCK + _reshape_call( + f.forward_storage, + f.sizes, + (ix2,), + _matmul_reverse_outer, + (f.reverse_storage, f.sizes, true, ix1, k), + ) + end + if f.nodes[ix2].type != NODE_VALUE_BLOCK + _reshape_call( + f.forward_storage, + f.sizes, + (ix1,), + _matmul_reverse_outer, + (f.reverse_storage, f.sizes, false, ix2, k), + ) + end return end function _matmul_reverse_outer( reverse_storage, sizes::Sizes, - ix1::Int, - ix2::Int, + lhs::Bool, + ix::Int, k::Int, - v1, - v2, + v, ) _reshape_call( reverse_storage, sizes, - (ix1, ix2, k), + (ix, k), _matmul_reverse_inner!, - (v1, v2), + (lhs, v), ) return end -function _matmul_reverse_inner!(v1, v2, rev_v1, rev_v2, rev_parent) - LinearAlgebra.mul!(rev_v1, rev_parent, transpose(v2)) - LinearAlgebra.mul!(rev_v2, transpose(v1), rev_parent) +function _matmul_reverse_inner!(lhs::Bool, v, rev_v, rev_parent) + if lhs + LinearAlgebra.mul!(rev_v, rev_parent, transpose(v)) + else + LinearAlgebra.mul!(rev_v, transpose(v), rev_parent) + end return end