From e3816f0b8e589f45e0208779319ee129e8941c79 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 20 Oct 2025 15:49:51 -0700 Subject: [PATCH 1/7] Don't customize torch cxx flags when building jit ir importer. The old path is still left in case we need to build against older versions of pytorch. Signed-off-by: zjgarvey --- build_tools/cmake/TorchMLIRPyTorch.cmake | 25 ++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/build_tools/cmake/TorchMLIRPyTorch.cmake b/build_tools/cmake/TorchMLIRPyTorch.cmake index 53253c8c7e14..0bc313ce8adc 100644 --- a/build_tools/cmake/TorchMLIRPyTorch.cmake +++ b/build_tools/cmake/TorchMLIRPyTorch.cmake @@ -75,20 +75,21 @@ function(TorchMLIRConfigurePyTorch) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} OUTPUT_VARIABLE _cxx_abi_version) if(_result) - message(FATAL_ERROR "Failed to determine C++ ABI version") - endif() - message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"") - - # Specialize compile flags for compiler - if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") - set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}") - elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") - set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'") + message(WARNING "Could not infer torch._C._PYBIND_BUILD_ABI. This was removed after pytorch updated to pybind11 version 3.0.1, and the TORCH_CXX_FLAGS manipulation is no longer required.") else() - message(WARNING "Unrecognized compiler. Cannot determine ABI flags.") - return() + message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"") + + # Specialize compile flags for compiler + if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}") + elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'") + else() + message(WARNING "Unrecognized compiler. Cannot determine ABI flags.") + return() + endif() + set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE) endif() - set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE) endif() endfunction() From 136b91669e04104ddbb25393ad20c85d7c1e7b03 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 20 Oct 2025 15:52:27 -0700 Subject: [PATCH 2/7] Update nightly pins Signed-off-by: zjgarvey --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index b1886c1abddd..582695ddc3c7 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -7956a1d1d0dc7cdaaaa42d0863eebb1b1e75eb65 +0dfcb1a118dd45c544a156e1d86566368e528e69 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 87cbf28f5a98..ac81781a6b2b 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch/ --pre -torch==2.9.0.dev20250820 +torch==2.10.0.dev20251016 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 68c96010c96f..546bfb138e43 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre -torchvision==0.24.0.dev20250820 +torchvision==0.25.0.dev20251016 From 8ebb0e26a5f088296a680f1d822e8c32b4bd8edc Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 20 Oct 2025 15:55:10 -0700 Subject: [PATCH 3/7] fix fx_importer input signature Signed-off-by: zjgarvey --- .../torch_mlir_e2e_test/configs/fx_importer_backend.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 396d43638a42..a116a94dabd3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -149,9 +149,12 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: ) module = self._backend.compile(module) backend_module = self._backend.load(module) + input_buffers = prog.graph_signature.inputs_to_buffers.values() params = { # **dict(artifact.named_parameters(remove_duplicate=False)), - **dict(artifact.named_buffers(remove_duplicate=False)), + name: value + for (name, value) in artifact.named_buffers(remove_duplicate=False) + if name in input_buffers } params_flat, params_spec = pytree.tree_flatten(params) params_flat = list(params_flat) From 3143abd63eb8b54c8adf15205ec2baac4b52be2d Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 20 Oct 2025 15:57:42 -0700 Subject: [PATCH 4/7] Set dynamo=False explicitly in onnx backend Signed-off-by: zjgarvey --- projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index 5461dc04c0d1..7bead331ea04 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -85,6 +85,7 @@ def convert_onnx(model, inputs): input_names=input_names, dynamic_axes=dynamic_tensors, opset_version=max_opset_ver, + dynamo=False, ) buffer = buffer.getvalue() return import_onnx(buffer) From c7bfa8f9498ac31b9dc57479a449d3b90f086128 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 20 Oct 2025 18:57:53 -0700 Subject: [PATCH 5/7] Update interp lib Signed-off-by: zjgarvey --- .../Transforms/AbstractInterpLibrary.cpp | 41 +++---------------- .../build_tools/abstract_interp_lib_gen.py | 6 --- 2 files changed, 6 insertions(+), 41 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bb7572395ba3..1fd3141637b4 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -15960,9 +15960,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" -" %int6 = torch.constant.int 6\n" -" %int15 = torch.constant.int 15\n" -" %int5 = torch.constant.int 5\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -16011,22 +16008,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %9 : !torch.int\n" " } else {\n" -" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" } else {\n" -" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" -" torch.prim.If.yield %8 : !torch.int\n" -" }\n" -" torch.prim.If.yield %7 : !torch.int\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" " }\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.linalg_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" -" %int6 = torch.constant.int 6\n" -" %int15 = torch.constant.int 15\n" -" %int5 = torch.constant.int 5\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -16075,15 +16062,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %9 : !torch.int\n" " } else {\n" -" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" } else {\n" -" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" -" torch.prim.If.yield %8 : !torch.int\n" -" }\n" -" torch.prim.If.yield %7 : !torch.int\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" " }\n" " return %4 : !torch.int\n" " }\n" @@ -16107,8 +16087,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %true = torch.constant.bool true\n" -" %int6 = torch.constant.int 6\n" -" %int15 = torch.constant.int 15\n" " %int5 = torch.constant.int 5\n" " %int8 = torch.constant.int 8\n" " %none = torch.constant.none\n" @@ -16126,15 +16104,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int5 : !torch.int\n" " } else {\n" -" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" } else {\n" -" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" -" torch.prim.If.yield %8 : !torch.int\n" -" }\n" -" torch.prim.If.yield %7 : !torch.int\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" " }\n" " return %4 : !torch.int\n" " }\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index e894d49d4dd6..fc6e52b0b2a4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -5544,8 +5544,6 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni return aten〇std〡dtype((self_rank, dtype)) assert not is_complex_dtype(dtype) return dtype - if self_dtype in [torch.float16, torch.bfloat16]: - return torch.float32 return aten〇std〡dtype(self_rank_dtype) @check_dtype_function( @@ -5569,8 +5567,6 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U return aten〇std〡dtype((self_rank, dtype)) assert not is_complex_dtype(dtype) return dtype - if self_dtype in [torch.float16, torch.bfloat16]: - return torch.float32 return aten〇std〡dtype(self_rank_dtype) def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int: @@ -5604,8 +5600,6 @@ def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, # Should possibly be added to aten〇std〡dtype. if self_dtype == torch.complex32: return torch.half - if self_dtype in [torch.float16, torch.bfloat16]: - return torch.float32 return aten〇std〡dtype(self_rank_dtype) @check_dtype_function([Invocation(0.0), From 57d91aa3905e7538745e2c3dd90ffbb0ec06b823 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Tue, 21 Oct 2025 14:43:58 -0700 Subject: [PATCH 6/7] Remove some configure warnings, and update comments. This change also more accurately screens for when `torch._C` doesn't have the attribute `_PYBIND11_BUILD_ABI`. Signed-off-by: zjgarvey --- build_tools/cmake/TorchMLIRPyTorch.cmake | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/build_tools/cmake/TorchMLIRPyTorch.cmake b/build_tools/cmake/TorchMLIRPyTorch.cmake index 0bc313ce8adc..35b8be5a6438 100644 --- a/build_tools/cmake/TorchMLIRPyTorch.cmake +++ b/build_tools/cmake/TorchMLIRPyTorch.cmake @@ -39,7 +39,10 @@ endfunction() # Separately, pybind11 keeps an internal variable which records its ABI info # (PYBIND11_INTERNALS_ID in include/pybind11/detail/internals.h). Differences # in this variable between torch-mlir and PyTorch will cause type errors. -# Thus, our best option is to: +# Note: as of version 2.9.0.dev20250826, torch has updated to pybind11 ver 3.0. +# This simplifies compatibility considerably. For reference, see +# https://github.com/pybind/pybind11/pull/5439 +# For pre-version 3.0 pybind11, our best option is to: # a) Identify which ABI version PyTorch was compiled with # b) Tell gcc to use that version # or @@ -70,15 +73,18 @@ function(TorchMLIRConfigurePyTorch) # Check ABI compatibility version execute_process( COMMAND ${Python3_EXECUTABLE} - -c "import torch; import sys; abi=torch._C._PYBIND11_BUILD_ABI; abi.startswith('_cxxabi10') or sys.exit(1); sys.stdout.write(str(abi[-2:]))" + -c "import torch; import sys; abi=getattr(torch._C, '_PYBIND11_BUILD_ABI', '-1'); abi=='-1' or abi.startswith('_cxxabi10') or sys.exit(1); sys.stdout.write(str(abi[-2:]))" RESULT_VARIABLE _result WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} OUTPUT_VARIABLE _cxx_abi_version) if(_result) - message(WARNING "Could not infer torch._C._PYBIND_BUILD_ABI. This was removed after pytorch updated to pybind11 version 3.0.1, and the TORCH_CXX_FLAGS manipulation is no longer required.") + message(FATAL_ERROR "Failed to determine C++ ABI version") + elseif(${_cxx_abi_version} STREQUAL "-1") + message(STATUS "Could not find `torch._C._PYBIND_BUILD_ABI`. This was removed in torch 2.9.0 (as of nightly release: dev20250826), and the TORCH_CXX_FLAGS manipulation is no longer required.") + # Everyone involved should be using cxx11 abi by default, but specify this just in case. + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi}") else() message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"") - # Specialize compile flags for compiler if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}") @@ -88,8 +94,8 @@ function(TorchMLIRConfigurePyTorch) message(WARNING "Unrecognized compiler. Cannot determine ABI flags.") return() endif() - set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE) endif() + set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE) endif() endfunction() From 310d2f7b369a0478817d0d030c23afc2802aed85 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Tue, 21 Oct 2025 14:46:55 -0700 Subject: [PATCH 7/7] lint Signed-off-by: zjgarvey --- build_tools/cmake/TorchMLIRPyTorch.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/cmake/TorchMLIRPyTorch.cmake b/build_tools/cmake/TorchMLIRPyTorch.cmake index 35b8be5a6438..751c72b8efaf 100644 --- a/build_tools/cmake/TorchMLIRPyTorch.cmake +++ b/build_tools/cmake/TorchMLIRPyTorch.cmake @@ -40,7 +40,7 @@ endfunction() # (PYBIND11_INTERNALS_ID in include/pybind11/detail/internals.h). Differences # in this variable between torch-mlir and PyTorch will cause type errors. # Note: as of version 2.9.0.dev20250826, torch has updated to pybind11 ver 3.0. -# This simplifies compatibility considerably. For reference, see +# This simplifies compatibility considerably. For reference, see # https://github.com/pybind/pybind11/pull/5439 # For pre-version 3.0 pybind11, our best option is to: # a) Identify which ABI version PyTorch was compiled with