File tree 1 file changed +31
-0
lines changed
1 file changed +31
-0
lines changed Original file line number Diff line number Diff line change @@ -98,3 +98,34 @@ def _backward(*args, **kwargs):
98
98
99
99
100
100
setattr (torch .Tensor , 'backward' , _backward )
101
+
102
+
103
+ # Monkeypatch .numpy() to fetch underlying tensor and call .numpy()
104
+ _old_numpy = torch .Tensor .numpy
105
+
106
+
107
+ @functools .wraps (_old_numpy )
108
+ def _numpy (tensor ):
109
+ level = _C .maybe_get_level (tensor )
110
+ if level == - 1 :
111
+ return _old_numpy (tensor )
112
+
113
+ if _C .is_functionaltensor (tensor ):
114
+ # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
115
+ # that it's up to date first
116
+ torch ._sync (tensor )
117
+
118
+ value = _C .get_unwrapped (tensor )
119
+ dl_enabled = _C .tls_set_is_included ()
120
+ try :
121
+ # Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
122
+ if (dl_enabled ):
123
+ _C ._set_dynamic_layer_keys_included (False )
124
+ return value .numpy ()
125
+ finally :
126
+ # Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
127
+ if (dl_enabled ):
128
+ _C ._set_dynamic_layer_keys_included (True )
129
+
130
+
131
+ setattr (torch .Tensor , 'numpy' , _numpy )
You can’t perform that action at this time.
0 commit comments