Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,21 @@ def transform_function(self, func_name, function_pointer):
return []

# Step 1. Parse the given function
file_name = inspect.getsourcefile(function_pointer)
lines, start_line = inspect.getsourcelines(function_pointer)
dedented_source = textwrap.dedent("".join(lines))
try:
file_name = inspect.getsourcefile(function_pointer)
lines, start_line = inspect.getsourcelines(function_pointer)
dedented_source = textwrap.dedent("".join(lines))
except (OSError, TypeError):
# Handle interactive mode (e.g., Python shell, Jupyter) where source is not available
# In this case, we cannot preprocess the function
log().warning(
"Cannot get source code for function [%s]. "
"Interactive mode (Python shell/Jupyter) is not supported for preprocessing. "
"Please define functions in a .py file instead.",
func_name
)
return []

tree = ast.parse(dedented_source, filename=file_name)
# Bump the line numbers so they match the real source file
ast.increment_lineno(tree, start_line - 1)
Expand Down Expand Up @@ -446,7 +458,12 @@ def transform(self, original_function, exec_globals):
"""
Transforms the provided function using the preprocessor.
"""
self.file_name = inspect.getsourcefile(original_function)
try:
self.file_name = inspect.getsourcefile(original_function)
except (OSError, TypeError):
# Interactive mode - use placeholder filename
self.file_name = "<interactive>"

self.function_globals = exec_globals
transformed_tree = self.transform_function(
original_function.__name__, original_function
Expand Down
7 changes: 6 additions & 1 deletion python/CuTeDSL/cutlass/base_dsl/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,12 @@ def run_preprocessor(self, funcBody):
return None

def get_function_ptr(self, original_function):
file_name = inspect.getsourcefile(original_function)
try:
file_name = inspect.getsourcefile(original_function)
except (OSError, TypeError):
# Interactive mode - use placeholder filename
file_name = "<interactive>"

code_object = compile(
original_function._transformed_ast, filename=file_name, mode="exec"
)
Expand Down