Skip to content

Commit 52e58b9

Browse files
committed
[shape_poly] Fix handling of shape polymorphism for pallas_call_batching
The previous code was not handling the case of a symbolic batch dimension (one of the most common uses of shape polymorphism)
1 parent 5479321 commit 52e58b9

3 files changed

Lines changed: 44 additions & 7 deletions

File tree

jax/_src/pallas/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,8 @@ def get_grid_mapping(
11631163
dim_check : Any = jax_core.is_constant_dim # type: ignore[no-redef]
11641164
assert all(i is None or dim_check(i) for i in grid_spec.grid)
11651165
grid_mapping_grid = tuple(
1166-
dynamic_grid_dim if d is None else d for d in grid_spec.grid
1166+
dynamic_grid_dim if (d is None or not jax_core.is_constant_dim(d)) else d
1167+
for d in grid_spec.grid
11671168
)
11681169
# The inputs for the index maps
11691170
index_map_avals = (

jax/_src/pallas/pallas_call.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,8 +717,11 @@ def temp_f(*args):
717717
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten(
718718
(batched_index_map_args, {}))
719719

720+
axis_size_is_dynamic = not jax_core.is_constant_dim(axis_size)
721+
new_grid_dim = pallas_core.dynamic_grid_dim if axis_size_is_dynamic else axis_size
722+
720723
batched_grid_mapping = grid_mapping.replace(
721-
grid=(axis_size, *grid_mapping.grid),
724+
grid=(new_grid_dim, *grid_mapping.grid),
722725
block_mappings=tuple(batched_block_mappings),
723726
index_map_avals=tuple(batched_index_map_avals),
724727
index_map_tree=batched_index_map_tree,
@@ -729,7 +732,7 @@ def temp_f(*args):
729732
# Avoid scaling the cost estimate by the batch size if the batch size is a
730733
# dynamic shape (DimExpr).
731734
# https://docs.jax.dev/en/latest/export/shape_poly.html#computing-with-dimension-variables
732-
if cost_estimate is not None and isinstance(axis_size, int):
735+
if cost_estimate is not None and not axis_size_is_dynamic:
733736
batched_cost_estimate = CostEstimate(
734737
flops=cost_estimate.flops * axis_size,
735738
bytes_accessed=cost_estimate.bytes_accessed * axis_size,
@@ -762,6 +765,8 @@ def temp_f(*args):
762765
with jax_core.remove_explicit_mesh_axis_names(ema):
763766
bind = shard_map(bind, out_specs=P(ema), axis_names=set(ema))
764767

768+
if axis_size_is_dynamic:
769+
dynamic_grid_args = [axis_size, *dynamic_grid_args]
765770
out = bind(*dynamic_grid_args, *args)
766771
return out, (0,) * len(out)
767772

tests/pallas/export_pallas_test.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
9797

9898
class ExportTestWithMosaicTPU(jtu.JaxTestCase):
9999
def test_dynamic_shapes_export(self):
100-
if jtu.device_under_test() != "tpu":
101-
self.skipTest("Mosaic TPU test only runs on TPU")
102-
103100
def add_vectors_kernel(x_ref, y_ref, o_ref):
104101
block_b = x_ref.shape[0]
105102

@@ -142,11 +139,45 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
142139
with pallas_export_experimental(dynamic_shapes=True):
143140
f_k = f_e(x_shape, y_shape)
144141

145-
print(f_k.mlir_module())
146142
self.assertRegex(
147143
f_k.mlir_module(),
148144
r"stablehlo.custom_call @tpu_custom_call.+kernel_name\s*=\s*\"my_custom_kernel_name\"")
149145

146+
def test_export_vmap(self):
147+
def add_vectors_kernel(x_ref, y_ref, o_ref):
148+
o_ref[...] = x_ref[...] + y_ref[...]
149+
150+
def add_vectors(x, y):
151+
block_size = 128
152+
# Grid depends on input shape, which will be symbolic
153+
grid = (x.shape[0] // block_size, x.shape[1] // block_size)
154+
block_spec = pl.BlockSpec((block_size, block_size), lambda i, j: (i, j))
155+
return pl.pallas_call(
156+
add_vectors_kernel,
157+
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
158+
grid=grid,
159+
in_specs=[block_spec, block_spec],
160+
out_specs=block_spec
161+
)(x, y)
162+
163+
b, m, n = jax.export.symbolic_shape("b,m,n")
164+
x_info = jax.ShapeDtypeStruct((b, m, n), jnp.float32)
165+
166+
exporter = jax.export.export(jax.jit(jax.vmap(add_vectors)),
167+
platforms=["tpu"])
168+
169+
with pallas_export_experimental(dynamic_shapes=True):
170+
exp = exporter(x_info, x_info) # No crash
171+
172+
if jtu.device_under_test() == "tpu":
173+
x = y = jnp.ones((4, 128, 128))
174+
res = exp.call(x, y)
175+
self.assertAllClose(res, x + y)
176+
177+
x = y = jnp.ones((4, 192, 192)) # Not multiple of 128
178+
res = exp.call(x, y)
179+
self.assertAllClose(res, x + y)
180+
150181

151182
class ExportTestWithMosaicGpu(ExportTestWithTriton):
152183

0 commit comments

Comments
 (0)