Skip to content

Commit d24ae1d

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Fix kernel signature parsing to support return type annotations
Summary: Replace string-based parsing with AST for robust function signature extraction. The previous implementation used string matching to find '):', which failed when functions had return type annotations like '-> None'. This is common in modern Python and Triton kernel code. This change uses Python's ast module for standards-compliant parsing that correctly handles: - Return type annotations (-> Type) - Complex type hints (Callable, generics, etc.) - Decorators (e.g., triton.jit) - Keyword-only arguments (after *) - All valid Python function signatures The AST approach is more robust and uses Python's built-in parser to correctly identify positional vs keyword arguments using the right-aligned defaults algorithm. This requires no dependencies on triton/torch imports since AST is static analysis, making it future-proof against Python syntax changes. Reviewed By: wychi Differential Revision: D87368842 fbshipit-source-id: 9b77d440d479274c3e637af3e3a2972310f2832d
1 parent a6a8856 commit d24ae1d

File tree

1 file changed

+64
-40
lines changed

1 file changed

+64
-40
lines changed

tritonparse/reproducer/utils.py

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22

3+
import ast
34
import importlib
45
import importlib.util
56
import json
@@ -403,51 +404,74 @@ def _generate_import_statements(kernel_info) -> tuple[str, str]:
403404

404405
def _parse_kernel_signature(kernel_source_code: str) -> tuple[list[str], list[str]]:
405406
"""
406-
Parses a Triton kernel's source code to distinguish positional args
407+
Parses a Triton kernel's source code using AST to distinguish positional args
407408
from keyword args (those with default values).
409+
410+
This implementation uses Python's ast module for robust parsing that handles:
411+
- Return type annotations (e.g., -> None)
412+
- Complex type annotations (e.g., Callable[[dict[str, int]], list[Tensor]])
413+
- Decorators (e.g., @triton.jit)
414+
- Keyword-only arguments (after *)
415+
- All Python syntax variations
416+
417+
Args:
418+
kernel_source_code: Python source code containing the kernel function
419+
420+
Returns:
421+
tuple[list[str], list[str]]: (positional_args, keyword_args)
422+
423+
Raises:
424+
ValueError: If parsing fails or no function definition is found
408425
"""
409-
signature_lines = []
410-
in_signature = False
411-
for line in kernel_source_code.splitlines():
412-
# Mark beginning of signature when function definition is found
413-
if line.strip().startswith("def "):
414-
in_signature = True
415-
if in_signature:
416-
# Strip comments and leading/trailing whitespace
417-
clean_line = line.split("#")[0].strip()
418-
signature_lines.append(clean_line)
419-
# Stop capturing after the signature ends
420-
if "):" in line:
426+
try:
427+
# Parse source code into AST
428+
tree = ast.parse(kernel_source_code)
429+
430+
# Find the first function definition
431+
func_def = None
432+
for node in ast.walk(tree):
433+
if isinstance(node, ast.FunctionDef):
434+
func_def = node
421435
break
422436

423-
full_signature = "".join(signature_lines)
424-
# Extract content between the first '(' and the last '):'
425-
try:
426-
params_str = full_signature[
427-
full_signature.find("(") + 1 : full_signature.rfind("):")
428-
]
429-
except IndexError as exc:
430-
raise ValueError("Could not parse kernel signature.") from exc
431-
432-
# Clean up and split the parameters string
433-
params = [p.strip() for p in params_str.replace("\n", "").split(",") if p.strip()]
434-
435-
positional_args = []
436-
keyword_args = []
437-
438-
for param in params:
439-
if "=" in param:
440-
# Keyword arguments have a default value
441-
arg_name = param.split("=")[0].strip()
442-
keyword_args.append(arg_name)
443-
else:
444-
# Positional arguments do not have a default value
445-
arg_name = param.split(":")[0].strip()
446-
positional_args.append(arg_name)
437+
if not func_def:
438+
raise ValueError("No function definition found in source code")
439+
440+
positional_args = []
441+
keyword_args = []
447442

448-
logger.debug("Parsed positional args: %s", positional_args)
449-
logger.debug("Parsed keyword args: %s", keyword_args)
450-
return positional_args, keyword_args
443+
# Extract function arguments
444+
args = func_def.args
445+
446+
# Calculate number of positional arguments
447+
# defaults are right-aligned with args, so:
448+
# num_positional = total_args - num_defaults
449+
num_defaults = len(args.defaults)
450+
num_args = len(args.args)
451+
num_positional = num_args - num_defaults
452+
453+
# Classify regular arguments
454+
for i, arg in enumerate(args.args):
455+
arg_name = arg.arg
456+
if i < num_positional:
457+
positional_args.append(arg_name)
458+
else:
459+
keyword_args.append(arg_name)
460+
461+
# Handle keyword-only arguments (after *)
462+
for arg in args.kwonlyargs:
463+
keyword_args.append(arg.arg)
464+
465+
logger.debug("Parsed positional args: %s", positional_args)
466+
logger.debug("Parsed keyword args: %s", keyword_args)
467+
return positional_args, keyword_args
468+
469+
except SyntaxError as e:
470+
raise ValueError(
471+
f"Invalid Python syntax in kernel source at line {e.lineno}: {e.msg}"
472+
) from e
473+
except Exception as e:
474+
raise ValueError(f"Failed to parse kernel signature: {e}") from e
451475

452476

453477
def _generate_invocation_snippet(

0 commit comments

Comments
 (0)