diff --git a/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py b/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py index 11f2d1ae84..dd1b4b7311 100644 --- a/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py +++ b/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py @@ -304,9 +304,21 @@ def transform_function(self, func_name, function_pointer): return [] # Step 1. Parse the given function - file_name = inspect.getsourcefile(function_pointer) - lines, start_line = inspect.getsourcelines(function_pointer) - dedented_source = textwrap.dedent("".join(lines)) + try: + file_name = inspect.getsourcefile(function_pointer) + lines, start_line = inspect.getsourcelines(function_pointer) + dedented_source = textwrap.dedent("".join(lines)) + except (OSError, TypeError): + # Handle interactive mode (e.g., Python shell, Jupyter) where source is not available + # In this case, we cannot preprocess the function + log().warning( + "Cannot get source code for function [%s]. " + "Interactive mode (Python shell/Jupyter) is not supported for preprocessing. " + "Please define functions in a .py file instead.", + func_name + ) + return [] + tree = ast.parse(dedented_source, filename=file_name) # Bump the line numbers so they match the real source file ast.increment_lineno(tree, start_line - 1) @@ -446,7 +458,12 @@ def transform(self, original_function, exec_globals): """ Transforms the provided function using the preprocessor. """ - self.file_name = inspect.getsourcefile(original_function) + try: + self.file_name = inspect.getsourcefile(original_function) + except (OSError, TypeError): + # Interactive mode - use placeholder filename + self.file_name = "" + self.function_globals = exec_globals transformed_tree = self.transform_function( original_function.__name__, original_function diff --git a/python/CuTeDSL/cutlass/base_dsl/dsl.py b/python/CuTeDSL/cutlass/base_dsl/dsl.py index 2b17d22b1e..7cf0f7579c 100644 --- a/python/CuTeDSL/cutlass/base_dsl/dsl.py +++ b/python/CuTeDSL/cutlass/base_dsl/dsl.py @@ -1259,7 +1259,12 @@ def run_preprocessor(self, funcBody): return None def get_function_ptr(self, original_function): - file_name = inspect.getsourcefile(original_function) + try: + file_name = inspect.getsourcefile(original_function) + except (OSError, TypeError): + # Interactive mode - use placeholder filename + file_name = "" + code_object = compile( original_function._transformed_ast, filename=file_name, mode="exec" )