99from typing import Optional
1010
1111import deep_ep
12- from utils import init_dist , bench , bench_kineto , calc_diff , hash_tensor , cast_fp8_to_fp32 , cast_nvfp4_to_fp32
12+ from utils import init_dist , bench , bench_kineto , calc_diff , hash_tensor , per_token_cast_back
1313
1414MAX_E4M3 = 448
1515MAX_NVFP4 = 6.0
@@ -54,7 +54,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
5454 for dispatch_data_type in ('bf16' , 'fp8' , 'nvfp4' ):
5555 dispatch_use_fp8 = dispatch_data_type == 'fp8'
5656 dispatch_use_nvfp4 = dispatch_data_type == 'nvfp4'
57- use_ue8m0_for_nvfp4_sf = False
57+ use_ue8m0_for_sf = False
5858 for round_scale in (False , True ) if dispatch_use_fp8 else (False , ):
5959 for use_ue8m0 in (False , True ) if round_scale else (False , ):
6060 num_times += 1
@@ -66,20 +66,20 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
6666 packed_recv_x , packed_recv_count , handle , event , hook = \
6767 buffer .low_latency_dispatch (current_x , topk_idx , num_tokens , num_experts ,
6868 use_fp8 = dispatch_use_fp8 , round_scale = round_scale , use_ue8m0 = use_ue8m0 ,
69- use_nvfp4 = dispatch_use_nvfp4 , use_ue8m0_for_nvfp4_sf = use_ue8m0_for_nvfp4_sf ,
69+ use_nvfp4 = dispatch_use_nvfp4 , use_ue8m0_for_sf = use_ue8m0_for_sf ,
7070 cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats ,
7171 x_sf_scale = x_sf_scale ,
7272 async_finish = not return_recv_hook , return_recv_hook = return_recv_hook )
7373 hook () if return_recv_hook else event .current_stream_wait ()
7474 if dispatch_use_fp8 :
7575 packed_recv_x = (packed_recv_x [0 ], packed_recv_x [1 ].contiguous ())
76- simulated_gemm_x = cast_fp8_to_fp32 (packed_recv_x [0 ].view (- 1 , hidden ), packed_recv_x [1 ].view (- 1 , hidden // 128 )).view (packed_recv_x [0 ].shape )
76+ simulated_gemm_x = per_token_cast_back (packed_recv_x [0 ].view (- 1 , hidden ), packed_recv_x [1 ].view (- 1 , hidden // 128 )).view (packed_recv_x [0 ].shape )
7777 elif dispatch_use_nvfp4 :
7878 recv_x_scale_view = packed_recv_x [1 ]
7979 recv_x_scale_view = recv_x_scale_view .permute (5 , 2 , 0 , 1 , 4 , 3 )
8080 recv_x_scale_view = recv_x_scale_view .contiguous ().view (num_local_experts , int (num_ranks * num_tokens ), hidden // 16 )
8181 packed_recv_x = (packed_recv_x [0 ], recv_x_scale_view )
82- simulated_gemm_x = cast_nvfp4_to_fp32 (packed_recv_x [0 ], packed_recv_x [1 ], x_sf_scale , use_ue8m0_for_nvfp4_sf = use_ue8m0_for_nvfp4_sf )
82+ simulated_gemm_x = per_token_cast_back (packed_recv_x [0 ], packed_recv_x [1 ], x_sf_scale , use_ue8m0_for_sf = use_ue8m0_for_sf , src_data_format = 'nvfp4' )
8383 else :
8484 packed_recv_x = packed_recv_x
8585 simulated_gemm_x = packed_recv_x .clone ()
0 commit comments