Skip to content

Commit

Permalink
Make conv2d tag graph matcher more general
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696493059
  • Loading branch information
joeljennings authored and KfacJaxDev committed Nov 14, 2024
1 parent e302ee3 commit a621e70
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,8 +1261,9 @@ def _conv2d(x: Array, params: Sequence[Array]) -> Array:
# No bias
return y

bias = params[1]
# Add bias
return y + params[1][None, None, None]
return y + bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape)


def _conv2d_parameter_extractor(
Expand Down

0 comments on commit a621e70

Please sign in to comment.