Skip to content

Commit 561a323

Browse files
Consolidate the distributed fused attention tests to shared input generation and execition logic.
Signed-off-by: Michael Goldfarb <[email protected]>
1 parent 2402406 commit 561a323

File tree

3 files changed

+323
-392
lines changed

3 files changed

+323
-392
lines changed

tests/jax/distributed_test_base.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,36 +18,39 @@
1818
def generate_configs():
1919
configs = []
2020
if is_devices_enough(2):
21-
configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
22-
configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
21+
configs.append(
22+
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
23+
)
24+
configs.append(
25+
pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2")
26+
)
2327

2428
if is_devices_enough(4):
25-
TP_size = 2
26-
DP_size = 2
2729
configs.append(
28-
[4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
30+
pytest.param(
31+
4,
32+
(2, 2),
33+
("dp", "tp"),
34+
MeshResource(dp_resource="dp", tp_resource="tp"),
35+
id=f"n4_dp2_tp2",
36+
)
2937
)
3038

3139
return configs
3240

3341

3442
def generate_context_parallel_configs():
3543
configs = []
36-
44+
mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp")
45+
axes = ("dp", "cp", "tp")
3746
DP_sizes = (1, 2)
3847
CP_sizes = (1, 2, 4, 8)
3948
TP_sizes = (1, 2)
4049
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
4150
ndev = cp * tp * dp
4251
if is_devices_enough(ndev):
4352
configs.append(
44-
pytest.param(
45-
ndev,
46-
(dp, cp, tp),
47-
("dp", "cp", "tp"),
48-
MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"),
49-
id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}",
50-
)
53+
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
5154
)
5255

5356
return configs

0 commit comments

Comments
 (0)