Skip to content

Commit ab2a107

Browse files
Run L0 pytorch tests in parallel across multiple GPUs
Detect available GPUs and dispatch pytest invocations in waves, one test per GPU per wave. On single-GPU machines, behavior is identical to the original sequential execution. Design: - GPU detection from CUDA_VISIBLE_DEVICES or nvidia-smi - Wave-based round-robin: launch NUM_GPUS background jobs, wait, repeat - File-based error tracking (shell vars don't propagate from subshells) - Per-test log files in multi-GPU mode to prevent stdout interleaving - Checkpoint pre-step runs synchronously before parallel section With 30 tests on 8 GPUs (B200), expected ~4 waves instead of 30 sequential runs, roughly 4-8x speedup depending on test duration spread.
1 parent 15cf65a commit ab2a107

File tree

1 file changed

+118
-41
lines changed

1 file changed

+118
-41
lines changed

qa/L0_pytorch_unittest/test.sh

Lines changed: 118 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,60 +7,137 @@ function error_exit() {
77
exit 1
88
}
99

10+
: ${TE_PATH:=/opt/transformerengine}
11+
: ${XML_LOG_DIR:=/logs}
12+
mkdir -p "$XML_LOG_DIR"
13+
14+
set -x
15+
16+
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
17+
18+
# ── Parallel test infrastructure ────────────────────────────────────────────
19+
# Detect GPUs and run tests in parallel waves (one test per GPU per wave).
20+
# With 1 GPU, behavior is identical to sequential execution.
21+
22+
FAIL_DIR=$(mktemp -d)
23+
1024
function test_fail() {
11-
RET=1
12-
FAILED_CASES="$FAILED_CASES $1"
25+
echo "$1" >> "$FAIL_DIR/failures"
1326
echo "Error: sub-test failed: $1"
1427
}
1528

16-
RET=0
17-
FAILED_CASES=""
29+
# Detect available GPUs
30+
if [ -n "${CUDA_VISIBLE_DEVICES:-}" ]; then
31+
IFS=',' read -ra GPU_LIST <<< "$CUDA_VISIBLE_DEVICES"
32+
NUM_GPUS=${#GPU_LIST[@]}
33+
else
34+
NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l)
35+
NUM_GPUS=${NUM_GPUS:-1}
36+
GPU_LIST=()
37+
for ((i=0; i<NUM_GPUS; i++)); do GPU_LIST+=($i); done
38+
fi
39+
if [ "$NUM_GPUS" -lt 1 ]; then
40+
NUM_GPUS=1
41+
GPU_LIST=(0)
42+
fi
43+
echo "Detected $NUM_GPUS GPU(s): ${GPU_LIST[*]}"
1844

19-
set -x
45+
GPU_COUNTER=0
46+
WAVE_COUNT=0
2047

21-
: ${TE_PATH:=/opt/transformerengine}
22-
: ${XML_LOG_DIR:=/logs}
23-
mkdir -p "$XML_LOG_DIR"
48+
function run_test() {
49+
local env_prefix="$1"
50+
local xml_name="$2"
51+
local test_path="$3"
52+
local fail_label="$4"
53+
local gpu_id=$((GPU_COUNTER % NUM_GPUS))
54+
GPU_COUNTER=$((GPU_COUNTER + 1))
55+
WAVE_COUNT=$((WAVE_COUNT + 1))
2456

25-
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
57+
if [ "$NUM_GPUS" -le 1 ]; then
58+
# Single GPU: run synchronously (identical to original behavior)
59+
eval "${env_prefix} python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/${xml_name} ${test_path}" \
60+
|| test_fail "$fail_label"
61+
else
62+
# Multi GPU: run in background on assigned GPU, capture output per-test
63+
(
64+
eval "CUDA_VISIBLE_DEVICES=${GPU_LIST[$gpu_id]} ${env_prefix} python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/${xml_name} ${test_path}" \
65+
> "$XML_LOG_DIR/${xml_name%.xml}.log" 2>&1 \
66+
|| test_fail "$fail_label"
67+
) &
68+
fi
69+
70+
# When we've filled all GPUs, wait for the wave to complete
71+
if [ "$WAVE_COUNT" -ge "$NUM_GPUS" ] && [ "$NUM_GPUS" -gt 1 ]; then
72+
wait
73+
WAVE_COUNT=0
74+
fi
75+
}
76+
77+
# ── Checkpoint pre-step (must run before test_checkpoint.py) ────────────────
2678

27-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
28-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
29-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
30-
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
31-
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
32-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
33-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
34-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
35-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8"
36-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py"
37-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
38-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
39-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
40-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py"
41-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
42-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
43-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
44-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
45-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
46-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
47-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
48-
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
49-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
50-
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
51-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
52-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
5379
export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint
5480
if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then
55-
python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit "Failed to generate checkpoint files"
81+
python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all \
82+
|| error_exit "Failed to generate checkpoint files"
83+
fi
84+
85+
# ── Tests ───────────────────────────────────────────────────────────────────
86+
# Each run_test call: env_prefix, xml_name, test_path, fail_label
87+
# Tests are dispatched in waves of NUM_GPUS, one per GPU.
88+
89+
run_test "" "pytest_test_sanity.xml" "$TE_PATH/tests/pytorch/test_sanity.py" "test_sanity.py"
90+
run_test "" "pytest_test_recipe.xml" "$TE_PATH/tests/pytorch/test_recipe.py" "test_recipe.py"
91+
run_test "" "pytest_test_deferred_init.xml" "$TE_PATH/tests/pytorch/test_deferred_init.py" "test_deferred_init.py"
92+
run_test "PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0" "pytest_test_numerics.xml" "$TE_PATH/tests/pytorch/test_numerics.py" "test_numerics.py"
93+
run_test "PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0" "pytest_test_cuda_graphs.xml" "$TE_PATH/tests/pytorch/test_cuda_graphs.py" "test_cuda_graphs.py"
94+
run_test "" "pytest_test_jit.xml" "$TE_PATH/tests/pytorch/test_jit.py" "test_jit.py"
95+
run_test "" "pytest_test_fused_rope.xml" "$TE_PATH/tests/pytorch/test_fused_rope.py" "test_fused_rope.py"
96+
run_test "" "pytest_test_nvfp4.xml" "$TE_PATH/tests/pytorch/nvfp4" "test_nvfp4"
97+
run_test "" "pytest_test_mxfp8.xml" "$TE_PATH/tests/pytorch/mxfp8" "test_mxfp8"
98+
run_test "" "pytest_test_quantized_tensor.xml" "$TE_PATH/tests/pytorch/test_quantized_tensor.py" "test_quantized_tensor.py"
99+
run_test "" "pytest_test_float8blockwisetensor.xml" "$TE_PATH/tests/pytorch/test_float8blockwisetensor.py" "test_float8blockwisetensor.py"
100+
run_test "" "pytest_test_float8_blockwise_scaling_exact.xml" "$TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py" "test_float8_blockwise_scaling_exact.py"
101+
run_test "" "pytest_test_float8_blockwise_gemm_exact.xml" "$TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py" "test_float8_blockwise_gemm_exact.py"
102+
run_test "" "test_grouped_tensor.xml" "$TE_PATH/tests/pytorch/test_grouped_tensor.py" "test_grouped_tensor.py"
103+
run_test "" "pytest_test_gqa.xml" "$TE_PATH/tests/pytorch/test_gqa.py" "test_gqa.py"
104+
run_test "" "pytest_test_fused_optimizer.xml" "$TE_PATH/tests/pytorch/test_fused_optimizer.py" "test_fused_optimizer.py"
105+
run_test "" "pytest_test_multi_tensor.xml" "$TE_PATH/tests/pytorch/test_multi_tensor.py" "test_multi_tensor.py"
106+
run_test "" "pytest_test_fusible_ops.xml" "$TE_PATH/tests/pytorch/test_fusible_ops.py" "test_fusible_ops.py"
107+
run_test "" "pytest_test_permutation.xml" "$TE_PATH/tests/pytorch/test_permutation.py" "test_permutation.py"
108+
run_test "" "pytest_test_parallel_cross_entropy.xml" "$TE_PATH/tests/pytorch/test_parallel_cross_entropy.py" "test_parallel_cross_entropy.py"
109+
run_test "" "pytest_test_cpu_offloading.xml" "$TE_PATH/tests/pytorch/test_cpu_offloading.py" "test_cpu_offloading.py"
110+
run_test "NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1" "pytest_test_cpu_offloading_v1.xml" "$TE_PATH/tests/pytorch/test_cpu_offloading_v1.py" "test_cpu_offloading_v1.py"
111+
run_test "" "pytest_test_attention.xml" "$TE_PATH/tests/pytorch/attention/test_attention.py" "test_attention.py"
112+
run_test "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0" "pytest_test_attention_deterministic.xml" "$TE_PATH/tests/pytorch/attention/test_attention.py" "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
113+
run_test "" "pytest_test_kv_cache.xml" "$TE_PATH/tests/pytorch/attention/test_kv_cache.py" "test_kv_cache.py"
114+
run_test "" "pytest_test_hf_integration.xml" "$TE_PATH/tests/pytorch/test_hf_integration.py" "test_hf_integration.py"
115+
run_test "" "pytest_test_checkpoint.xml" "$TE_PATH/tests/pytorch/test_checkpoint.py" "test_checkpoint.py"
116+
run_test "" "pytest_test_fused_router.xml" "$TE_PATH/tests/pytorch/test_fused_router.py" "test_fused_router.py"
117+
run_test "" "pytest_test_partial_cast.xml" "$TE_PATH/tests/pytorch/test_partial_cast.py" "test_partial_cast.py"
118+
119+
# ── Wait for remaining background jobs ──────────────────────────────────────
120+
121+
if [ "$NUM_GPUS" -gt 1 ]; then
122+
wait
56123
fi
57-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
58-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
59-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py"
60124

61-
if [ "$RET" -ne 0 ]; then
62-
echo "Error in the following test cases:$FAILED_CASES"
125+
# ── Display per-test logs from parallel runs ────────────────────────────────
126+
127+
if [ "$NUM_GPUS" -gt 1 ]; then
128+
for logfile in "$XML_LOG_DIR"/*.log; do
129+
[ -f "$logfile" ] && echo "=== $(basename "$logfile") ===" && cat "$logfile"
130+
done
131+
fi
132+
133+
# ── Report results ──────────────────────────────────────────────────────────
134+
135+
if [ -s "$FAIL_DIR/failures" ]; then
136+
FAILED_CASES=$(cat "$FAIL_DIR/failures" | tr '\n' ' ')
137+
echo "Error in the following test cases: $FAILED_CASES"
138+
rm -rf "$FAIL_DIR"
63139
exit 1
64140
fi
141+
rm -rf "$FAIL_DIR"
65142
echo "All tests passed"
66143
exit 0

0 commit comments

Comments
 (0)