Skip to content

Commit a827a27

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for loops, debug_print, and unary ops to Warp semantics.
PiperOrigin-RevId: 762036132
1 parent 210b5fc commit a827a27

File tree

3 files changed

+139
-60
lines changed

3 files changed

+139
-60
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,19 @@ def convert(ty, x):
16981698
lax.not_p: lambda ctx, x: ~x,
16991699
})
17001700

1701+
def _unary_warp_lowering_rule(impl):
1702+
def _lowering_rule(ctx: LoweringRuleContext, x):
1703+
if not all(aval_in.shape == () for aval_in in ctx.avals_in):
1704+
raise NotImplementedError(
1705+
"Non-scalar arithmetic is not supported in warp-level lowering.")
1706+
return impl(x)
1707+
return _lowering_rule
1708+
1709+
mosaic_lowering_rules[gpu_core.LANExWARP_SEMANTICS].update({
1710+
lax.neg_p: _unary_warp_lowering_rule(lambda x: -x),
1711+
lax.not_p: _unary_warp_lowering_rule(lambda x: ~x)
1712+
})
1713+
17011714
mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS].update({
17021715
lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False),
17031716
lax.not_p: _lower_fun(
@@ -2163,6 +2176,8 @@ def _axis_index_warp_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
21632176

21642177

21652178
@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane)
2179+
@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane,
2180+
gpu_core.PrimitiveSemantics.Warp)
21662181
def _debug_print_lowering_rule(
21672182
ctx: LoweringRuleContext,
21682183
*args,
@@ -2171,13 +2186,17 @@ def _debug_print_lowering_rule(
21712186
):
21722187
del has_placeholders # Unused.
21732188
primitives.check_debug_print_format(fmt, *args)
2189+
scope = mgpu.ThreadSubset.WARPGROUP
2190+
if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp:
2191+
scope = mgpu.ThreadSubset.WARP
21742192
if not any(aval.shape for aval in ctx.avals_in):
21752193
mgpu.debug_print(
21762194
fmt,
21772195
*(
21782196
_ensure_ir_value(arg, aval.dtype)
21792197
for arg, aval in zip(args, ctx.avals_in)
21802198
),
2199+
scope=scope
21812200
)
21822201
elif len(ctx.avals_in) == 1:
21832202
[arg] = args
@@ -2461,6 +2480,8 @@ def loop(loop_index, body_args):
24612480

24622481
@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane)
24632482
@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Warpgroup)
2483+
@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane,
2484+
gpu_core.PrimitiveSemantics.Warp)
24642485
def _scan_lowering_rule(
24652486
ctx: LoweringRuleContext,
24662487
*args,

jax/experimental/mosaic/gpu/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ def _debug_scalar_ty_format(arg):
144144
return "%f", arg
145145
raise NotImplementedError(f"Can't print the type {arg.type}")
146146

147-
def debug_print(fmt, *args, uniform=True):
147+
def debug_print(fmt, *args, uniform=True, scope=None):
148+
if not uniform and scope is not None:
149+
raise ValueError("Cannot specify scope to a non-uniform debug_print.")
150+
if scope is None:
151+
scope = ThreadSubset.WARPGROUP
148152
type_formats = []
149153
new_args = []
150154
for arg in args:
@@ -168,7 +172,7 @@ def debug_print(fmt, *args, uniform=True):
168172
raise NotImplementedError(arg.type)
169173
type_formats.append(ty_format)
170174
ctx = (
171-
functools.partial(single_thread, scope=ThreadSubset.WARPGROUP)
175+
functools.partial(single_thread, scope=scope)
172176
if uniform
173177
else contextlib.nullcontext
174178
)

tests/pallas/mosaic_gpu_test.py

Lines changed: 112 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,64 +1569,6 @@ def kernel(x_ref, y_ref, o_ref):
15691569
y = jax.lax.iota(jnp.float32, 128) * 3
15701570
np.testing.assert_array_equal(kernel(x, y), x + y)
15711571

1572-
def test_warp_specialization_axis_index(self):
1573-
if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane:
1574-
self.skipTest("Test only works on Lane semantics")
1575-
warp_mesh = plgpu.WarpMesh(axis_name="warp")
1576-
@functools.partial(plgpu.kernel,
1577-
out_shape=jax.ShapeDtypeStruct((2, 128), jnp.int32))
1578-
def kernel(y_ref):
1579-
def scope(ones_smem_ref, threes_smem_ref):
1580-
# Prepare data to copy.
1581-
ones_smem_ref[:] = jnp.ones((1, 128), jnp.int32)
1582-
threes_smem_ref[:] = jnp.ones((1, 128), jnp.int32) * 3
1583-
plgpu.commit_smem()
1584-
@pl.core_map(warp_mesh)
1585-
def _():
1586-
warp_id = lax.axis_index("warp")
1587-
# We cannot load/store inside of core_map, so we issue async
1588-
# copies instead to produce a testable result.
1589-
@pl.when(warp_id == 1)
1590-
def _():
1591-
plgpu.copy_smem_to_gmem(ones_smem_ref, y_ref.at[0:1])
1592-
@pl.when(warp_id == 3)
1593-
def _():
1594-
plgpu.copy_smem_to_gmem(threes_smem_ref, y_ref.at[1:2])
1595-
plgpu.wait_smem_to_gmem(0)
1596-
pl.run_scoped(scope,
1597-
plgpu.SMEM((1, 128), jnp.int32),
1598-
plgpu.SMEM((1, 128), jnp.int32)
1599-
)
1600-
result = kernel()
1601-
expected = jnp.stack((jnp.ones((128,), jnp.int32),
1602-
jnp.ones((128,), jnp.int32) * 3), axis=0)
1603-
np.testing.assert_array_equal(result, expected)
1604-
1605-
def test_warp_mesh_errors_when_closing_over_array(self):
1606-
if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane:
1607-
self.skipTest("Test only works on Lane semantics")
1608-
# We currently do not allow closing over arrays when mapping over
1609-
# a mesh, since we would need to present a view of the array local
1610-
# to each warp.
1611-
warp_mesh = plgpu.WarpMesh(axis_name="warp")
1612-
@functools.partial(plgpu.kernel,
1613-
out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32),
1614-
scratch_shapes=[plgpu.SMEM((32, 32), jnp.float32)])
1615-
def kernel(out_ref, smem_ref):
1616-
arr = jnp.ones((32, 32), dtype=jnp.float32)
1617-
@pl.core_map(warp_mesh)
1618-
def _():
1619-
smem_ref[...] = arr + 1
1620-
plgpu.commit_smem()
1621-
plgpu.copy_smem_to_gmem(smem_ref, out_ref)
1622-
plgpu.wait_smem_to_gmem(0)
1623-
with self.assertRaisesRegex(
1624-
mgpu_lowering.LoweringError,
1625-
"Can only close over scalars and Refs when using core_map with "
1626-
"WarpMesh",
1627-
):
1628-
kernel()
1629-
16301572
def test_smem_aliasing_works(self):
16311573
self.skip_if_wg_semantics()
16321574

@@ -1825,6 +1767,118 @@ def body(idx, _):
18251767
)
18261768

18271769

1770+
class PallasCallWarpPrimitiveSemanticsTest(PallasTest):
1771+
def setUp(self):
1772+
super().setUp()
1773+
if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane:
1774+
self.skipTest("Test only works on Lane semantics")
1775+
1776+
def test_axis_index(self):
1777+
warp_mesh = plgpu.WarpMesh(axis_name="warp")
1778+
@functools.partial(plgpu.kernel,
1779+
out_shape=jax.ShapeDtypeStruct((2, 128), jnp.int32))
1780+
def kernel(y_ref):
1781+
def scope(ones_smem_ref, threes_smem_ref):
1782+
# Prepare data to copy.
1783+
ones_smem_ref[:] = jnp.ones((1, 128), jnp.int32)
1784+
threes_smem_ref[:] = jnp.ones((1, 128), jnp.int32) * 3
1785+
plgpu.commit_smem()
1786+
@pl.core_map(warp_mesh)
1787+
def _():
1788+
warp_id = lax.axis_index("warp")
1789+
# We cannot load/store inside of core_map, so we issue async
1790+
# copies instead to produce a testable result.
1791+
@pl.when(warp_id == 1)
1792+
def _():
1793+
plgpu.copy_smem_to_gmem(ones_smem_ref, y_ref.at[0:1])
1794+
@pl.when(warp_id == 3)
1795+
def _():
1796+
plgpu.copy_smem_to_gmem(threes_smem_ref, y_ref.at[1:2])
1797+
plgpu.wait_smem_to_gmem(0)
1798+
pl.run_scoped(scope,
1799+
plgpu.SMEM((1, 128), jnp.int32),
1800+
plgpu.SMEM((1, 128), jnp.int32)
1801+
)
1802+
result = kernel()
1803+
expected = jnp.stack((jnp.ones((128,), jnp.int32),
1804+
jnp.ones((128,), jnp.int32) * 3), axis=0)
1805+
np.testing.assert_array_equal(result, expected)
1806+
1807+
def test_errors_when_closing_over_array(self):
1808+
# We currently do not allow closing over arrays when mapping over
1809+
# a mesh, since we would need to present a view of the array local
1810+
# to each warp.
1811+
warp_mesh = plgpu.WarpMesh(axis_name="warp")
1812+
@functools.partial(plgpu.kernel,
1813+
out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32),
1814+
scratch_shapes=[plgpu.SMEM((32, 32), jnp.float32)])
1815+
def kernel(out_ref, smem_ref):
1816+
arr = jnp.ones((32, 32), dtype=jnp.float32)
1817+
@pl.core_map(warp_mesh)
1818+
def _():
1819+
smem_ref[...] = arr + 1
1820+
plgpu.commit_smem()
1821+
plgpu.copy_smem_to_gmem(smem_ref, out_ref)
1822+
plgpu.wait_smem_to_gmem(0)
1823+
with self.assertRaisesRegex(
1824+
mgpu_lowering.LoweringError,
1825+
"Can only close over scalars and Refs when using core_map with "
1826+
"WarpMesh",
1827+
):
1828+
kernel()
1829+
1830+
def test_single_warp_scan(self):
1831+
warp_mesh = plgpu.WarpMesh(axis_name="warp")
1832+
@functools.partial(plgpu.kernel,
1833+
out_shape=jax.ShapeDtypeStruct((10, 128), jnp.int32))
1834+
def kernel(y_ref):
1835+
def scope(smem_ref):
1836+
# Prepare data to copy.
1837+
for i in range(10):
1838+
smem_ref[i, :] = jnp.ones_like(smem_ref.at[i]) * i
1839+
plgpu.commit_smem()
1840+
@pl.core_map(warp_mesh)
1841+
def _():
1842+
warp_id = lax.axis_index("warp")
1843+
@pl.when(warp_id == 0)
1844+
def _():
1845+
def loop_body(i, _):
1846+
_slice = pl.ds(i, 1)
1847+
plgpu.copy_smem_to_gmem(smem_ref.at[_slice], y_ref.at[_slice])
1848+
lax.fori_loop(0, 10, loop_body, None)
1849+
plgpu.wait_smem_to_gmem(0)
1850+
pl.run_scoped(scope, plgpu.SMEM((10, 128), jnp.int32))
1851+
result = kernel()
1852+
expected = jnp.stack(
1853+
[jnp.ones((128,), jnp.int32) * i for i in range(10)], axis=0)
1854+
np.testing.assert_array_equal(result, expected)
1855+
1856+
def test_debug_print(self):
1857+
warp_mesh = plgpu.WarpMesh(axis_name="warp")
1858+
@functools.partial(
1859+
plgpu.kernel,
1860+
out_shape=jnp.zeros(128, np.int32),
1861+
)
1862+
def kernel(ref):
1863+
ref[...] = ref[...] # Prevent kernel from being DCE'd
1864+
@pl.core_map(warp_mesh)
1865+
def _():
1866+
warp_id = lax.axis_index("warp")
1867+
pl.debug_print("warp: {}", warp_id)
1868+
1869+
with self.capture_stdout() as output:
1870+
jax.block_until_ready(kernel())
1871+
self.assertEqual(
1872+
set(output().splitlines()),
1873+
{
1874+
"warp: 0",
1875+
"warp: 1",
1876+
"warp: 2",
1877+
"warp: 3",
1878+
},
1879+
)
1880+
1881+
18281882
class PallasCallWGTest(
18291883
PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup
18301884
):

0 commit comments

Comments
 (0)