Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions build_tools/cmake/TorchMLIRPyTorch.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ endfunction()
# Separately, pybind11 keeps an internal variable which records its ABI info
# (PYBIND11_INTERNALS_ID in include/pybind11/detail/internals.h). Differences
# in this variable between torch-mlir and PyTorch will cause type errors.
# Thus, our best option is to:
# Note: as of version 2.9.0.dev20250826, torch has updated to pybind11 ver 3.0.
# This simplifies compatibility considerably. For reference, see
# https://github.com/pybind/pybind11/pull/5439
# For pre-version 3.0 pybind11, our best option is to:
# a) Identify which ABI version PyTorch was compiled with
# b) Tell gcc to use that version
# or
Expand Down Expand Up @@ -70,23 +73,27 @@ function(TorchMLIRConfigurePyTorch)
# Check ABI compatibility version
execute_process(
COMMAND ${Python3_EXECUTABLE}
-c "import torch; import sys; abi=torch._C._PYBIND11_BUILD_ABI; abi.startswith('_cxxabi10') or sys.exit(1); sys.stdout.write(str(abi[-2:]))"
-c "import torch; import sys; abi=getattr(torch._C, '_PYBIND11_BUILD_ABI', '-1'); abi=='-1' or abi.startswith('_cxxabi10') or sys.exit(1); sys.stdout.write(str(abi[-2:]))"
RESULT_VARIABLE _result
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE _cxx_abi_version)
if(_result)
message(FATAL_ERROR "Failed to determine C++ ABI version")
endif()
message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"")

# Specialize compile flags for compiler
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}")
elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'")
elseif(${_cxx_abi_version} STREQUAL "-1")
message(STATUS "Could not find `torch._C._PYBIND_BUILD_ABI`. This was removed in torch 2.9.0 (as of nightly release: dev20250826), and the TORCH_CXX_FLAGS manipulation is no longer required.")
# Everyone involved should be using cxx11 abi by default, but specify this just in case.
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi}")
else()
message(WARNING "Unrecognized compiler. Cannot determine ABI flags.")
return()
message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"")
# Specialize compile flags for compiler
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}")
elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'")
else()
message(WARNING "Unrecognized compiler. Cannot determine ABI flags.")
return()
endif()
endif()
set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE)
endif()
Expand Down
41 changes: 6 additions & 35 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15960,9 +15960,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int15 = torch.constant.int 15\n"
" %int5 = torch.constant.int 5\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down Expand Up @@ -16011,22 +16008,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" torch.prim.If.yield %9 : !torch.int\n"
" } else {\n"
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" torch.prim.If.yield %8 : !torch.int\n"
" }\n"
" torch.prim.If.yield %7 : !torch.int\n"
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<number>, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int15 = torch.constant.int 15\n"
" %int5 = torch.constant.int 5\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down Expand Up @@ -16075,15 +16062,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" torch.prim.If.yield %9 : !torch.int\n"
" } else {\n"
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" torch.prim.If.yield %8 : !torch.int\n"
" }\n"
" torch.prim.If.yield %7 : !torch.int\n"
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" return %4 : !torch.int\n"
" }\n"
Expand All @@ -16107,8 +16087,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %int6 = torch.constant.int 6\n"
" %int15 = torch.constant.int 15\n"
" %int5 = torch.constant.int 5\n"
" %int8 = torch.constant.int 8\n"
" %none = torch.constant.none\n"
Expand All @@ -16126,15 +16104,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %int5 : !torch.int\n"
" } else {\n"
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" torch.prim.If.yield %8 : !torch.int\n"
" }\n"
" torch.prim.If.yield %7 : !torch.int\n"
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" return %4 : !torch.int\n"
" }\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5544,8 +5544,6 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni
return aten〇std〡dtype((self_rank, dtype))
assert not is_complex_dtype(dtype)
return dtype
if self_dtype in [torch.float16, torch.bfloat16]:
return torch.float32
return aten〇std〡dtype(self_rank_dtype)

@check_dtype_function(
Expand All @@ -5569,8 +5567,6 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
return aten〇std〡dtype((self_rank, dtype))
assert not is_complex_dtype(dtype)
return dtype
if self_dtype in [torch.float16, torch.bfloat16]:
return torch.float32
return aten〇std〡dtype(self_rank_dtype)

def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int:
Expand Down Expand Up @@ -5604,8 +5600,6 @@ def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int,
# Should possibly be added to aten〇std〡dtype.
if self_dtype == torch.complex32:
return torch.half
if self_dtype in [torch.float16, torch.bfloat16]:
return torch.float32
return aten〇std〡dtype(self_rank_dtype)

@check_dtype_function([Invocation(0.0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,12 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
)
module = self._backend.compile(module)
backend_module = self._backend.load(module)
input_buffers = prog.graph_signature.inputs_to_buffers.values()
params = {
# **dict(artifact.named_parameters(remove_duplicate=False)),
**dict(artifact.named_buffers(remove_duplicate=False)),
name: value
for (name, value) in artifact.named_buffers(remove_duplicate=False)
if name in input_buffers
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = list(params_flat)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def convert_onnx(model, inputs):
input_names=input_names,
dynamic_axes=dynamic_tensors,
opset_version=max_opset_ver,
dynamo=False,
)
buffer = buffer.getvalue()
return import_onnx(buffer)
Expand Down
2 changes: 1 addition & 1 deletion pytorch-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7956a1d1d0dc7cdaaaa42d0863eebb1b1e75eb65
0dfcb1a118dd45c544a156e1d86566368e528e69
2 changes: 1 addition & 1 deletion pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch/
--pre
torch==2.9.0.dev20250820
torch==2.10.0.dev20251016
2 changes: 1 addition & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torchvision/
--pre
torchvision==0.24.0.dev20250820
torchvision==0.25.0.dev20251016
Loading