diff --git a/build_tools/cmake/TorchMLIRPyTorch.cmake b/build_tools/cmake/TorchMLIRPyTorch.cmake index 53253c8c7e14..751c72b8efaf 100644 --- a/build_tools/cmake/TorchMLIRPyTorch.cmake +++ b/build_tools/cmake/TorchMLIRPyTorch.cmake @@ -39,7 +39,10 @@ endfunction() # Separately, pybind11 keeps an internal variable which records its ABI info # (PYBIND11_INTERNALS_ID in include/pybind11/detail/internals.h). Differences # in this variable between torch-mlir and PyTorch will cause type errors. -# Thus, our best option is to: +# Note: as of version 2.9.0.dev20250826, torch has updated to pybind11 ver 3.0. +# This simplifies compatibility considerably. For reference, see +# https://github.com/pybind/pybind11/pull/5439 +# For pre-version 3.0 pybind11, our best option is to: # a) Identify which ABI version PyTorch was compiled with # b) Tell gcc to use that version # or @@ -70,23 +73,27 @@ function(TorchMLIRConfigurePyTorch) # Check ABI compatibility version execute_process( COMMAND ${Python3_EXECUTABLE} - -c "import torch; import sys; abi=torch._C._PYBIND11_BUILD_ABI; abi.startswith('_cxxabi10') or sys.exit(1); sys.stdout.write(str(abi[-2:]))" + -c "import torch; import sys; abi=getattr(torch._C, '_PYBIND11_BUILD_ABI', '-1'); abi=='-1' or abi.startswith('_cxxabi10') or sys.exit(1); sys.stdout.write(str(abi[-2:]))" RESULT_VARIABLE _result WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} OUTPUT_VARIABLE _cxx_abi_version) if(_result) message(FATAL_ERROR "Failed to determine C++ ABI version") - endif() - message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"") - - # Specialize compile flags for compiler - if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") - set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}") - elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") - set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'") + elseif(${_cxx_abi_version} STREQUAL "-1") + message(STATUS "Could not find `torch._C._PYBIND_BUILD_ABI`. This was removed in torch 2.9.0 (as of nightly release: dev20250826), and the TORCH_CXX_FLAGS manipulation is no longer required.") + # Everyone involved should be using cxx11 abi by default, but specify this just in case. + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi}") else() - message(WARNING "Unrecognized compiler. Cannot determine ABI flags.") - return() + message(STATUS "PyTorch C++ ABI version: \"${_cxx_abi_version}\"") + # Specialize compile flags for compiler + if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -fabi-version=${_cxx_abi_version}") + elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=${_use_cxx11_abi} -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=10${_cxx_abi_version} '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'") + else() + message(WARNING "Unrecognized compiler. Cannot determine ABI flags.") + return() + endif() endif() set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE) endif() diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bb7572395ba3..1fd3141637b4 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -15960,9 +15960,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" -" %int6 = torch.constant.int 6\n" -" %int15 = torch.constant.int 15\n" -" %int5 = torch.constant.int 5\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -16011,22 +16008,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %9 : !torch.int\n" " } else {\n" -" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" } else {\n" -" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" -" torch.prim.If.yield %8 : !torch.int\n" -" }\n" -" torch.prim.If.yield %7 : !torch.int\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" " }\n" " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.linalg_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" -" %int6 = torch.constant.int 6\n" -" %int15 = torch.constant.int 15\n" -" %int5 = torch.constant.int 5\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -16075,15 +16062,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield %9 : !torch.int\n" " } else {\n" -" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" } else {\n" -" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" -" torch.prim.If.yield %8 : !torch.int\n" -" }\n" -" torch.prim.If.yield %7 : !torch.int\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" " }\n" " return %4 : !torch.int\n" " }\n" @@ -16107,8 +16087,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %true = torch.constant.bool true\n" -" %int6 = torch.constant.int 6\n" -" %int15 = torch.constant.int 15\n" " %int5 = torch.constant.int 5\n" " %int8 = torch.constant.int 8\n" " %none = torch.constant.none\n" @@ -16126,15 +16104,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int5 : !torch.int\n" " } else {\n" -" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" } else {\n" -" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" -" torch.prim.If.yield %8 : !torch.int\n" -" }\n" -" torch.prim.If.yield %7 : !torch.int\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" " }\n" " return %4 : !torch.int\n" " }\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index e894d49d4dd6..fc6e52b0b2a4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -5544,8 +5544,6 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni return aten〇std〡dtype((self_rank, dtype)) assert not is_complex_dtype(dtype) return dtype - if self_dtype in [torch.float16, torch.bfloat16]: - return torch.float32 return aten〇std〡dtype(self_rank_dtype) @check_dtype_function( @@ -5569,8 +5567,6 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U return aten〇std〡dtype((self_rank, dtype)) assert not is_complex_dtype(dtype) return dtype - if self_dtype in [torch.float16, torch.bfloat16]: - return torch.float32 return aten〇std〡dtype(self_rank_dtype) def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int: @@ -5604,8 +5600,6 @@ def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, # Should possibly be added to aten〇std〡dtype. if self_dtype == torch.complex32: return torch.half - if self_dtype in [torch.float16, torch.bfloat16]: - return torch.float32 return aten〇std〡dtype(self_rank_dtype) @check_dtype_function([Invocation(0.0), diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 396d43638a42..a116a94dabd3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -149,9 +149,12 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: ) module = self._backend.compile(module) backend_module = self._backend.load(module) + input_buffers = prog.graph_signature.inputs_to_buffers.values() params = { # **dict(artifact.named_parameters(remove_duplicate=False)), - **dict(artifact.named_buffers(remove_duplicate=False)), + name: value + for (name, value) in artifact.named_buffers(remove_duplicate=False) + if name in input_buffers } params_flat, params_spec = pytree.tree_flatten(params) params_flat = list(params_flat) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index 5461dc04c0d1..7bead331ea04 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -85,6 +85,7 @@ def convert_onnx(model, inputs): input_names=input_names, dynamic_axes=dynamic_tensors, opset_version=max_opset_ver, + dynamo=False, ) buffer = buffer.getvalue() return import_onnx(buffer) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index b1886c1abddd..582695ddc3c7 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -7956a1d1d0dc7cdaaaa42d0863eebb1b1e75eb65 +0dfcb1a118dd45c544a156e1d86566368e528e69 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 87cbf28f5a98..ac81781a6b2b 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch/ --pre -torch==2.9.0.dev20250820 +torch==2.10.0.dev20251016 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 68c96010c96f..546bfb138e43 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre -torchvision==0.24.0.dev20250820 +torchvision==0.25.0.dev20251016