Skip to content

Commit 38d3b47

Browse files
committed
0609
1 parent 6a4a100 commit 38d3b47

File tree

4 files changed

+77
-41
lines changed

4 files changed

+77
-41
lines changed

lightllm-kernel/test/fusion/add_norm_quant_test.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ def torch_add_norm_quant_bf16_fp8(X, R, W, eps=1e-6):
1010
# 1. Add residual
1111
X = X.add_(R)
1212
# 2. rmsnorm
13-
normalized = torch.nn.functional.rms_norm(X, (N, ), W, eps=eps)
13+
normalized = torch.nn.functional.rms_norm(X, (N,), W, eps=eps)
1414
# 3. per token quant
1515
quantized, scales = ops.scaled_fp8_quant(normalized, scale=None, use_per_token_if_dynamic=True)
1616

1717
return quantized, scales
1818

19+
1920
class TestFusedAddNormQuantBF16(unittest.TestCase):
2021
def setUp(self):
2122
"""Set up common test parameters."""
@@ -31,40 +32,65 @@ def test_accuracy(self):
3132
for batch in self.batchs:
3233
for seqLen in self.seqLens:
3334
for embed_dim in self.embed_dims:
34-
with self.subTest(shape=[batch, seqLen, embed_dim]):
35-
X1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
36-
X2 = X1.clone()
37-
R1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
38-
R2 = R1.clone()
39-
W = torch.rand(size=[embed_dim], device=self.device, dtype=self.dtype) - 0.5
40-
output_real, scales_real = torch_add_norm_quant_bf16_fp8(X1.reshape(-1, X1.shape[2]), R1.reshape(-1, R1.shape[2]), W, self.eps)
41-
output_pred, scales_pred = add_norm_quant_bf16_fp8(X2.reshape(-1, X1.shape[2]), R2.reshape(-1, R2.shape[2]), W, self.eps)
35+
with self.subTest(shape=[batch, seqLen, embed_dim]):
36+
X1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
37+
X2 = X1.clone()
38+
R1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
39+
R2 = R1.clone()
40+
W = torch.rand(size=[embed_dim], device=self.device, dtype=self.dtype) - 0.5
41+
output_real, scales_real = torch_add_norm_quant_bf16_fp8(
42+
X1.reshape(-1, X1.shape[2]), R1.reshape(-1, R1.shape[2]), W, self.eps
43+
)
44+
output_pred, scales_pred = add_norm_quant_bf16_fp8(
45+
X2.reshape(-1, X1.shape[2]), R2.reshape(-1, R2.shape[2]), W, self.eps
46+
)
4247

43-
self.assertTrue(
44-
error(output_real, output_pred) < 0.01,
45-
f"Accuracy test failed for size {batch}, {seqLen}, {embed_dim}. output_real={output_real}, output_pred={output_pred}"
46-
)
47-
self.assertTrue(
48-
error(scales_real, scales_pred) < 0.01,
49-
f"Accuracy test failed for size {batch}, {seqLen}, {embed_dim}. scales_real={scales_real}, scales_pred={scales_pred}"
50-
)
48+
self.assertTrue(
49+
error(output_real, output_pred) < 0.01,
50+
f"Accuracy test failed for size {batch}, {seqLen}, {embed_dim}. "
51+
f"output_real={output_real}, output_pred={output_pred}",
52+
)
53+
self.assertTrue(
54+
error(scales_real, scales_pred) < 0.01,
55+
f"Accuracy test failed for size {batch}, {seqLen}, {embed_dim}. "
56+
f"scales_real={scales_real}, scales_pred={scales_pred}",
57+
)
5158

5259
def test_performance(self):
5360
"""Test the performance of FusedAddNormQuant using benchmark."""
5461
for batch in self.batchs:
5562
for seqLen in self.seqLens:
5663
for embed_dim in self.embed_dims:
57-
with self.subTest(shape=[batch, seqLen, embed_dim]):
58-
X1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
59-
X2 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
60-
R1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
61-
R2 = R1.clone()
62-
W = torch.rand(size=[embed_dim], device=self.device, dtype=self.dtype) - 0.5
64+
with self.subTest(shape=[batch, seqLen, embed_dim]):
65+
X1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
66+
X2 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
67+
R1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
68+
R2 = R1.clone()
69+
W = torch.rand(size=[embed_dim], device=self.device, dtype=self.dtype) - 0.5
70+
71+
shape = [[batch, seqLen, embed_dim]]
72+
tflops = 0.0
73+
benchmark(
74+
torch_add_norm_quant_bf16_fp8,
75+
shape,
76+
tflops,
77+
100,
78+
X1.reshape(-1, X1.shape[2]),
79+
R1.reshape(-1, R1.shape[2]),
80+
W,
81+
self.eps,
82+
)
83+
benchmark(
84+
add_norm_quant_bf16_fp8,
85+
shape,
86+
tflops,
87+
100,
88+
X2.reshape(-1, X1.shape[2]),
89+
R2.reshape(-1, R2.shape[2]),
90+
W,
91+
self.eps,
92+
)
6393

64-
shape = [[batch, seqLen, embed_dim]]
65-
tflops = 0.0
66-
benchmark(torch_add_norm_quant_bf16_fp8, shape, tflops, 100, X1.reshape(-1, X1.shape[2]), R1.reshape(-1, R1.shape[2]), W, self.eps)
67-
benchmark(add_norm_quant_bf16_fp8, shape, tflops, 100, X2.reshape(-1, X1.shape[2]), R2.reshape(-1, R2.shape[2]), W, self.eps)
6894

6995
if __name__ == "__main__":
70-
unittest.main()
96+
unittest.main()

lightllm-kernel/test/fusion/gelu_per_token_quant_test.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from lightllm_kernel.ops import per_token_quant_bf16_fp8, gelu_per_token_quant_bf16_fp8
55
from test.utils import benchmark, error
66

7+
78
def gelu_quant(x):
89
y = gelu_fwd(x)
910
return per_token_quant_bf16_fp8(y)
1011

12+
1113
class TestGeluQuantBF16(unittest.TestCase):
1214
def setUp(self):
1315
"""Set up common test parameters."""
@@ -21,20 +23,23 @@ def test_accuracy(self):
2123
for token in self.tokens:
2224
for hiddenDim in self.hiddenDims:
2325
with self.subTest(shape=[token, hiddenDim]):
24-
input = torch.normal(mean=0.0, std=10, size=[token, hiddenDim], device=self.device, dtype=self.dtype)
26+
input = torch.normal(
27+
mean=0.0, std=10, size=[token, hiddenDim], device=self.device, dtype=self.dtype
28+
)
2529

2630
y_real, scales_real = gelu_quant(input)
2731
y_pred, scales_pred = gelu_per_token_quant_bf16_fp8(input)
28-
32+
2933
self.assertTrue(
3034
error(scales_real, scales_pred) < 0.01,
31-
f"Accuracy test failed for size {token}, {hiddenDim}. scales_real={scales_real}, scales_pred={scales_pred}"
35+
f"Accuracy test failed for size {token}, {hiddenDim}. "
36+
f"scales_real={scales_real}, scales_pred={scales_pred}",
3237
)
3338
self.assertTrue(
3439
error(y_real, y_pred) < 0.01,
35-
f"Accuracy test failed for size {token}, {hiddenDim}. y_real={y_real}, y_pred={y_pred}"
40+
f"Accuracy test failed for size {token}, {hiddenDim}." f"y_real={y_real}, y_pred={y_pred}",
3641
)
37-
42+
3843
def test_performance(self):
3944
"""Test the performance of gelu_per_token_quant using benchmark."""
4045
for token in self.tokens:
@@ -46,5 +51,6 @@ def test_performance(self):
4651
benchmark(gelu_per_token_quant_bf16_fp8, shape, tflops, 100, input)
4752
benchmark(gelu_quant, shape, tflops, 100, input)
4853

54+
4955
if __name__ == "__main__":
50-
unittest.main()
56+
unittest.main()

lightllm-kernel/test/fusion/pre_tp_norm_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def pre_tp_norm(input):
99
tp_variance = input.pow(2).sum(-1, keepdim=False)
1010
return tp_variance
1111

12+
1213
class TestPreTpNormBF16(unittest.TestCase):
1314
def setUp(self):
1415
"""Set up common test parameters."""
@@ -27,20 +28,21 @@ def test_accuracy(self):
2728
y_pred = pre_tp_norm_bf16(X)
2829
self.assertTrue(
2930
error(y_pred, y_real) < 0.01,
30-
f"Accuracy test failed for size {batch}, {size}. y_real={y_real}, y_pred={y_pred}"
31+
f"Accuracy test failed for size {batch}, {size}. y_real={y_real}, y_pred={y_pred}",
3132
)
3233

3334
def test_performance(self):
3435
for batch in self.batchs:
3536
for size in self.sizes:
3637
with self.subTest(shape=[batch, size]):
3738
X = torch.rand(size=[batch, size], device=self.device, dtype=self.dtype) - 0.5
38-
W = torch.rand(size=[size], device=self.device, dtype=self.dtype) - 0.5
39+
# W = torch.rand(size=[size], device=self.device, dtype=self.dtype) - 0.5
3940

4041
shape = [[batch, size], [size], [batch, size]]
4142
tflops = 0.0
4243
benchmark(pre_tp_norm_bf16, shape, tflops, 100, X)
4344
benchmark(pre_tp_norm, shape, tflops, 100, X)
4445

46+
4547
if __name__ == "__main__":
46-
unittest.main()
48+
unittest.main()

lightllm-kernel/test/quant/quant_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ def test_accuracy(self):
2525
y_pred, scales_pred = per_token_quant_bf16_fp8(input)
2626
self.assertTrue(
2727
error(scales_real, scales_pred) < 0.01,
28-
f"Accuracy test failed for size {token}, {hiddenDim}. scales_real={scales_real}, scales_pred={scales_pred}"
28+
f"Accuracy test failed for size {token}, {hiddenDim}."
29+
f"scales_real={scales_real}, scales_pred={scales_pred}",
2930
)
3031
self.assertTrue(
3132
error(y_real, y_pred) < 0.01,
32-
f"Accuracy test failed for size {token}, {hiddenDim}. y_real={y_real}, y_pred={y_pred}"
33+
f"Accuracy test failed for size {token}, {hiddenDim}. y_real={y_real}, y_pred={y_pred}",
3334
)
3435

3536
def test_performance(self):
@@ -39,9 +40,10 @@ def test_performance(self):
3940
with self.subTest(shape=[token, size]):
4041
input = torch.rand(size=[token, size], device=self.device, dtype=self.dtype) - 0.5
4142
shape = [[token, size]]
42-
tflops = token * size / 1024**4
43+
tflops = token * size / 1024 ** 4
4344
benchmark(per_token_quant_bf16_fp8, shape, tflops, 100, input)
4445
benchmark(ops.scaled_fp8_quant, shape, tflops, 100, input, None, True)
4546

47+
4648
if __name__ == "__main__":
47-
unittest.main()
49+
unittest.main()

0 commit comments

Comments
 (0)