Skip to content

Commit

Permalink
Allow Fakified subclass to have different device for inner and outer …
Browse files Browse the repository at this point in the history
…tensor (pytorch#141839)

Previously if a wrapper tensor subclass is fakified, the inner tensors would end up having the same device as the outer tensor. This PR makes it so that inner and outer tensors can have different devices.

See OffloadTensor PR https://github.com/pytorch/pytorch/pull/141840/files#diff-3bc0cf540b694f4ec0a3749f78b047456657a53a5657e495ffb68e5970c5fdaaR1955 for an application. A simpler test has been added in this PR.

This is technically bc-breaking because now the callback passed to MetaConverter needs to accept an extra argument, but no one external should be using this anyway?
Pull Request resolved: pytorch#141839
Approved by: https://github.com/bdhirsh
ghstack dependencies: pytorch#141166
  • Loading branch information
soulitzer authored and pytorchmergebot committed Dec 3, 2024
1 parent 9830e7b commit e41a0b3
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 15 deletions.
54 changes: 54 additions & 0 deletions test/test_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1951,6 +1951,60 @@ def test_inference_mode(self):
extract_tensor_metadata(res4),
)


@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_wrapper_tensor_subclass_different_device(self):
class DifferentDeviceTensor(torch.Tensor):
@staticmethod
def __new__(cls, a):
kwargs = {}
kwargs["strides"] = a.stride()
kwargs["storage_offset"] = a.storage_offset()
kwargs["device"] = torch.device("cpu")
kwargs["layout"] = a.layout
kwargs["requires_grad"] = a.requires_grad
kwargs["dtype"] = a.dtype
out = torch.Tensor._make_wrapper_subclass(cls, a.size(), **kwargs)
return out

def __init__(self, a):
self.inner_tensor = a

def __repr__(self):
return f"DifferentDeviceTensor({repr(self.inner_tensor)})"

def __tensor_flatten__(self):
return ["inner_tensor"], None

@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is None
return DifferentDeviceTensor(inner_tensors["inner_tensor"])

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, args)
kwargs = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs)
# Returns unwrapped tensor
return func(*args, **kwargs)

a = torch.ones(2, 2, 768, device="cuda")
wrapped_a = DifferentDeviceTensor(a)

# Outer Tensor is on cpu, inner is on cuda
self.assertTrue(wrapped_a.is_cpu)
self.assertFalse(wrapped_a.inner_tensor.is_cpu)

with FakeTensorMode() as fake_mode:
fake_wrapped_a = fake_mode.from_tensor(wrapped_a)

self.assertTrue(fake_wrapped_a.is_cpu)
assert isinstance(fake_wrapped_a, DifferentDeviceTensor)
self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu)


def test_cache_tuple_outputs(self):
"""
Test to check that ops with tuple outputs work.
Expand Down
16 changes: 11 additions & 5 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,20 @@ def from_real_tensor(
maybe_memo = self._get_memo(t)
if maybe_memo is not None:
return maybe_memo
existing_device = t.device
# not yet supported in metatensors
if t.is_quantized:
raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
if type(t) is torch.nn.Parameter:
assert not make_constant

def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor:
constant = t if make_constant else None

# This callback is used by both subclass and inner tensors. Require the
# caller to explicitly specify the device in case outer and inner tensors
# have different devices.
def mk_fake_tensor(
make_meta_t: Callable[[], object], device: torch.device
) -> FakeTensor:
# NB: don't use in_kernel_invocation_manager. to
# ensure FakeTensor can internally do constant computation
# as necessary. Invocation manager is "more correct" as
Expand All @@ -373,16 +379,16 @@ def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor:
return FakeTensor(
fake_mode,
make_meta_t(),
existing_device,
device,
# TODO: callback might be used in recursive contexts, in
# which case using t is wrong! BUG!
constant=t if make_constant else None,
constant=constant,
)

out = self.meta_converter(
t,
shape_env=shape_env,
callback=mk_fake_tensor,
callback=mk_fake_tensor, # type: ignore[arg-type]
source=source,
symbolic_context=symbolic_context,
trace=trace,
Expand Down
27 changes: 17 additions & 10 deletions torch/_subclasses/meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import dataclasses
import functools
import typing
import warnings
import weakref
Expand Down Expand Up @@ -709,7 +710,7 @@ def set_storage_memo(self, s: MetaStorageDesc, v: torch.UntypedStorage) -> None:
def meta_storage(
self,
s: MetaStorageDesc,
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
callback: Callable[[Callable], _TensorT],
) -> torch.UntypedStorage:
# If we are fakeifying a tensor that has a secretly-zero-sized storage,
# Need to make sure to resize the meta storage too.
Expand All @@ -734,7 +735,9 @@ def _checked_cast_tensor_t(cls, t: torch.Tensor) -> _TensorT:
return typing.cast(_TensorT, t)

@classmethod
def _identity_callable(cls, t: Callable[[], torch.Tensor]) -> _TensorT:
def _identity_callable(
cls, t: Callable[[], torch.Tensor], device: torch.device
) -> _TensorT:
return cls._checked_cast_tensor_t(t())

@classmethod
Expand All @@ -756,10 +759,11 @@ def meta_tensor(
self,
t: MetaTensorDesc,
shape_env: Optional[ShapeEnv],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
callback: Callable[[Callable], _TensorT],
source: Optional[Source],
symbolic_context: Optional[SymbolicContext],
) -> _TensorT:
callback = functools.partial(callback, device=t.device) # type: ignore[call-arg]
if source is None:
from torch._dynamo.source import ConstantSource

Expand Down Expand Up @@ -905,7 +909,7 @@ def _empty_create_subclass(
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
callback: Callable[[Callable], _TensorT],
source: torch._guards.Source,
) -> _TensorT:
# We are hitting plain meta_desc tensor so actually
Expand Down Expand Up @@ -933,12 +937,15 @@ def _empty_create_subclass(
)

current_source = AttrSource(source, attr)
inner_callback = functools.partial(
callback, device=meta_tensor_desc.device # type: ignore[call-arg]
)
new_empty_tensor = _empty_create_subclass(
meta_tensor_desc,
meta_tensor_desc.size,
meta_tensor_desc.stride,
current_context,
callback,
inner_callback,
current_source,
)
inner_tensors[attr] = new_empty_tensor
Expand Down Expand Up @@ -975,7 +982,7 @@ def all_dynamic_symbolic_context(
t: MetaTensorDesc,
source: torch._guards.Source,
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
callback: Callable[[Callable], _TensorT],
) -> torch.fx.experimental.symbolic_shapes.SymbolicContext:
from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import (
Expand Down Expand Up @@ -1137,7 +1144,7 @@ def tensor_visitor_fn(
shape_env: Optional[
torch.fx.experimental.symbolic_shapes.ShapeEnv
] = shape_env,
callback: Callable[[Callable[[], torch.Tensor]], _TensorT] = callback, # type: ignore[assignment]
callback: Callable[[Callable], _TensorT] = callback, # type: ignore[assignment]
) -> torch.Tensor:
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
if visited_t is None:
Expand Down Expand Up @@ -1723,17 +1730,17 @@ def __call__(
t: torch.Tensor,
shape_env: Optional[ShapeEnv] = None,
*,
callback: Optional[Callable[[Callable[[], torch.Tensor]], _TensorT]] = None,
callback: Optional[Callable[[Callable], _TensorT]] = None,
source: Optional[Source] = None,
symbolic_context: Optional[SymbolicContext] = None,
# Controls whether or not we should dump the tensor metadata to structured logs
# when source is not None. Because we refakify after Dynamo is done,
# we don't want to dump info again from AOTAutograd, it is redundant.
trace: bool = True,
) -> _TensorT:
callback_: Callable[[Callable[[], torch.Tensor]], _TensorT]
callback_: Callable[[Callable], _TensorT]
if callback is None:
callback_ = self._identity_callable
callback_ = self._identity_callable # type: ignore[assignment]
else:
callback_ = callback
# TODO: zero tensors? We appear to have eliminated them by
Expand Down

0 comments on commit e41a0b3

Please sign in to comment.