@@ -7,60 +7,137 @@ function error_exit() {
77 exit 1
88}
99
10+ : ${TE_PATH:=/ opt/ transformerengine}
11+ : ${XML_LOG_DIR:=/ logs}
12+ mkdir -p " $XML_LOG_DIR "
13+
14+ set -x
15+
16+ pip3 install pytest==8.2.1 || error_exit " Failed to install pytest"
17+
18+ # ── Parallel test infrastructure ────────────────────────────────────────────
19+ # Detect GPUs and run tests in parallel waves (one test per GPU per wave).
20+ # With 1 GPU, behavior is identical to sequential execution.
21+
22+ FAIL_DIR=$( mktemp -d)
23+
1024function test_fail() {
11- RET=1
12- FAILED_CASES=" $FAILED_CASES $1 "
25+ echo " $1 " >> " $FAIL_DIR /failures"
1326 echo " Error: sub-test failed: $1 "
1427}
1528
16- RET=0
17- FAILED_CASES=" "
29+ # Detect available GPUs
30+ if [ -n " ${CUDA_VISIBLE_DEVICES:- } " ]; then
31+ IFS=' ,' read -ra GPU_LIST <<< " $CUDA_VISIBLE_DEVICES"
32+ NUM_GPUS=${# GPU_LIST[@]}
33+ else
34+ NUM_GPUS=$( nvidia-smi -L 2> /dev/null | wc -l)
35+ NUM_GPUS=${NUM_GPUS:- 1}
36+ GPU_LIST=()
37+ for (( i= 0 ; i< NUM_GPUS; i++ )) ; do GPU_LIST+=($i ); done
38+ fi
39+ if [ " $NUM_GPUS " -lt 1 ]; then
40+ NUM_GPUS=1
41+ GPU_LIST=(0)
42+ fi
43+ echo " Detected $NUM_GPUS GPU(s): ${GPU_LIST[*]} "
1844
19- set -x
45+ GPU_COUNTER=0
46+ WAVE_COUNT=0
2047
21- : ${TE_PATH:=/ opt/ transformerengine}
22- : ${XML_LOG_DIR:=/ logs}
23- mkdir -p " $XML_LOG_DIR "
48+ function run_test() {
49+ local env_prefix=" $1 "
50+ local xml_name=" $2 "
51+ local test_path=" $3 "
52+ local fail_label=" $4 "
53+ local gpu_id=$(( GPU_COUNTER % NUM_GPUS))
54+ GPU_COUNTER=$(( GPU_COUNTER + 1 ))
55+ WAVE_COUNT=$(( WAVE_COUNT + 1 ))
2456
25- pip3 install pytest==8.2.1 || error_exit " Failed to install pytest"
57+ if [ " $NUM_GPUS " -le 1 ]; then
58+ # Single GPU: run synchronously (identical to original behavior)
59+ eval " ${env_prefix} python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /${xml_name} ${test_path} " \
60+ || test_fail " $fail_label "
61+ else
62+ # Multi GPU: run in background on assigned GPU, capture output per-test
63+ (
64+ eval " CUDA_VISIBLE_DEVICES=${GPU_LIST[$gpu_id]} ${env_prefix} python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /${xml_name} ${test_path} " \
65+ > " $XML_LOG_DIR /${xml_name% .xml} .log" 2>&1 \
66+ || test_fail " $fail_label "
67+ ) &
68+ fi
69+
70+ # When we've filled all GPUs, wait for the wave to complete
71+ if [ " $WAVE_COUNT " -ge " $NUM_GPUS " ] && [ " $NUM_GPUS " -gt 1 ]; then
72+ wait
73+ WAVE_COUNT=0
74+ fi
75+ }
76+
77+ # ── Checkpoint pre-step (must run before test_checkpoint.py) ────────────────
2678
27- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_sanity.xml $TE_PATH /tests/pytorch/test_sanity.py || test_fail " test_sanity.py"
28- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_recipe.xml $TE_PATH /tests/pytorch/test_recipe.py || test_fail " test_recipe.py"
29- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_deferred_init.xml $TE_PATH /tests/pytorch/test_deferred_init.py || test_fail " test_deferred_init.py"
30- PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_numerics.xml $TE_PATH /tests/pytorch/test_numerics.py || test_fail " test_numerics.py"
31- PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_cuda_graphs.xml $TE_PATH /tests/pytorch/test_cuda_graphs.py || test_fail " test_cuda_graphs.py"
32- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_jit.xml $TE_PATH /tests/pytorch/test_jit.py || test_fail " test_jit.py"
33- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_fused_rope.xml $TE_PATH /tests/pytorch/test_fused_rope.py || test_fail " test_fused_rope.py"
34- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_nvfp4.xml $TE_PATH /tests/pytorch/nvfp4 || test_fail " test_nvfp4"
35- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_mxfp8.xml $TE_PATH /tests/pytorch/mxfp8 || test_fail " test_mxfp8"
36- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_quantized_tensor.xml $TE_PATH /tests/pytorch/test_quantized_tensor.py || test_fail " test_quantized_tensor.py"
37- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_float8blockwisetensor.xml $TE_PATH /tests/pytorch/test_float8blockwisetensor.py || test_fail " test_float8blockwisetensor.py"
38- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH /tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail " test_float8_blockwise_scaling_exact.py"
39- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH /tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail " test_float8_blockwise_gemm_exact.py"
40- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /test_grouped_tensor.xml $TE_PATH /tests/pytorch/test_grouped_tensor.py || test_fail " test_grouped_tensor.py"
41- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_gqa.xml $TE_PATH /tests/pytorch/test_gqa.py || test_fail " test_gqa.py"
42- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_fused_optimizer.xml $TE_PATH /tests/pytorch/test_fused_optimizer.py || test_fail " test_fused_optimizer.py"
43- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_multi_tensor.xml $TE_PATH /tests/pytorch/test_multi_tensor.py || test_fail " test_multi_tensor.py"
44- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_fusible_ops.xml $TE_PATH /tests/pytorch/test_fusible_ops.py || test_fail " test_fusible_ops.py"
45- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_permutation.xml $TE_PATH /tests/pytorch/test_permutation.py || test_fail " test_permutation.py"
46- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_parallel_cross_entropy.xml $TE_PATH /tests/pytorch/test_parallel_cross_entropy.py || test_fail " test_parallel_cross_entropy.py"
47- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_cpu_offloading.xml $TE_PATH /tests/pytorch/test_cpu_offloading.py || test_fail " test_cpu_offloading.py"
48- NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_cpu_offloading_v1.xml $TE_PATH /tests/pytorch/test_cpu_offloading_v1.py || test_fail " test_cpu_offloading_v1.py"
49- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_attention.xml $TE_PATH /tests/pytorch/attention/test_attention.py || test_fail " test_attention.py"
50- NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_attention_deterministic.xml $TE_PATH /tests/pytorch/attention/test_attention.py || test_fail " NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
51- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_kv_cache.xml $TE_PATH /tests/pytorch/attention/test_kv_cache.py || test_fail " test_kv_cache.py"
52- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_hf_integration.xml $TE_PATH /tests/pytorch/test_hf_integration.py || test_fail " test_hf_integration.py"
5379export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH /artifacts/tests/pytorch/test_checkpoint
5480if [ ! -d " $NVTE_TEST_CHECKPOINT_ARTIFACT_PATH " ]; then
55- python3 $TE_PATH /tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit " Failed to generate checkpoint files"
81+ python3 $TE_PATH /tests/pytorch/test_checkpoint.py --save-checkpoint all \
82+ || error_exit " Failed to generate checkpoint files"
83+ fi
84+
85+ # ── Tests ───────────────────────────────────────────────────────────────────
86+ # Each run_test call: env_prefix, xml_name, test_path, fail_label
87+ # Tests are dispatched in waves of NUM_GPUS, one per GPU.
88+
89+ run_test " " " pytest_test_sanity.xml" " $TE_PATH /tests/pytorch/test_sanity.py" " test_sanity.py"
90+ run_test " " " pytest_test_recipe.xml" " $TE_PATH /tests/pytorch/test_recipe.py" " test_recipe.py"
91+ run_test " " " pytest_test_deferred_init.xml" " $TE_PATH /tests/pytorch/test_deferred_init.py" " test_deferred_init.py"
92+ run_test " PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0" " pytest_test_numerics.xml" " $TE_PATH /tests/pytorch/test_numerics.py" " test_numerics.py"
93+ run_test " PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0" " pytest_test_cuda_graphs.xml" " $TE_PATH /tests/pytorch/test_cuda_graphs.py" " test_cuda_graphs.py"
94+ run_test " " " pytest_test_jit.xml" " $TE_PATH /tests/pytorch/test_jit.py" " test_jit.py"
95+ run_test " " " pytest_test_fused_rope.xml" " $TE_PATH /tests/pytorch/test_fused_rope.py" " test_fused_rope.py"
96+ run_test " " " pytest_test_nvfp4.xml" " $TE_PATH /tests/pytorch/nvfp4" " test_nvfp4"
97+ run_test " " " pytest_test_mxfp8.xml" " $TE_PATH /tests/pytorch/mxfp8" " test_mxfp8"
98+ run_test " " " pytest_test_quantized_tensor.xml" " $TE_PATH /tests/pytorch/test_quantized_tensor.py" " test_quantized_tensor.py"
99+ run_test " " " pytest_test_float8blockwisetensor.xml" " $TE_PATH /tests/pytorch/test_float8blockwisetensor.py" " test_float8blockwisetensor.py"
100+ run_test " " " pytest_test_float8_blockwise_scaling_exact.xml" " $TE_PATH /tests/pytorch/test_float8_blockwise_scaling_exact.py" " test_float8_blockwise_scaling_exact.py"
101+ run_test " " " pytest_test_float8_blockwise_gemm_exact.xml" " $TE_PATH /tests/pytorch/test_float8_blockwise_gemm_exact.py" " test_float8_blockwise_gemm_exact.py"
102+ run_test " " " test_grouped_tensor.xml" " $TE_PATH /tests/pytorch/test_grouped_tensor.py" " test_grouped_tensor.py"
103+ run_test " " " pytest_test_gqa.xml" " $TE_PATH /tests/pytorch/test_gqa.py" " test_gqa.py"
104+ run_test " " " pytest_test_fused_optimizer.xml" " $TE_PATH /tests/pytorch/test_fused_optimizer.py" " test_fused_optimizer.py"
105+ run_test " " " pytest_test_multi_tensor.xml" " $TE_PATH /tests/pytorch/test_multi_tensor.py" " test_multi_tensor.py"
106+ run_test " " " pytest_test_fusible_ops.xml" " $TE_PATH /tests/pytorch/test_fusible_ops.py" " test_fusible_ops.py"
107+ run_test " " " pytest_test_permutation.xml" " $TE_PATH /tests/pytorch/test_permutation.py" " test_permutation.py"
108+ run_test " " " pytest_test_parallel_cross_entropy.xml" " $TE_PATH /tests/pytorch/test_parallel_cross_entropy.py" " test_parallel_cross_entropy.py"
109+ run_test " " " pytest_test_cpu_offloading.xml" " $TE_PATH /tests/pytorch/test_cpu_offloading.py" " test_cpu_offloading.py"
110+ run_test " NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1" " pytest_test_cpu_offloading_v1.xml" " $TE_PATH /tests/pytorch/test_cpu_offloading_v1.py" " test_cpu_offloading_v1.py"
111+ run_test " " " pytest_test_attention.xml" " $TE_PATH /tests/pytorch/attention/test_attention.py" " test_attention.py"
112+ run_test " NVTE_ALLOW_NONDETERMINISTIC_ALGO=0" " pytest_test_attention_deterministic.xml" " $TE_PATH /tests/pytorch/attention/test_attention.py" " NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
113+ run_test " " " pytest_test_kv_cache.xml" " $TE_PATH /tests/pytorch/attention/test_kv_cache.py" " test_kv_cache.py"
114+ run_test " " " pytest_test_hf_integration.xml" " $TE_PATH /tests/pytorch/test_hf_integration.py" " test_hf_integration.py"
115+ run_test " " " pytest_test_checkpoint.xml" " $TE_PATH /tests/pytorch/test_checkpoint.py" " test_checkpoint.py"
116+ run_test " " " pytest_test_fused_router.xml" " $TE_PATH /tests/pytorch/test_fused_router.py" " test_fused_router.py"
117+ run_test " " " pytest_test_partial_cast.xml" " $TE_PATH /tests/pytorch/test_partial_cast.py" " test_partial_cast.py"
118+
119+ # ── Wait for remaining background jobs ──────────────────────────────────────
120+
121+ if [ " $NUM_GPUS " -gt 1 ]; then
122+ wait
56123fi
57- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_checkpoint.xml $TE_PATH /tests/pytorch/test_checkpoint.py || test_fail " test_checkpoint.py"
58- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_fused_router.xml $TE_PATH /tests/pytorch/test_fused_router.py || test_fail " test_fused_router.py"
59- python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR /pytest_test_partial_cast.xml $TE_PATH /tests/pytorch/test_partial_cast.py || test_fail " test_partial_cast.py"
60124
61- if [ " $RET " -ne 0 ]; then
62- echo " Error in the following test cases:$FAILED_CASES "
125+ # ── Display per-test logs from parallel runs ────────────────────────────────
126+
127+ if [ " $NUM_GPUS " -gt 1 ]; then
128+ for logfile in " $XML_LOG_DIR " /* .log; do
129+ [ -f " $logfile " ] && echo " === $( basename " $logfile " ) ===" && cat " $logfile "
130+ done
131+ fi
132+
133+ # ── Report results ──────────────────────────────────────────────────────────
134+
135+ if [ -s " $FAIL_DIR /failures" ]; then
136+ FAILED_CASES=$( cat " $FAIL_DIR /failures" | tr ' \n' ' ' )
137+ echo " Error in the following test cases: $FAILED_CASES "
138+ rm -rf " $FAIL_DIR "
63139 exit 1
64140fi
141+ rm -rf " $FAIL_DIR "
65142echo " All tests passed"
66143exit 0
0 commit comments