diff --git a/kfac_jax/_src/tag_graph_matcher.py b/kfac_jax/_src/tag_graph_matcher.py index 63ef795..98a13eb 100644 --- a/kfac_jax/_src/tag_graph_matcher.py +++ b/kfac_jax/_src/tag_graph_matcher.py @@ -1261,8 +1261,9 @@ def _conv2d(x: Array, params: Sequence[Array]) -> Array: # No bias return y + bias = params[1] # Add bias - return y + params[1][None, None, None] + return y + bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) def _conv2d_parameter_extractor(