Skip to content

Commit

Permalink
Make conv2d tag graph matcher more general
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696653103
  • Loading branch information
joeljennings authored and KfacJaxDev committed Nov 14, 2024
1 parent e302ee3 commit 59fea08
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ def _make_general_dense_pattern(
)


def _conv2d(x: Array, params: Sequence[Array]) -> Array:
def _conv2d(x: Array, params: Sequence[Array], flax_style: bool) -> Array:
"""Example of a conv2d layer function."""

w = params[0]
Expand All @@ -1262,6 +1262,9 @@ def _conv2d(x: Array, params: Sequence[Array]) -> Array:
return y

# Add bias
if flax_style:
bias = params[1]
return y + bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape)
return y + params[1][None, None, None]


Expand Down Expand Up @@ -1290,6 +1293,7 @@ def _conv2d_parameter_extractor(

def _make_conv2d_pattern(
with_bias: bool,
flax_style: bool,
) -> GraphPattern:

x_shape = [2, 8, 8, 5]
Expand All @@ -1300,7 +1304,7 @@ def _make_conv2d_pattern(
return GraphPattern(
name="conv2d_with_bias" if with_bias else "conv2d_no_bias",
tag_primitive=tags.layer_tag,
compute_func=_conv2d,
compute_func=functools.partial(_conv2d, flax_style=flax_style),
parameters_extractor_func=_conv2d_parameter_extractor,
example_args=[np.zeros(x_shape), [np.zeros(s) for s in p_shapes]],
)
Expand Down Expand Up @@ -1513,8 +1517,9 @@ def _make_normalization_haiku_pattern(
_make_general_dense_pattern(False, False, 0),
_make_general_dense_pattern(False, False, 1),
_make_general_dense_pattern(False, False, 2),
_make_conv2d_pattern(True),
_make_conv2d_pattern(False),
_make_conv2d_pattern(True, False),
_make_conv2d_pattern(True, True),
_make_conv2d_pattern(False, False),
_make_scale_and_shift_pattern(1, True, True),
_make_scale_and_shift_pattern(0, True, True),
_make_normalization_haiku_pattern(1, False),
Expand Down

0 comments on commit 59fea08

Please sign in to comment.