@@ -10,12 +10,13 @@ def torch_add_norm_quant_bf16_fp8(X, R, W, eps=1e-6):
10
10
# 1. Add residual
11
11
X = X .add_ (R )
12
12
# 2. rmsnorm
13
- normalized = torch .nn .functional .rms_norm (X , (N , ), W , eps = eps )
13
+ normalized = torch .nn .functional .rms_norm (X , (N ,), W , eps = eps )
14
14
# 3. per token quant
15
15
quantized , scales = ops .scaled_fp8_quant (normalized , scale = None , use_per_token_if_dynamic = True )
16
16
17
17
return quantized , scales
18
18
19
+
19
20
class TestFusedAddNormQuantBF16 (unittest .TestCase ):
20
21
def setUp (self ):
21
22
"""Set up common test parameters."""
@@ -31,40 +32,65 @@ def test_accuracy(self):
31
32
for batch in self .batchs :
32
33
for seqLen in self .seqLens :
33
34
for embed_dim in self .embed_dims :
34
- with self .subTest (shape = [batch , seqLen , embed_dim ]):
35
- X1 = torch .rand (size = [batch , seqLen , embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
36
- X2 = X1 .clone ()
37
- R1 = torch .rand (size = [batch , seqLen , embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
38
- R2 = R1 .clone ()
39
- W = torch .rand (size = [embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
40
- output_real , scales_real = torch_add_norm_quant_bf16_fp8 (X1 .reshape (- 1 , X1 .shape [2 ]), R1 .reshape (- 1 , R1 .shape [2 ]), W , self .eps )
41
- output_pred , scales_pred = add_norm_quant_bf16_fp8 (X2 .reshape (- 1 , X1 .shape [2 ]), R2 .reshape (- 1 , R2 .shape [2 ]), W , self .eps )
35
+ with self .subTest (shape = [batch , seqLen , embed_dim ]):
36
+ X1 = torch .rand (size = [batch , seqLen , embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
37
+ X2 = X1 .clone ()
38
+ R1 = torch .rand (size = [batch , seqLen , embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
39
+ R2 = R1 .clone ()
40
+ W = torch .rand (size = [embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
41
+ output_real , scales_real = torch_add_norm_quant_bf16_fp8 (
42
+ X1 .reshape (- 1 , X1 .shape [2 ]), R1 .reshape (- 1 , R1 .shape [2 ]), W , self .eps
43
+ )
44
+ output_pred , scales_pred = add_norm_quant_bf16_fp8 (
45
+ X2 .reshape (- 1 , X1 .shape [2 ]), R2 .reshape (- 1 , R2 .shape [2 ]), W , self .eps
46
+ )
42
47
43
- self .assertTrue (
44
- error (output_real , output_pred ) < 0.01 ,
45
- f"Accuracy test failed for size { batch } , { seqLen } , { embed_dim } . output_real={ output_real } , output_pred={ output_pred } "
46
- )
47
- self .assertTrue (
48
- error (scales_real , scales_pred ) < 0.01 ,
49
- f"Accuracy test failed for size { batch } , { seqLen } , { embed_dim } . scales_real={ scales_real } , scales_pred={ scales_pred } "
50
- )
48
+ self .assertTrue (
49
+ error (output_real , output_pred ) < 0.01 ,
50
+ f"Accuracy test failed for size { batch } , { seqLen } , { embed_dim } . "
51
+ f"output_real={ output_real } , output_pred={ output_pred } " ,
52
+ )
53
+ self .assertTrue (
54
+ error (scales_real , scales_pred ) < 0.01 ,
55
+ f"Accuracy test failed for size { batch } , { seqLen } , { embed_dim } . "
56
+ f"scales_real={ scales_real } , scales_pred={ scales_pred } " ,
57
+ )
51
58
52
59
def test_performance (self ):
53
60
"""Test the performance of FusedAddNormQuant using benchmark."""
54
61
for batch in self .batchs :
55
62
for seqLen in self .seqLens :
56
63
for embed_dim in self .embed_dims :
57
- with self .subTest (shape = [batch , seqLen , embed_dim ]):
58
- X1 = torch .rand (size = [batch , seqLen , embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
59
- X2 = torch .rand (size = [batch , seqLen , embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
60
- R1 = torch .rand (size = [batch , seqLen , embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
61
- R2 = R1 .clone ()
62
- W = torch .rand (size = [embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
64
+ with self .subTest (shape = [batch , seqLen , embed_dim ]):
65
+ X1 = torch .rand (size = [batch , seqLen , embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
66
+ X2 = torch .rand (size = [batch , seqLen , embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
67
+ R1 = torch .rand (size = [batch , seqLen , embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
68
+ R2 = R1 .clone ()
69
+ W = torch .rand (size = [embed_dim ], device = self .device , dtype = self .dtype ) - 0.5
70
+
71
+ shape = [[batch , seqLen , embed_dim ]]
72
+ tflops = 0.0
73
+ benchmark (
74
+ torch_add_norm_quant_bf16_fp8 ,
75
+ shape ,
76
+ tflops ,
77
+ 100 ,
78
+ X1 .reshape (- 1 , X1 .shape [2 ]),
79
+ R1 .reshape (- 1 , R1 .shape [2 ]),
80
+ W ,
81
+ self .eps ,
82
+ )
83
+ benchmark (
84
+ add_norm_quant_bf16_fp8 ,
85
+ shape ,
86
+ tflops ,
87
+ 100 ,
88
+ X2 .reshape (- 1 , X1 .shape [2 ]),
89
+ R2 .reshape (- 1 , R2 .shape [2 ]),
90
+ W ,
91
+ self .eps ,
92
+ )
63
93
64
- shape = [[batch , seqLen , embed_dim ]]
65
- tflops = 0.0
66
- benchmark (torch_add_norm_quant_bf16_fp8 , shape , tflops , 100 , X1 .reshape (- 1 , X1 .shape [2 ]), R1 .reshape (- 1 , R1 .shape [2 ]), W , self .eps )
67
- benchmark (add_norm_quant_bf16_fp8 , shape , tflops , 100 , X2 .reshape (- 1 , X1 .shape [2 ]), R2 .reshape (- 1 , R2 .shape [2 ]), W , self .eps )
68
94
69
95
if __name__ == "__main__" :
70
- unittest .main ()
96
+ unittest .main ()
0 commit comments