@@ -1569,64 +1569,6 @@ def kernel(x_ref, y_ref, o_ref):
1569
1569
y = jax .lax .iota (jnp .float32 , 128 ) * 3
1570
1570
np .testing .assert_array_equal (kernel (x , y ), x + y )
1571
1571
1572
- def test_warp_specialization_axis_index (self ):
1573
- if self .LOWERING_SEMANTICS != plgpu .LoweringSemantics .Lane :
1574
- self .skipTest ("Test only works on Lane semantics" )
1575
- warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
1576
- @functools .partial (plgpu .kernel ,
1577
- out_shape = jax .ShapeDtypeStruct ((2 , 128 ), jnp .int32 ))
1578
- def kernel (y_ref ):
1579
- def scope (ones_smem_ref , threes_smem_ref ):
1580
- # Prepare data to copy.
1581
- ones_smem_ref [:] = jnp .ones ((1 , 128 ), jnp .int32 )
1582
- threes_smem_ref [:] = jnp .ones ((1 , 128 ), jnp .int32 ) * 3
1583
- plgpu .commit_smem ()
1584
- @pl .core_map (warp_mesh )
1585
- def _ ():
1586
- warp_id = lax .axis_index ("warp" )
1587
- # We cannot load/store inside of core_map, so we issue async
1588
- # copies instead to produce a testable result.
1589
- @pl .when (warp_id == 1 )
1590
- def _ ():
1591
- plgpu .copy_smem_to_gmem (ones_smem_ref , y_ref .at [0 :1 ])
1592
- @pl .when (warp_id == 3 )
1593
- def _ ():
1594
- plgpu .copy_smem_to_gmem (threes_smem_ref , y_ref .at [1 :2 ])
1595
- plgpu .wait_smem_to_gmem (0 )
1596
- pl .run_scoped (scope ,
1597
- plgpu .SMEM ((1 , 128 ), jnp .int32 ),
1598
- plgpu .SMEM ((1 , 128 ), jnp .int32 )
1599
- )
1600
- result = kernel ()
1601
- expected = jnp .stack ((jnp .ones ((128 ,), jnp .int32 ),
1602
- jnp .ones ((128 ,), jnp .int32 ) * 3 ), axis = 0 )
1603
- np .testing .assert_array_equal (result , expected )
1604
-
1605
- def test_warp_mesh_errors_when_closing_over_array (self ):
1606
- if self .LOWERING_SEMANTICS != plgpu .LoweringSemantics .Lane :
1607
- self .skipTest ("Test only works on Lane semantics" )
1608
- # We currently do not allow closing over arrays when mapping over
1609
- # a mesh, since we would need to present a view of the array local
1610
- # to each warp.
1611
- warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
1612
- @functools .partial (plgpu .kernel ,
1613
- out_shape = jax .ShapeDtypeStruct ((32 , 32 ), jnp .float32 ),
1614
- scratch_shapes = [plgpu .SMEM ((32 , 32 ), jnp .float32 )])
1615
- def kernel (out_ref , smem_ref ):
1616
- arr = jnp .ones ((32 , 32 ), dtype = jnp .float32 )
1617
- @pl .core_map (warp_mesh )
1618
- def _ ():
1619
- smem_ref [...] = arr + 1
1620
- plgpu .commit_smem ()
1621
- plgpu .copy_smem_to_gmem (smem_ref , out_ref )
1622
- plgpu .wait_smem_to_gmem (0 )
1623
- with self .assertRaisesRegex (
1624
- mgpu_lowering .LoweringError ,
1625
- "Can only close over scalars and Refs when using core_map with "
1626
- "WarpMesh" ,
1627
- ):
1628
- kernel ()
1629
-
1630
1572
def test_smem_aliasing_works (self ):
1631
1573
self .skip_if_wg_semantics ()
1632
1574
@@ -1825,6 +1767,118 @@ def body(idx, _):
1825
1767
)
1826
1768
1827
1769
1770
+ class PallasCallWarpPrimitiveSemanticsTest (PallasTest ):
1771
+ def setUp (self ):
1772
+ super ().setUp ()
1773
+ if self .LOWERING_SEMANTICS != plgpu .LoweringSemantics .Lane :
1774
+ self .skipTest ("Test only works on Lane semantics" )
1775
+
1776
+ def test_axis_index (self ):
1777
+ warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
1778
+ @functools .partial (plgpu .kernel ,
1779
+ out_shape = jax .ShapeDtypeStruct ((2 , 128 ), jnp .int32 ))
1780
+ def kernel (y_ref ):
1781
+ def scope (ones_smem_ref , threes_smem_ref ):
1782
+ # Prepare data to copy.
1783
+ ones_smem_ref [:] = jnp .ones ((1 , 128 ), jnp .int32 )
1784
+ threes_smem_ref [:] = jnp .ones ((1 , 128 ), jnp .int32 ) * 3
1785
+ plgpu .commit_smem ()
1786
+ @pl .core_map (warp_mesh )
1787
+ def _ ():
1788
+ warp_id = lax .axis_index ("warp" )
1789
+ # We cannot load/store inside of core_map, so we issue async
1790
+ # copies instead to produce a testable result.
1791
+ @pl .when (warp_id == 1 )
1792
+ def _ ():
1793
+ plgpu .copy_smem_to_gmem (ones_smem_ref , y_ref .at [0 :1 ])
1794
+ @pl .when (warp_id == 3 )
1795
+ def _ ():
1796
+ plgpu .copy_smem_to_gmem (threes_smem_ref , y_ref .at [1 :2 ])
1797
+ plgpu .wait_smem_to_gmem (0 )
1798
+ pl .run_scoped (scope ,
1799
+ plgpu .SMEM ((1 , 128 ), jnp .int32 ),
1800
+ plgpu .SMEM ((1 , 128 ), jnp .int32 )
1801
+ )
1802
+ result = kernel ()
1803
+ expected = jnp .stack ((jnp .ones ((128 ,), jnp .int32 ),
1804
+ jnp .ones ((128 ,), jnp .int32 ) * 3 ), axis = 0 )
1805
+ np .testing .assert_array_equal (result , expected )
1806
+
1807
+ def test_errors_when_closing_over_array (self ):
1808
+ # We currently do not allow closing over arrays when mapping over
1809
+ # a mesh, since we would need to present a view of the array local
1810
+ # to each warp.
1811
+ warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
1812
+ @functools .partial (plgpu .kernel ,
1813
+ out_shape = jax .ShapeDtypeStruct ((32 , 32 ), jnp .float32 ),
1814
+ scratch_shapes = [plgpu .SMEM ((32 , 32 ), jnp .float32 )])
1815
+ def kernel (out_ref , smem_ref ):
1816
+ arr = jnp .ones ((32 , 32 ), dtype = jnp .float32 )
1817
+ @pl .core_map (warp_mesh )
1818
+ def _ ():
1819
+ smem_ref [...] = arr + 1
1820
+ plgpu .commit_smem ()
1821
+ plgpu .copy_smem_to_gmem (smem_ref , out_ref )
1822
+ plgpu .wait_smem_to_gmem (0 )
1823
+ with self .assertRaisesRegex (
1824
+ mgpu_lowering .LoweringError ,
1825
+ "Can only close over scalars and Refs when using core_map with "
1826
+ "WarpMesh" ,
1827
+ ):
1828
+ kernel ()
1829
+
1830
+ def test_single_warp_scan (self ):
1831
+ warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
1832
+ @functools .partial (plgpu .kernel ,
1833
+ out_shape = jax .ShapeDtypeStruct ((10 , 128 ), jnp .int32 ))
1834
+ def kernel (y_ref ):
1835
+ def scope (smem_ref ):
1836
+ # Prepare data to copy.
1837
+ for i in range (10 ):
1838
+ smem_ref [i , :] = jnp .ones_like (smem_ref .at [i ]) * i
1839
+ plgpu .commit_smem ()
1840
+ @pl .core_map (warp_mesh )
1841
+ def _ ():
1842
+ warp_id = lax .axis_index ("warp" )
1843
+ @pl .when (warp_id == 0 )
1844
+ def _ ():
1845
+ def loop_body (i , _ ):
1846
+ _slice = pl .ds (i , 1 )
1847
+ plgpu .copy_smem_to_gmem (smem_ref .at [_slice ], y_ref .at [_slice ])
1848
+ lax .fori_loop (0 , 10 , loop_body , None )
1849
+ plgpu .wait_smem_to_gmem (0 )
1850
+ pl .run_scoped (scope , plgpu .SMEM ((10 , 128 ), jnp .int32 ))
1851
+ result = kernel ()
1852
+ expected = jnp .stack (
1853
+ [jnp .ones ((128 ,), jnp .int32 ) * i for i in range (10 )], axis = 0 )
1854
+ np .testing .assert_array_equal (result , expected )
1855
+
1856
+ def test_debug_print (self ):
1857
+ warp_mesh = plgpu .WarpMesh (axis_name = "warp" )
1858
+ @functools .partial (
1859
+ plgpu .kernel ,
1860
+ out_shape = jnp .zeros (128 , np .int32 ),
1861
+ )
1862
+ def kernel (ref ):
1863
+ ref [...] = ref [...] # Prevent kernel from being DCE'd
1864
+ @pl .core_map (warp_mesh )
1865
+ def _ ():
1866
+ warp_id = lax .axis_index ("warp" )
1867
+ pl .debug_print ("warp: {}" , warp_id )
1868
+
1869
+ with self .capture_stdout () as output :
1870
+ jax .block_until_ready (kernel ())
1871
+ self .assertEqual (
1872
+ set (output ().splitlines ()),
1873
+ {
1874
+ "warp: 0" ,
1875
+ "warp: 1" ,
1876
+ "warp: 2" ,
1877
+ "warp: 3" ,
1878
+ },
1879
+ )
1880
+
1881
+
1828
1882
class PallasCallWGTest (
1829
1883
PallasCallTest , lowering_semantics = plgpu .LoweringSemantics .Warpgroup
1830
1884
):
0 commit comments