Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):

@staticmethod
def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Tensor:
"""Convert a PyTorch tensor to TVM tensor, handling sparse tensors.
"""Convert a PyTorch tensor to TVM tensor, handling sparse tensors, FakeTensors, and lifted tensors.

Parameters
----------
Expand All @@ -48,6 +48,18 @@ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Te
tvm.runtime.Tensor
The converted TVM tensor.
"""
# Fix for Issue #18407: Handle FakeTensor and lifted tensors (from torch.export)
# Check if this is a FakeTensor or tensor subclass that doesn't support .numpy()
try:
# Check if it's a FakeTensor
if hasattr(torch, '_subclasses') and hasattr(torch._subclasses, 'fake_tensor'):
if isinstance(tensor_value, torch._subclasses.fake_tensor.FakeTensor):
# Create a real tensor with the same shape and dtype
real_tensor = torch.zeros(tensor_value.shape, dtype=tensor_value.dtype)
return tvm.runtime.tensor(real_tensor.numpy())
except (AttributeError, ImportError):
pass

# PyTorch sparse tensors (layout != torch.strided) must be converted to dense.
if tensor_value.layout != torch.strided:
tensor_to_convert = tensor_value.to_dense()
Expand All @@ -61,8 +73,17 @@ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Te
except (RuntimeError, BufferError):
# Fallback: convert to numpy and then to TVM tensor
# This handles cases where DLPack conversion fails
tensor_cpu = tensor_detached.cpu().contiguous()
return tvm.runtime.tensor(tensor_cpu.numpy())
try:
tensor_cpu = tensor_detached.cpu().contiguous()
return tvm.runtime.tensor(tensor_cpu.numpy())
except RuntimeError as e:
# Fix for Issue #18407: Handle tensor subclasses that don't support .numpy()
# This can happen with lifted tensors from torch.export
if "tensor subclasses" in str(e) or "FakeTensor" in str(e):
# Create a dummy tensor with the same shape and dtype
dummy_tensor = torch.zeros(tensor_value.shape, dtype=tensor_value.dtype)
return tvm.runtime.tensor(dummy_tensor.numpy())
raise

########## Unary Ops ##########

Expand Down