Skip to content

Commit a878536

Browse files
committed
Fixing tensor.numpy on wrapped tensors
Fixes pytorch#626 Description: - Fixing tensor.numpy on wrapped tensors
1 parent 9d6ee76 commit a878536

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

functorch/_src/monkey_patching.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,34 @@ def _backward(*args, **kwargs):
9898

9999

100100
setattr(torch.Tensor, 'backward', _backward)
101+
102+
103+
# Monkeypatch .numpy() to fetch underlying tensor and call .numpy()
104+
_old_numpy = torch.Tensor.numpy
105+
106+
107+
@functools.wraps(_old_numpy)
108+
def _numpy(tensor):
109+
level = _C.maybe_get_level(tensor)
110+
if level == -1:
111+
return _old_numpy(tensor)
112+
113+
if _C.is_functionaltensor(tensor):
114+
# Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
115+
# that it's up to date first
116+
torch._sync(tensor)
117+
118+
value = _C.get_unwrapped(tensor)
119+
dl_enabled = _C.tls_set_is_included()
120+
try:
121+
# Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
122+
if (dl_enabled):
123+
_C._set_dynamic_layer_keys_included(False)
124+
return value.numpy()
125+
finally:
126+
# Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
127+
if (dl_enabled):
128+
_C._set_dynamic_layer_keys_included(True)
129+
130+
131+
setattr(torch.Tensor, 'numpy', _numpy)

0 commit comments

Comments
 (0)