|
13 | 13 |
|
14 | 14 | # Configs |
15 | 15 | import deep_gemm_cpp |
| 16 | +from deep_gemm_cpp import ( |
| 17 | + set_num_sms, |
| 18 | + get_num_sms, |
| 19 | + set_tc_util, |
| 20 | + get_tc_util, |
| 21 | +) |
16 | 22 |
|
17 | 23 | # Kernels |
18 | 24 | from deep_gemm_cpp import ( |
19 | | - bf16_gemm_nn, |
20 | | - # BF16 GEMMs |
21 | | - bf16_gemm_nt, |
22 | | - bf16_gemm_tn, |
23 | | - bf16_gemm_tt, |
24 | | - fp8_gemm_nn, |
25 | 25 | # FP8 GEMMs |
26 | | - fp8_gemm_nt, |
27 | | - fp8_gemm_tn, |
28 | | - fp8_gemm_tt, |
29 | | - get_num_sms, |
30 | | - get_tc_util, |
| 26 | + fp8_gemm_nt, fp8_gemm_nn, |
| 27 | + fp8_gemm_tn, fp8_gemm_tt, |
| 28 | + m_grouped_fp8_gemm_nt_contiguous, |
| 29 | + m_grouped_fp8_gemm_nn_contiguous, |
| 30 | + m_grouped_fp8_gemm_nt_masked, |
31 | 31 | k_grouped_fp8_gemm_tn_contiguous, |
| 32 | + # BF16 GEMMs |
| 33 | + bf16_gemm_nt, bf16_gemm_nn, |
| 34 | + bf16_gemm_tn, bf16_gemm_tt, |
32 | 35 | m_grouped_bf16_gemm_nt_contiguous, |
33 | 36 | m_grouped_bf16_gemm_nt_masked, |
34 | | - m_grouped_fp8_gemm_nn_contiguous, |
35 | | - m_grouped_fp8_gemm_nt_contiguous, |
36 | | - m_grouped_fp8_gemm_nt_masked, |
37 | | - set_num_sms, |
38 | | - set_tc_util, |
39 | 37 | # Layout kernels |
40 | | - transform_sf_into_required_layout, |
| 38 | + transform_sf_into_required_layout |
41 | 39 | ) |
42 | 40 |
|
43 | 41 | # Some alias for legacy supports |
@@ -74,5 +72,3 @@ def _find_cuda_home() -> str: |
74 | 72 | os.path.dirname(os.path.abspath(__file__)), # Library root directory path |
75 | 73 | _find_cuda_home() # CUDA home |
76 | 74 | ) |
77 | | - |
78 | | -__version__ = "2.0.0" |
0 commit comments