From 9f79bcc5e15f216a2890542bc72fab124efe296c Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 19 May 2025 21:11:28 -0700 Subject: [PATCH] Copy different device --- torchax/test/test_core_aten_ops.py | 7 +++++++ torchax/torchax/ops/jaten.py | 10 ++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/torchax/test/test_core_aten_ops.py b/torchax/test/test_core_aten_ops.py index 67d2ab8c0c55..e8b6d8effeb3 100644 --- a/torchax/test/test_core_aten_ops.py +++ b/torchax/test/test_core_aten_ops.py @@ -4524,6 +4524,13 @@ def test_aten_linear(self): rtol=1e-2, check_dtype=True) + def test_aten_copy_different_device(self): + cpu_tensor = torch.tensor([1, 2, 3]) + + with self.env: + xla_tensor = torch.tensor([0, 0, 0]) + xla_tensor.copy_(cpu_tensor) + if __name__ == "__main__": test_base.main() diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index fc8dcc71e466..370d1f3b9d28 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -125,8 +125,14 @@ def _aten_add(x, y, *, alpha=1): return res -@op(torch.ops.aten.copy_, is_jax_function=False, is_view_op=True) -def _aten_copy(x, y, memory_format=None): +@op(torch.ops.aten.copy_, + is_jax_function=False, + is_view_op=True, + needs_env=True) +def _aten_copy(x, y, memory_format=None, env=None): + + if y.device.type == 'cpu': + y = env.to_xla(y) if isinstance(x, View): x.update(y)