@@ -2225,25 +2225,70 @@ def wrap_tensor(self, value: torch.Tensor):
22252225 if isinstance (source , GradSource ) and is_from_optimizer_source (source ):
22262226 guard_type = GuardBuilder .NOT_NONE_MATCH
22272227
2228- self .install_guards (
2229- functools .partial (
2230- guard_type ,
2231- value = (
2232- value
2233- if isinstance (source , NumpyTensorSource )
2234- else TensorWeakRef (value )
2235- ),
2236- )
2228+ is_dtensor = torch .distributed .is_available () and isinstance (
2229+ value , torch .distributed .tensor .DTensor
22372230 )
2231+ if not is_dtensor :
2232+ # We guard on the _local_tensor and the _spec, and therefore we dont
2233+ # have to guard on the outer DTensor.
2234+ self .install_guards (
2235+ functools .partial (
2236+ guard_type ,
2237+ value = (
2238+ value
2239+ if isinstance (source , NumpyTensorSource )
2240+ else TensorWeakRef (value )
2241+ ),
2242+ )
2243+ )
22382244
22392245 # We install TYPE_MATCH guards for traceable wrapper subclass object,
22402246 # and recursively install corresponding guard for each inner attribute.
22412247 if is_traceable_wrapper_subclass (value ):
2242- self .install_guards (GuardBuilder .TENSOR_SUBCLASS_METADATA_MATCH )
2243- self .install_guards (GuardBuilder .TYPE_MATCH )
2244- install_guard (
2245- SubclassAttrListSource (source ).make_guard (GuardBuilder .EQUALS_MATCH )
2246- )
2248+ # Tensor subclass guards are very expensive because they are
2249+ # implemented in Python. Since DTensor is PyTorch-maintained class,
2250+ # we can skip a lot of these guards.
2251+ if is_dtensor :
2252+ self .install_guards (GuardBuilder .TYPE_MATCH )
2253+
2254+ # The inner tensor name is always _local_tensor. If its not, we
2255+ # raise assertion to update the check accordingly.
2256+ inner_tensor_name = value .__tensor_flatten__ ()[0 ][0 ]
2257+ if inner_tensor_name != "_local_tensor" :
2258+ raise RuntimeError (
2259+ "Expecting Dtensor inner tensor name to be _local_tensor"
2260+ )
2261+
2262+ # Now selectively guard on the flattening context
2263+ flattening_ctx = value .__tensor_flatten__ ()[1 ]
2264+ # This is supposed to be (self._spec, self.requires_grad)
2265+ if not (
2266+ len (flattening_ctx ) == 2
2267+ and flattening_ctx [0 ] == value ._spec
2268+ and flattening_ctx [1 ] == value .requires_grad
2269+ ):
2270+ # If not, raise an assertion to update to the new guards
2271+ raise RuntimeError (
2272+ "Expecting Dtensor flattening ctx to be _spec, requires_grad"
2273+ )
2274+ # Guard on the dtensor spec
2275+ install_guard (
2276+ AttrSource (self .source , "_spec" ).make_guard (
2277+ GuardBuilder .DTENSOR_SPEC_MATCH
2278+ )
2279+ )
2280+ # Move this to C++
2281+ install_guard (
2282+ AttrSource (self .source , "requires_grad" ).make_guard (
2283+ GuardBuilder .EQUALS_MATCH
2284+ )
2285+ )
2286+ else :
2287+ self .install_guards (GuardBuilder .TENSOR_SUBCLASS_METADATA_MATCH )
2288+ self .install_guards (GuardBuilder .TYPE_MATCH )
2289+ install_guard (
2290+ SubclassAttrListSource (source ).make_guard (GuardBuilder .EQUALS_MATCH )
2291+ )
22472292
22482293 attrs , _ = value .__tensor_flatten__ ()
22492294 for attr in attrs :
0 commit comments