|
1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. |
2 | 2 |
|
| 3 | +import ast |
3 | 4 | import importlib |
4 | 5 | import importlib.util |
5 | 6 | import json |
@@ -403,51 +404,74 @@ def _generate_import_statements(kernel_info) -> tuple[str, str]: |
403 | 404 |
|
404 | 405 | def _parse_kernel_signature(kernel_source_code: str) -> tuple[list[str], list[str]]: |
405 | 406 | """ |
406 | | - Parses a Triton kernel's source code to distinguish positional args |
| 407 | + Parses a Triton kernel's source code using AST to distinguish positional args |
407 | 408 | from keyword args (those with default values). |
| 409 | +
|
| 410 | + This implementation uses Python's ast module for robust parsing that handles: |
| 411 | + - Return type annotations (e.g., -> None) |
| 412 | + - Complex type annotations (e.g., Callable[[dict[str, int]], list[Tensor]]) |
| 413 | + - Decorators (e.g., @triton.jit) |
| 414 | + - Keyword-only arguments (after *) |
| 415 | + - All Python syntax variations |
| 416 | +
|
| 417 | + Args: |
| 418 | + kernel_source_code: Python source code containing the kernel function |
| 419 | +
|
| 420 | + Returns: |
| 421 | + tuple[list[str], list[str]]: (positional_args, keyword_args) |
| 422 | +
|
| 423 | + Raises: |
| 424 | + ValueError: If parsing fails or no function definition is found |
408 | 425 | """ |
409 | | - signature_lines = [] |
410 | | - in_signature = False |
411 | | - for line in kernel_source_code.splitlines(): |
412 | | - # Mark beginning of signature when function definition is found |
413 | | - if line.strip().startswith("def "): |
414 | | - in_signature = True |
415 | | - if in_signature: |
416 | | - # Strip comments and leading/trailing whitespace |
417 | | - clean_line = line.split("#")[0].strip() |
418 | | - signature_lines.append(clean_line) |
419 | | - # Stop capturing after the signature ends |
420 | | - if "):" in line: |
| 426 | + try: |
| 427 | + # Parse source code into AST |
| 428 | + tree = ast.parse(kernel_source_code) |
| 429 | + |
| 430 | + # Find the first function definition |
| 431 | + func_def = None |
| 432 | + for node in ast.walk(tree): |
| 433 | + if isinstance(node, ast.FunctionDef): |
| 434 | + func_def = node |
421 | 435 | break |
422 | 436 |
|
423 | | - full_signature = "".join(signature_lines) |
424 | | - # Extract content between the first '(' and the last '):' |
425 | | - try: |
426 | | - params_str = full_signature[ |
427 | | - full_signature.find("(") + 1 : full_signature.rfind("):") |
428 | | - ] |
429 | | - except IndexError as exc: |
430 | | - raise ValueError("Could not parse kernel signature.") from exc |
431 | | - |
432 | | - # Clean up and split the parameters string |
433 | | - params = [p.strip() for p in params_str.replace("\n", "").split(",") if p.strip()] |
434 | | - |
435 | | - positional_args = [] |
436 | | - keyword_args = [] |
437 | | - |
438 | | - for param in params: |
439 | | - if "=" in param: |
440 | | - # Keyword arguments have a default value |
441 | | - arg_name = param.split("=")[0].strip() |
442 | | - keyword_args.append(arg_name) |
443 | | - else: |
444 | | - # Positional arguments do not have a default value |
445 | | - arg_name = param.split(":")[0].strip() |
446 | | - positional_args.append(arg_name) |
| 437 | + if not func_def: |
| 438 | + raise ValueError("No function definition found in source code") |
| 439 | + |
| 440 | + positional_args = [] |
| 441 | + keyword_args = [] |
447 | 442 |
|
448 | | - logger.debug("Parsed positional args: %s", positional_args) |
449 | | - logger.debug("Parsed keyword args: %s", keyword_args) |
450 | | - return positional_args, keyword_args |
| 443 | + # Extract function arguments |
| 444 | + args = func_def.args |
| 445 | + |
| 446 | + # Calculate number of positional arguments |
| 447 | + # defaults are right-aligned with args, so: |
| 448 | + # num_positional = total_args - num_defaults |
| 449 | + num_defaults = len(args.defaults) |
| 450 | + num_args = len(args.args) |
| 451 | + num_positional = num_args - num_defaults |
| 452 | + |
| 453 | + # Classify regular arguments |
| 454 | + for i, arg in enumerate(args.args): |
| 455 | + arg_name = arg.arg |
| 456 | + if i < num_positional: |
| 457 | + positional_args.append(arg_name) |
| 458 | + else: |
| 459 | + keyword_args.append(arg_name) |
| 460 | + |
| 461 | + # Handle keyword-only arguments (after *) |
| 462 | + for arg in args.kwonlyargs: |
| 463 | + keyword_args.append(arg.arg) |
| 464 | + |
| 465 | + logger.debug("Parsed positional args: %s", positional_args) |
| 466 | + logger.debug("Parsed keyword args: %s", keyword_args) |
| 467 | + return positional_args, keyword_args |
| 468 | + |
| 469 | + except SyntaxError as e: |
| 470 | + raise ValueError( |
| 471 | + f"Invalid Python syntax in kernel source at line {e.lineno}: {e.msg}" |
| 472 | + ) from e |
| 473 | + except Exception as e: |
| 474 | + raise ValueError(f"Failed to parse kernel signature: {e}") from e |
451 | 475 |
|
452 | 476 |
|
453 | 477 | def _generate_invocation_snippet( |
|
0 commit comments