From a621e707fb129f73f617e8d48a37daa392dcf7c8 Mon Sep 17 00:00:00 2001 From: Joel Jennings Date: Thu, 14 Nov 2024 06:01:04 -0800 Subject: [PATCH] Make conv2d tag graph matcher more general PiperOrigin-RevId: 696493059 --- kfac_jax/_src/tag_graph_matcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kfac_jax/_src/tag_graph_matcher.py b/kfac_jax/_src/tag_graph_matcher.py index 63ef795..98a13eb 100644 --- a/kfac_jax/_src/tag_graph_matcher.py +++ b/kfac_jax/_src/tag_graph_matcher.py @@ -1261,8 +1261,9 @@ def _conv2d(x: Array, params: Sequence[Array]) -> Array: # No bias return y + bias = params[1] # Add bias - return y + params[1][None, None, None] + return y + bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) def _conv2d_parameter_extractor(