From 27390db3ea3429e5fc4ef8b257c0bb9c4869f31b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 28 May 2026 08:23:23 +0200 Subject: [PATCH 1/4] Avoid computing derivative of constant nodes in matmul --- src/reverse_mode.jl | 45 +++++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 7a9484e..00c5791 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -9,38 +9,51 @@ # the right view type for each node and `LinearAlgebra.mul!` covers both # shape combinations. function _matmul_reverse!(f, k::Int, ix1::Int, ix2::Int) - _reshape_call( - f.forward_storage, - f.sizes, - (ix1, ix2), - _matmul_reverse_outer, - (f.reverse_storage, f.sizes, ix1, ix2, k), - ) + if f.nodes[ix1].type != CONSTANT + _reshape_call( + f.forward_storage, + f.sizes, + (ix2,), + _matmul_reverse_outer, + (f.reverse_storage, f.sizes, false, ix1, k), + ) + end + if f.nodes[ix2].type != CONSTANT + _reshape_call( + f.forward_storage, + f.sizes, + (ix1,), + _matmul_reverse_outer, + (f.reverse_storage, f.sizes, true, ix2, k), + ) + end return end function _matmul_reverse_outer( reverse_storage, sizes::Sizes, - ix1::Int, - ix2::Int, + lhs::Bool, + ix::Int, k::Int, - v1, - v2, + v, ) _reshape_call( reverse_storage, sizes, - (ix1, ix2, k), + (ix, k), _matmul_reverse_inner!, - (v1, v2), + (lhs, v,), ) return end -function _matmul_reverse_inner!(v1, v2, rev_v1, rev_v2, rev_parent) - LinearAlgebra.mul!(rev_v1, rev_parent, transpose(v2)) - LinearAlgebra.mul!(rev_v2, transpose(v1), rev_parent) +function _matmul_reverse_inner!(lhs::Bool, v, rev_v, rev_parent) + if lhs + LinearAlgebra.mul!(rev_v, rev_parent, transpose(v)) + else + LinearAlgebra.mul!(rev_v, transpose(v), rev_parent) + end return end From 179b4287923e3bd93e7e72230251d2444ef68b62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 28 May 2026 08:59:16 +0200 Subject: [PATCH 2/4] Fix --- src/reverse_mode.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 00c5791..d535841 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -15,7 +15,7 @@ function _matmul_reverse!(f, k::Int, ix1::Int, ix2::Int) f.sizes, (ix2,), _matmul_reverse_outer, - (f.reverse_storage, f.sizes, false, ix1, k), + (f.reverse_storage, f.sizes, true, ix1, k), ) end if f.nodes[ix2].type != CONSTANT @@ -24,7 +24,7 @@ function _matmul_reverse!(f, k::Int, ix1::Int, ix2::Int) f.sizes, (ix1,), _matmul_reverse_outer, - (f.reverse_storage, f.sizes, true, ix2, k), + (f.reverse_storage, f.sizes, false, ix2, k), ) end return From 87c598653c7649072f7d203423f66f62fde42068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 28 May 2026 11:26:28 +0200 Subject: [PATCH 3/4] Fix --- src/reverse_mode.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index d535841..96e3350 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -9,7 +9,7 @@ # the right view type for each node and `LinearAlgebra.mul!` covers both # shape combinations. function _matmul_reverse!(f, k::Int, ix1::Int, ix2::Int) - if f.nodes[ix1].type != CONSTANT + if f.nodes[ix1].type != NODE_VALUE_BLOCK _reshape_call( f.forward_storage, f.sizes, @@ -18,7 +18,7 @@ function _matmul_reverse!(f, k::Int, ix1::Int, ix2::Int) (f.reverse_storage, f.sizes, true, ix1, k), ) end - if f.nodes[ix2].type != CONSTANT + if f.nodes[ix2].type != NODE_VALUE_BLOCK _reshape_call( f.forward_storage, f.sizes, From 2fea3963be7ac913a92c0dc0094913ebbcc1015e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 28 May 2026 11:26:46 +0200 Subject: [PATCH 4/4] Fix --- src/reverse_mode.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 96e3350..86c158e 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -43,7 +43,7 @@ function _matmul_reverse_outer( sizes, (ix, k), _matmul_reverse_inner!, - (lhs, v,), + (lhs, v), ) return end