Skip to content

Commit ee7434b

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][guards] 1/N Guard selectively for DTensor (pytorch#165824)
A few internal jobs are observing very high guard overhead for DTensor. Since we own DTensor, we can make those guards way faster. Pull Request resolved: pytorch#165824 Approved by: https://github.com/Lucaskabela, https://github.com/bdhirsh
1 parent d049ed2 commit ee7434b

File tree

4 files changed

+93
-14
lines changed

4 files changed

+93
-14
lines changed

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,25 @@ def g(x):
464464
run(g, 64, 8)
465465
self.assertEqual(cnt.frame_count, 2)
466466

467+
def test_dtensor_requires_grad_recompile(self):
468+
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
469+
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
470+
471+
@torch.compile(backend=cnt, fullgraph=True)
472+
def f(x):
473+
y = x * x
474+
return y.to_local()
475+
476+
full_x = torch.randn(8, 8, requires_grad=False)
477+
x = distribute_tensor(full_x, mesh, [Shard(0)])
478+
f(x)
479+
480+
full_x = torch.randn(8, 8, requires_grad=True)
481+
x = distribute_tensor(full_x, mesh, [Shard(0)])
482+
f(x)
483+
484+
self.assertEqual(cnt.frame_count, 2)
485+
467486
def test_dtensor_attribute_access_on_intermediate(self):
468487
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
469488

torch/_dynamo/guards.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,6 +2150,19 @@ def metadata_checker(x: Any) -> bool:
21502150
metadata_checker, get_verbose_code_parts(global_name, guard)
21512151
)
21522152

2153+
def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None:
2154+
# Copied from DTensor __metadata_guard__
2155+
# TODO - Consider moving this to C++ if stable
2156+
value = deepcopy(self.get(guard.name))
2157+
2158+
def guard_fn(x: Any) -> bool:
2159+
return x._check_equals(value, skip_shapes=True)
2160+
2161+
code = f"__dtensor_spec_{id(guard_fn)}"
2162+
self.get_guard_manager(guard).add_lambda_guard(
2163+
guard_fn, get_verbose_code_parts(code, guard)
2164+
)
2165+
21532166
def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
21542167
ref = self.arg_ref(guard)
21552168
val = self.get(guard.name)

torch/_dynamo/variables/builder.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,25 +2225,70 @@ def wrap_tensor(self, value: torch.Tensor):
22252225
if isinstance(source, GradSource) and is_from_optimizer_source(source):
22262226
guard_type = GuardBuilder.NOT_NONE_MATCH
22272227

2228-
self.install_guards(
2229-
functools.partial(
2230-
guard_type,
2231-
value=(
2232-
value
2233-
if isinstance(source, NumpyTensorSource)
2234-
else TensorWeakRef(value)
2235-
),
2236-
)
2228+
is_dtensor = torch.distributed.is_available() and isinstance(
2229+
value, torch.distributed.tensor.DTensor
22372230
)
2231+
if not is_dtensor:
2232+
# We guard on the _local_tensor and the _spec, and therefore we dont
2233+
# have to guard on the outer DTensor.
2234+
self.install_guards(
2235+
functools.partial(
2236+
guard_type,
2237+
value=(
2238+
value
2239+
if isinstance(source, NumpyTensorSource)
2240+
else TensorWeakRef(value)
2241+
),
2242+
)
2243+
)
22382244

22392245
# We install TYPE_MATCH guards for traceable wrapper subclass object,
22402246
# and recursively install corresponding guard for each inner attribute.
22412247
if is_traceable_wrapper_subclass(value):
2242-
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
2243-
self.install_guards(GuardBuilder.TYPE_MATCH)
2244-
install_guard(
2245-
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
2246-
)
2248+
# Tensor subclass guards are very expensive because they are
2249+
# implemented in Python. Since DTensor is PyTorch-maintained class,
2250+
# we can skip a lot of these guards.
2251+
if is_dtensor:
2252+
self.install_guards(GuardBuilder.TYPE_MATCH)
2253+
2254+
# The inner tensor name is always _local_tensor. If its not, we
2255+
# raise assertion to update the check accordingly.
2256+
inner_tensor_name = value.__tensor_flatten__()[0][0]
2257+
if inner_tensor_name != "_local_tensor":
2258+
raise RuntimeError(
2259+
"Expecting Dtensor inner tensor name to be _local_tensor"
2260+
)
2261+
2262+
# Now selectively guard on the flattening context
2263+
flattening_ctx = value.__tensor_flatten__()[1]
2264+
# This is supposed to be (self._spec, self.requires_grad)
2265+
if not (
2266+
len(flattening_ctx) == 2
2267+
and flattening_ctx[0] == value._spec
2268+
and flattening_ctx[1] == value.requires_grad
2269+
):
2270+
# If not, raise an assertion to update to the new guards
2271+
raise RuntimeError(
2272+
"Expecting Dtensor flattening ctx to be _spec, requires_grad"
2273+
)
2274+
# Guard on the dtensor spec
2275+
install_guard(
2276+
AttrSource(self.source, "_spec").make_guard(
2277+
GuardBuilder.DTENSOR_SPEC_MATCH
2278+
)
2279+
)
2280+
# Move this to C++
2281+
install_guard(
2282+
AttrSource(self.source, "requires_grad").make_guard(
2283+
GuardBuilder.EQUALS_MATCH
2284+
)
2285+
)
2286+
else:
2287+
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
2288+
self.install_guards(GuardBuilder.TYPE_MATCH)
2289+
install_guard(
2290+
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
2291+
)
22472292

22482293
attrs, _ = value.__tensor_flatten__()
22492294
for attr in attrs:

torch/distributed/tensor/_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,8 @@ def __get_tensor_shard__(self, index):
671671
def __metadata_guard__(
672672
cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool]
673673
) -> bool:
674+
# TODO - delete this - This is now unused after the PR -
675+
# https://github.com/pytorch/pytorch/pull/165824
674676
orig_spec, orig_requires_grad = orig
675677
other_spec, other_requires_grad = other
676678
return (

0 commit comments

Comments
 (0)