Skip to content

Commit 86fa46b

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Add support for Triton dtype parameters in reproducer
Summary: Add support for parsing and handling Triton dtype parameters (like `tl.bfloat16`, `tl.float16`) in TritonParse reproducer scripts. **Problem**: TritonParse reproducer scripts failed when Triton kernels had dtype parameters (e.g., `ab_dtype`, `c_dtype`), showing "Warning: Unhandled argument type 'dtype'. Returning None" and passing `None` instead of the actual dtype object like `tl.bfloat16`. **Solution**: 1. Added `TRITON_DTYPE_MAP` dictionary mapping dtype string representations (e.g., 'bf16') to Triton dtype objects (e.g., `tl.bfloat16`) 2. Added dtype handling branch in `_create_arg_from_info()` to parse dtype arguments 3. Updated function extractor to include triton.language import and TRITON_DTYPE_MAP in generated reproducers The mapping covers all Triton dtypes: - Signed/unsigned integers: int1, int8-64, uint8-64 - Standard floating point: fp16, bf16, fp32, fp64 - FP8 variants: fp8e4b15, fp8e4nv, fp8e4b8, fp8e5, fp8e5b16 **Impact**: Future generated reproducer scripts will correctly handle dtype parameters, allowing kernels with dtype arguments to run successfully. Reviewed By: htyu Differential Revision: D87575372 fbshipit-source-id: b68e1c14a4945fd73c73d434267aa95f9a7db333
1 parent 984089b commit 86fa46b

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

tritonparse/reproducer/function_extractor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def extract_utility_functions() -> str:
5858
if constant:
5959
extracted_parts.append(constant)
6060

61+
# Extract TRITON_DTYPE_MAP constant
62+
dtype_map = _extract_assignment(utils_tree, utils_lines, "TRITON_DTYPE_MAP")
63+
if dtype_map:
64+
extracted_parts.append(dtype_map)
65+
6166
# Extract load_tensor functions
6267
extracted_parts.extend(
6368
_extract_functions(
@@ -218,5 +223,6 @@ def _generate_imports() -> str:
218223
"from typing import Union",
219224
"",
220225
"import torch",
226+
"import triton.language as tl",
221227
]
222228
return "\n".join(imports)

tritonparse/reproducer/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from pathlib import Path
1212

1313
import torch
14+
import triton.language as tl
15+
1416
from tritonparse.tools.load_tensor import load_tensor
1517
from tritonparse.tp_logger import logger
1618

@@ -19,6 +21,32 @@
1921
and importlib.util.find_spec("triton_kernels.tensor") is not None
2022
)
2123

24+
# Mapping from dtype string representation to Triton dtype objects
25+
TRITON_DTYPE_MAP = {
26+
# Signed integers
27+
"int8": tl.int8,
28+
"int16": tl.int16,
29+
"int32": tl.int32,
30+
"int64": tl.int64,
31+
# Unsigned integers
32+
"int1": tl.int1,
33+
"uint8": tl.uint8,
34+
"uint16": tl.uint16,
35+
"uint32": tl.uint32,
36+
"uint64": tl.uint64,
37+
# Standard floating point types
38+
"fp16": tl.float16,
39+
"bf16": tl.bfloat16,
40+
"fp32": tl.float32,
41+
"fp64": tl.float64,
42+
# FP8 variants
43+
"fp8e4b15": tl.float8e4b15,
44+
"fp8e4nv": tl.float8e4nv,
45+
"fp8e4b8": tl.float8e4b8,
46+
"fp8e5": tl.float8e5,
47+
"fp8e5b16": tl.float8e5b16,
48+
}
49+
2250

2351
@lru_cache(maxsize=1)
2452
def _get_triton_tensor_types():
@@ -322,6 +350,14 @@ def _create_arg_from_info(arg_info):
322350
)
323351
Tensor, Storage, StridedLayout = _get_triton_tensor_types()
324352
return StridedLayout(shape=arg_info.get("initial_shape"))
353+
354+
elif arg_type == "dtype":
355+
dtype_repr = arg_info.get("repr")
356+
if dtype_repr in TRITON_DTYPE_MAP:
357+
return TRITON_DTYPE_MAP[dtype_repr]
358+
else:
359+
raise NotImplementedError(f"Unsupported Triton dtype: {dtype_repr}")
360+
325361
else:
326362
print(f"Warning: Unhandled argument type '{arg_type}'. Returning None.")
327363
return None

0 commit comments

Comments
 (0)