Skip to content

Commit f203299

Browse files
committed
precommit fix
1 parent eae4675 commit f203299

File tree

2 files changed

+96
-77
lines changed

2 files changed

+96
-77
lines changed

flashinfer/cute_dsl/blockwise_gemm.py

Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import cutlass
3535
import cutlass.cute as cute
3636
from cutlass.cute.nvgpu import cpasync, tcgen05
37-
from cutlass.cute.runtime import from_dlpack
3837
import cutlass.torch as cutlass_torch
3938
import cutlass.utils as utils
4039
import cutlass.pipeline as pipeline
@@ -620,7 +619,7 @@ class SharedStorage:
620619
grid=grid,
621620
block=[self.threads_per_cta, 1, 1],
622621
cluster=(*self.cluster_shape_mn, 1),
623-
smem=self.shared_storage.size_in_bytes(),
622+
smem=self.shared_storage.size_in_bytes(), # type: ignore[attr-defined]
624623
stream=stream,
625624
min_blocks_per_mp=1,
626625
)
@@ -1095,7 +1094,7 @@ def kernel(
10951094
#
10961095
# Tma load loop
10971096
#
1098-
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
1097+
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # noqa: B007
10991098
tAgA_k = tAgA_slice[(None, ab_producer_state.count)]
11001099
tBgB_k = tBgB_slice[(None, ab_producer_state.count)]
11011100
tAsA_pipe = tAsA[(None, ab_producer_state.index)]
@@ -1187,7 +1186,6 @@ def kernel(
11871186
is_valid_tile = tile_info[3] == 1
11881187

11891188
while is_valid_tile:
1190-
11911189
#
11921190
# Prepare the mask for scaleA/scaleB
11931191
#
@@ -1219,7 +1217,7 @@ def kernel(
12191217
#
12201218
# load loop
12211219
#
1222-
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
1220+
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # noqa: B007
12231221
#
12241222
# Slice to per mma tile index
12251223
#
@@ -1390,7 +1388,6 @@ def kernel(
13901388
is_valid_tile = tile_info[3] == 1
13911389

13921390
while is_valid_tile:
1393-
13941391
# Peek (try_wait) AB buffer full for k_tile = 0
13951392
ab_consumer_state.reset_count()
13961393
peek_ab_full_status = cutlass.Boolean(1)
@@ -1410,7 +1407,7 @@ def kernel(
14101407
#
14111408
# Mma mainloop
14121409
#
1413-
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
1410+
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # noqa: B007
14141411
# Set tensor memory buffer for current tile
14151412
# (MMA, MMA_M, MMA_N)
14161413
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
@@ -1591,7 +1588,6 @@ def kernel(
15911588
is_valid_tile = tile_info[3] == 1
15921589

15931590
while is_valid_tile:
1594-
15951591
# initialize the final accumulator
15961592
tTR_rAcc_final.fill(0.0)
15971593

@@ -1618,7 +1614,7 @@ def kernel(
16181614
acc_consumer_state
16191615
)
16201616

1621-
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
1617+
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # noqa: B007
16221618
# Set tensor memory buffer for current tile
16231619
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
16241620
tTR_tAcc = tTR_tAcc_base[
@@ -1794,7 +1790,6 @@ def kernel(
17941790

17951791
tTR_rC = None
17961792
tiled_copy_r2s = None
1797-
simt_atom = None
17981793
tRS_rC = None
17991794
tRS_sC = None
18001795
bSG_sC = None
@@ -2301,7 +2296,7 @@ def _compute_stages(
23012296
sfb_count: int,
23022297
num_smem_capacity: int,
23032298
occupancy: int,
2304-
) -> Tuple[int, int, int]:
2299+
) -> Tuple[int, int, int, int, int]:
23052300
"""Computes the number of stages for A/B/C operands based on heuristics.
23062301
23072302
:param tiled_mma: The tiled MMA object defining the core computation.
@@ -2687,7 +2682,7 @@ def __init__(
26872682
self._use_2cta_instrs = use_2cta_instrs
26882683
self._mma_tiler_mn = mma_tiler_mn
26892684
self._cluster_shape_mn = cluster_shape_mn
2690-
2685+
26912686
if not BlockwiseGemmKernel.can_implement(
26922687
ab_dtype,
26932688
acc_dtype,
@@ -2709,13 +2704,13 @@ def __init__(
27092704

27102705
hardware_info = cutlass.utils.HardwareInfo()
27112706
self._max_active_clusters = min(
2712-
hardware_info.get_max_active_clusters(
2713-
self._cluster_shape_mn[0] * self._cluster_shape_mn[1]
2714-
),
2715-
sm_count,
2707+
hardware_info.get_max_active_clusters(
2708+
self._cluster_shape_mn[0] * self._cluster_shape_mn[1]
2709+
),
2710+
sm_count,
27162711
)
27172712
self._sm_version = sm_version
2718-
2713+
27192714
@cute.jit
27202715
def __call__(
27212716
self,
@@ -2726,7 +2721,6 @@ def __call__(
27262721
c_ptr: cute.Pointer,
27272722
current_stream: cuda.CUstream,
27282723
):
2729-
#TODO(asamani): double check the shapes and layouts
27302724
a_tensor = cute.make_tensor(
27312725
a_ptr,
27322726
layout=cute.make_ordered_layout(
@@ -2752,14 +2746,14 @@ def __call__(
27522746
sfa_ptr,
27532747
layout=cute.make_ordered_layout(
27542748
(self._m, math.ceil(self._k / 128), self._l),
2755-
order=(0, 1, 2), #if self._a_major == "m" else (1, 0, 2)
2749+
order=(0, 1, 2),
27562750
),
27572751
)
27582752
sfb_tensor = cute.make_tensor(
27592753
sfb_ptr,
27602754
layout=cute.make_ordered_layout(
2761-
(math.ceil(self._n / 128), math.ceil(self._k / 128),self._l),
2762-
order=(1, 0, 2), #if self._b_major == "n" else (1, 0, 2),
2755+
(math.ceil(self._n / 128), math.ceil(self._k / 128), self._l),
2756+
order=(1, 0, 2),
27632757
),
27642758
)
27652759

@@ -2777,7 +2771,7 @@ def __call__(
27772771
self._max_active_clusters,
27782772
current_stream,
27792773
)
2780-
2774+
27812775

27822776
@functools.cache
27832777
def get_cute_dsl_compiled_blockwise_gemm_kernel(
@@ -2831,7 +2825,7 @@ def get_cute_pointers(
28312825
sfb_tensor_gpu.data_ptr(),
28322826
c_tensor_gpu.data_ptr(),
28332827
)
2834-
2828+
28352829
a_ptr = make_ptr(
28362830
ab_dtype,
28372831
a_data_ptr,
@@ -2863,7 +2857,7 @@ def get_cute_pointers(
28632857
assumed_align=16,
28642858
)
28652859
return [a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr]
2866-
2860+
28672861
kernel = cute.compile(
28682862
BlockwiseGemmCuteDSL(
28692863
m=m,
@@ -2887,7 +2881,6 @@ def get_cute_pointers(
28872881
cutlass_torch.current_stream(),
28882882
)
28892883

2890-
28912884
def tensor_api(
28922885
a_tensor_gpu: torch.Tensor,
28932886
b_tensor_gpu: torch.Tensor,
@@ -2907,7 +2900,13 @@ def tensor_api(
29072900
nonlocal kernel
29082901
kernel(
29092902
*get_cute_pointers(
2910-
[a_tensor_gpu, b_tensor_gpu, sfa_tensor_gpu, sfb_tensor_gpu, c_tensor_gpu]
2903+
[
2904+
a_tensor_gpu,
2905+
b_tensor_gpu,
2906+
sfa_tensor_gpu,
2907+
sfb_tensor_gpu,
2908+
c_tensor_gpu,
2909+
]
29112910
),
29122911
current_stream,
29132912
)
@@ -2932,7 +2931,7 @@ def blockwise_gemm(
29322931
):
29332932
m, k, l = a_torch.shape
29342933
n, _, _ = b_torch.shape
2935-
2934+
29362935
mma_tiler_mn = kwargs.pop("mma_tiler_mn", (128, 128))
29372936
cluster_shape_mn = kwargs.pop("cluster_shape_mn", (1, 1))
29382937
if sm_count is None:
@@ -2943,28 +2942,28 @@ def blockwise_gemm(
29432942
major, minor = get_compute_capability(a_torch.device)
29442943
if major == 11 and minor == 0:
29452944
raise ValueError("SM110 is not supported for cute-dsl backend.")
2946-
2945+
29472946
return get_cute_dsl_compiled_blockwise_gemm_kernel(
2948-
m=m,
2949-
n=n,
2950-
k=k,
2951-
l=l,
2952-
a_major="k",
2953-
b_major="k",
2954-
c_major="n",
2955-
ab_dtype=get_cutlass_dtype(ab_dtype),
2956-
sf_dtype=get_cutlass_dtype(sf_dtype),
2957-
c_dtype=get_cutlass_dtype(c_dtype),
2958-
acc_dtype=get_cutlass_dtype(acc_dtype),
2959-
use_2cta_instrs = use_2cta_instrs,
2960-
mma_tiler_mn=mma_tiler_mn,
2961-
cluster_shape_mn=cluster_shape_mn,
2962-
sm_count=sm_count,
2963-
sm_version=f"sm_{major}{minor}",
2964-
)(
2965-
a_tensor_gpu=a_torch,
2966-
b_tensor_gpu=b_torch,
2967-
sfa_tensor_gpu=sfa_torch,
2968-
sfb_tensor_gpu=sfb_torch,
2969-
c_tensor_gpu=c_torch,
2970-
)
2947+
m=m,
2948+
n=n,
2949+
k=k,
2950+
l=l,
2951+
a_major="k",
2952+
b_major="k",
2953+
c_major="n",
2954+
ab_dtype=get_cutlass_dtype(ab_dtype),
2955+
sf_dtype=get_cutlass_dtype(sf_dtype),
2956+
c_dtype=get_cutlass_dtype(c_dtype),
2957+
acc_dtype=get_cutlass_dtype(acc_dtype),
2958+
use_2cta_instrs=use_2cta_instrs,
2959+
mma_tiler_mn=mma_tiler_mn,
2960+
cluster_shape_mn=cluster_shape_mn,
2961+
sm_count=sm_count,
2962+
sm_version=f"sm_{major}{minor}",
2963+
)(
2964+
a_tensor_gpu=a_torch,
2965+
b_tensor_gpu=b_torch,
2966+
sfa_tensor_gpu=sfa_torch,
2967+
sfb_tensor_gpu=sfb_torch,
2968+
c_tensor_gpu=c_torch,
2969+
)

tests/gemm/test_cute_dsl_blockwise_gemm.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,54 @@ def create_tensors(
2121

2222
a_torch_cpu = cutlass_torch.matrix(
2323
l, m, k, a_major == "m", get_cutlass_dtype(ab_dtype), device=device
24-
)
24+
)
2525
b_torch_cpu = cutlass_torch.matrix(
2626
l, n, k, b_major == "n", get_cutlass_dtype(ab_dtype), device=device
27-
)
27+
)
2828
c_torch_cpu = cutlass_torch.matrix(
2929
l, m, n, cd_major == "m", get_cutlass_dtype(c_dtype), device=device
30-
)
30+
)
3131
sfa_torch_cpu = cutlass_torch.matrix(
3232
l, m, math.ceil(k / 128), True, get_cutlass_dtype(scale_dtype), device=device
33-
)
33+
)
3434
sfb_torch_cpu = cutlass_torch.matrix(
35-
l, math.ceil(n / 128), math.ceil(k / 128), False,
36-
get_cutlass_dtype(scale_dtype), device=device,
35+
l,
36+
math.ceil(n / 128),
37+
math.ceil(k / 128),
38+
False,
39+
get_cutlass_dtype(scale_dtype),
40+
device=device,
3741
)
3842

3943
a_tensor, a_torch = cutlass_torch.cute_tensor_like(
40-
a_torch_cpu, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16
44+
a_torch_cpu,
45+
get_cutlass_dtype(ab_dtype),
46+
is_dynamic_layout=True,
47+
assumed_align=16,
4148
)
4249
b_tensor, b_torch = cutlass_torch.cute_tensor_like(
43-
b_torch_cpu, get_cutlass_dtype(ab_dtype), is_dynamic_layout=True, assumed_align=16
50+
b_torch_cpu,
51+
get_cutlass_dtype(ab_dtype),
52+
is_dynamic_layout=True,
53+
assumed_align=16,
4454
)
4555
c_tensor, c_torch = cutlass_torch.cute_tensor_like(
46-
c_torch_cpu, get_cutlass_dtype(c_dtype), is_dynamic_layout=True, assumed_align=16
56+
c_torch_cpu,
57+
get_cutlass_dtype(c_dtype),
58+
is_dynamic_layout=True,
59+
assumed_align=16,
4760
)
4861
sfa_tensor, sfa_torch = cutlass_torch.cute_tensor_like(
49-
sfa_torch_cpu, get_cutlass_dtype(scale_dtype), is_dynamic_layout=True, assumed_align=16
62+
sfa_torch_cpu,
63+
get_cutlass_dtype(scale_dtype),
64+
is_dynamic_layout=True,
65+
assumed_align=16,
5066
)
5167
sfb_tensor, sfb_torch = cutlass_torch.cute_tensor_like(
52-
sfb_torch_cpu, get_cutlass_dtype(scale_dtype), is_dynamic_layout=True, assumed_align=16
68+
sfb_torch_cpu,
69+
get_cutlass_dtype(scale_dtype),
70+
is_dynamic_layout=True,
71+
assumed_align=16,
5372
)
5473

5574
return (
@@ -138,7 +157,7 @@ def test_blockwise_gemm_python_interface(
138157
pytest.skip(
139158
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {c_dtype}, {acc_dtype}, {use_2cta_instrs} ,{mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}"
140159
)
141-
160+
142161
(
143162
a_tensor,
144163
a_torch,
@@ -161,22 +180,23 @@ def test_blockwise_gemm_python_interface(
161180

162181
for _ in range(iterations):
163182
blockwise_gemm(
164-
a_torch,
165-
sfa_torch,
166-
b_torch,
167-
sfb_torch,
168-
c_torch,
169-
ab_dtype=ab_dtype,
170-
sf_dtype=sf_dtype,
171-
c_dtype=c_dtype,
172-
acc_dtype=acc_dtype,
173-
sm_count=sm_count,
174-
mma_tiler_mn=mma_tiler_mn,
175-
cluster_shape_mn=cluster_shape_mn,
176-
use_2cta_instrs=use_2cta_instrs,
183+
a_torch,
184+
sfa_torch,
185+
b_torch,
186+
sfb_torch,
187+
c_torch,
188+
ab_dtype=ab_dtype,
189+
sf_dtype=sf_dtype,
190+
c_dtype=c_dtype,
191+
acc_dtype=acc_dtype,
192+
sm_count=sm_count,
193+
mma_tiler_mn=mma_tiler_mn,
194+
cluster_shape_mn=cluster_shape_mn,
195+
use_2cta_instrs=use_2cta_instrs,
177196
)
178197

179198
torch.cuda.synchronize()
199+
180200
def pad_and_multiply(scale, tensor):
181201
cm, ck, _ = scale.shape
182202
m, k, _ = tensor.shape
@@ -228,4 +248,4 @@ def pad_and_multiply(scale, tensor):
228248
cluster_shape_mn=(1, 1),
229249
tolerance=1e-01,
230250
iterations=3,
231-
)
251+
)

0 commit comments

Comments
 (0)