diff --git a/API_GUIDE.md b/API_GUIDE.md index 47f0e674b798..414f14b8b32d 100644 --- a/API_GUIDE.md +++ b/API_GUIDE.md @@ -15,14 +15,14 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -t = torch.randn(2, 2, device=xm.xla_device()) +t = torch.randn(2, 2, device="xla") print(t.device) print(t) ``` This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and -`xm.xla_device()` returns the current XLA device. This may be a CPU or TPU +`torch_xla.device()` returns the current XLA device. This may be a CPU or TPU depending on your environment. ## XLA Tensors are PyTorch Tensors @@ -32,8 +32,8 @@ PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors For example, XLA tensors can be added together: ```python -t0 = torch.randn(2, 2, device=xm.xla_device()) -t1 = torch.randn(2, 2, device=xm.xla_device()) +t0 = torch.randn(2, 2, device="xla") +t1 = torch.randn(2, 2, device="xla") print(t0 + t1) ``` @@ -46,8 +46,8 @@ print(t0.mm(t1)) Or used with neural network modules: ```python -l_in = torch.randn(10, device=xm.xla_device()) -linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +l_in = torch.randn(10, device="xla") +linear = torch.nn.Linear(10, 20).to(torch_xla.device()) l_out = linear(l_in) print(l_out) ``` @@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like ```python -l_in = torch.randn(10, device=xm.xla_device()) +l_in = torch.randn(10, device="xla") linear = torch.nn.Linear(10, 20) l_out = linear(l_in) print(l_out) @@ -109,10 +109,10 @@ class MNIST(nn.Module): batch_size = 128 train_loader = xu.SampleGenerator( data=(torch.zeros(batch_size, 1, 28, 28), - torch.zeros(batch_size, dtype=torch.int64)), + torch.zeros(batch_size, dtype=torch.int64)), sample_count=60000 // batch_size // xr.world_size()) -device = xm.xla_device() # Get the XLA device (TPU). +device = torch_xla.device() # Get the XLA device (TPU). model = MNIST().train().to(device) # Create a model and move it to the device. loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) @@ -169,7 +169,7 @@ def _mp_fn(index): index: Index of the process. """ - device = xm.xla_device() # Get the device assigned to this process. + device = torch_xla.device() # Get the device assigned to this process. # Wrap the loader for multi-device. mp_device_loader = pl.MpDeviceLoader(train_loader, device) @@ -197,7 +197,7 @@ single device snippet. Let's go over then one by one. - `torch_xla.launch()` - Creates the processes that each run an XLA device. - This function is a wrapper of multithreading spawn to allow user run the script with torchrun command line also. Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device. - - Note that if you print the `xm.xla_device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details). + - Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details). - `MpDeviceLoader` - Loads the training data onto each device. - `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance. @@ -290,7 +290,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -device = xm.xla_device() +device = torch_xla.device() t0 = torch.randn(2, 2, device=device) t1 = torch.randn(2, 2, device=device) diff --git a/README.md b/README.md index eca641f9602d..6938155361e7 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,7 @@ If you're using `DistributedDataParallel`, make the following changes: + # Rank and world size are inferred from the XLA device runtime + dist.init_process_group("xla", init_method='xla://') + -+ model.to(xm.xla_device()) ++ model.to(torch_xla.device()) + ddp_model = DDP(model, gradient_as_bucket_view=True) - model = model.to(rank) diff --git a/benchmarks/benchmark_experiment.py b/benchmarks/benchmark_experiment.py index a82490d373b1..e1fab48334a8 100644 --- a/benchmarks/benchmark_experiment.py +++ b/benchmarks/benchmark_experiment.py @@ -208,7 +208,7 @@ def update_process_env(self, process_env: Dict[str, str]): def get_device(self): if self.torch_xla2: # Initiate the model in CPU first for xla2. We will move the model to jax device later. - # This is because we don't have xm.xla_device() function in torch_xla2. + # This is because we don't have torch_xla.device() function in torch_xla2. return torch.device("cpu") if self.xla: return xm.xla_device(devkind=self.accelerator.upper()) diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index a3b09c7cd7e0..b784af68e47b 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -255,7 +255,7 @@ def _default_iter_fn(self, benchmark_experiment: BenchmarkExperiment, def _pure_wall_time_iter_fn(self, benchmark_experiment: BenchmarkExperiment, benchmark_model: BenchmarkModel, input_tensor): - device = xm.xla_device() if benchmark_experiment.xla else 'cuda' + device = torch_xla.device() if benchmark_experiment.xla else 'cuda' sync_fn = xm.wait_device_ops if benchmark_experiment.xla else torch.cuda.synchronize timing, output = bench.do_bench( lambda: benchmark_model.model_iter_fn( diff --git a/benchmarks/matmul_bench.py b/benchmarks/matmul_bench.py index af518f355ca2..51c595054152 100644 --- a/benchmarks/matmul_bench.py +++ b/benchmarks/matmul_bench.py @@ -39,10 +39,7 @@ def main(): """ xla_bench_fn = lambda fn: do_bench( - fn, - return_mode='min', - sync_fn=lambda: xm.wait_device_ops(), - device=xm.xla_device()) + fn, return_mode='min', sync_fn=lambda: xm.wait_device_ops(), device="xla") ind_bench_fn = lambda fn: do_bench( fn, return_mode='min', @@ -53,7 +50,7 @@ def main(): for dtype in dtypes: for inductor_matmul, xla_matmul in zip( get_matmuls(device='cuda', dtype=dtype, backend='inductor'), - get_matmuls(device=xm.xla_device(), dtype=dtype, backend='openxla')): + get_matmuls(device="xla", dtype=dtype, backend='openxla')): ind_lhs_shape, ind_rhs_shape, ind_fn = inductor_matmul xla_lhs_shape, xla_rhs_shape, xla_fn = xla_matmul assert ind_lhs_shape == xla_lhs_shape, f"Expect matmul shapes to match for benchmarking. Mismatch lhs: {ind_lhs_shape}, rhs: {xla_rhs_shape}" diff --git a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb index d4d676f745e5..d72b31350d7b 100644 --- a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb +++ b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb @@ -188,7 +188,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To get the current process/thread's default XLA device, use `xm.xla_device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`." + "To get the current process/thread's default XLA device, use `torch_xla.device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`." ] }, { @@ -210,7 +210,7 @@ "lock = mp.Manager().Lock()\n", "\n", "def print_device(i, lock):\n", - " device = xm.xla_device()\n", + " device = torch_xla.device()\n", " with lock:\n", " print('process', i, device)" ] @@ -273,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-01-10T19:30:33.219878Z", @@ -318,12 +318,12 @@ ], "source": [ "def add_ones(i, lock):\n", - " x = torch.ones((3, 3), device=xm.xla_device())\n", + " x = torch.ones((3, 3), device=\"xla\")\n", " y = x + x\n", - " \n", + "\n", " # Run graph to compute `y` before printing\n", " torch_xla.sync()\n", - " \n", + "\n", " with lock:\n", " print(i, y)\n", "\n", @@ -340,7 +340,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-01-10T19:30:35.656796Z", @@ -378,10 +378,10 @@ "source": [ "def gather_ids(i, lock):\n", " # Create a tensor on each device with the device ID\n", - " t = torch.tensor([i], device=xm.xla_device())\n", + " t = torch.tensor([i], device=\"xla\")\n", " with lock:\n", " print(i, t)\n", - " \n", + "\n", " # Collect and concatenate the IDs\n", " ts = xm.all_gather(t)\n", " torch_xla.sync()\n", @@ -402,7 +402,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-01-10T19:30:38.315927Z", @@ -454,7 +454,7 @@ "import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n", "\n", "def toy_model(index, lock):\n", - " device = xm.xla_device()\n", + " device = torch_xla.device()\n", " dist.init_process_group('xla', init_method='xla://')\n", "\n", " # Initialize a basic toy model\n", @@ -479,7 +479,7 @@ " loss.backward()\n", "\n", " optimizer.step()\n", - " \n", + "\n", " # Run the pending graph\n", " torch_xla.sync()\n", "\n", diff --git a/contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb b/contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb index a2c2f0d099e2..a0c5d6d3d769 100644 --- a/contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb +++ b/contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb @@ -172,7 +172,7 @@ "\n", "pipeline = DiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\")\n", "# Move the model to the first TPU core\n", - "pipeline = pipeline.to(xm.xla_device())" + "pipeline = pipeline.to(torch_xla.device())" ] }, { diff --git a/docs/source/learn/_pjrt.md b/docs/source/learn/_pjrt.md index 5531ce8824a0..58548bf5676b 100644 --- a/docs/source/learn/_pjrt.md +++ b/docs/source/learn/_pjrt.md @@ -73,7 +73,7 @@ import torch_xla.distributed.xla_backend def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() - dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size()) + dist.init_process_group('xla', init_method='xla://') @@ -377,7 +377,7 @@ def _all_gather(index: int): # No need to pass in `rank` or `world_size` dist.init_process_group('xla', init_method='xla://') - t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device()) + t = torch.tensor([index], dtype=torch.int32, device="xla") output = [torch.zeros_like(t) for _ in range(dist.get_world_size())] dist.all_gather(output, t) diff --git a/docs/source/learn/pytorch-on-xla-devices.md b/docs/source/learn/pytorch-on-xla-devices.md index 48e0da77bd70..d55b7161929d 100644 --- a/docs/source/learn/pytorch-on-xla-devices.md +++ b/docs/source/learn/pytorch-on-xla-devices.md @@ -14,14 +14,14 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -t = torch.randn(2, 2, device=xm.xla_device()) +t = torch.randn(2, 2, device="xla") print(t.device) print(t) ``` This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing `torch_xla` initializes -PyTorch/XLA, and `xm.xla_device()` returns the current XLA device. This +PyTorch/XLA, and `torch_xla.device()` returns the current XLA device. This may be a CPU or TPU depending on your environment. ## XLA Tensors are PyTorch Tensors @@ -32,8 +32,8 @@ tensors. For example, XLA tensors can be added together: ``` python -t0 = torch.randn(2, 2, device=xm.xla_device()) -t1 = torch.randn(2, 2, device=xm.xla_device()) +t0 = torch.randn(2, 2, device="xla") +t1 = torch.randn(2, 2, device="xla") print(t0 + t1) ``` @@ -46,8 +46,8 @@ print(t0.mm(t1)) Or used with neural network modules: ``` python -l_in = torch.randn(10, device=xm.xla_device()) -linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +l_in = torch.randn(10, device="xla") +linear = torch.nn.Linear(10, 20).to(torch_xla.device()) l_out = linear(l_in) print(l_out) ``` @@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like ``` python -l_in = torch.randn(10, device=xm.xla_device()) +l_in = torch.randn(10, device="xla") linear = torch.nn.Linear(10, 20) l_out = linear(l_in) print(l_out) @@ -79,7 +79,7 @@ The following snippet shows a network training on a single XLA device: ``` python import torch_xla.core.xla_model as xm -device = xm.xla_device() +device = torch_xla.device() model = MNIST().train().to(device) loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) @@ -116,7 +116,7 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() mp_device_loader = pl.MpDeviceLoader(train_loader, device) model = MNIST().train().to(device) @@ -144,7 +144,7 @@ previous single device snippet. Let's go over then one by one. will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device. - - Note that if you print the `xm.xla_device()` on each process you + - Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only execution is with PJRT runtime on TPU v2 @@ -279,7 +279,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -device = xm.xla_device() +device = torch_xla.device() t0 = torch.randn(2, 2, device=device) t1 = torch.randn(2, 2, device=device) diff --git a/docs/source/learn/troubleshoot.md b/docs/source/learn/troubleshoot.md index 3768104ccf98..264fc822d561 100644 --- a/docs/source/learn/troubleshoot.md +++ b/docs/source/learn/troubleshoot.md @@ -32,8 +32,8 @@ vm:~$ export PJRT_DEVICE=TPU vm:~$ python3 >>> import torch >>> import torch_xla.core.xla_model as xm ->>> t1 = torch.tensor(100, device=xm.xla_device()) ->>> t2 = torch.tensor(200, device=xm.xla_device()) +>>> t1 = torch.tensor(100, device="xla") +>>> t2 = torch.tensor(200, device="xla") >>> print(t1 + t2) tensor(300, device='xla:0') ``` diff --git a/docs/source/learn/xla-overview.md b/docs/source/learn/xla-overview.md index 987d7b9629f2..f6b0761fd69a 100644 --- a/docs/source/learn/xla-overview.md +++ b/docs/source/learn/xla-overview.md @@ -184,7 +184,7 @@ repo. contains examples for training and serving many LLM and diffusion models. General guidelines to modify your code: -- Replace `cuda` with `xm.xla_device()` +- Replace `cuda` with `torch_xla.device()` - Remove progress bar, printing that would access the XLA tensor values - Reduce logging and callbacks that would access the XLA tensor values @@ -227,7 +227,7 @@ tutorial, but you can pass the `device` value to the function as well. ``` python import torch_xla.core.xla_model as xm - self.device = xm.xla_device() + self.device = torch_xla.device() ``` Another place in the code that has cuda specific code is DDIM scheduler. @@ -244,7 +244,7 @@ if attr.device != torch.device("cuda"): with ``` python -device = xm.xla_device() +device = torch_xla.device() attr = attr.to(torch.device(device)) ``` @@ -339,7 +339,7 @@ with the following lines: ``` python import torch_xla.core.xla_model as xm -device = xm.xla_device() +device = torch_xla.device() pipe.to(device) ``` diff --git a/docs/source/perf/amp.md b/docs/source/perf/amp.md index b5b1a3ffa790..3261624816ab 100644 --- a/docs/source/perf/amp.md +++ b/docs/source/perf/amp.md @@ -16,7 +16,7 @@ example is below: ``` python # Creates model and optimizer in default precision -model = Net().to(xm.xla_device()) +model = Net().to(torch_xla.device()) # Pytorch/XLA provides sync-free optimizers for improved performance optimizer = syncfree.SGD(model.parameters(), ...) @@ -24,7 +24,7 @@ for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(xm.xla_device()): + with autocast(torch_xla.device()): output = model(input) loss = loss_fn(output, target) @@ -33,7 +33,7 @@ for input, target in data: xm.optimizer_step.(optimizer) ``` -`autocast(xm.xla_device())` aliases `torch.autocast('xla')` when the XLA +`autocast(torch_xla.device())` aliases `torch.autocast('xla')` when the XLA Device is a TPU. Alternatively, if a script is only used with TPUs, then `torch.autocast('xla', dtype=torch.bfloat16)` can be directly used. @@ -100,7 +100,7 @@ specific behavior. A simple CUDA AMP example is below: ``` python # Creates model and optimizer in default precision -model = Net().to(xm.xla_device()) +model = Net().to(torch_xla.device()) # Pytorch/XLA provides sync-free optimizers for improved performance optimizer = syncfree.SGD(model.parameters(), ...) scaler = GradScaler() @@ -109,7 +109,7 @@ for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(xm.xla_device()): + with autocast(torch_xla.device()): output = model(input) loss = loss_fn(output, target) @@ -121,12 +121,12 @@ for input, target in data: scaler.update() ``` -`autocast(xm.xla_device())` aliases `torch.cuda.amp.autocast()` when the +`autocast(torch_xla.device())` aliases `torch.cuda.amp.autocast()` when the XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is only used with CUDA devices, then `torch.cuda.amp.autocast` can be directly used, but requires `torch` is compiled with `cuda` support for datatype of `torch.bfloat16`. We recommend using -`autocast(xm.xla_device())` on XLA:GPU as it does not require +`autocast(torch_xla.device())` on XLA:GPU as it does not require `torch.cuda` support for any datatypes, including `torch.bfloat16`. ### AMP for XLA:GPU Best Practices diff --git a/docs/source/perf/ddp.md b/docs/source/perf/ddp.md index 42880a3f5986..b713e6c79d23 100644 --- a/docs/source/perf/ddp.md +++ b/docs/source/perf/ddp.md @@ -105,7 +105,7 @@ def demo_basic(rank): setup(rank, world_size) # create model and move it to XLA device - device = xm.xla_device() + device = torch_xla.device() model = ToyModel().to(device) ddp_model = DDP(model, gradient_as_bucket_view=True) diff --git a/docs/source/perf/dynamo.md b/docs/source/perf/dynamo.md index a34a162d89bc..2ef86a2a5f9d 100644 --- a/docs/source/perf/dynamo.md +++ b/docs/source/perf/dynamo.md @@ -23,8 +23,8 @@ import torch import torch_xla.core.xla_model as xm def add(a, b): - a_xla = a.to(xm.xla_device()) - b_xla = b.to(xm.xla_device()) + a_xla = a.to(torch_xla.device()) + b_xla = b.to(torch_xla.device()) return a_xla + b_xla compiled_code = torch.compile(add, backend='openxla') @@ -41,7 +41,7 @@ import torchvision import torch_xla.core.xla_model as xm def eval_model(loader): - device = xm.xla_device() + device = torch_xla.device() xla_resnet18 = torchvision.models.resnet18().to(device) xla_resnet18.eval() dynamo_resnet18 = torch.compile( @@ -129,7 +129,7 @@ def train_model(model, data, target, optimizer): return pred def train_model_main(loader): - device = xm.xla_device() + device = torch_xla.device() xla_resnet18 = torchvision.models.resnet18().to(device) xla_resnet18.train() dynamo_train_model = torch.compile( diff --git a/docs/source/perf/fori_loop.md b/docs/source/perf/fori_loop.md index 1fb80be8a3f9..bfdd2bf318ab 100644 --- a/docs/source/perf/fori_loop.md +++ b/docs/source/perf/fori_loop.md @@ -30,7 +30,7 @@ result = while_loop(cond_fn, body_fn, init) >>> from torch._higher_order_ops.while_loop import while_loop >>> import torch_xla.core.xla_model as xm >>> ->>> device = xm.xla_device() +>>> device = torch_xla.device() >>> >>> def cond_fn(iteri, x): ... return iteri > 0 @@ -60,7 +60,7 @@ with similar logic: cumulative plus 1 for ten times: >>> import torch_xla >>> import torch_xla.core.xla_model as xm >>> ->>> device = xm.xla_device() +>>> device = torch_xla.device() >>> >>> init_val = torch.tensor(1, device=device) >>> iteri = torch.tensor(50, device=device) diff --git a/docs/source/perf/quantized_ops.md b/docs/source/perf/quantized_ops.md index 41e37f6a0708..6d44b05e433b 100644 --- a/docs/source/perf/quantized_ops.md +++ b/docs/source/perf/quantized_ops.md @@ -48,7 +48,7 @@ scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16) # Call with torch CPU tensor (For debugging purpose) matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler) -device = xm.xla_device() +device = torch_xla.device() x_xla = x.to(device) w_int_xla = w_int.to(device) scaler_xla = scaler.to(device) diff --git a/docs/source/perf/spmd_basic.md b/docs/source/perf/spmd_basic.md index dbeaea77a95f..ef714beb6298 100644 --- a/docs/source/perf/spmd_basic.md +++ b/docs/source/perf/spmd_basic.md @@ -41,7 +41,7 @@ mesh_shape = (num_devices, 1) device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, ('data', 'model')) -t = torch.randn(8, 4).to(xm.xla_device()) +t = torch.randn(8, 4).to(torch_xla.device()) # Mesh partitioning, each device holds 1/8-th of the input partition_spec = ('data', 'model') diff --git a/examples/scan/scan_examples.py b/examples/scan/scan_examples.py index 5a4097d029ee..8fb0dd2e2b8f 100644 --- a/examples/scan/scan_examples.py +++ b/examples/scan/scan_examples.py @@ -18,8 +18,8 @@ def cumsum(accumulated, element): return accumulated, accumulated # 2) Define an initial carry and the input tensor. - init_sum = torch.tensor([0.0], device=torch_xla.device()) - xs = torch.tensor([1.0, 2.0, 3.0], device=torch_xla.device()) + init_sum = torch.tensor([0.0], device="xla") + xs = torch.tensor([1.0, 2.0, 3.0], device="xla") torch_xla.sync() # 3) Call `scan` with our combine function, initial carry, and input tensor. @@ -40,15 +40,15 @@ def scan_example_pytree(): # - 'sum' to accumulate the sum of all seen values # - 'count' to count how many values have been seen carry = { - 'sum': torch.tensor([0.0], device=torch_xla.device()), - 'count': torch.tensor([0.0], device=torch_xla.device()) + 'sum': torch.tensor([0.0], device="xla"), + 'count': torch.tensor([0.0], device="xla") } # 2) Define our input PyTree, which in this case is just a dictionary with one leaf: # - 'values' is a 1D tensor representing data points we want to scan over. xs = { 'values': - torch.arange(1, 6, dtype=torch.float32, device=torch_xla.device()) + torch.arange(1, 6, dtype=torch.float32, device="xla") } # Here, xs['values'] has shape [5]. The `scan` function will automatically slice diff --git a/examples/train_resnet_amp.py b/examples/train_resnet_amp.py index ae541705d717..8082d01524e9 100644 --- a/examples/train_resnet_amp.py +++ b/examples/train_resnet_amp.py @@ -17,7 +17,7 @@ def train_loop_fn(self, loader, epoch): for step, (data, target) in enumerate(loader): self.optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(xm.xla_device()): + with autocast(torch_xla.device()): output = self.model(data) loss = self.loss_fn(output, target) # TPU amp uses bf16 hence gradient scaling is not necessary. If runnign with XLA:GPU diff --git a/plugins/cpu/README.md b/plugins/cpu/README.md index c46a315e31c3..76c9d0b7c88e 100644 --- a/plugins/cpu/README.md +++ b/plugins/cpu/README.md @@ -38,5 +38,5 @@ plugins.use_dynamic_plugins() plugins.register_plugin('CPU', torch_xla_cpu_plugin.CpuPlugin()) xr.set_device_type('CPU') -print(xm.xla_device()) +print(torch_xla.device()) ``` diff --git a/plugins/cuda/README.md b/plugins/cuda/README.md index f86caf8e0d27..45a002e06f6c 100644 --- a/plugins/cuda/README.md +++ b/plugins/cuda/README.md @@ -35,5 +35,5 @@ plugins.use_dynamic_plugins() plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin()) xr.set_device_type('CUDA') -print(xm.xla_device()) +print(torch_xla.device()) ``` diff --git a/test/bench.py b/test/bench.py index 97fb24f0a5ae..e5eff86a34d5 100644 --- a/test/bench.py +++ b/test/bench.py @@ -29,7 +29,7 @@ class BaseBench(object): def __init__(self, args): self.args = args - self.device = xm.xla_device() + self.device = torch_xla.device() self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0) torch.manual_seed(42) diff --git a/test/debug_tool/test_mp_pt_xla_debug.py b/test/debug_tool/test_mp_pt_xla_debug.py index 45a9502af795..785554657b14 100644 --- a/test/debug_tool/test_mp_pt_xla_debug.py +++ b/test/debug_tool/test_mp_pt_xla_debug.py @@ -16,7 +16,7 @@ def _mp_fn(index): assert False, "This test should be run with PT_XLA_DEBUG_FILE" if index == 0: open(debug_file_name, 'w').close() - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(10, 10, device=device) t2 = t1 * 100 torch_xla.sync() diff --git a/test/debug_tool/test_pt_xla_debug.py b/test/debug_tool/test_pt_xla_debug.py index 57864cd74657..4ebcb2cd1bb9 100644 --- a/test/debug_tool/test_pt_xla_debug.py +++ b/test/debug_tool/test_pt_xla_debug.py @@ -31,7 +31,7 @@ def setUpClass(cls): def test_eager_sync(self): with torch_xla.experimental.eager_mode_context(True): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(5, 9, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: @@ -41,7 +41,7 @@ def test_eager_sync(self): open(self.debug_file_name, 'w').close() def test_user_sync(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(2, 2, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: @@ -79,7 +79,7 @@ def test_user_sync(self): open(self.debug_file_name, 'w').close() def test_step_trace(self): - device = xm.xla_device() + device = torch_xla.device() with xp.StepTrace('train_pt_xla_debug'): t1 = torch.randn(3, 3, device=device) with open(self.debug_file_name, 'rb') as f: @@ -111,7 +111,7 @@ def test_step_trace(self): open(self.debug_file_name, 'w').close() def test_dynamo(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(4, 4, device=device) def toy_program(t1): @@ -161,7 +161,7 @@ def toy_program(t1): open(self.debug_file_name, 'w').close() def test_torch_xla_compile(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(12, 4, device=device) def toy_program(t1): @@ -209,7 +209,7 @@ def toy_program(t1): open(self.debug_file_name, 'w').close() def test_torch_xla_compile_custom_name(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(18, 4, device=device) def toy_program2(t1): @@ -239,7 +239,7 @@ def toy_program2(t1): open(self.debug_file_name, 'w').close() def test_parallel_loader(self): - device = xm.xla_device() + device = torch_xla.device() train_dataset_len = 100 batch_size = 10 @@ -287,7 +287,7 @@ def test_parallel_loader(self): open(self.debug_file_name, 'w').close() def test_print(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(5, 5, device=device) print(t1) with open(self.debug_file_name, 'rb') as f: @@ -315,7 +315,7 @@ def test_print(self): open(self.debug_file_name, 'w').close() def test_frame(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(6, 6, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: diff --git a/test/distributed_util.py b/test/distributed_util.py index 040ea451b4e2..85069aaabc82 100644 --- a/test/distributed_util.py +++ b/test/distributed_util.py @@ -101,7 +101,7 @@ def ddp_correctness(init_method: str = 'env://', dist.init_process_group("xla", init_method=init_method) rank, world_size = dist.get_rank(), dist.get_world_size() - device = xm.xla_device() + device = torch_xla.device() # Module initialization is not thread safe. Force threads to initialize one # at a time with the same seed diff --git a/test/ds/test_dynamic_shape_models.py b/test/ds/test_dynamic_shape_models.py index b6d06ea65f7f..2c1c827e7fac 100644 --- a/test/ds/test_dynamic_shape_models.py +++ b/test/ds/test_dynamic_shape_models.py @@ -17,7 +17,7 @@ # It enables us to run python implementations of CompositeAutogradImplicit ops. # CompositeAutogradImplicit means we don't have an explicit backward formula for an op instead an op is composed of a bunch of ops that do have backward formulas and combines this formulas is equivalent to differentiating the op explicitly. pd = torch._C._EnablePythonDispatcher() -xla_dev = xm.xla_device() +xla_dev = torch_xla.device() class Feedforward(torch.nn.Module): diff --git a/test/ds/test_dynamic_shapes.py b/test/ds/test_dynamic_shapes.py index 1c1a62e2724c..5139c2412a33 100644 --- a/test/ds/test_dynamic_shapes.py +++ b/test/ds/test_dynamic_shapes.py @@ -9,7 +9,7 @@ import test_utils pd = torch._C._EnablePythonDispatcher() -dev = xm.xla_device() +dev = torch_xla.device() class TestDynamicShapes(test_utils.XlaTestCase): @@ -163,7 +163,7 @@ def test_t_copy(self): self.assertEqual(t2_t.shape[1], 7) def test_nonzero_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device="xla") x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.nonzero(x, as_tuple=False), 0) self.assertEqual(x_dim0_shape.item(), 4) @@ -176,14 +176,14 @@ def test_nonzero_correctness(self): self.assertEqual(t2.cpu(), t2_aten) def test_masked_select_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device="xla") mask = x.ge(2) x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.masked_select(x, mask), 0) self.assertEqual(x_dim0_shape.item(), 3) def test_nonzero_cast(self): - t1 = torch.ones(5, 2, device=xm.xla_device()) + t1 = torch.ones(5, 2, device="xla") # Result of the nonzero should be the index type. Currently # index type is s64 on cpu and gpu, but s32 on TPU. We should be # able to cast it to any other type without error. @@ -191,7 +191,7 @@ def test_nonzero_cast(self): torch_xla.sync() def test_expand_symint_correctness(self): - dev = xm.xla_device() + dev = torch_xla.device() size1 = 5 size2 = 2 t1 = torch.ones([size1, size2]) diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 52c9cbf1053b..da36031759a2 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -116,7 +116,7 @@ def unwrap(cont): def make_reuse_graph_test(module_class, niter=100): def test_wrapper(self): - xla_dev = xm.xla_device() + xla_dev = torch_xla.device() xla_module = module_class().to(device=xla_dev) inputs = tuple(x.to(device=xla_dev) for x in xla_module.get_random_inputs()) metrics.clear_counters() @@ -187,7 +187,7 @@ def make_training_test(model_cls): def test_wrapper(self): import torch_xla.core.xla_model as xm - xla_dev = xm.xla_device() + xla_dev = torch_xla.device() model = model_cls() inputs = model.get_random_inputs() @@ -240,7 +240,7 @@ class Emb(torch.nn.Embedding): def __init__(self): super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0) - device = xm.xla_device() + device = torch_xla.device() module = Emb() module.to(device) @@ -255,7 +255,7 @@ def test_inputs_not_computed(self): def foo(x): return x * 2 - device = xm.xla_device() + device = torch_xla.device() x = torch.rand(5, device=device) x = x.unsqueeze(dim=-1) self._compile_and_check(foo, (x,)) @@ -265,7 +265,7 @@ def test_factory_copy(self): def foo(device): return torch.arange(5, device="cpu").to(device) - self._compile_and_check(foo, (xm.xla_device(),)) + self._compile_and_check(foo, (torch_xla.device(),)) def test_index_flag_unsupported(self): # The indices of the index operation are represented as @@ -277,7 +277,7 @@ def test_index_flag_unsupported(self): def foo(xt, t): return xt[t] - device = xm.xla_device() + device = torch_xla.device() xt = torch.rand(5, device=device) t = torch.randint(0, 5, (3,)) self._compile_and_check(foo, (xt, t)) @@ -299,7 +299,7 @@ def test_cpu_flag_unsupported(self): def foo(t): return t.cpu() - device = xm.xla_device() + device = torch_xla.device() t = torch.randint(0, 5, (3,), device=device) self._compile_and_check(foo, (t,)) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 0bf4b9110f2a..815b4ab05a9d 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -49,7 +49,7 @@ def inplace_update(self, a): def test_inplace_update_correctness(self, backend): dynamo_inplace = torch.compile( self.inplace_update, backend=backend, fullgraph=True) - t = torch.tensor([0, 1, 2], device=xm.xla_device()) + t = torch.tensor([0, 1, 2], device="xla") for i in range(10): t = dynamo_inplace(t) self.assertTrue(torch.all(torch.eq(t.cpu(), torch.tensor([10, 11, 12])))) @@ -66,7 +66,7 @@ def test_random_op_different_result_each_run(self, backend): met.clear_all() dynamo_random_op = torch.compile( self.random_op, backend=backend, fullgraph=True) - t = torch.randn(5, 5).to(xm.xla_device()) + t = torch.randn(5, 5).to(torch_xla.device()) dynamo_res_1 = dynamo_random_op(t) dynamo_res_2 = dynamo_random_op(t) dynamo_res_3 = dynamo_random_op(t) @@ -89,7 +89,7 @@ def test_sync_after_dynamo(self): head_dim = 128 running = 16 - device = xm.xla_device() + device = torch_xla.device() cache = torch.rand((cache_len, kv_heads, head_dim)).to(device) update_indices = torch.randint( 0, cache_len, (running,), dtype=torch.long).to(device) @@ -131,7 +131,7 @@ def dummy_fn(self, a): def test_dynamo_with_trace(self): dynamo_dummy = torch.compile( self.dummy_fn, backend="openxla", fullgraph=True) - t = torch.randn(2, 3, 4, device=xm.xla_device()) + t = torch.randn(2, 3, 4, device="xla") for i in range(10): with xp.Trace('build_graph'): t = dynamo_dummy(t) @@ -150,7 +150,7 @@ def fn_simple(self, x, y): def _choose_proper_device(self, initialize_on_cuda): if not initialize_on_cuda: - return xm.xla_device() + return torch_xla.device() assert initialize_on_cuda if xr.device_type() != "CUDA" or not torch.cuda.is_available(): @@ -164,7 +164,7 @@ def _choose_proper_device(self, initialize_on_cuda): @skipOnNeuron def test_simple_model(self): - device = xm.xla_device() + device = torch_xla.device() x = torch.tensor(100.0) y = torch.tensor(200.0) xla_x = x.to(device) @@ -448,7 +448,7 @@ def fn_fallback(t): torch._dynamo.reset() met.clear_all() - device = xm.xla_device() + device = torch_xla.device() # Initial tracing dynamo_fn = torch.compile(fn_fallback, backend="openxla") @@ -488,7 +488,7 @@ def fn_fallback(t): torch._dynamo.reset() met.clear_all() - device = xm.xla_device() + device = torch_xla.device() # Initial tracing dynamo_fn = torch.compile(fn_fallback, backend="openxla") @@ -541,7 +541,7 @@ def train_model(self, model, data, target): def test_simple_model(self): torch._dynamo.reset() - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(3, 5, requires_grad=True) xla_input = input.detach().to(device) xla_input.requires_grad = True @@ -577,7 +577,7 @@ def test_simple_model(self): def test_resnet18(self): torch._dynamo.reset() met.clear_counters() - device = xm.xla_device() + device = torch_xla.device() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = xu.SampleGenerator( @@ -650,7 +650,7 @@ def train_model(self, model, data, target, optimizer): def test_simple_model(self): torch._dynamo.reset() - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(3, 5, requires_grad=True) saved_input = input.detach().to(device).cpu() xla_input = input.detach().to(device) @@ -673,7 +673,7 @@ def test_simple_model(self): def test_resnet18(self): torch._dynamo.reset() met.clear_counters() - device = xm.xla_device() + device = torch_xla.device() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = xu.SampleGenerator( @@ -732,7 +732,7 @@ def test_resnet18(self): class DynamoErrorMessageTest(parameterized.TestCase): def test_mixed_cpu_tensor(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(4, 3, 224, 224) input_xla = input.clone().to(device) resnet18 = torchvision.models.resnet18() @@ -783,7 +783,7 @@ def foo(x): optfoo = torch.compile(backend=backend)(foo) t = torch.arange(9) - Xt = t.to(xm.xla_device()) + Xt = t.to(torch_xla.device()) expected = foo(t) actual = optfoo(Xt).cpu() @@ -803,7 +803,7 @@ def foo(x): optfoo = torch.compile(backend=backend)(foo) t = torch.arange(10) - Xt = t.to(xm.xla_device()) + Xt = t.to(torch_xla.device()) expected = foo(t) actual = optfoo(Xt) diff --git a/test/dynamo/test_dynamo_aliasing.py b/test/dynamo/test_dynamo_aliasing.py index a28567f42c25..36bfb5744bd4 100644 --- a/test/dynamo/test_dynamo_aliasing.py +++ b/test/dynamo/test_dynamo_aliasing.py @@ -11,7 +11,7 @@ class TestBufferDonationUtil(unittest.TestCase): def test_hash_with_buffer_donor(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) res = torch.cos(input) hash_no_donor = torch_xla._XLAC._get_graph_hash([res]) @@ -40,7 +40,7 @@ def dummy_mul(self, input): return input * 1.1 def test_manual_buffer_donation(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_mul_compiled = torch.compile( @@ -55,7 +55,7 @@ def test_manual_buffer_donation(self): torch.allclose(input_cloned.cpu() * 1.1, input.cpu()) def test_manual_buffer_donation_for_non_inplce_op(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_mul_compiled = torch.compile(self.dummy_mul, backend='openxla') @@ -81,7 +81,7 @@ def dummy_inplace(input): torch.ops.xla.dynamo_set_buffer_donor_(input, True) input += (0.5 * torch.sin(input)) - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla') @@ -109,7 +109,7 @@ def dummy_add(self, input): return input + 1 def test_manual_buffer_donation(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile( @@ -127,7 +127,7 @@ def test_manual_buffer_donation(self): self.assertFalse(torch_xla._XLAC._get_buffer_donation(input)) def test_manual_buffer_donation_for_non_inplce_op(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_add_compiled = torch.compile(self.dummy_add, backend='openxla') @@ -152,7 +152,7 @@ def test_manual_buffer_donation_for_inplce_op_repeat(self): def dummy_inplace(input): input += (0.3 * torch.cos(input)) - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla') @@ -174,7 +174,7 @@ def dummy_inplace(input): self.assertEqual(met.metric_data('CompileTime')[0], 1) def test_buffer_donation_on_non_data_tensor(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) res = input + 1 diff --git a/test/dynamo/test_dynamo_graph_dump.py b/test/dynamo/test_dynamo_graph_dump.py index 6ce95dcbeff3..ae0383a47963 100644 --- a/test/dynamo/test_dynamo_graph_dump.py +++ b/test/dynamo/test_dynamo_graph_dump.py @@ -27,7 +27,7 @@ def test_dump_graph_with_dynamo_execution(self): if not save_file: assert False, "This test should be run with XLA_SAVE_TENSORS_FILE" save_file += '.0' - device = xm.xla_device() + device = torch_xla.device() xla_x = torch.tensor(100.0).to(device) xla_y = torch.tensor(200.0).to(device) res_xla_dynamo = self.fn_simple_dynamo(xla_x, xla_y) diff --git a/test/dynamo/test_dynamo_integrations_util.py b/test/dynamo/test_dynamo_integrations_util.py index b18262fc113a..293bef17ec05 100644 --- a/test/dynamo/test_dynamo_integrations_util.py +++ b/test/dynamo/test_dynamo_integrations_util.py @@ -20,7 +20,7 @@ class PybindTest(unittest.TestCase): def test_get_tensors_xla_device_data_node(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.randn(20, 5).to(xla_device) t2 = torch.randn(20, 5).to(xla_device) t3 = t2 + t1 @@ -42,7 +42,7 @@ def test_get_tensors_xla_device_data_node(self): assert (expected_tensor_ids == sorted(res_pair[0])) def test_get_base_seed_as_tensor(self): - device = xm.xla_device() + device = torch_xla.device() xm.set_rng_state(23, str(device)) base_seed = torch_xla._XLAC._get_base_seed_as_tensor(str(device)).item() self.assertEqual(23, base_seed) @@ -51,7 +51,7 @@ def test_get_seed_info_id(self): self.assertEqual(torch_xla._XLAC._get_seed_info_id(), -127389) def test_check_tensor_need_materialization(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.randn(20, 5) assert (torch_xla._XLAC._check_tensor_need_materialization([t1]) == [False]) t1 = t1.to(xla_device) @@ -67,7 +67,7 @@ def test_check_tensor_need_materialization(self): assert (torch_xla._XLAC._check_tensor_need_materialization([t1]) == [True]) def test_get_graph_hash(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_input = torch.randn(64, 256, 14, 14).to(xla_device) xla_dummy_model = dummy_model.to(xla_device) xla_out = xla_dummy_model(xla_input) @@ -85,7 +85,7 @@ def test_get_graph_hash(self): assert (hash == torch_xla._XLAC._get_graph_hash([xla_out_2])) def test_clear_pending_irs(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() torch_xla.sync() t1 = torch.randn(20, 5).to(xla_device) t2 = torch.randn(20, 5).to(xla_device) @@ -104,7 +104,7 @@ def test_clear_pending_irs(self): self.assertEqual(met.metric_data('ExecuteTime')[0], 1) def test_run_cached_graph(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_input = torch.randn(64, 256, 14, 14).to(xla_device) xla_dummy_model = dummy_model.to(xla_device) xla_out = xla_dummy_model(xla_input) diff --git a/test/dynamo/test_graph_input_matcher.py b/test/dynamo/test_graph_input_matcher.py index 9060248d32de..70dd0be73f57 100644 --- a/test/dynamo/test_graph_input_matcher.py +++ b/test/dynamo/test_graph_input_matcher.py @@ -24,7 +24,7 @@ def get_example_inputs(self): class TestGraphInputMatcher(unittest.TestCase): def test_no_cache_fx_gragh_inputs(self): - xla_dev = xm.xla_device() + xla_dev = torch_xla.device() model = M().to(device=xla_dev) inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev), model.get_example_inputs()) diff --git a/test/dynamo/test_num_output.py b/test/dynamo/test_num_output.py index ab86df8a6c50..b540e0691643 100644 --- a/test/dynamo/test_num_output.py +++ b/test/dynamo/test_num_output.py @@ -59,7 +59,7 @@ def get_example_inputs(self): class TestNumOutput(unittest.TestCase): def do_test(self, model_class, expected_num_output): - xla_dev = xm.xla_device() + xla_dev = torch_xla.device() model = model_class().to(device=xla_dev) inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev), model.get_example_inputs()) diff --git a/test/dynamo/test_traceable_collectives.py b/test/dynamo/test_traceable_collectives.py index 948491bddcc4..45bd89266604 100644 --- a/test/dynamo/test_traceable_collectives.py +++ b/test/dynamo/test_traceable_collectives.py @@ -18,7 +18,7 @@ def collective_broadcast_and_cos(input, src): def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): print(f'skip this test for hw {xm.xla_device_hw(device)}') diff --git a/test/metrics_compare_utils_test.py b/test/metrics_compare_utils_test.py index 69a942646a84..b06453a00de2 100644 --- a/test/metrics_compare_utils_test.py +++ b/test/metrics_compare_utils_test.py @@ -275,7 +275,7 @@ def test_compare_metrics_reports_new_counters(self): def test_parse_real_metrics(self): print( 'Testing against TPU. If this hangs, check that $XRT_TPU_CONFIG is set') - x = torch.rand(3, 5, device=xm.xla_device()) + x = torch.rand(3, 5, device="xla") x = torch.flatten(x, 1) x = torch.roll(x, 1, 0) x = torch.flip(x, [0, 1]) diff --git a/test/neuron/test_neuron_data_types.py b/test/neuron/test_neuron_data_types.py index 4b8fb76c001b..105e38497118 100644 --- a/test/neuron/test_neuron_data_types.py +++ b/test/neuron/test_neuron_data_types.py @@ -9,8 +9,8 @@ class NeuronXlaDataTypeTest(unittest.TestCase): def _test_datatypes(self, dtype, op_xla_dtype, op): - t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) - t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) + t1 = torch.tensor([2, 3], dtype=dtype, device="xla") + t2 = torch.tensor([2, 3], dtype=dtype, device="xla") t3 = op(t1, t2) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index ccb020b8bd8a..614040a81dc7 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -17,7 +17,7 @@ class TestXMCollectiveOpsTpu(parameterized.TestCase): @staticmethod def _broadcast(sync): torch.manual_seed(xr.global_ordinal()) - device = xm.xla_device() + device = torch_xla.device() model = nn.Linear(5, 5).to(device) if sync: xm.broadcast_master_param(model) @@ -41,7 +41,7 @@ def test_broadcast_master_param(self, sync): @staticmethod def _all_reduce(pin_layout): - device = xm.xla_device() + device = torch_xla.device() # Prevent 0 and 1 from being converted to constants ordinal = xm.send_cpu_data_to_device( torch.tensor( @@ -63,7 +63,7 @@ def test_all_reduce(self, pin_layout): @staticmethod def _all_gather(pin_layout): - device = xm.xla_device() + device = torch_xla.device() ordinal = torch.tensor([xr.global_ordinal()], device=device) out = xm.all_gather(ordinal, pin_layout=pin_layout) torch_xla.sync() @@ -80,7 +80,7 @@ def test_all_gather(self, pin_layout): @staticmethod def _reduce_scatter(pin_layout): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() tensor = -torch.arange(world_size, dtype=torch.float32).to(device) @@ -105,7 +105,7 @@ def test_reduce_scatter(self, pin_layout): @staticmethod def _all_to_all(pin_layout): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() tensor = torch.cat( @@ -151,7 +151,7 @@ def callable(input): return input dist.init_process_group("xla", init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() input = torch.tensor([xr.global_ordinal()], dtype=torch.float, device=device) @@ -175,7 +175,7 @@ def callable(output, input): return output_tensor dist.init_process_group("xla", init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() input = torch.tensor([xr.global_ordinal()], dtype=torch.float, device=device) @@ -194,7 +194,7 @@ def callable(output, input): def _all_gather(use_dynamo: bool): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() def callable(input): output_tensor = [ @@ -223,7 +223,7 @@ def callable(input): def _reduce_scatter(use_dynamo: bool): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() def callable(output, input): dist.reduce_scatter_tensor(output, input) @@ -248,7 +248,7 @@ def callable(output, input): def _all_to_all_single(use_dynamo: bool, split_size: int = 1): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() def callable(output, input): dist.all_to_all_single(output, input) diff --git a/test/pjrt/test_ddp.py b/test/pjrt/test_ddp.py index 0be8835ddb36..d236b8e11ea1 100644 --- a/test/pjrt/test_ddp.py +++ b/test/pjrt/test_ddp.py @@ -25,7 +25,7 @@ class TestPjRtDistributedDataParallel(parameterized.TestCase): @staticmethod def _ddp_init(index: int = ...): dist.init_process_group('xla', init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() model = nn.Linear(10, 10).to(device) ddp_model = DDP(model) diff --git a/test/pjrt/test_dtypes.py b/test/pjrt/test_dtypes.py index ebac882efdf4..dd6a4344c94b 100644 --- a/test/pjrt/test_dtypes.py +++ b/test/pjrt/test_dtypes.py @@ -10,7 +10,7 @@ class TestDtypes(parameterized.TestCase): torch.bfloat16, torch.complex64) def test_float_round_trip(self, dtype: torch.dtype): t = torch.randn((3, 3), dtype=dtype) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) torch.testing.assert_close(xt.cpu(), t) @parameterized.parameters( @@ -22,12 +22,12 @@ def test_float_round_trip(self, dtype: torch.dtype): ) def test_int_round_trip(self, dtype: torch.dtype): t = torch.randint(0, 128, (3, 3), dtype=dtype) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) torch.testing.assert_close(xt.cpu(), t) def test_bool_round_trip(self): t = torch.randint(0, 2, (3, 3), dtype=torch.bool) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) torch.testing.assert_close(xt.cpu(), t) diff --git a/test/pjrt/test_metrics.py b/test/pjrt/test_metrics.py index 5cee1b7ea5da..66d26f3b33fd 100644 --- a/test/pjrt/test_metrics.py +++ b/test/pjrt/test_metrics.py @@ -27,7 +27,7 @@ def test_metrics_report(self): self.assertEmpty(met.metrics_report()) # Move a tensor to the XLA device and back - torch.rand(3, 3, device=xm.xla_device()).cpu() + torch.rand(3, 3, device="xla").cpu() metrics = met.metrics_report() self.assertNotEmpty(metrics) diff --git a/test/pjrt/test_profiler.py b/test/pjrt/test_profiler.py index 3be3d4a06c40..15e799473b3d 100644 --- a/test/pjrt/test_profiler.py +++ b/test/pjrt/test_profiler.py @@ -32,12 +32,12 @@ class TestPjRtProfiler(absltest.TestCase): def setUp(self): # HACK: ensure libtpu is loaded if using TPU - xm.xla_device() + torch_xla.device() def test_profiler_output(self): tempdir = self.create_tempdir().full_path - device = xm.xla_device() + device = torch_xla.device() ones = torch.ones([5]) with _profile(tempdir): xones = ones.to(device) diff --git a/test/pjrt/test_runtime.py b/test/pjrt/test_runtime.py index fcb44e2cb939..6529b5e826e1 100644 --- a/test/pjrt/test_runtime.py +++ b/test/pjrt/test_runtime.py @@ -59,7 +59,7 @@ def test_num_global_devices(self): def test_xla_device_error(self): with self.assertRaises(IndexError): - xm.xla_device(10) + torch_xla.device(10) @parameterized.named_parameters(('default', {}, True), ('no_default', { 'PJRT_SELECT_DEFAULT_DEVICE': '0' diff --git a/test/pjrt/test_runtime_multi_cpu.py b/test/pjrt/test_runtime_multi_cpu.py index 54da40346ff5..36100b0c90ff 100644 --- a/test/pjrt/test_runtime_multi_cpu.py +++ b/test/pjrt/test_runtime_multi_cpu.py @@ -65,10 +65,10 @@ def forward(ctx, x): def backward(ctx, grad_output): results['forward_ordinal'] = ctx.forward_ordinal results['backward_ordinal'] = xr.global_ordinal() - results['device'] = str(xm.xla_device()) + results['device'] = str(torch_xla.device()) return grad_output - x = torch.ones(1, requires_grad=True, device=xm.xla_device()) + x = torch.ones(1, requires_grad=True, device="xla") y = _CustomBackwards.apply(x) y.backward() torch_xla.sync() @@ -110,7 +110,7 @@ def _hlo_dump(tmpdir: str): os.environ['XLA_SAVE_TENSORS_FMT'] = 'hlo' os.environ['XLA_SAVE_TENSORS_FILE'] = os.path.join(tmpdir, 'save.hlo') - x = torch.randn((3, 3), device=xm.xla_device()) + x = torch.randn((3, 3), device="xla") torch_xla.sync() x.cpu() @@ -124,7 +124,7 @@ def test_hlo_dump(self): @staticmethod def _all_reduce_hlo(): - ones = torch.ones((3, 3), device=xm.xla_device()) + ones = torch.ones((3, 3), device="xla") torch_xla.sync() reduced = xm.all_reduce(xm.REDUCE_SUM, ones) diff --git a/test/pjrt/test_runtime_multi_gpu.py b/test/pjrt/test_runtime_multi_gpu.py index 6609bc39d282..f47c6637c441 100644 --- a/test/pjrt/test_runtime_multi_gpu.py +++ b/test/pjrt/test_runtime_multi_gpu.py @@ -122,10 +122,10 @@ def forward(ctx, x): def backward(ctx, grad_output): results['forward_ordinal'] = ctx.forward_ordinal results['backward_ordinal'] = xr.global_ordinal() - results['device'] = str(xm.xla_device()) + results['device'] = str(torch_xla.device()) return grad_output - x = torch.ones(1, requires_grad=True, device=xm.xla_device()) + x = torch.ones(1, requires_grad=True, device="xla") y = _CustomBackwards.apply(x) y.backward() torch_xla.sync() @@ -166,7 +166,7 @@ def test_spawn(self, spawn): @staticmethod def _broadcast(sync): torch.manual_seed(xr.global_ordinal()) - device = xm.xla_device() + device = torch_xla.device() model = nn.Linear(5, 5).to(device) if sync: xm.broadcast_master_param(model) @@ -188,7 +188,7 @@ def test_broadcast_master_param(self, sync): @staticmethod def _all_gather(pin_layout): - device = xm.xla_device() + device = torch_xla.device() ordinal = torch.tensor([xr.global_ordinal()], device=device) out = xm.all_gather(ordinal, pin_layout=pin_layout) torch_xla.sync() @@ -205,7 +205,7 @@ def test_all_gather(self, pin_layout): @staticmethod def _reduce_scatter(pin_layout): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() tensor = -torch.arange(world_size, dtype=torch.float32).to(device) @@ -231,7 +231,7 @@ def test_reduce_scatter(self, pin_layout): @staticmethod def _all_to_all(pin_layout): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() tensor = torch.cat( diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 021de719adb6..930c9f5ddb7d 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -172,7 +172,7 @@ def test_local_ordinal_with_discontiguous_global_ordinal_v4_threaded(self): @staticmethod def _spawn_threads() -> Dict[int, torch.device]: results = {} - pjrt.spawn_threads(lambda i: results.setdefault(i, xm.xla_device())) + pjrt.spawn_threads(lambda i: results.setdefault(i, torch_xla.device())) return results @@ -187,7 +187,7 @@ def test_spawn_threads(self): @staticmethod def _spawn_error(): # Initialize the client in the parent process - xm.xla_device() + torch_xla.device() torch_xla.launch(xm.xla_device) @@ -199,7 +199,7 @@ def test_spawn_error(self): @staticmethod def _runtime_device_attributes(): - return xr.runtime_device_attributes(str(xm.xla_device())) + return xr.runtime_device_attributes(str(torch_xla.device())) def test_runtime_device_attributes(self): result = pjrt.run_multiprocess(self._runtime_device_attributes) @@ -226,12 +226,12 @@ def test_global_runtime_device_attributes(self): @staticmethod def _execute_time_metric(): # Initialize the client before starting the timer. - xm.xla_device() + torch_xla.device() begin = time.perf_counter_ns() value = ( - torch.randn(10000, 10000, device=xm.xla_device()) * - torch.randn(10000, 10000, device=xm.xla_device())) + torch.randn(10000, 10000, device="xla") * + torch.randn(10000, 10000, device="xla")) value_mean = value.mean() torch_xla.sync() cpu_value = value_mean.cpu() diff --git a/test/pjrt/test_torchrun.py b/test/pjrt/test_torchrun.py index 3939f7f6c582..02cc60982e48 100644 --- a/test/pjrt/test_torchrun.py +++ b/test/pjrt/test_torchrun.py @@ -26,9 +26,7 @@ def test_all_gather(self): expected_world_size = dist_world_size * devices_per_thread - rank = torch.tensor([dist.get_rank()], - dtype=torch.float32, - device=xm.xla_device()) + rank = torch.tensor([dist.get_rank()], dtype=torch.float32, device="xla") output = [rank.clone() for _ in range(expected_world_size)] dist.all_gather(output, rank) result = torch.concat(output) @@ -52,7 +50,7 @@ def test_all_reduce(self): expected = sum(tensors) xla_tensor = torch.arange( - 2, dtype=torch.int64, device=xm.xla_device()) + 1 + 2 * dist.get_rank() + 2, dtype=torch.int64, device="xla") + 1 + 2 * dist.get_rank() dist.all_reduce(xla_tensor, op=dist.ReduceOp.SUM) torch_xla.sync() @@ -69,10 +67,9 @@ def test_reduce_scatter(self): world_size * world_size, dtype=torch.int64) expected = torch.split(tensor, world_size)[dist.get_rank()] - tensor_out = torch.zeros( - world_size, dtype=torch.int64, device=xm.xla_device()) + tensor_out = torch.zeros(world_size, dtype=torch.int64, device="xla") tensor_in = torch.arange( - world_size * world_size, dtype=torch.int64, device=xm.xla_device()) + world_size * world_size, dtype=torch.int64, device="xla") dist.reduce_scatter(tensor_out, [tensor_in], op=dist.ReduceOp.SUM) torch_xla.sync() diff --git a/test/pjrt/test_train_hf_transformer.py b/test/pjrt/test_train_hf_transformer.py index d2c113e9b5eb..d484edc0a6ce 100644 --- a/test/pjrt/test_train_hf_transformer.py +++ b/test/pjrt/test_train_hf_transformer.py @@ -55,7 +55,7 @@ def finetune(rank, train_dataset, test_dataset, tokenizer, flags): drop_last=True, generator=rng) - device = xm.xla_device() + device = torch_xla.device() model = AutoModelForSequenceClassification.from_pretrained( 'google-bert/bert-base-cased', num_labels=5) model.to(device) diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index b47ae3f3de6d..3355f8efba99 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -559,7 +559,7 @@ def _alt_lookup(d, keys, defval): def instantiate_test(cls, name, test, *, generic_cls): test_name = name + '_' + cls.device_type class_name = cls.__name__ - real_device_type = xm.xla_device_hw(str(xm.xla_device())) + real_device_type = xm.xla_device_hw(str(torch_xla.device())) assert real_device_type in DISABLED_TORCH_TESTS, 'Unsupported device type:' + real_device_type disabled_torch_tests = DISABLED_TORCH_TESTS[real_device_type] @@ -632,7 +632,7 @@ def get_primary_device(cls): @classmethod def setUpClass(cls): # Sets the primary test device to the xla_device (CPU or TPU) - cls.primary_device = str(xm.xla_device()) + cls.primary_device = str(torch_xla.device()) torch_xla._XLAC._xla_set_mat_mul_precision('highest') def setUp(self): diff --git a/test/quantized_ops/test_quantized_matmul.py b/test/quantized_ops/test_quantized_matmul.py index b7f415a82b60..88a34c69a4ae 100644 --- a/test/quantized_ops/test_quantized_matmul.py +++ b/test/quantized_ops/test_quantized_matmul.py @@ -12,7 +12,7 @@ torch.manual_seed(123456) -device = xm.xla_device() +device = torch_xla.device() class M(torch.nn.Module): diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index b61d8648fa2d..dad0e5c1ce06 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -273,7 +273,7 @@ def test_scan_external_in_place_mutation(self): giving wrong results. """ # TODO(yifeit): Modify this test when external in-place mutation is eventually supported. - weird_global = torch.tensor([0.0, 0.0], device=torch_xla.device()) + weird_global = torch.tensor([0.0, 0.0], device="xla") def step_fn(carry, x): new_carry = carry + x @@ -281,9 +281,8 @@ def step_fn(carry, x): y = new_carry + weird_global return new_carry, y - init = torch.tensor([0.0, 0.0], device=torch_xla.device()) - xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], - device=torch_xla.device()) + init = torch.tensor([0.0, 0.0], device="xla") + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device="xla") with self.assertRaisesRegex(AssertionError, "FakeTensor"): scan(step_fn, init, xs) @@ -351,12 +350,11 @@ def test_scan_rand_in_fn(self): def step_fn(carry, x): new_carry = carry + x - y = new_carry + torch.rand(2, device=torch_xla.device()) + y = new_carry + torch.rand(2, device="xla") return new_carry, y - init = torch.tensor([0.0, 0.0], device=torch_xla.device()) - xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], - device=torch_xla.device()) + init = torch.tensor([0.0, 0.0], device="xla") + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device="xla") _, ys = scan(step_fn, init, xs) # ys should be a 2D tensor with this shape. self.assertEqual(ys.shape, (3, 2)) diff --git a/test/scan/test_scan_debug.py b/test/scan/test_scan_debug.py index d800a36998df..d105c22a6798 100644 --- a/test/scan/test_scan_debug.py +++ b/test/scan/test_scan_debug.py @@ -36,12 +36,10 @@ def fn2(carry, x): y = x + 42 return carry, y - init = torch.tensor([0.0, 0.0], - requires_grad=True, - device=torch_xla.device()) + init = torch.tensor([0.0, 0.0], requires_grad=True, device="xla") xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], requires_grad=True, - device=torch_xla.device()) + device="xla") # Run some graph involving a scan operation two times. for i in range(2): diff --git a/test/scan/test_scan_layers.py b/test/scan/test_scan_layers.py index f239e1a51d25..f419b2330d03 100644 --- a/test/scan/test_scan_layers.py +++ b/test/scan/test_scan_layers.py @@ -265,15 +265,13 @@ def test_heterogenous_layers(self): layer1 = nn.Linear(128, 128).to(torch_xla.device()) layer2 = nn.Sequential(nn.Linear(128, 128).to(torch_xla.device())) with self.assertRaisesRegex(ValueError, "mismatched keys"): - scan_layers([layer1, layer2], - torch.zeros((128,), device=torch_xla.device())) + scan_layers([layer1, layer2], torch.zeros((128,), device="xla")) def test_mismatched_shapes(self): layer1 = nn.Linear(128, 128).to(torch_xla.device()) layer2 = nn.Linear(128, 129).to(torch_xla.device()) with self.assertRaisesRegex(ValueError, "Shape mismatch"): - scan_layers([layer1, layer2], - torch.zeros((128,), device=torch_xla.device())) + scan_layers([layer1, layer2], torch.zeros((128,), device="xla")) if __name__ == '__main__': diff --git a/test/scan/test_scan_pallas.py b/test/scan/test_scan_pallas.py index 2613cc66f217..a267886cd3f7 100644 --- a/test/scan/test_scan_pallas.py +++ b/test/scan/test_scan_pallas.py @@ -72,7 +72,7 @@ def fake_fa_wrapper(self, has_model_weight, use_scan): torch.manual_seed(12) torch_xla.manual_seed(12) hidden_states = torch.randn((8, 4, 256, 256)).requires_grad_().to('xla') - with xm.xla_device(): + with torch_xla.device(): attention_layers = AttentionLayers( has_model_weight, num_layer=3, use_scan=use_scan) hidden_states.retain_grad() diff --git a/test/spmd/test_dtensor_integration.py b/test/spmd/test_dtensor_integration.py index d7d03a536899..e83aead9fa27 100644 --- a/test/spmd/test_dtensor_integration.py +++ b/test/spmd/test_dtensor_integration.py @@ -30,10 +30,7 @@ def test_xla_distribute_tensor(self): for requires_grad in [True, False]: tensor_to_shard = torch.randn( - 3 * device_count, - 3, - requires_grad=requires_grad, - device=xm.xla_device()) + 3 * device_count, 3, requires_grad=requires_grad, device="xla") dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor assert type(dist_tensor).__name__ == "XLAShardedTensor" @@ -49,7 +46,7 @@ def test_xla_distribute_tensor(self): def test_optimizer_step_with_sharding(self): # Use simple linear model to test model parameter sharding - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) # Running the same mark_sharding test with xla_distribute_tensor instead device_count = xr.global_runtime_device_count() @@ -60,8 +57,8 @@ def test_optimizer_step_with_sharding(self): model.train() optimizer = optim.SGD(model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for _ in range(3): optimizer.zero_grad() @@ -76,7 +73,7 @@ def test_optimizer_step_with_sharding(self): torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) def test_xla_distribute_module(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) device_count = xr.global_runtime_device_count() device_mesh = init_device_mesh("xla", mesh_shape=(device_count,)) @@ -96,8 +93,8 @@ def shard_params(mod_name, mod, mesh): sharded_model.train() optimizer = optim.SGD(sharded_model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for _ in range(3): optimizer.zero_grad() diff --git a/test/spmd/test_dtensor_integration2.py b/test/spmd/test_dtensor_integration2.py index 2d1329cdf4dc..0c729fbb91c7 100644 --- a/test/spmd/test_dtensor_integration2.py +++ b/test/spmd/test_dtensor_integration2.py @@ -38,8 +38,8 @@ def test_xla_distribute_module_auto(self): self.assertTrue(torch_xla._XLAC._xla_get_auto_sharding()) optimizer = optim.SGD(sharded_model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for _ in range(5): optimizer.zero_grad() diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index b4375e145c85..518e4203b459 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -42,7 +42,7 @@ def setUpClass(cls): super().setUpClass() def test_dynamo_spmd_basic(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -58,7 +58,7 @@ def test_dynamo_spmd_basic(self): # a ExecuteMetric. def test_dynamo_spmd_output_sharding_spec(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -74,7 +74,7 @@ def test_dynamo_spmd_output_sharding_spec(self): ) def test_dynamo_spmd_output_sharding_cache(self): met.clear_all() - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -90,7 +90,7 @@ def test_dynamo_spmd_output_sharding_cache(self): self.assertEqual(met.counter_value('UncachedOutputSharding'), 1) def test_dynamo_sharded_input(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -103,7 +103,7 @@ def test_dynamo_sharded_input(self): torch.allclose(xla_res.cpu(), dynamo_res.cpu()) def test_dynamo_input_sharding_changed(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -142,7 +142,7 @@ def test_dynamo_input_sharding_changed(self): @unittest.skipIf(xr.global_runtime_device_count() == 1, "Multiple devices needed to test the mesh change") def test_dynamo_input_sharding_threashold(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -183,7 +183,7 @@ def test_dynamo_input_sharding_threashold(self): del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] def test_dynamo_spmd_basic_with_dynamo_mark_sharding(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -202,7 +202,7 @@ def test_dynamo_spmd_basic_with_dynamo_mark_sharding(self): torch.allclose(xla_res.cpu(), dynamo_res.cpu()) def test_dynamo_spmd_activation_sharding_with_dynamo_mark_sharding(self): - device = xm.xla_device() + device = torch_xla.device() mesh = self._get_mesh((1, self.n_devices)) device_ids = mesh.device_ids.tolist() mesh_shape = list(mesh.mesh_shape) diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index d4a85f531a31..0f8a4d088ef2 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -24,7 +24,7 @@ def setUpClass(cls): super().setUpClass() def test_fsdp_v2_basic(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) model.fc1 = FSDPv2(model.fc1, mesh=mesh) model.fc2 = FSDPv2(model.fc2, mesh=mesh) @@ -39,7 +39,7 @@ def test_fsdp_v2_basic(self): self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) - x = torch.randn(16, 128).to(xm.xla_device()) + x = torch.randn(16, 128).to(torch_xla.device()) xs.mark_sharding(x, mesh, ('fsdp', None)) output = model(x) # Make sure output are sharded. @@ -63,7 +63,7 @@ def test_fsdp_v2_basic(self): xm.wait_device_ops() def test_fsdp_v2_output_correctness(self): - model_expected = self.SimpleLinear().to(xm.xla_device()) + model_expected = self.SimpleLinear().to(torch_xla.device()) model = copy.deepcopy(model_expected) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) @@ -71,7 +71,7 @@ def test_fsdp_v2_output_correctness(self): model.fc2 = FSDPv2(model.fc2, mesh=mesh) model = FSDPv2(model, mesh=mesh) - x_expected = torch.randn(16, 128).to(xm.xla_device()) + x_expected = torch.randn(16, 128).to(torch_xla.device()) x = copy.deepcopy(x_expected) xs.mark_sharding(x, mesh, ('fsdp', None)) @@ -81,7 +81,7 @@ def test_fsdp_v2_output_correctness(self): self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu())) def test_fsdp_v2_auto_wrap_basic(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, @@ -93,7 +93,7 @@ def test_fsdp_v2_auto_wrap_basic(self): self.assertTrue(isinstance(model.fc2, FSDPv2)) def test_fsdp_v2_auto_wrap_callable(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, @@ -115,7 +115,7 @@ def auto_wrapper_callable(m, *args, **kwargs): self.assertFalse(isinstance(model.fc2, FSDPv2)) def test_fsdp_v2_global_mesh(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) xs.set_global_mesh(mesh) @@ -123,7 +123,7 @@ def test_fsdp_v2_global_mesh(self): self.assertEqual(id(model._mesh), id(mesh)) def test_fsdp_v2_global_mesh_error(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) xs.set_global_mesh(None) with self.assertRaises(ValueError): @@ -141,7 +141,7 @@ def test_fsdp_v2_cpu_model(self): @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_fsdp_v2_multi_slice(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) model = FSDPv2(model, mesh=mesh, extra_data_axis="data") @@ -155,7 +155,7 @@ def test_fsdp_v2_multi_slice(self): self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) - x = torch.randn(16, 128).to(xm.xla_device()) + x = torch.randn(16, 128).to(torch_xla.device()) xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) output = model(x) # Make sure output are sharded. @@ -171,14 +171,14 @@ def test_fsdp_v2_multi_slice(self): @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_fsdp_v2_multi_slice_output_correctness(self): - model_expected = self.SimpleLinear().to(xm.xla_device()) + model_expected = self.SimpleLinear().to(torch_xla.device()) model = copy.deepcopy(model_expected) mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) model = FSDPv2(model, mesh=mesh, extra_data_axis="data") - x_expected = torch.randn(16, 128).to(xm.xla_device()) + x_expected = torch.randn(16, 128).to(torch_xla.device()) x = copy.deepcopy(x_expected) xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) @@ -188,7 +188,7 @@ def test_fsdp_v2_multi_slice_output_correctness(self): self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu())) def test_fsdp_v2_multi_slice_error(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) xs.set_global_mesh( self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))) diff --git a/test/spmd/test_mp_input_sharding.py b/test/spmd/test_mp_input_sharding.py index 6b78a3714e79..dc1e4aba12b0 100644 --- a/test/spmd/test_mp_input_sharding.py +++ b/test/spmd/test_mp_input_sharding.py @@ -34,7 +34,7 @@ def __next__(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_multiple_inputs(self): - device = xm.xla_device() + device = torch_xla.device() batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -61,7 +61,7 @@ def test_multiple_inputs(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_single_tensor(self): - device = xm.xla_device() + device = torch_xla.device() batch = torch.randn((16, 128)) train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -78,7 +78,7 @@ def test_single_tensor(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_error_single_tensor_with_input_sharding_dict(self): - device = xm.xla_device() + device = torch_xla.device() batch = torch.randn((16, 128)) train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -95,7 +95,7 @@ def test_error_single_tensor_with_input_sharding_dict(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_input_sharding_none(self): - device = xm.xla_device() + device = torch_xla.device() batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -112,7 +112,7 @@ def test_input_sharding_none(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_error_missing_keys(self): - device = xm.xla_device() + device = torch_xla.device() batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) mesh = xs.get_1d_mesh('x') @@ -127,7 +127,7 @@ def test_error_missing_keys(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_input_sharding_not_dict(self): - device = xm.xla_device() + device = torch_xla.device() num_devices = xr.global_runtime_device_count() batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))} train_loader = self.fake_dataloader(batch) diff --git a/test/spmd/test_sharding_strategies.py b/test/spmd/test_sharding_strategies.py index d6bc4221b811..2dd09580a5a6 100644 --- a/test/spmd/test_sharding_strategies.py +++ b/test/spmd/test_sharding_strategies.py @@ -146,7 +146,7 @@ def training_step(data): torch.manual_seed(42) tries = 5 -device = xm.xla_device() +device = torch_xla.device() if args.profile: print("Profiler server started at port 9012") server = xp.start_server(9012) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index a8113b2ae532..34221d375e9c 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -209,7 +209,7 @@ def test_single_host_replicated_tpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_debugging_spmd_single_host_tiled_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = xm.xla_device() + device = torch_xla.device() num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) @@ -252,7 +252,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_partial_replication_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = xm.xla_device() + device = torch_xla.device() num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) @@ -295,7 +295,7 @@ def test_single_host_partial_replication_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = xm.xla_device() + device = torch_xla.device() num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) diff --git a/test/spmd/test_spmd_graph_dump.py b/test/spmd/test_spmd_graph_dump.py index 2d1c2f84a4cf..45af3b154934 100644 --- a/test/spmd/test_spmd_graph_dump.py +++ b/test/spmd/test_spmd_graph_dump.py @@ -26,7 +26,7 @@ def test_dump_with_output_sharding(self): assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE" should_dump_output_sharding = (save_format == 'hlo') save_file += '.0' - device = xm.xla_device() + device = torch_xla.device() xla_x = torch.randn(8, 32).to(device) xla_y = torch.randn(8, 32).to(device) # shard one of the input tensor diff --git a/test/spmd/test_spmd_lowering_context.py b/test/spmd/test_spmd_lowering_context.py index 5cc0ac464bda..9bc80194318f 100644 --- a/test/spmd/test_spmd_lowering_context.py +++ b/test/spmd/test_spmd_lowering_context.py @@ -38,7 +38,7 @@ def test_basic(self): mesh_shape = (data_axis, model_axis) spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) - device = xm.xla_device() + device = torch_xla.device() a = torch.zeros(2048, device=device, requires_grad=True) xs.mark_sharding(a, spmd_mesh, ('x',)) b = torch.randn([32, 2048], device=device, requires_grad=True) @@ -108,7 +108,7 @@ def test_device_parameter_id_tensor_mapping(self): mesh_shape = (data_axis, model_axis) spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) - device = xm.xla_device() + device = torch_xla.device() a = torch.randn([32, 2048]).to(device) xs.mark_sharding(a, spmd_mesh, ('x', 'y')) b = torch.ones(2048).to(device) diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py index 727103586d1e..f814810b2eb9 100644 --- a/test/spmd/test_train_spmd_imagenet.py +++ b/test/spmd/test_train_spmd_imagenet.py @@ -206,7 +206,7 @@ def train_imagenet(): torch.manual_seed(42) - device = xm.xla_device() + device = torch_xla.device() model = get_model_property('model_fn')().to(device) if FLAGS.use_gradient_checkpointing: @@ -313,8 +313,8 @@ def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): - x = data.to(xm.xla_device()) - y = target.to(xm.xla_device()) + x = data.to(torch_xla.device()) + y = target.to(torch_xla.device()) with xp.StepTrace('train_imagenet'): with xp.Trace('build_graph'): optimizer.zero_grad() @@ -344,8 +344,8 @@ def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): - data = data.to(xm.xla_device()) - target = target.to(xm.xla_device()) + data = data.to(torch_xla.device()) + target = target.to(torch_xla.device()) output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() diff --git a/test/spmd/test_xla_auto_sharding.py b/test/spmd/test_xla_auto_sharding.py index 40b0566f8b28..b30fc0c0e88e 100644 --- a/test/spmd/test_xla_auto_sharding.py +++ b/test/spmd/test_xla_auto_sharding.py @@ -39,11 +39,11 @@ def setUpClass(cls): xr.use_spmd(auto=True) def init_test_variables(cls): - xt_no_auto = torch.ones(2, 2).to(xm.xla_device()) + xt_no_auto = torch.ones(2, 2).to(torch_xla.device()) cls.hash_no_auto = torch_xla._XLAC._get_graph_hash([xt_no_auto + 0]) def test_auto_sharding_hashing(self): - xt = torch.ones(2, 2).to(xm.xla_device()) + xt = torch.ones(2, 2).to(torch_xla.device()) assert torch_xla._XLAC._xla_get_auto_sharding() hash_auto_spmd = torch_xla._XLAC._get_graph_hash([xt + 0]) self.assertNotEqual(hash_auto_spmd, self.hash_no_auto) @@ -60,8 +60,8 @@ def test_matmul(self): t2 = torch.ones(128, 256) t3 = (t1 @ t2).sum() - xt1 = t1.to(xm.xla_device()) - xt2 = t2.to(xm.xla_device()) + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) xt3 = (xt1 @ xt2).sum() torch_xla.sync() self.assertEqual(met.counter_value("CompileWithAutoSharding"), 1) @@ -72,11 +72,11 @@ def test_matmul(self): def test_simple_linear_training(self): met.clear_counters() - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) model.train() optimizer = optim.SGD(model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for i in range(5): optimizer.zero_grad() diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 3096d8b6d9dc..380470d5c8c1 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -50,7 +50,7 @@ def _get_sharded_model(self, mesh_shape=None): # Return a sharded SimpleLinear model with fc1.weight sharded and # fc2.weight explicitly replicated mesh_shape = mesh_shape or (1, self.n_devices) - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh(mesh_shape) xs.mark_sharding(model.fc1.weight, mesh, (0, 1)) xs.mark_sharding(model.fc2.weight, mesh, (None, None)) @@ -76,7 +76,7 @@ def _assert_same_state_dict(self, sd1, sd2, keypath=""): if isinstance(sd1, torch.Tensor): assert sd1.device == sd2.device, f"Tensors on different devices at {keypath}: {sd1} vs {sd2}" - if sd1.device == xm.xla_device(): + if sd1.device == torch_xla.device(): sharding1 = torch_xla._XLAC._get_xla_sharding_spec(sd1) sharding2 = torch_xla._XLAC._get_xla_sharding_spec(sd2) assert sharding1 == sharding2, f"Different sharding on tensors at {keypath}: {sharding1} vs {sharding2}" @@ -145,14 +145,14 @@ def _save_and_restore(self, def test_resharding_unsharded_to_sharded(self): # Save an unsharded model using the DefaultSavePlanner and load into a # sharded model using the SPMDLoadPlanner - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) sharded_model = self._get_sharded_model() self._save_and_restore(model, sharded_model, load_planner=SPMDLoadPlanner()) def test_resharding_sharded_to_unsharded(self): for chkpt_on_cpu in [True, False]: with self.subTest(chkpt_on_cpu): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) sharded_model = self._get_sharded_model() self._save_and_restore( sharded_model, @@ -338,7 +338,7 @@ def test_save_state_dict_with_cpu_shards(self): def test_cpu_state_dict_flattening(self): # In the case of a nested state_dict with fully sharded parameters, # _CpuShards should be treated as terminal nodes. - t = torch.randn(128, 128).to(xm.xla_device()) + t = torch.randn(128, 128).to(torch_xla.device()) mesh = self._get_mesh((self.n_devices, 1)) xs.mark_sharding(t, mesh, (0, 1)) state_dict = _sharded_cpu_state_dict({'model': {'weight': t}}) @@ -395,7 +395,7 @@ def test_resolve_shard_data(self): class DistributedCheckpointHelpersTest(DistributedCheckpointTestBase): def test_sharded_cpu_state_dict(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) state_dict = model.state_dict() sharded_cpu_state_dict = _sharded_cpu_state_dict(state_dict) self.assertCountEqual(sharded_cpu_state_dict, diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 7fa438e2e420..eb9ee0c00341 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -36,14 +36,14 @@ def test_xla_sharded_tensor(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device="xla") xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) self.assertTrue(isinstance(xst1, XLAShardedTensor)) def test_xla_sharded_tensor_repr(self): - xt = torch.randn(128, 128).to(xm.xla_device()) - model = self.SimpleLinear().to(xm.xla_device()) + xt = torch.randn(128, 128).to(torch_xla.device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((1, self.n_devices)) partition_spec = (0, 1) @@ -59,7 +59,7 @@ def test_sharded_tensor_debug_info(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device="xla") xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) @@ -73,7 +73,7 @@ def test_xla_shards(self): num_element = self.n_devices mesh = self._get_mesh((self.n_devices,)) t = torch.arange(num_element, dtype=torch.float32) - xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (0,)) + xt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (0,)) shards = xt.local_shards self.assertEqual(len(shards), self.n_devices) @@ -97,7 +97,7 @@ def test_padded_xla_shards(self): num_element = self.n_devices + 1 # Ensure padding with two or more devices mesh = self._get_mesh((self.n_devices,)) t = torch.arange(num_element, dtype=torch.float32) - xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (0,)) + xt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (0,)) shards = xt.local_shards self.assertEqual(len(shards), self.n_devices) shard_len = math.ceil(num_element / self.n_devices) @@ -127,7 +127,7 @@ def test_replicated_xla_shards(self): num_element = self.n_devices mesh = self._get_mesh((self.n_devices,)) t = torch.arange(num_element, dtype=torch.float32) - xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None,)) + xt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (None,)) shards = xt.local_shards self.assertEqual(len(shards), self.n_devices) for i, shard in enumerate(shards): @@ -147,7 +147,7 @@ def test_partially_replicated_xla_shards(self): mesh = self._get_mesh((self.n_devices // 2, 2)) t = torch.arange(num_element, dtype=torch.float32).reshape((16, 16)) # Partial replication along the 0th tensor axis, shard 2-way on the 1st - xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None, 1)) + xt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (None, 1)) shard_len = t.shape[1] // 2 shards = xt.local_shards @@ -172,7 +172,7 @@ def test_load_local_shards(self): num_element = self.n_devices mesh = self._get_mesh((self.n_devices,)) t = torch.arange(num_element, dtype=torch.float32) + 1 - xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (0,)) + xt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (0,)) local_shards = xt.local_shards self.assertTrue(len(local_shards) == self.n_devices) @@ -197,13 +197,13 @@ def test_load_local_shards(self): xt.load_local_shards_(local_shards) # Replicated shards should fail - rt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None,)) + rt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (None,)) local_shards = rt.local_shards with self.assertRaises(RuntimeError): rt.load_local_shards_(local_shards) def test_xla_sharding_type(self): - t = torch.randn(10, 20).to(xm.xla_device()) + t = torch.randn(10, 20).to(torch_xla.device()) self.assertEqual(torch_xla._XLAC._get_xla_sharding_type(t), None) x_dim = 2 if self.n_devices >= 2 else 1 @@ -229,7 +229,7 @@ def test_xla_sharding_type(self): self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED) def test_custom_tile_assignment(self): - xt = torch.randn(10, 20).to(device=xm.xla_device()) + xt = torch.randn(10, 20).to(device="xla") mesh_shape = (1, self.n_devices) device_ids = np.flip(self.device_ids) mesh = self._get_mesh(mesh_shape, device_ids) @@ -245,8 +245,8 @@ def test_mark_sharding_2d(self): t2 = torch.randn(1, 128, device='cpu') expected = t1 + t2 - xt1 = t1.to(xm.xla_device()) - xt2 = t2.to(xm.xla_device()) + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1)) if self.n_devices > 1: @@ -261,7 +261,7 @@ def test_mark_sharding_4d(self): t = torch.randn(2, 4, 8, 16, device='cpu') expected = t + t - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) # Shard along two axes if four or more devices are available z_dim = 2 if self.n_devices >= 4 else 1 xs.mark_sharding(xt, self._get_mesh((1, 1, z_dim, self.n_devices // z_dim)), @@ -277,7 +277,7 @@ def test_mark_sharding_4d(self): self.assertTrue(torch.allclose(expected, actual)) def test_mark_sharding_not_ordered_sharding_spec_2d(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(8, 16, device='cpu') expected = t1 + t1 @@ -290,7 +290,7 @@ def test_mark_sharding_not_ordered_sharding_spec_2d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_not_ordered_sharding_spec_3d(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(4, 8, 16, device='cpu') expected = t1 + t1 @@ -307,7 +307,7 @@ def test_mark_sharding_not_ordered_sharding_spec_3d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_not_ordered_sharding_spec_4d(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(32, 4, 8, 16, device='cpu') expected = t1 + t1 @@ -326,7 +326,7 @@ def test_mark_sharding_not_ordered_sharding_spec_4d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_partial(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(4, 4).to(device) t2 = torch.randn(4, 4).to(device) # Somehow the eager cpu result is different from the xla result. @@ -356,7 +356,7 @@ def test_mark_sharding_partial(self): self.assertTrue(torch.allclose(expected, actual)) def test_propagate_replicated_sharding(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(4, 4).to(device) t2 = torch.randn(4, 4).to(device) t3 = t1 @ t2 @@ -368,7 +368,7 @@ def test_propagate_replicated_sharding(self): self.assertIn("replicated", torch_xla._XLAC._get_xla_sharding_spec(t3)) def test_mark_sharding_partial_unordered(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(4, 3, 4).to(device) t2 = torch.randn(4, 3, 4).to(device) expected = t1 + t2 @@ -401,7 +401,7 @@ def test_mark_sharding_partial_unordered(self): "Multiple devices required for tupled partition spec") def test_tupled_partition_spec(self): mesh = self._get_mesh((2, self.n_devices // 2)) - t = torch.randn(16).to(xm.xla_device()) + t = torch.randn(16).to(torch_xla.device()) xs.mark_sharding(t, mesh, ((0, 1),)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" % @@ -413,7 +413,7 @@ def test_named_partial_tupled_partition_spec(self): mesh = xs.Mesh( range(self.n_devices), (1, 2, self.n_devices // 2), ('r', 'b', 'm')) # Shard the first dimension on `r` and `b`, replicate the second dimension - t = torch.randn(16, 16).to(xm.xla_device()) + t = torch.randn(16, 16).to(torch_xla.device()) xs.mark_sharding(t, mesh, (('r', 'b'), None)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(t), @@ -421,14 +421,14 @@ def test_named_partial_tupled_partition_spec(self): (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) # Replicate the first dimension, shard the second on `b` and `m` - u = torch.randn(16, 16).to(xm.xla_device()) + u = torch.randn(16, 16).to(torch_xla.device()) xs.mark_sharding(u, mesh, (None, ('b', 'm'))) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" % (self.n_devices, ','.join(str(x) for x in range(self.n_devices)))) # Replicate the first dimension, shard the second on `r` and `m` - v = torch.randn(16, 16).to(xm.xla_device()) + v = torch.randn(16, 16).to(torch_xla.device()) xs.mark_sharding(v, mesh, (None, ('r', 'm'))) device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten() self.assertEqual( @@ -437,7 +437,7 @@ def test_named_partial_tupled_partition_spec(self): (self.n_devices // 2, ','.join(str(x) for x in device_order))) # Replicate the first dimension, shard the second on `m` and `b` - v = torch.randn(16, 16).to(xm.xla_device()) + v = torch.randn(16, 16).to(torch_xla.device()) xs.mark_sharding(v, mesh, (None, ('m', 'b'))) device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten() self.assertEqual( @@ -450,7 +450,7 @@ def test_multiple_tuples_in_spec(self): mesh = xs.Mesh( range(self.n_devices), (1, 2, self.n_devices // 2, 1), ('a', 'b', 'c', 'd')) - t = torch.randn(2, 2).to(xm.xla_device()) + t = torch.randn(2, 2).to(torch_xla.device()) xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd'))) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" % @@ -460,14 +460,14 @@ def test_multiple_tuples_in_spec(self): 'At least 2 devices needed for 2D mesh') def test_3d_tensor_2d_mesh(self): mesh = self._get_mesh((2, self.n_devices // 2)) - t = torch.randn(16, 16, 16).to(xm.xla_device()) + t = torch.randn(16, 16, 16).to(torch_xla.device()) xs.mark_sharding(t, mesh, (None, 0, 1)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) def test_partial_replication_addmm(self): - device = xm.xla_device() + device = torch_xla.device() z_dim = 2 if self.n_devices >= 4 else 1 mesh = self._get_mesh((z_dim, self.n_devices // z_dim)) @@ -495,7 +495,7 @@ def test_partial_replication_addmm(self): self.assertTrue(torch.allclose(expected, actual, atol=1e-5)) def test_clear_sharding(self): - xt = torch.randn(2, 4, 8, 16).to(xm.xla_device()) + xt = torch.randn(2, 4, 8, 16).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)), (0, 1, 2, 3)) self.assertTrue(torch_xla._XLAC._get_xla_sharding_spec(xt)) @@ -503,7 +503,7 @@ def test_clear_sharding(self): self.assertFalse(torch_xla._XLAC._get_xla_sharding_spec(xt)) def test_replication_with_no_clear_sharding(self): - xt = torch.randn(2, 4).to(xm.xla_device()) + xt = torch.randn(2, 4).to(torch_xla.device()) # replication xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (None, None)) # sharding annotation over an existing replication sharding is permitted. @@ -513,7 +513,7 @@ def test_replication_with_no_clear_sharding(self): "replicated" in torch_xla._XLAC._get_xla_sharding_spec(xt)) def test_deep_copy(self): - xt = torch.randn(2, 4, 8, 16).to(xm.xla_device()) + xt = torch.randn(2, 4, 8, 16).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)), (0, 1, 2, 3)) xt2 = copy.deepcopy(xt) @@ -522,7 +522,7 @@ def test_deep_copy(self): torch_xla._XLAC._get_xla_sharding_spec(xt2)) def test_clone(self): - xt = torch.randn(2, 4, 8, 16).to(xm.xla_device()) + xt = torch.randn(2, 4, 8, 16).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)), (0, 1, 2, 3)) sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt) @@ -537,7 +537,7 @@ def test_clone(self): torch_xla._XLAC._get_xla_sharding_spec(xt2)) def test_sync_with_sharding(self): - xt = torch.ones(2, 2).to(xm.xla_device()) + xt = torch.ones(2, 2).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1)) sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt) torch_xla.sync() # `torch_xla.sync()` should preserve the sharding @@ -545,7 +545,7 @@ def test_sync_with_sharding(self): def test_execute_replicated_metrics(self): met.clear_all() - xt = torch.ones(2, 2).to(xm.xla_device()) + xt = torch.ones(2, 2).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1)) xt += 2 torch_xla.sync() @@ -554,15 +554,15 @@ def test_execute_replicated_metrics(self): def test_optimizer_step_with_sharding(self): # Use simple linear model to test model parameter sharding - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) xs.mark_sharding(model.fc1.weight, self._get_mesh((1, self.n_devices)), (0, 1)) sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight) model.train() optimizer = optim.SGD(model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for i in range(3): optimizer.zero_grad() @@ -581,7 +581,7 @@ def test_sharding_propagation(self): self.assertFalse(met.counter_value("ReplicateShardedData")) # Linear model with two linear layers and only one is annotated. - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) xs.mark_sharding(model.fc1.weight, self._get_mesh((1, self.n_devices)), (0, 1)) self.assertTrue(torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) @@ -589,8 +589,8 @@ def test_sharding_propagation(self): model.train() optimizer = optim.SGD(model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for i in range(3): optimizer.zero_grad() @@ -606,7 +606,7 @@ def test_sharding_propagation(self): self.assertEqual(met.counter_value("ReplicateShardedData"), 2) def test_inplace_add_with_sharding(self): - xt = torch.ones(2, 2).to(xm.xla_device()) + xt = torch.ones(2, 2).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1)) sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt) xt.add_(1) # inplace update should preserve the sharding @@ -622,8 +622,8 @@ def test_inplace_add_with_sharding(self): xr.device_type() == 'CPU', "sharding will be the same for both tensors on single device") def test_shard_hashing(self): - xt1 = torch.ones(2, 2).to(xm.xla_device()) - xt2 = torch.ones(2, 2).to(xm.xla_device()) + xt1 = torch.ones(2, 2).to(torch_xla.device()) + xt2 = torch.ones(2, 2).to(torch_xla.device()) # Add sharding to xt1, this should result in the hashes being different for # xt1 and xt2 @@ -639,7 +639,7 @@ def test_shard_hashing(self): self.assertNotEqual(hash1, hash2) def test_transfer_sharded_data_to_host(self): - xt1 = torch.ones(16, 16).to(xm.xla_device()) + xt1 = torch.ones(16, 16).to(torch_xla.device()) xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1)) t1 = xt1.cpu() self.assertTrue(torch.allclose(t1, torch.ones(16, 16))) @@ -657,7 +657,7 @@ def test_send_cpu_data_to_device_with_sharding(self): sharding_spec = xs.ShardingSpec(mesh, (0, 1)) self.assertTrue(sharding_spec.can_apply(tensor)) xtensors = xm.send_cpu_data_to_device([tensor], - xm.xla_device(), + torch_xla.device(), input_sharding=sharding_spec) self.assertEqual(len(xtensors), 1) outbound = met.metric_data("OutboundData")[1] @@ -666,7 +666,7 @@ def test_send_cpu_data_to_device_with_sharding(self): # Verify the resulting sharding annotation matches an explicit # `mark_sharding` call. xt = xtensors[0] - explicit_xt = tensor.to(xm.xla_device()) + explicit_xt = tensor.to(torch_xla.device()) xs.mark_sharding(explicit_xt, mesh, (0, 1)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(xt), @@ -676,8 +676,8 @@ def test_multiple_operations(self): t1 = torch.randn(2, 2) t2 = torch.randn(2, 2) expected_1 = t1 + t2 - xt1 = t1.to(xm.xla_device()) - xt2 = t2.to(xm.xla_device()) + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1)) xt3 = xt1 + xt2 self.assertTrue(torch.allclose(expected_1, xt3.cpu())) @@ -685,8 +685,8 @@ def test_multiple_operations(self): t4 = torch.randn(2, 2) t5 = torch.randn(2, 2) expected_2 = t4 + t5 - xt4 = t4.to(xm.xla_device()) - xt5 = t5.to(xm.xla_device()) + xt4 = t4.to(torch_xla.device()) + xt5 = t5.to(torch_xla.device()) xs.mark_sharding(xt4, self._get_mesh((1, self.n_devices)), (0, 1)) xs.mark_sharding(xt5, self._get_mesh((1, self.n_devices)), (0, 1)) xt6 = xt4 + xt5 @@ -696,10 +696,10 @@ def test_no_sharding(self): partition_spec = (0, 1) t1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device="xla") t2 = torch.tensor([[8, 7, 6, 5, 4, 3, 2, 1]], dtype=torch.float, - device=xm.xla_device()) + device="xla") t3 = t1 + t2 t3_expected = [9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0] self.assertEqual(t3.tolist()[0], t3_expected) @@ -708,7 +708,7 @@ def test_xla_sharded_hlo_dump(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device="xla") xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) xst2 = xst1 + 5 @@ -724,8 +724,8 @@ def test_2d_tensor_3d_mesh(self): ct2 = torch.randn(16, 16, device='cpu') expected = ct1 + ct2 - t1 = ct1.to(xm.xla_device()) - t2 = ct2.to(xm.xla_device()) + t1 = ct1.to(torch_xla.device()) + t2 = ct2.to(torch_xla.device()) # Meaningful test for higher-order mesh with extra replication # requires multiple devices. Otherwise, this should defaults back to @@ -821,8 +821,8 @@ def test_mark_sharding_ir(self): t2 = torch.randn(1, 128, device='cpu') expected = t1 + t2 - xt1 = t1.to(xm.xla_device()) - xt2 = t2.to(xm.xla_device()) + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) actual = xt1 + xt2 actual = xs.mark_sharding(actual, self._get_mesh((1, self.n_devices)), (0, 1)) @@ -912,7 +912,7 @@ def test_sharded_tensor_aliasing(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device="xla") xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) xst1 += 1 @@ -921,7 +921,7 @@ def test_sharded_tensor_aliasing(self): def test_mark_sharding_ir_with_multiple_output(self): partition_spec = (0,) - xt1 = torch.randn(8, 8).to(xm.xla_device()) + xt1 = torch.randn(8, 8).to(torch_xla.device()) # max return 2 tensors `value` and `indices`. They are the output # of the same IR Node `MaxInDim` (xt_val, xt_index) = torch.max(xt1, 1) @@ -937,13 +937,13 @@ def test_mark_sharding_ir_with_multiple_output(self): def test_sharded_tensor_to_cpu_int_type(self): partition_spec = (0, 1) t1 = torch.arange(64).reshape(8, 8) - xt1 = t1.clone().to(xm.xla_device()) + xt1 = t1.clone().to(torch_xla.device()) xst1 = xs.mark_sharding(xt1, self._get_mesh((self.n_devices, 1)), partition_spec) self.assertTrue(torch.allclose(t1, xst1.cpu())) def test_named_partition_spec(self): - xt1 = torch.arange(64).reshape(8, 8).to(xm.xla_device()) + xt1 = torch.arange(64).reshape(8, 8).to(torch_xla.device()) mesh = xs.Mesh( list(range(self.n_devices)), (1, self.n_devices), ('data', 'model')) partition_spec = ('model', 'data') @@ -955,7 +955,7 @@ def test_named_partition_spec(self): self.assertTrue("replicated" in sharding_spec) def test_shard_device_data_ir(self): - device = xm.xla_device() + device = torch_xla.device() xla_x = torch.randn(8, 128, device=device) # xla_x now becomes a device data IR xla_y = xla_x * 5 @@ -967,7 +967,7 @@ def test_shard_device_data_ir(self): self.assertTrue(torch.allclose(xla_y.cpu(), xla_x.cpu() * 5)) def test_shard_device_data_ir_after_sync(self): - device = xm.xla_device() + device = torch_xla.device() xla_x = torch.randn(8, 128, device=device) x = xla_x.cpu() # xla_x now becomes a device data IR without XLAData @@ -981,18 +981,18 @@ def test_op_sharding_cache(self): met.clear_all() mesh = self._get_mesh((1, self.n_devices)) - t = torch.randn(1, self.n_devices).to(xm.xla_device()) + t = torch.randn(1, self.n_devices).to(torch_xla.device()) xs.mark_sharding(t, mesh, (0, 1)) self.assertIn("CreateOpSharding", met.counter_names()) self.assertEqual(met.counter_value("CreateOpSharding"), 1) # Sharding with the same partition spec should not result in another call - u = torch.randn(1, self.n_devices).to(xm.xla_device()) + u = torch.randn(1, self.n_devices).to(torch_xla.device()) xs.mark_sharding(u, mesh, (0, 1)) self.assertEqual(met.counter_value("CreateOpSharding"), 1) # Changing the partition spec will result in another CreateOpSharding - v = torch.randn(1, self.n_devices).to(xm.xla_device()) + v = torch.randn(1, self.n_devices).to(torch_xla.device()) xs.mark_sharding(v, mesh, (0, None)) self.assertEqual(met.counter_value("CreateOpSharding"), 2) @@ -1130,11 +1130,11 @@ def test_from_cpu_shards_global_shape(self): from_cpu_shards(shards, op_sharding, torch.Size((1,))) def test_backward_optimization_barrier(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) # The first layer won't have gradients in the hook. Not sure why. xs.xla_sharding.apply_backward_optimization_barrier(model.fc2) - x = torch.randn(2, 128).to(xm.xla_device()) + x = torch.randn(2, 128).to(torch_xla.device()) y = model(x) loss = y.sum() loss.backward() @@ -1145,7 +1145,7 @@ def test_backward_optimization_barrier(self): hlo) def test_mark_shard_scalar(self): - x = torch.tensor(1.0).to(xm.xla_device()) + x = torch.tensor(1.0).to(torch_xla.device()) self.assertEqual(len(x.shape), 0) xt = xs.mark_sharding(x, self._get_mesh((1, self.n_devices)), ()) @@ -1174,7 +1174,7 @@ def test_global_mesh(self): self.assertEqual(id(mesh), id(expected_mesh)) def test_mark_manual_sharding(self): - x = torch.zeros(3, 2).to(xm.xla_device()) + x = torch.zeros(3, 2).to(torch_xla.device()) with self.assertRaises(RuntimeError): xt = xs._mark_manual_sharding(x) @@ -1192,7 +1192,7 @@ def test_mark_manual_sharding(self): # xt.global_tensor.cpu() def test_spmd_full_to_shard_shape(self): - x = torch.zeros(8, 8).to(xm.xla_device()) + x = torch.zeros(8, 8).to(torch_xla.device()) with self.assertRaises(RuntimeError): x = torch_xla._XLAC._spmd_full_to_shard_shape(x) @@ -1213,7 +1213,7 @@ def test_spmd_full_to_shard_shape(self): # xx.cpu() # Replicated shape - x = torch.zeros(8, 4).to(xm.xla_device()) + x = torch.zeros(8, 4).to(torch_xla.device()) xt = xs.mark_sharding(x, self._get_mesh((self.n_devices, 1)), (None, None)) xx = torch_xla._XLAC._spmd_full_to_shard_shape(xt.global_tensor) @@ -1225,7 +1225,7 @@ def test_spmd_full_to_shard_shape(self): self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}") def test_spmd_shard_to_full_shape(self): - x = torch.zeros(8, 8).to(xm.xla_device()) + x = torch.zeros(8, 8).to(torch_xla.device()) x += 1 # No sharding spec attached. with self.assertRaises(RuntimeError): @@ -1256,7 +1256,7 @@ def test_spmd_shard_to_full_shape(self): self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{replicated}") def test_manual_sharding_e2e(self): - x = torch.zeros(8, 8).to(xm.xla_device()) + x = torch.zeros(8, 8).to(torch_xla.device()) mesh = self._get_mesh((1, self.n_devices)) partition_spec = (0, 1) xt = xs.mark_sharding(x, mesh, partition_spec) @@ -1275,7 +1275,7 @@ def test_manual_sharding_e2e(self): def test_manual_sharding_api_e2e(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) - x = torch.zeros(8, 8).to(xm.xla_device()) + x = torch.zeros(8, 8).to(torch_xla.device()) partition_spec = (0, 1) xx = xs.enable_manual_sharding(x, partition_spec) @@ -1290,7 +1290,7 @@ def test_manual_sharding_api_e2e(self): "Only runs on TPUv4") def test_spmd_reduce_scatter(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) - x = torch.ones(8, 8).to(xm.xla_device()) + x = torch.ones(8, 8).to(torch_xla.device()) # Reduce scatter x = xs.enable_manual_sharding(x, (None, None)).global_tensor @@ -1311,7 +1311,7 @@ def test_spmd_reduce_scatter(self): "Only runs on TPUv4") def test_spmd_reduce_scatter_canonical_index(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) - x = torch.ones(8, 8).to(xm.xla_device()) + x = torch.ones(8, 8).to(torch_xla.device()) # Reduce scatter x = xs.enable_manual_sharding(x, (None, None)).global_tensor @@ -1332,7 +1332,7 @@ def test_spmd_reduce_scatter_canonical_index(self): "Only runs on TPUv4") def test_spmd_all_reduce(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) - x = torch.ones(8, 8).to(xm.xla_device()) + x = torch.ones(8, 8).to(torch_xla.device()) # all reduce x = xs.enable_manual_sharding(x, (None, None)).global_tensor @@ -1352,7 +1352,7 @@ def test_spmd_all_reduce(self): "Only runs on TPUv4") def test_spmd_all_reduce_scale(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) - x = torch.ones(8, 8).to(xm.xla_device()) + x = torch.ones(8, 8).to(torch_xla.device()) scale = 0.25 # all reduce @@ -1484,10 +1484,10 @@ def test_xla_patched_linear(self): """ from torch_xla.distributed.spmd.xla_sharding import XLAPatchedLinear - import torch_xla.runtime + import torch_xla.core.xla_model as xm import torch.nn.functional as F - with torch_xla.runtime.xla_device(): + with torch_xla.device(): torch_xla.manual_seed(42) x0 = torch.randn(2, 3, requires_grad=True) w0 = torch.randn(4, 3, requires_grad=True) @@ -1536,7 +1536,7 @@ def test_mark_sharding_with_gradients_basic(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device(), + device="xla", requires_grad=True) mesh = self._get_mesh((1, self.n_devices)) xst1 = xs.mark_sharding_with_gradients(xt1, mesh, partition_spec) @@ -1550,7 +1550,7 @@ def test_mark_sharding_with_gradients_annotation(self): partition_spec = (0,) x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, - device=xm.xla_device(), + device="xla", requires_grad=True) # Notice that the function does not modify in-place. y = xs.mark_sharding_with_gradients(x, mesh, partition_spec) @@ -1669,13 +1669,9 @@ def test_get_logical_mesh(self): def test_shard_as(self): mesh = self._get_mesh((self.n_devices,)) partition_spec = (0,) - x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.float, - device=xm.xla_device()) + x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, device="xla") x = xs.mark_sharding_with_gradients(x, mesh, partition_spec) - y = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.float, - device=xm.xla_device()) + y = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, device="xla") x, y = xs.shard_as(x, y) torch_xla.sync() diff --git a/test/spmd/test_xla_sharding_hlo.py b/test/spmd/test_xla_sharding_hlo.py index a5a1159aa9e4..9a1653a7ef93 100644 --- a/test/spmd/test_xla_sharding_hlo.py +++ b/test/spmd/test_xla_sharding_hlo.py @@ -22,8 +22,8 @@ def setUpClass(cls): @patch.dict(os.environ, {"XLA_DUMP_POST_OPTIMIZATIONS": "1"}) def test_xla_sharded_hlo_dump_post_optimizations(self): - t1 = torch.randn(1, 128).to(xm.xla_device()) - t2 = torch.randn(128, 1).to(xm.xla_device()) + t1 = torch.randn(1, 128).to(torch_xla.device()) + t2 = torch.randn(128, 1).to(torch_xla.device()) xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), (0, 1)) t3 = t1 @ t2 diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index 8530ec3e7e4e..bc89c535608e 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -39,22 +39,22 @@ def test_is_master_ordinal(self): self.assertTrue(xm.is_master_ordinal()) def test_xla_device(self): - device = xm.xla_device() + device = torch_xla.device() self.assertEqual(device, torch.device('xla:0')) def test_xla_real_devices(self): - device = xm.xla_device() + device = torch_xla.device() device_type = os.environ['PJRT_DEVICE'] self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0']) def test_xla_device_hw(self): - device = xm.xla_device() + device = torch_xla.device() device_type = os.environ['PJRT_DEVICE'] replication_devices = xm.xla_replication_devices([device]) self.assertEqual(xm.xla_device_hw(device), device_type) def test_xla_replication_devices(self): - device = xm.xla_device() + device = torch_xla.device() device_type = os.environ['PJRT_DEVICE'] replication_devices = xm.xla_replication_devices([device]) self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0']) @@ -127,7 +127,7 @@ def test_runtime_spmd_api(self): # unittest process can persist XLA_USE_SPMD from other test suites, # so t may be on a SPMD or non-SPMD device. If this test is run independently # outside unittest, then it lives on a non-SPMD device. - t = torch.ones(2, 2).to(xm.xla_device()) + t = torch.ones(2, 2).to(torch_xla.device()) # Should enable SPMD without crashing. xr.use_spmd() @@ -149,7 +149,7 @@ def setUpClass(cls): @unittest.skipIf(xr.device_type() not in ['TPU', 'CUDA'], f"TPU/GPU autocast test.") def test_xla_autocast_api(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.ones([2, 3], device=device, dtype=torch.float32) t2 = torch.ones([3, 2], device=device, dtype=torch.float32) with autocast(device, dtype=torch.bfloat16): diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 38d04ca7a95d..63cc78c3fbdf 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -23,21 +23,21 @@ def test_mark_sharding(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device="xla") xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) self.assertTrue( torch.allclose( xt1 + 0, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, - device=xm.xla_device()))) + device="xla"))) def test_metrics_recorded(self): met.clear_counters() partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device="xla") xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) self.assertIn("VirtualDeviceUsage", met.counter_names()) self.assertNotEqual(met.counter_value("VirtualDeviceUsage"), 0) @@ -45,7 +45,7 @@ def test_metrics_recorded(self): def test_model_weight_metrics(self): met.clear_counters() partition_spec = (0, 1) - model = nn.Linear(128, 64).to(xm.xla_device()) + model = nn.Linear(128, 64).to(torch_xla.device()) xs.mark_sharding(model.weight, self._get_mesh((1, self.n_devices)), partition_spec) self.assertIn("VirtualDeviceUsage", met.counter_names()) @@ -54,17 +54,17 @@ def test_model_weight_metrics(self): def test_no_sharding(self): t1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device="xla") t2 = torch.tensor([[8, 7, 6, 5, 4, 3, 2, 1]], dtype=torch.float, - device=xm.xla_device()) + device="xla") t3 = t1 + t2 t3_expected = [9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0] self.assertEqual(t3.tolist()[0], t3_expected) def test_no_sharding_1d(self): - t1 = torch.arange(9, dtype=torch.float, device=xm.xla_device()) - t2 = torch.arange(9, dtype=torch.float, device=xm.xla_device()) + t1 = torch.arange(9, dtype=torch.float, device="xla") + t2 = torch.arange(9, dtype=torch.float, device="xla") t3 = t1 + t2 t3_expected = list(range(0, 18, 2)) self.assertEqual(t3.tolist(), t3_expected) @@ -75,7 +75,7 @@ def test_outbound_data_metrics(self): met.clear_all() xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device="xla") xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) outbound_with_virtual_device = met.metric_data("OutboundData")[1] @@ -88,7 +88,7 @@ def test_non_tensor_scalar(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], - xm.xla_device(), + torch_xla.device(), input_sharding=sharding_spec)[0] # we will transfer 0.5 as a device_data to the 'SPMD:0' device, need to make sure # that virtual device can handle this case. @@ -101,7 +101,7 @@ def test_sync_on_virtual_device(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], - xm.xla_device(), + torch_xla.device(), input_sharding=sharding_spec)[0] xt2 = xt1 / 0.5 torch_xla.sync(wait=True) @@ -111,7 +111,7 @@ def test_sync_on_virtual_device(self): def test_virtual_device_no_upload(self): met.clear_all() - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(5, 5).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) # t1's upload to device should be deferred @@ -125,7 +125,7 @@ def test_virtual_device_no_upload(self): def test_virtual_device_upload_after_mark_sharding(self): met.clear_all() partition_spec = (0, 1) - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(8, 8).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info) @@ -139,7 +139,7 @@ def test_virtual_device_upload_after_mark_sharding(self): def test_virtual_device_upload_after_tracing(self): met.clear_all() - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(8, 8).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info) @@ -152,7 +152,7 @@ def test_virtual_device_upload_after_tracing(self): def test_virtual_device_upload_for_sharded_dataloader(self): met.clear_counters() - device = xm.xla_device() + device = torch_xla.device() sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ t1 = xm.send_cpu_data_to_device([torch.randn(8, 8)], diff --git a/test/stablehlo/test_composite.py b/test/stablehlo/test_composite.py index 64d08e4879fb..6e9521c79d5c 100644 --- a/test/stablehlo/test_composite.py +++ b/test/stablehlo/test_composite.py @@ -70,7 +70,7 @@ class XlaMarkPatternTest(unittest.TestCase): def run_func_get_stablehlo(self, f, input_args): - device = xm.xla_device() + device = torch_xla.device() input_args = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), input_args) exported = torch.export.export(AsModule(f), input_args) diff --git a/test/stablehlo/test_implicit_broadcasting.py b/test/stablehlo/test_implicit_broadcasting.py index 24c8a80c77d7..10fbe5789981 100644 --- a/test/stablehlo/test_implicit_broadcasting.py +++ b/test/stablehlo/test_implicit_broadcasting.py @@ -10,7 +10,7 @@ # The following tests cover the implcit-broadcasting for static and bounded # dynamic shapes. -device = xm.xla_device() +device = torch_xla.device() class ImplicitBroadcasting(unittest.TestCase): diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index 1f6ad974f203..3fc1276ec612 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -54,7 +54,7 @@ def count_qdq_ops(g: torch.fx.Graph): class PT2EExportTest(unittest.TestCase): def test_per_tensor_qdq(self): - device = xm.xla_device() + device = torch_xla.device() x = torch.randn(2, 3, 4, 5).to(device) x = torch.ops.quantized_decomposed.quantize_per_tensor( x, 0.4, 2, -128, 127, torch.int8) @@ -68,7 +68,7 @@ def test_per_tensor_qdq(self): self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) def test_per_channel_qdq(self): - device = xm.xla_device() + device = torch_xla.device() x = torch.randn(2, 3, 4, 5).to(device) scale = torch.tensor([3.2, 5.3, 0.1, 10]).to(device) zero_point = torch.tensor([1, 2, -1, -2], dtype=torch.int64).to(device) diff --git a/test/stablehlo/test_stablehlo_compile.py b/test/stablehlo/test_stablehlo_compile.py index 243f48fbff7c..a57faf7ff5f2 100644 --- a/test/stablehlo/test_stablehlo_compile.py +++ b/test/stablehlo/test_stablehlo_compile.py @@ -21,7 +21,7 @@ def test_resnet18_stablehlo_compile(self): torch_input = torch.tensor(np_input).float() cpu_output = resnet18(torch_input) # Run ResNet on XLA device. - device = xm.xla_device() + device = torch_xla.device() # materalize the fake data for test purpose torch_xla.sync() xm.wait_device_ops() diff --git a/test/stablehlo/test_stablehlo_custom_call.py b/test/stablehlo/test_stablehlo_custom_call.py index 112abb2135aa..a315bbc230db 100644 --- a/test/stablehlo/test_stablehlo_custom_call.py +++ b/test/stablehlo/test_stablehlo_custom_call.py @@ -118,7 +118,7 @@ def forward(self, x): # self.assertTrue("api_version = 1" in shlo_text) def test_place_to_host_device(self): - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones(10, device=dev) b = place_to_host(a) shlo_text = xm.get_stablehlo([b]) @@ -137,7 +137,7 @@ def test_place_to_host_device(self): def test_place_to_host_device_autograd(self): # Test that gradient can flow through place_to_host and place_to_device ops. - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones(10, device=dev, requires_grad=True) b = place_to_host(a) c = b.sum() @@ -155,7 +155,7 @@ def test_place_to_host_device_aot_autograd(self): # specifically `aot_function`. from functorch.compile import aot_function, make_boxed_func # type: ignore - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones(10, device=dev, requires_grad=True) def my_fn(x): diff --git a/test/stablehlo/test_stablehlo_inference.py b/test/stablehlo/test_stablehlo_inference.py index 311aed0c4b54..a29b66ebceaa 100644 --- a/test/stablehlo/test_stablehlo_inference.py +++ b/test/stablehlo/test_stablehlo_inference.py @@ -67,7 +67,7 @@ def forward(self, x, y): output = m(*data) exported = export_torch_model(m, data) - device = xm.xla_device() + device = torch_xla.device() data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data) output2 = exported(*data).cpu() @@ -91,7 +91,7 @@ def forward(self, inputs): output = m(*data) exported = export_torch_model(m, data) - device = xm.xla_device() + device = torch_xla.device() data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data) output2 = exported(*data) self.assertEqual(len(output2), 2) diff --git a/test/stablehlo/test_stablehlo_save_load.py b/test/stablehlo/test_stablehlo_save_load.py index 1e1e41c3513a..71ff463578cb 100644 --- a/test/stablehlo/test_stablehlo_save_load.py +++ b/test/stablehlo/test_stablehlo_save_load.py @@ -17,7 +17,7 @@ class StableHloDumpTest(unittest.TestCase): def test_simple(self): - device = xm.xla_device() + device = torch_xla.device() x = torch.tensor([3], device=device) y = torch.tensor([3], device=device) z = x + y @@ -26,7 +26,7 @@ def test_simple(self): self.assertEqual(stablehlo.count("stablehlo.add"), 1) def test_resnet18(self): - device = xm.xla_device() + device = torch_xla.device() xla_resnet18 = torchvision.models.resnet18() xla_resnet18.eval() xla_resnet18 = xla_resnet18.to(device) @@ -66,7 +66,7 @@ class SimpleExportTest(unittest.TestCase): def export_stable_hlo(self, model, args, kwargs=None): if kwargs is None: kwargs = {} - device = xm.xla_device() + device = torch_xla.device() model.eval() model = model.to(device) args = tuple(i.to(device) for i in args if hasattr(i, 'to')) diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index 3cd17a7fe340..aa33a6533437 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -19,7 +19,7 @@ compare_exported_program_and_saved_model_result, has_tf_package, load_save_model_and_inference, wrap_func_as_nn_module) -device = xm.xla_device() +device = torch_xla.device() os.environ['EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM'] = '1' diff --git a/test/stablehlo/test_xla_export_interpreter.py b/test/stablehlo/test_xla_export_interpreter.py index 85a4607d3a85..51a73a402703 100644 --- a/test/stablehlo/test_xla_export_interpreter.py +++ b/test/stablehlo/test_xla_export_interpreter.py @@ -7,7 +7,7 @@ import torch_xla.core.xla_model as xm from torch_xla.stablehlo import exported_program_to_stablehlo -device = xm.xla_device() +device = torch_xla.device() class XLAExportInterpreterTest(unittest.TestCase): diff --git a/test/test_as_stride_use_slice.py b/test/test_as_stride_use_slice.py index 48c65bb80f69..454ff78caeb2 100644 --- a/test/test_as_stride_use_slice.py +++ b/test/test_as_stride_use_slice.py @@ -100,8 +100,8 @@ def pure_strided_wrapper(self, use_xla, use_aten_slice): ss = StridedAndSlice().to("cpu") input = torch.randn((2, 4, 256, 256), device="cpu").requires_grad_() if use_xla: - ss.to(xm.xla_device()) - input = input.to(xm.xla_device()) + ss.to(torch_xla.device()) + input = input.to(torch_xla.device()) return ss(input, use_aten_slice) @parameterized.named_parameters( @@ -137,7 +137,7 @@ def compiler(gm, _): cpu_output = compiler_func(input_cpu, use_aten_slice=use_aten_slice) torch_xla.sync() - input_xla = input_xla.to(xm.xla_device()) + input_xla = input_xla.to(torch_xla.device()) xla_output = compiler_func(input_xla, use_aten_slice=use_aten_slice) torch_xla.sync() torch.testing.assert_close(cpu_output, xla_output.cpu()) diff --git a/test/test_autocast.py b/test/test_autocast.py index 1e0a82a2870c..7c743aab7cc1 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -282,7 +282,7 @@ def cast(val, to_type): add_kwargs = {} self.assertFalse(self.is_autocast_enabled()) - with autocast(xm.xla_device(), dtype=autocast_dtype): + with autocast(torch_xla.device(), dtype=autocast_dtype): self.assertTrue(self.is_autocast_enabled()) out_type = out_type if out_type is not None else run_as_type @@ -332,7 +332,7 @@ def compare(first, second): # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method - with autocast(xm.xla_device(), enabled=False): + with autocast(torch_xla.device(), enabled=False): self.assertFalse(self.is_autocast_enabled()) if module is not None and hasattr(module, op): @@ -355,9 +355,9 @@ class TestAutocastCuda(TestAutocastBase): @classmethod def setUpClass(cls): super().setUpClass() - cls.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) + cls.autocast_lists = AutocastTestLists(torch.device(torch_xla.device())) cls.autocast_lists_extra = AutocastCudaTestExtraLists( - torch.device(xm.xla_device())) + torch.device(torch_xla.device())) cls.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists() def setUp(self): @@ -439,7 +439,7 @@ class TestAutocastTPU(TestAutocastBase): @classmethod def setUpClass(cls): super().setUpClass() - cls.autocast_lists = AutocastTPUTestLists(torch.device(xm.xla_device())) + cls.autocast_lists = AutocastTPUTestLists(torch.device(torch_xla.device())) def setUp(self): super(TestAutocastTPU, self).setUp() @@ -481,7 +481,7 @@ def test_autocast_methods_expect_builtin_promote(self): op, args, torch.float32, module=None, out_type=out_type) def test_autocast_tpu_check_dtype(self): - with autocast(xm.xla_device(), dtype=torch.float16): + with autocast(torch_xla.device(), dtype=torch.float16): assert not torch.is_autocast_xla_enabled() @@ -491,7 +491,7 @@ class TestOtherOps(unittest.TestCase): not xm.get_xla_supported_devices("GPU"), "the behavior of batch_norm autocast on GPU is different from others") def test_batch_norm_gpu(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16) batch_norm = torch.nn.BatchNorm2d(16) with autocast(device, dtype=torch.bfloat16): @@ -504,7 +504,7 @@ def test_batch_norm_gpu(self): not xm.get_xla_supported_devices("TPU"), "the behavior of batch_norm autocast on TPU is different from others") def test_batch_norm_tpu(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16) batch_norm = torch.nn.BatchNorm2d(16) with autocast(device, dtype=torch.bfloat16): diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py index 529219a98c65..e287cb1bae55 100644 --- a/test/test_autocast_xla.py +++ b/test/test_autocast_xla.py @@ -6,7 +6,7 @@ import torch_xla.distributed.spmd.xla_sharding as xs -device = xm.xla_device() +device = torch_xla.device() class TestAutocastXla(unittest.TestCase): diff --git a/test/test_callback.py b/test/test_callback.py index 242fef6443ea..14ffbe4c8cd5 100644 --- a/test/test_callback.py +++ b/test/test_callback.py @@ -11,8 +11,8 @@ class TestExperimentalCallback(absltest.TestCase): @staticmethod @torch_xla.compile def executable(): - a, b = torch.randn((100, 100), device=torch_xla.device()), torch.randn( - (100, 100), device=torch_xla.device()) + a, b = torch.randn((100, 100), device="xla"), torch.randn((100, 100), + device="xla") return a @ b def test_callback(self): diff --git a/test/test_compilation_cache_utils.py b/test/test_compilation_cache_utils.py index 0fb12c32ae5b..0ac8a013d814 100644 --- a/test/test_compilation_cache_utils.py +++ b/test/test_compilation_cache_utils.py @@ -31,7 +31,7 @@ def _test_spawn(fn, args): class TestGraphHash(parameterized.TestCase): def _test_num_graph_hash(self, use_dynamo, use_persistent): - xla_dev = xm.xla_device() + xla_dev = torch_xla.device() model = M().to(device=xla_dev) input_shape = (10, 5) if use_dynamo: diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 6e2ac67e4f52..4247071e9ddf 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -46,7 +46,7 @@ def run_export_and_compare(testcase, atol=1e-3, rtol=1e-5, equal_nan=True): - device = xm.xla_device() + device = torch_xla.device() with testcase.subTest('torch_eval'): res = func(*args, **kwargs) with testcase.subTest('torch_xla_eval'): @@ -2932,7 +2932,7 @@ def test_aten_randperm_0(self): kwargs = dict() pytorch = torch.randperm(20) - xla = torch.randperm(20, device=xm.xla_device()) + xla = torch.randperm(20, device="xla") xla_detached = xla.detach().cpu() # Check equal lengths and that the sorted sets are equal. Since these numbers are randomly diff --git a/test/test_data_type.py b/test/test_data_type.py index ecd554c187ec..db71cfc5199a 100644 --- a/test/test_data_type.py +++ b/test/test_data_type.py @@ -29,8 +29,8 @@ def _set_env(self, **kwargs): os.environ[key] = value def _test_datatype(self, dtype, expected_type, op): - t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) - t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) + t1 = torch.tensor([2, 3], dtype=dtype, device="xla") + t2 = torch.tensor([2, 3], dtype=dtype, device="xla") t3 = op(t1, t2) self.assertEqual(t3.dtype, dtype) diff --git a/test/test_dynamic_shapes_detector.py b/test/test_dynamic_shapes_detector.py index 99a8423e45f5..e0483c387c08 100644 --- a/test/test_dynamic_shapes_detector.py +++ b/test/test_dynamic_shapes_detector.py @@ -50,7 +50,7 @@ def test_single(self): def foo(x): return x + x - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device="xla") self._run_and_compare(foo, args=(inp,), max_different_graphs=1) def test_many_graphs(self): @@ -70,7 +70,7 @@ def foo(x, step): return r * 4 return r0 - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device="xla") for i in range(6): self._run_and_compare(foo, args=(inp, i), max_different_graphs=4) @@ -84,7 +84,7 @@ def test_graph_limit_exceeded_different_input_shape(self): def foo(x): return x + x - inp1 = torch.rand(10, device=torch_xla.device()) + inp1 = torch.rand(10, device="xla") self._run_and_compare( foo, args=(inp1,), max_different_graphs=max_different_graphs) @@ -95,7 +95,7 @@ def foo(x): """) with self.assertRaisesRegex(RuntimeError, expected_error_msg): - inp2 = torch.rand(5, device=torch_xla.device()) + inp2 = torch.rand(5, device="xla") self._run_and_compare( foo, args=(inp2,), max_different_graphs=max_different_graphs) @@ -118,7 +118,7 @@ def foo(x, step): else: return x * 5 - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device="xla") self._run_and_compare( foo, args=(inp, 0), max_different_graphs=max_different_graphs) @@ -157,7 +157,7 @@ def foo(x, step): return r + x return r / 3 - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device="xla") self._run_and_compare( foo, args=(inp, 0), max_different_graphs=max_different_graphs) self._run_and_compare( @@ -194,7 +194,7 @@ def foo(x, mul=False): else: return r - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device="xla") self._run_and_compare( foo, args=(inp, True), max_different_graphs=max_different_graphs) @@ -231,7 +231,7 @@ def foo(x, step): return r + x return r - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device="xla") self._run_and_compare( foo, args=(inp, 0), max_different_graphs=max_different_graphs) self._run_and_compare( diff --git a/test/test_env_var_mapper.py b/test/test_env_var_mapper.py index 26cb2f0870eb..e4dcef2ba8cb 100644 --- a/test/test_env_var_mapper.py +++ b/test/test_env_var_mapper.py @@ -15,7 +15,7 @@ def check_env_flag(name, default=''): class EnvVarMapperTest(unittest.TestCase): def test_xla_ir_debug_(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() with xp.Trace('test_xla_ir_debug'): t = torch.tensor([2.0, 3.0], dtype=torch.float, device=xla_device) diff --git a/test/test_fsdp_auto_wrap.py b/test/test_fsdp_auto_wrap.py index 3ed373f98dcc..019612899697 100644 --- a/test/test_fsdp_auto_wrap.py +++ b/test/test_fsdp_auto_wrap.py @@ -35,7 +35,7 @@ def forward(self, x): "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" ) def test(self): - dev = xm.xla_device() + dev = torch_xla.device() input = torch.zeros([16, 16], device=dev) model = self.MyModel(input_size=16, hidden_size=4) model = XlaFullyShardedDataParallel( @@ -48,7 +48,7 @@ def test(self): def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_grad_checkpoint.py b/test/test_grad_checkpoint.py index 57826bbc6748..e4e318ba8310 100644 --- a/test/test_grad_checkpoint.py +++ b/test/test_grad_checkpoint.py @@ -11,7 +11,7 @@ def run(): - device = xm.xla_device() + device = torch_xla.device() model = torch.nn.ModuleList([ torch.nn.Sequential( torch.nn.Conv2d(1024, 1024, 1), diff --git a/test/test_gradient_accumulation.py b/test/test_gradient_accumulation.py index 6e431a4237d9..62ecfc431132 100644 --- a/test/test_gradient_accumulation.py +++ b/test/test_gradient_accumulation.py @@ -23,7 +23,7 @@ def forward(self, x): class GradAccumulationTest(XlaTestCase): def setUp(self): - self.device = xm.xla_device() + self.device = torch_xla.device() torch.manual_seed(123) def test_basic(self): diff --git a/test/test_hlo_metadata.py b/test/test_hlo_metadata.py index 82eebd9f3ada..803315513d55 100644 --- a/test/test_hlo_metadata.py +++ b/test/test_hlo_metadata.py @@ -78,10 +78,10 @@ def test_metadata(self): model = torch.nn.Sequential(layer1, nl1, layer2, nl2) with CustomOpNameLowering() as c: - model = model.to(device=xm.xla_device()) - inp = torch.rand(4, 4, device=xm.xla_device()) + model = model.to(device="xla") + inp = torch.rand(4, 4, device="xla") #inp = torch.rand(4, 4) - #inp = inp.to(device=xm.xla_device()) + #inp = inp.to(device="xla") out = model(inp) # Get outer frames diff --git a/test/test_inplace_update.py b/test/test_inplace_update.py index f68811ecd0de..704888d4f6e7 100644 --- a/test/test_inplace_update.py +++ b/test/test_inplace_update.py @@ -11,7 +11,7 @@ class InplaceUpdateTest(unittest.TestCase): def test_aten_op_after_full_update(self): - device = xm.xla_device() + device = torch_xla.device() t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t.zero_() @@ -21,7 +21,7 @@ def test_aten_op_after_full_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_aten_op_after_partial_update(self): - device = xm.xla_device() + device = torch_xla.device() t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t[0][0] = 0 @@ -31,7 +31,7 @@ def test_aten_op_after_partial_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_non_aten_op_after_full_update(self): - device = xm.xla_device() + device = torch_xla.device() t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t.zero_() @@ -41,7 +41,7 @@ def test_non_aten_op_after_full_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_non_aten_op_after_partial_update(self): - device = xm.xla_device() + device = torch_xla.device() t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t[0][0] = 0 @@ -53,7 +53,7 @@ def test_non_aten_op_after_partial_update(self): def test_xm_save(self): with temporary_env( XLA_DISABLE_FUNCTIONALIZATION="0", XLA_ENABLE_PARAM_ALIASING="0"): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor([1], device=xla_device) t2 = t1.detach() torch_xla.sync() diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index 906ffc326834..3f20f9d25c97 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -38,7 +38,7 @@ def config_context(value): class InputOutputAliasesTest(parameterized.TestCase): def test_non_view(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.randn(4, 2, 2).to(xla_device) t2 = torch.randn(4, 2, 2).to(xla_device) torch_xla.sync() @@ -53,7 +53,7 @@ def test_non_view(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) def test_aliasing_with_cloned(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.randn(4, 2, 2).to(xla_device) # t1_cloned share the same storage as t1 @@ -66,7 +66,7 @@ def test_aliasing_with_cloned(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) def test_aliasing_across_custom_inplace(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.randn(4, 5).to(xla_device) t1 *= t1 @@ -78,7 +78,7 @@ def test_aliasing_across_custom_inplace(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) def test_aliasing_across_sync(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.randn(4, 5).to(xla_device) t1 += 1 @@ -96,7 +96,7 @@ def test_aliasing_with_multiple_inplace_update(self): BLOCK_SIZE = 16 DTYPE = torch.bfloat16 num_blocks = 1024 - device = xm.xla_device() + device = torch_xla.device() key = torch.randn( BATCH_SIZE * SEQ_LEN, NUM_KV_HEADS, @@ -145,7 +145,7 @@ def try_grad_accum(model, device, train_x, train_label, accum_steps): torch_xla.sync() return [p.grad.to('cpu').numpy() for p in model.parameters()] - dev = xm.xla_device() + dev = torch_xla.device() train_x_sample = torch.rand((1, 28 * 28)) train_label_sample = torch.tensor([5]) c_model = MLP().to('cpu') @@ -171,7 +171,7 @@ def test_separate_graphs(self): """ Test that paramater aliasing differences should produce different graphs. """ - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.tensor([1], device=xla_device) t1 = torch.tensor([2], device=xla_device) torch_xla.sync() @@ -190,7 +190,7 @@ def test_xm_save_no_aliasing(self): """ Test that xm.save() does not perform aliasing. """ - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.tensor([1], device=xla_device) t1 = torch.tensor([2], device=xla_device) torch_xla.sync() @@ -212,7 +212,7 @@ def test_device_data_cache_no_aliasing(self): """ Test that device data in DataCache are not aliased. """ - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.tensor(42, device=xla_device) # drops the read-only bit on t0's device_data @@ -235,7 +235,7 @@ def test_device_data_cache_no_aliasing(self): def test_user_config_donation_with_ltc_donation(self): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) t1 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) @@ -255,7 +255,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync( self, enable_buffer_donor_config): with alias_with_buffer_donor_config_context(enable_buffer_donor_config): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) t1 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) @@ -279,7 +279,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync( def test_user_config_donation_with_ltc_donation_overlap(self): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -291,7 +291,7 @@ def test_user_config_donation_with_ltc_donation_overlap(self): def test_user_config_donation(self): with alias_with_buffer_donor_config_context(True): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -308,7 +308,7 @@ def test_user_config_donation(self): def test_user_config_donation_inplace_aliasing(self): with alias_with_buffer_donor_config_context(True): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -322,7 +322,7 @@ def test_user_config_donation_inplace_aliasing(self): def test_user_config_donation_no_op_sync(self): with alias_with_buffer_donor_config_context(True): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) torch_xla.sync() @@ -331,7 +331,7 @@ def test_user_config_donation_no_op_sync(self): self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) def test_no_op_sync_keep_buffer_donation(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() input = torch.randn(5, 5).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) torch_xla.sync() @@ -346,7 +346,7 @@ def test_device_data_node_tracing_aliasing(self): for a given set of unmutated input tensor during its tracing. This helps ensure that aliasings can be retained if using the binding for tracing purposes. """ - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.tensor(10).to(xla_device) t1 = t0 + 5 diff --git a/test/test_jax_interop.py b/test/test_jax_interop.py index e69821cfe219..5016462b982e 100644 --- a/test/test_jax_interop.py +++ b/test/test_jax_interop.py @@ -14,7 +14,7 @@ def setUp(self): def test_call_jax(self): """Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) def f(a, b): @@ -29,7 +29,7 @@ def f(a, b): def test_call_jax_input_pytree(self): """Test that call_jax works with PyTree inputs.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((2, 2), device=dev) b = torch.ones((2, 2), device=dev) * 2 @@ -55,7 +55,7 @@ def f(inputs): def test_call_jax_output_pytree(self): """Test that call_jax works with PyTree outputs.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((2, 2), device=dev) def f(a): @@ -89,7 +89,7 @@ def f(a): def test_call_jax_some_arg_unused(self): """Test when the jax function doesn't use some input arguments.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.randn((3, 3), device=dev) b = torch.randn((3, 3), device=dev) c = torch.randn((3, 3), device=dev) @@ -106,7 +106,7 @@ def f(a, b, c, d): def test_call_jax_grad(self): """Test calling a simple jax.grad transformed function.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.randn((3, 3), device=dev, requires_grad=True) b = torch.randn((3, 3), device=dev, requires_grad=True) torch_xla.sync() @@ -143,7 +143,7 @@ def f_jax(a, b): def test_call_jax_non_tensor_args(self): """Test that call_jax works with non-tensor arguments.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) def f(a, num: float, string: str, dictionary: dict, none): @@ -173,7 +173,7 @@ def test_call_jax_cache_hlo(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace two different jax functions a couple of times. - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) def f(a, b): @@ -198,7 +198,7 @@ def test_call_jax_cache_by_shape(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different shapes. - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) b = torch.ones((2, 2), device=dev) @@ -217,7 +217,7 @@ def test_call_jax_cache_by_tree_spec(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different tree specs. - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) b = torch.ones((3, 2), device=dev) @@ -237,7 +237,7 @@ def test_call_jax_cache_by_static_args(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different static args. - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) def f(a, num: float): @@ -255,7 +255,7 @@ def test_call_jax_with_different_jax_config(self): import jax starting_cache_misses = xb._jax_to_xla_computation_cache_elements() - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) def f(a, b): diff --git a/test/test_metrics.py b/test/test_metrics.py index f124784fbfee..ba2ececca408 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -24,7 +24,7 @@ def check_metrics_file(): class MetricsTest(unittest.TestCase): def test_clear_counters(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(100, device=xla_device) t1 += 2 self.assertIn("xla::add", met.metrics_report()) @@ -39,7 +39,7 @@ def test_clear_counters(self): assert (len(met.counter_names()) > 0) def test_clear_metrics(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(156, device=xla_device) self.assertIn("TensorToData", met.metrics_report()) assert (len(met.metric_names()) > 0) @@ -52,7 +52,7 @@ def test_clear_metrics(self): assert (len(met.metric_names()) > 0) def test_tracing_time_metrics(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.tensor(156, device=xla_device) t2 = t1 + 100 @@ -61,7 +61,7 @@ def test_tracing_time_metrics(self): def test_eager_metrics(self): with torch_xla.experimental.eager_mode_context(True): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.tensor(156, device=xla_device) t2 = t1 + 100 @@ -78,7 +78,7 @@ def test_eager_metrics(self): self.assertNotIn('ExecuteTime', met.metric_names()) def test_short_metrics_report_default_list(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(1456, device=xla_device) t2 = t1 * 2 torch_xla.sync() @@ -100,7 +100,7 @@ def test_short_metrics_report_default_list(self): assert check_metrics_file() def test_short_metrics_report_custom_list(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(100, device=xla_device) t2 = t1 * 2 t1 += 2 @@ -120,7 +120,7 @@ def test_short_metrics_report_custom_list(self): self.assertIn('InputOutputAliasCount', short_report) def test_short_metrics_fallback_counter(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(100, device=xla_device) t2 = t1 * 2 # this will trigger a aten::_local_scalar_dense which is the same as fallback counter @@ -135,7 +135,7 @@ def test_short_metrics_fallback_counter(self): def test_metrics_report(self): # TODO(jwtan): Add test to cover TrimIrGraph, SyncTensorsToData, TransferToDeviceAsync, IrValueTensorToXlaData - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(2077, device=xla_device) t2 = t1 * 2 torch_xla.sync() @@ -207,12 +207,12 @@ def test_metrics_report(self): @unittest.skipIf(xr.device_type() != "CPU", f"This test only works on CPU.") def test_execute_time_metric(self): # Initialize the client before starting the timer. - xm.xla_device() + torch_xla.device() begin = time.perf_counter_ns() value = torch.randn( - 10000, 10000, device=xm.xla_device()) * torch.randn( - 10000, 10000, device=xm.xla_device()) + 10000, 10000, device="xla") * torch.randn( + 10000, 10000, device="xla") value_mean = value.mean() torch_xla.sync() cpu_value = value_mean.cpu() @@ -226,7 +226,7 @@ def test_execute_time_metric(self): def test_pybind_increment_counter(self): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(2077, device=xla_device) self.assertEqual(met.counter_value('CreateXlaTensor'), 1) torch_xla._XLAC._xla_increment_counter('CreateXlaTensor', 3) @@ -254,10 +254,10 @@ def getAndAssertFallbackOpsLenEquals(count): # Create N boxes in the format XYXY. # This should not run any fallback ops. N = 10 - x = torch.rand(N, 1).to(xm.xla_device()) - y = torch.rand(N, 1).to(xm.xla_device()) - width = torch.rand(N, 1).to(xm.xla_device()) - height = torch.rand(N, 1).to(xm.xla_device()) + x = torch.rand(N, 1).to(torch_xla.device()) + y = torch.rand(N, 1).to(torch_xla.device()) + width = torch.rand(N, 1).to(torch_xla.device()) + height = torch.rand(N, 1).to(torch_xla.device()) xys = torch.cat((x, x + width, y, y - height), dim=1) getAndAssertFallbackOpsLenEquals(0) @@ -274,7 +274,7 @@ def getAndAssertFallbackOpsLenEquals(count): if not XLAExperimentalContains("nms"): # Run torchvision operations as fallback. import torchvision - scores = torch.rand(N).to(xm.xla_device()) + scores = torch.rand(N).to(torch_xla.device()) # NMS doesn't have a PyTorch/XLA implementation without dynamic shapes. torchvision.ops.nms(xys, scores, 0.5) # remove_small_boxes is not implemented in C++. It calls other PyTorch diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 8cf8a7a92170..93d64f47ef3e 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -11,7 +11,7 @@ def all_gather(tensor, dim): def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() input_list_size = 5 if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): diff --git a/test/test_mp_all_to_all.py b/test/test_mp_all_to_all.py index f7e4a2f0c084..9761507dea13 100644 --- a/test/test_mp_all_to_all.py +++ b/test/test_mp_all_to_all.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'NEURON'): slots_per_device = 4 size = slots_per_device * xr.world_size() diff --git a/test/test_mp_collective_matmul.py b/test/test_mp_collective_matmul.py index 7ebfd7d80f89..29f115c986cd 100644 --- a/test/test_mp_collective_matmul.py +++ b/test/test_mp_collective_matmul.py @@ -8,7 +8,7 @@ def _mp_fn(index): os.environ["ENABLE_COLLECTIVE_MATMUL_IN_MP"] = "1" - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() groups = [[i for i in range(world_size)]] scale = 1 / world_size diff --git a/test/test_mp_collective_permute.py b/test/test_mp_collective_permute.py index 79c7196ac5ab..81a1eb771bcd 100644 --- a/test/test_mp_collective_permute.py +++ b/test/test_mp_collective_permute.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ['TPU', 'NEURON']: world_size = xr.world_size() ordinal = xr.global_ordinal() diff --git a/test/test_mp_distributed_mm.py b/test/test_mp_distributed_mm.py index fd90398a7158..7d6c7982cb2f 100644 --- a/test/test_mp_distributed_mm.py +++ b/test/test_mp_distributed_mm.py @@ -7,7 +7,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA'): world_size = xr.world_size() diff --git a/test/test_mp_early_exit.py b/test/test_mp_early_exit.py index e8f411b9abab..89e46722e232 100644 --- a/test/test_mp_early_exit.py +++ b/test/test_mp_early_exit.py @@ -12,7 +12,7 @@ def _mp_fn(): dist.init_process_group('xla', init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ['TPU', 'CUDA']: train_loader = xu.SampleGenerator( data=torch.zeros(1, 12), sample_count=1024) diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 2b6d55bab596..bba65cde1ee8 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() scale = 1 / world_size scatter_dim = 1 diff --git a/test/test_mp_replication.py b/test/test_mp_replication.py index 5b3392f3c487..61a302a65784 100644 --- a/test/test_mp_replication.py +++ b/test/test_mp_replication.py @@ -10,7 +10,7 @@ def all_reduce(tensor): def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() if world_size > 1: ones = torch.ones((2, 3)) diff --git a/test/test_mp_save.py b/test/test_mp_save.py index 1a3696f9e76b..ae9f46df120a 100644 --- a/test/test_mp_save.py +++ b/test/test_mp_save.py @@ -35,7 +35,7 @@ def _get_data_str(data): def _mp_fn(index, temp_file): - device = xm.xla_device() + device = torch_xla.device() dd = _create_state_dict(device) xm.save(dd, temp_file) # User needs to manually rendezvous since only master process diff --git a/test/test_mp_sync_batch_norm.py b/test/test_mp_sync_batch_norm.py index 561b7976a83b..fa4f18ad00d2 100644 --- a/test/test_mp_sync_batch_norm.py +++ b/test/test_mp_sync_batch_norm.py @@ -47,7 +47,7 @@ def _sync_bn1d_no_channel(rank): t_global = torch.rand((xr.world_size() * bsz, length)) # XLA SyncBatchNorm - device = xm.xla_device() + device = torch_xla.device() t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(length).to(device) result = run_step(sbn_xla, t_xla) @@ -72,7 +72,7 @@ def _sync_bn1d_multi_channel(rank): t_global = torch.rand((xr.world_size() * bsz, features, length)) # XLA SyncBatchNorm - device = xm.xla_device() + device = torch_xla.device() t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) @@ -97,7 +97,7 @@ def _sync_bn2d(rank): t_global = torch.rand((xr.world_size() * bsz, features, h, w)) # XLA SyncBatchNorm - device = xm.xla_device() + device = torch_xla.device() t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) @@ -122,7 +122,7 @@ def _sync_bn3d(rank): t_global = torch.rand((xr.world_size() * bsz, features, d, h, w)) # XLA SyncBatchNorm - device = xm.xla_device() + device = torch_xla.device() t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) diff --git a/test/test_operations.py b/test/test_operations.py index b3f31e8a0f3b..3975f20c6dba 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -179,7 +179,7 @@ def onlyIfPJRTDeviceIsCUDA(fn): class TestToXlaTensorArena(test_utils.XlaTestCase): def test(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() kdata = [_gen_tensor(2, 3), _gen_tensor(3, 4)] kdata.append([_gen_tensor(2, 5), _gen_tensor(3, 6)]) @@ -307,7 +307,7 @@ def loop_fn(model, loader, device, context): class TestLongGraphChain(test_utils.XlaTestCase): def test(self): - device = xm.xla_device() + device = torch_xla.device() orig_x = torch.Tensor([[1, 2], [3, 4]]) orig_y = torch.Tensor([[0.1, 0.2], [0.3, 0.4]]) x = orig_x @@ -328,7 +328,7 @@ def test(self): class TestSelect(test_utils.XlaTestCase): def test_get_xla_tensor(self): - x = _gen_tensor(14, 24, 8, device=xm.xla_device()) + x = _gen_tensor(14, 24, 8, device="xla") t = x.data.cpu() sx = x.select(1, 12) tx = t.select(1, 12) @@ -343,7 +343,7 @@ def fn(tensor): # Call masked_fill. return tensor.masked_fill(mask, 10) - x = _gen_tensor(2, 2, device=xm.xla_device()) + x = _gen_tensor(2, 2, device="xla") x_cpu = x.cpu() self.assertEqual(fn(x_cpu), fn(x)) @@ -352,7 +352,7 @@ class TestRandom(test_utils.XlaTestCase): def test_random_from_to_bool(self): for from_val, to_val in [[0, 1], [0, 2], [1, 2]]: - x = _gen_tensor(10, device=xm.xla_device()) + x = _gen_tensor(10, device="xla") x.random_(from_val, to_val) delta = 1 self.assertTrue(from_val <= x.to(torch.int).min() < (from_val + delta)) @@ -416,20 +416,20 @@ def test_fn(x): class TestDynamicShape(test_utils.XlaTestCase): def test_nonzero_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device="xla") x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.nonzero(x, as_tuple=False), 0) self.assertEqual(x_dim0_shape.item(), 4) def test_masked_select_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device="xla") mask = x.ge(2) x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.masked_select(x, mask), 0) self.assertEqual(x_dim0_shape.item(), 3) def test_nonzero_cast(self): - t1 = torch.ones(5, 2, device=xm.xla_device()) + t1 = torch.ones(5, 2, device="xla") # Result of the nonzero should be the index type. Currently # index type is s64 on cpu and gpu, but s32 on TPU. We should be # able to cast it to any other type without error. @@ -440,7 +440,7 @@ def test_nonzero_cast(self): class TestOptimizationBarrier(test_utils.XlaTestCase): def test_optimization_barrier_correctness(self): - device = xm.xla_device() + device = torch_xla.device() # only test optimization_barrier on TPU if xm.xla_device_hw(device) != 'TPU': return @@ -459,7 +459,7 @@ def op_fn(a): return xb.Op.tuple((a, a.cast(xb.Type.BF16))) op = xor.register('test_mixed_dtype_tuple', op_fn) - xla_device = xm.xla_device() + xla_device = torch_xla.device() a_tensor = torch.randn([2, 3]).to(xla_device) a_result, a_cast = op(a_tensor) self.assertEqual(a_result.dtype, torch.float) @@ -476,14 +476,14 @@ def test_get_real_xla_devices(self): def test_negative_slice(self): t = _gen_tensor(32, 24, 32) - x = t.to(xm.xla_device()) + x = t.to(torch_xla.device()) t_slice = t[:, :, -1] x_slice = x[:, :, -1] self.assertEqual(t_slice.data, x_slice.data.cpu()) def test_negative_cat(self): t = _gen_tensor(2, 5, 3) - x = t.to(xm.xla_device()) + x = t.to(torch_xla.device()) t_cat = torch.cat([t, t], -1) x_cat = torch.cat([x, x], -1) self.assertEqual(t_cat.data, x_cat.data.cpu()) @@ -491,8 +491,8 @@ def test_negative_cat(self): def test_cat_empty_tensor(self): t = _gen_tensor(2, 5, 3) empty_tensor = torch.Tensor() - x = t.to(xm.xla_device()) - empty_tensor_xla = empty_tensor.to(xm.xla_device()) + x = t.to(torch_xla.device()) + empty_tensor_xla = empty_tensor.to(torch_xla.device()) t_cat = torch.cat([t, empty_tensor], 0) x_cat = torch.cat([x, empty_tensor_xla], 0) self.assertEqual(t_cat.data, x_cat.data.cpu()) @@ -530,7 +530,7 @@ def test_amp_foreach_non_finite_check_and_unscale_(self): found_inf_output0 = torch.tensor(0, dtype=torch.float32) found_inf_output1 = torch.tensor(1, dtype=torch.float32) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_grads0 = grads0.to(xla_device) xla_inv_scale = inv_scale.to(xla_device) xla_found_inf = found_inf.to(xla_device) @@ -550,9 +550,9 @@ def test_masked_fill_with_tensor(self): input = _gen_tensor(2, 5, 4, 3) mask = _gen_mask(input.size()) value = torch.tensor(42) - xla_input = input.to(xm.xla_device()) - xla_mask = mask.to(xm.xla_device()) - xla_value = value.to(xm.xla_device()) + xla_input = input.to(torch_xla.device()) + xla_mask = mask.to(torch_xla.device()) + xla_value = value.to(torch_xla.device()) result = torch.masked_fill(input, mask, value) xla_result = torch.masked_fill(xla_input, xla_mask, xla_value) self.assertEqual(input.data, xla_input.data.cpu()) @@ -571,63 +571,63 @@ def test_fn(a, b, m): def test_add_mixed_device(self): input = _gen_tensor(3, 800, 1066) - xla_input = input.to(xm.xla_device()) + xla_input = input.to(torch_xla.device()) output = input + 2 xla_output = xla_input + 2 self.assertEqual(output.data, xla_output.data.cpu()) def test_mul_mixed_device(self): input = _gen_tensor(3, 800, 1066) - xla_input = input.to(xm.xla_device()) + xla_input = input.to(torch_xla.device()) output = input * 2 xla_output = xla_input * 2 self.assertEqual(output.data, xla_output.data.cpu()) def test_sub_mixed_device(self): input = _gen_tensor(3, 800, 1066) - xla_input = input.to(xm.xla_device()) + xla_input = input.to(torch_xla.device()) output = input - 2 xla_output = xla_input - 2 self.assertEqual(output.data, xla_output.data.cpu()) def test_div_mixed_device(self): input = _gen_tensor(3, 800, 1066) - xla_input = input.to(xm.xla_device()) + xla_input = input.to(torch_xla.device()) output = input / 2 xla_output = xla_input / 2 self.assertEqual(output.data, xla_output.data.cpu()) def test_rand(self): - x = torch.rand(3, 5, device=xm.xla_device()) + x = torch.rand(3, 5, device="xla") self.assertEqual(x.device.type, 'xla') def test_randperm(self): - x = torch.randperm(3, device=xm.xla_device(), dtype=torch.int32) + x = torch.randperm(3, device="xla", dtype=torch.int32) self.assertEqual(x.device.type, 'xla') def test_randn_like(self): shape = (5, 1, 1) - x = torch.randn_like(torch.zeros(shape, device=xm.xla_device())) + x = torch.randn_like(torch.zeros(shape, device="xla")) self.assertEqual(x.device.type, 'xla') def test_rand_like(self): shape = (5, 1, 1) - x = torch.rand_like(torch.zeros(shape, device=xm.xla_device())) + x = torch.rand_like(torch.zeros(shape, device="xla")) self.assertEqual(x.device.type, 'xla') def test_randint_like(self): shape = (5, 1, 1) x = torch.randint_like( - torch.zeros(shape, device=xm.xla_device(), dtype=torch.uint8), 6, 10) + torch.zeros(shape, device="xla", dtype=torch.uint8), 6, 10) self.assertEqual(x.device.type, 'xla') def test_no_storage(self): - x = torch.randn(5, device=xm.xla_device()) + x = torch.randn(5, device="xla") self.assertRaises(Exception, x.device) def test_slice_copy(self): a = torch.rand(3, 3, 3) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_a = a.to(xla_device) shape = (4, 4, 4) b = a.new(*shape).zero_() @@ -638,7 +638,7 @@ def test_slice_copy(self): def test_slice_assign(self): a = torch.rand(3, 3, 3) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_a = a.to(xla_device) shape = (4, 4, 4) b = a.new(*shape).zero_() @@ -649,7 +649,7 @@ def test_slice_assign(self): def test_slice_stepped_assign(self): a = torch.ones((10, 4)) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_a = a.to(xla_device) a[:, 0::2] = 2 xla_a[:, 0::2] = 2 @@ -657,14 +657,14 @@ def test_slice_stepped_assign(self): def test_slice_stepped_other_assign(self): a = torch.ones((10, 4)) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_a = a.to(xla_device) a[:, 1::4] = 2 xla_a[:, 1::4] = 2 self.assertEqual(a.data, xla_a.data.cpu()) def test_ailing_slice(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.ones((1000, 324)).to(xla_device) xla_a = a.to(xla_device) w = a[:, 2::4] @@ -674,7 +674,7 @@ def test_ailing_slice(self): self.assertEqual(w.data, xla_w.data.cpu()) def test_slice_rnd_stepped_assign(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() size = 10 for s in range(0, size - 1): for e in range(1, size - s): @@ -686,12 +686,12 @@ def test_slice_rnd_stepped_assign(self): def test_arange_nan(self): with self.assertRaisesRegex(RuntimeError, r'unsupported range'): - a = torch.arange(-5, float('nan'), device=xm.xla_device()) + a = torch.arange(-5, float('nan'), device="xla") with self.assertRaisesRegex(RuntimeError, r'unsupported range'): - a = torch.arange(float('nan'), 5, device=xm.xla_device()) + a = torch.arange(float('nan'), 5, device="xla") def test_empty_advanced_indexing(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() base = torch.randn(2, 3, 4, 5) xla_base = base.to(device=xla_device) result = base[:, torch.empty(0, 6, dtype=torch.int64)] @@ -702,7 +702,7 @@ def test_empty_advanced_indexing(self): "grad_input produces wrong results after functionalization. pytorch/pytorch#91199" ) def test_empty_strided(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() m = nn.Conv1d(4, 6, kernel_size=3, groups=2) a = torch.rand(2, 4, 6, requires_grad=True) xla_m = copy.deepcopy(m).to(xla_device) @@ -730,13 +730,13 @@ def test_empty_strided(self): def test_clamp(self): a = torch.randn(3, 3) - xla_a = a.to(xm.xla_device()) + xla_a = a.to(torch_xla.device()) b = torch.clamp(a, max=3.4) xla_b = torch.clamp(xla_a, max=3.4) self.assertEqual(b.data, xla_b.data.cpu()) def test_rrelu_module(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.rand(1, 2, 2, requires_grad=True) xla_a = a.to(xla_device).detach() xla_a.requires_grad = True @@ -753,7 +753,7 @@ def test_rrelu_module(self): self.assertEqual(a.grad, xla_a.grad.cpu()) def test_max_broadcast(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.rand(3, 1, 2) b = torch.rand(4, 2) c = torch.max(a, b) @@ -763,7 +763,7 @@ def test_max_broadcast(self): self.assertEqual(c.data, xla_c.data.cpu()) def test_sgn(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t = torch.randn(2, 3, dtype=torch.cfloat) # Generate inf+infj t[0][0].real.div_(0) @@ -847,7 +847,7 @@ def test_view_as_complex_f64(self): torch_xla._XLAC._get_xla_tensors_text([complex]).split('\n')[-3]) def test_index_put(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32) b = torch.rand(4) > 0.1 a[b] = 10 @@ -891,7 +891,7 @@ def test_baddmm_integer_types(self): def test_view_empty(self): # These used to throw floating point exception. - empty = torch.empty(0, device=xm.xla_device()) + empty = torch.empty(0, device="xla") with self.assertRaisesRegex( RuntimeError, r'unspecified dimension size -1 can be any value'): empty.view(-1, 0) @@ -912,12 +912,12 @@ def test_fn(device): return loss, linear.weight.grad cpu_loss, cpu_weight_grad = test_fn('cpu') - xla_loss, xla_weight_grad = test_fn(xm.xla_device()) + xla_loss, xla_weight_grad = test_fn(torch_xla.device()) self.assertEqual(cpu_loss, xla_loss) self.assertEqual(cpu_weight_grad, xla_weight_grad) def test_inplace_view_backprop_base(self): - root = torch.randn(2, 2, device=xm.xla_device(), requires_grad=True) + root = torch.randn(2, 2, device="xla", requires_grad=True) x = root.clone() v1 = x.narrow(0, 0, 1) v1.mul_(2) @@ -925,7 +925,7 @@ def test_inplace_view_backprop_base(self): self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]]) def test_inplace_view_backprop_view_of_view(self): - root = torch.randn(2, 2, device=xm.xla_device(), requires_grad=True) + root = torch.randn(2, 2, device="xla", requires_grad=True) x = root.clone() v1 = x.narrow(0, 0, 1) v2 = x.narrow(0, 0, 1) @@ -935,7 +935,7 @@ def test_inplace_view_backprop_view_of_view(self): def test_inplace_view_of_view(self): # modify view-of-view and backprop through base - root = torch.randn(2, 2, device=xm.xla_device(), requires_grad=True) + root = torch.randn(2, 2, device="xla", requires_grad=True) x = root.clone() v1 = x.narrow(0, 0, 1) v2 = v1.narrow(1, 1, 1) @@ -944,8 +944,7 @@ def test_inplace_view_of_view(self): self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1]]) def test_inplace_view_multiple_outputs(self): - root = torch.arange( - 9., device=xm.xla_device()).reshape(3, 3).requires_grad_() + root = torch.arange(9., device="xla").reshape(3, 3).requires_grad_() x = root.clone() v1 = x.unbind() with self.assertRaises(RuntimeError): @@ -986,7 +985,7 @@ def func(root, b): def test_inplace_view_backprop_view(self): # modify view and backprop through view - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.tensor([2., 5.], device=xla_device, requires_grad=False) b = torch.tensor([3.], device=xla_device, requires_grad=True) res = a.narrow(0, 1, 1).mul_(b) @@ -1040,7 +1039,7 @@ def func(root, b): def test_inplace_view_non_contig(self): root = torch.ones( - 2, 3, 2, device=xm.xla_device()).select(2, 1).t().requires_grad_(True) + 2, 3, 2, device="xla").select(2, 1).t().requires_grad_(True) x = root.clone() v1 = x.narrow(0, 0, 1) v2 = v1.narrow(1, 1, 1) @@ -1079,12 +1078,12 @@ def func(x): def test_set(self): met.clear_all() - t1 = torch.zeros(50, device=xm.xla_device()) + t1 = torch.zeros(50, device="xla") t1 += 1 torch_xla.sync() self.assertEqual(met.counter_value('DestroyXlaTensor'), 3) - t2 = torch.zeros(10, device=xm.xla_device()) + t2 = torch.zeros(10, device="xla") self.assertEqual(met.counter_value('DestroyXlaTensor'), 4) t1.set_(t2) @@ -1097,12 +1096,12 @@ def test_set(self): def test_replace_xla_tensor(self): met.clear_all() - t1 = torch.zeros(50, device=xm.xla_device()) + t1 = torch.zeros(50, device="xla") t1 += 1 torch_xla.sync() self.assertEqual(met.counter_value('DestroyXlaTensor'), 3) - t2 = torch.zeros(10, device=xm.xla_device()) + t2 = torch.zeros(10, device="xla") self.assertEqual(met.counter_value('DestroyXlaTensor'), 4) torch_xla._XLAC._replace_xla_tensor(t1, t2) self.assertEqual(met.counter_value('DestroyXlaTensor'), 5) @@ -1111,7 +1110,7 @@ def test_replace_xla_tensor(self): self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10))) def test_pred_type(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.rand(4) b = torch.rand(4) xla_a = a.to(xla_device) @@ -1133,7 +1132,7 @@ def test_pred_type(self): self.runAtenTest(c, lambda x: x ^ x.byte()) def test_bitwise_and_not(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.randint(255, (4,), dtype=torch.long) xla_a = a.to(xla_device) @@ -1143,27 +1142,27 @@ def test_fn(a): self.runAtenTest(a, test_fn) def test_s_copy_dtype(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.rand(10).to(xla_device).to(dtype=torch.uint8) b = torch.tensor([0, 1, 2, 3]).to(xla_device) self.assertEqual(a[b].dtype, torch.uint8) def test_slice_zero_sized_dim(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() v = torch.randn(2, 3, 4, 5).to(xla_device) y = v[:, :, :, 1] z = y[:, 1:1, :] self.assertEqual(z.size()[1], 0) def test_byte_dtype(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.ByteTensor([0, 1]).to(xla_device) y = torch.ByteTensor([0, 1]).to(xla_device) z = x + y self.assertEqual(z.dtype, torch.uint8) def test_frac_negative(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.tensor(-3.2) b = a.frac() xla_a = a.to(xla_device) @@ -1171,7 +1170,7 @@ def test_frac_negative(self): self.assertEqual(b, xla_b) def test_flip(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) self.assertEqual( torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0)) @@ -1194,7 +1193,7 @@ def test_flip(self): torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0)) def test_flip_check_throws(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) # not allow flip on the same dim more than once self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1)) @@ -1206,7 +1205,7 @@ def test_flip_check_throws(self): self.assertRaises(RuntimeError, lambda: data.flip(3)) def test_flip_expand(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2) transposed_data = torch.arange( @@ -1218,7 +1217,7 @@ def test_flip_expand(self): transposed_data.flip(0, 1, 2)) def test_flip_shape(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.randn(2, 3, 4, device=device) size = [2, 3, 4] test_dims = [] @@ -1228,7 +1227,7 @@ def test_flip_shape(self): self.assertEqual(size, list(data.flip(ds).size())) def test_flip_rectangular(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3).to(device) flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]]).to(device) flip1_result = torch.tensor([[3, 2, 1], [6, 5, 4]]).to(device) @@ -1237,13 +1236,13 @@ def test_flip_rectangular(self): self.assertEqual(flip1_result, data.flip(1)) def test_flip_empty_tensor(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.tensor([]) self.assertEqual(data, data.flip(0)) def test_norm_p0(self): # p = 0 is equivalent to nonzero - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.randn(3, 2) xla_a = a.to(xla_device) norm = a.norm(p=0) @@ -1289,7 +1288,7 @@ def test_fn(input, src): self.runAtenTest([torch.zeros(3, 3), torch.ones(3)], test_fn) def test_scatter_add_bool(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]]) b = torch.zeros(3, 5, dtype=torch.bool) @@ -1334,7 +1333,7 @@ def test_reduction_0dim(self): self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.mean(x)) self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.prod(x)) # min & max throws - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.rand(2, 0, 4) xla_a = a.to(xla_device) self.assertRaises(IndexError, lambda: torch.max(a, dim=1)) @@ -1470,11 +1469,11 @@ def check(device): d = a xm.check_view_sharing([a, d]) - check(xm.xla_device()) + check(torch_xla.device()) check(torch.device('cpu')) def test_save(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.randn(5, device=xla_device) with tempfile.NamedTemporaryFile() as tf: torch.save(x, tf) @@ -1482,7 +1481,7 @@ def test_save(self): self.assertEqual(x, x_loaded) def test_save_bf16(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.randn(5, dtype=torch.bfloat16, device=xla_device) with tempfile.NamedTemporaryFile() as tf: torch.save(x, tf) @@ -1490,7 +1489,7 @@ def test_save_bf16(self): self.assertEqual(x, x_loaded) def test_save_tuple(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.randn(5, device=xla_device) number = 3 with tempfile.NamedTemporaryFile() as tf: @@ -1500,7 +1499,7 @@ def test_save_tuple(self): self.assertEqual(number, number_loaded) def test_save_api(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() model = XlaMNIST().to(xla_device) with tempfile.NamedTemporaryFile() as tf: xm.save(model.state_dict(), tf) @@ -1513,7 +1512,7 @@ def test_save_api(self): def test_serialization_api(self): with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, 'data.pt') - xla_device = xm.xla_device() + xla_device = torch_xla.device() model = XlaMNIST().to(xla_device) xser.save(model.state_dict(), path) state_dict = xser.load(path) @@ -1523,7 +1522,7 @@ def test_serialization_api(self): self.assertEqual(model.state_dict(), loaded_model.state_dict()) def test_deepcopy(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.rand(5, device=xla_device) x0 = x[0] y = copy.deepcopy(x) @@ -1533,7 +1532,7 @@ def test_deepcopy(self): self.assertEqual(x[0], x0) def test_print(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.tensor([5], device=xla_device) expected_str = 'tensor([5], device=\'' + str(xla_device) + '\')' self.assertEqual(str(x), expected_str) @@ -1728,14 +1727,14 @@ def test_fn(t): self.runAtenTest([torch.tensor(20.0)], test_fn) def test_view_and_copy_(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5, 6.5], device='cpu') y = torch.tensor([0, 0, 0, 0, 0, 0], device=xla_device) y[::2].copy_(x[::2]) self.assertEqual(y, [1, 0, 3, 0, 5, 0]) def test_view_and_multi_sync(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.zeros(100, device=xla_device) t1[10] = 113 torch_xla.sync() @@ -1745,7 +1744,7 @@ def test_view_and_multi_sync(self): torch_xla._XLAC._get_xla_tensors_text([t1])) def test_binaryop_order(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.rand(5, device=xla_device) y = torch.rand(5) self.assertEqual(x + y, y + x) @@ -1753,14 +1752,14 @@ def test_binaryop_order(self): # Since in eager mode the tensor would be materialized and hence _get_xla_tensors_text would not show the prim::Constant node. @skipOnEagerDebug def test_pow_constant(self): - t1 = torch.pow(torch.tensor([2.0, 3.0], device=xm.xla_device()), 5) + t1 = torch.pow(torch.tensor([2.0, 3.0], device="xla"), 5) hlo_text = torch_xla._XLAC._get_xla_tensors_text([t1]) const_hlo = hlo_text.split('\n')[1] assert 'prim::Constant' in const_hlo assert 'xla::device_data' not in const_hlo def test_emb_bf16(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() index = torch.ones(1, dtype=torch.long, device=xla_device) emb = torch.nn.Embedding(1024, 128, device=xla_device) emb = emb.to(torch.bfloat16) @@ -1780,7 +1779,7 @@ def test_on_device(device): return m(index) out = test_on_device("cpu") - out_x = test_on_device(xm.xla_device()) + out_x = test_on_device(torch_xla.device()) self.assertEqual(out, out_x.cpu()) def test_transpose_1d(self): @@ -1799,7 +1798,7 @@ def test_fn(t1): def test_sigmoid_bounds(self): torch.manual_seed(0) - xla_device = xm.xla_device() + xla_device = torch_xla.device() for _ in range(100): x = torch.rand(1000).to(xla_device) lower_bound = torch.sigmoid(x * (-100.0)) @@ -1816,7 +1815,7 @@ def test_manual_seed(self): self.assertTrue(torch.allclose(t1.cpu(), t2.cpu())) def test_cached_addcdiv(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.randn(1, 3).to(xla_device) @@ -1834,7 +1833,7 @@ def test_cached_addcdiv(self): @skipOnEagerDebug def test_print_execution(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() torch_xla.sync() xm.wait_device_ops() met.clear_all() @@ -1888,7 +1887,7 @@ def test_fn(input): return dropped[1].cpu(), input.grad.cpu() met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() input_cpu = torch.randn(7, 7, requires_grad=True) input_xla = torch.randn(7, 7, device=xla_device, requires_grad=True) mask_cpu, grad_cpu = test_fn(input_cpu) @@ -2046,7 +2045,7 @@ def foo(x): x = torch.arange(10).to(dtype) r = foo(x) - device = xm.xla_device() + device = torch_xla.device() Xx = x.to(device) Xr = foo(Xx) @@ -2089,8 +2088,8 @@ def foo(grad, inp): grad = torch.rand(10, 10, dtype=torch.bfloat16) inp = torch.rand(10, 10) - Xgrad = grad.to(xm.xla_device()) - Xinp = inp.to(xm.xla_device()) + Xgrad = grad.to(torch_xla.device()) + Xinp = inp.to(torch_xla.device()) r = foo(grad, inp) Xr = foo(Xgrad, Xinp) @@ -2105,8 +2104,8 @@ def foo(t): t = torch.rand(10, 10, requires_grad=True, dtype=torch.bfloat16) t.retain_grad() t.grad = torch.rand(10, 10, dtype=torch.bfloat16) - xt = t.to(xm.xla_device()) - xt.grad = t.grad.to(xm.xla_device(), dtype=torch.bfloat16) + xt = t.to(torch_xla.device()) + xt.grad = t.grad.to(torch_xla.device(), dtype=torch.bfloat16) foo(t) foo(xt) @@ -2116,7 +2115,7 @@ def foo(t): def test_clip_grad_norm_zero(self): t = torch.rand(10, 10, dtype=torch.bfloat16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) result = torch.nn.utils.clip_grad_norm_(xt, 1.0) self.assertEqual(result.device.type, 'xla') self.assertTrue(torch.allclose(result.cpu(), torch.tensor(0.))) @@ -2129,8 +2128,8 @@ def foo(t0, t1): t0 = torch.rand(10, 10, dtype=torch.bfloat16) t1 = torch.rand(10, 10) - Xt0 = t0.to(xm.xla_device()) - Xt1 = t1.to(xm.xla_device()) + Xt0 = t0.to(torch_xla.device()) + Xt1 = t1.to(torch_xla.device()) r = foo(t0, t1) Xr = foo(Xt0, Xt1) @@ -2171,8 +2170,8 @@ def test(f, xshape, ishapes): x = make_tensor(xshape) ilist = [make_index(s) for s in ishapes] - Xx = x.to(xm.xla_device()) - Xilist = [i.to(xm.xla_device()) for i in ilist] + Xx = x.to(torch_xla.device()) + Xilist = [i.to(torch_xla.device()) for i in ilist] out = f(x, *ilist) Xout = f(Xx, *Xilist) @@ -2212,8 +2211,8 @@ def fn(inp, s): inp = torch.rand(10, dtype=torch.half) s = torch.tensor(7, dtype=torch.double) - Xinp = inp.to(xm.xla_device()) - Xs = s.to(xm.xla_device()) + Xinp = inp.to(torch_xla.device()) + Xs = s.to(torch_xla.device()) out = fn(inp, s) Xout = fn(Xinp, Xs) @@ -2267,7 +2266,7 @@ def foo(x, is_xla=False): return r + 5 inp = torch.rand(1, 3, 10, 10, dtype=torch.double) - Xinp = inp.to(xm.xla_device()) + Xinp = inp.to(torch_xla.device()) out = foo(inp) Xout = foo(Xinp, is_xla=True) @@ -2332,8 +2331,7 @@ def clone_and_maybe_move(tensor, device=None): with self.subTest(sparse=sparse, mode=mode): kwargs_ = {k: clone_and_maybe_move(v) for k, v in kwargs.items()} xla_kwargs = { - k: clone_and_maybe_move(v, device=xm.xla_device()) - for k, v in kwargs.items() + k: clone_and_maybe_move(v, device="xla") for k, v in kwargs.items() } expected_out, expected_grad = fn(**kwargs_, **extra_kwargs) @@ -2365,7 +2363,7 @@ def foo(x: torch.Tensor) -> torch.Tensor: input = torch.rand((10, 10), dtype=torch.float16) out = foo(input) - in_xla = input.to(xm.xla_device()) + in_xla = input.to(torch_xla.device()) out_xla = foo(in_xla) self.assertEqual(out.dtype, out_xla.dtype) @@ -2381,7 +2379,7 @@ def test_cummax_0_sized_dimension(self): a = torch.rand(5, 5, 0, 5) expected = torch.cummax(a, dim) - actual = torch.cummax(a.to(xm.xla_device()), dim) + actual = torch.cummax(a.to(torch_xla.device()), dim) self.assertEqual(actual, expected) @@ -2395,7 +2393,7 @@ def run(device): return runf(*args_) actual = run("cpu") - expected = run(xm.xla_device()) + expected = run(torch_xla.device()) self.assertFalse( met.executed_fallback_ops(), msg="expected no fallback operations.") @@ -2454,7 +2452,7 @@ class TestModelComparator(test_utils.XlaTestCase): def test(self): SEED = 42 - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = _gen_tensor(8, 1, 28, 28) xla_x = x.to(xla_device) @@ -2479,13 +2477,13 @@ def test(self): class TestWaitDeviceOps(test_utils.XlaTestCase): def test_wait_device_ops(self): - xm.xla_device() - value = torch.randn(10000, 10000, device=xm.xla_device()) + torch_xla.device() + value = torch.randn(10000, 10000, device="xla") val_list = [] val_mean_list = [] met.clear_all() for _ in range(5): - new_val = value * torch.randn(10000, 10000, device=xm.xla_device()) + new_val = value * torch.randn(10000, 10000, device="xla") val_list.append(new_val) val_mean_list.append(new_val.mean()) torch_xla.sync() @@ -2498,7 +2496,7 @@ class TestDebuggingUtil(test_utils.XlaTestCase): @skipOnEagerDebug def test_get_xla_tensor_debug_info(self): - device = xm.xla_device() + device = torch_xla.device() # test non xla tensor cpu_t1 = torch.randn(5) cpu_t1_info = torch_xla._XLAC._get_xla_tensor_debug_info(cpu_t1) @@ -2533,7 +2531,7 @@ def runOpBuilderTest(self, kwargs=dict()): op = xor.register(name, opfn) if device is None: - device = xm.xla_device() + device = torch_xla.device() if aten_fn is None: aten_fn = opfn tensors = xu.as_list(tensors) @@ -2655,7 +2653,7 @@ class MpDecoratorTest(test_utils.XlaTestCase): @xtu.mp_test def test_mp_decorator(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() self.assertTrue(xla_device.type == 'xla') @@ -2694,7 +2692,7 @@ class TestLoweringContext(test_utils.XlaTestCase): def test_api(self): met.clear_all() - device = xm.xla_device() + device = torch_xla.device() a = torch.tensor([1.0, 2.0, 3.0], device=device) b = torch.tensor([4.0, 5.0, 6.0], device=device) @@ -2755,13 +2753,13 @@ def test_git_revisons(self): self.assertTrue('torch' in revs) def test_send_to_device_grad(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t = _gen_tensor(2, 2, requires_grad=True) dt = xm.send_cpu_data_to_device([t], xla_device) self.assertTrue(dt[0].requires_grad) def test_send_to_device_single(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t = _gen_tensor(2, 2) dt = xm.send_cpu_data_to_device(t, xla_device) self.assertEqual(dt[0].device, xla_device) @@ -2861,7 +2859,7 @@ def from_tensors(self, tensors): wpack = PackWrapper(pack) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xdata = xm.send_cpu_data_to_device(wpack, xla_device) self.assertTrue(isinstance(xdata, nn.utils.rnn.PackedSequence)) self.assertEqual(xdata.batch_sizes.device, torch.device('cpu')) @@ -2871,7 +2869,7 @@ def from_tensors(self, tensors): "https://github.com/pytorch/xla/pull/7864#issuecomment-2294034008") def test_as_strided_input_larger(self): size = (5, 5) - device = xm.xla_device() + device = torch_xla.device() a = torch.ones(size, device=device) small_a = a[:, ::2] @@ -2883,7 +2881,7 @@ def _test_move_tensor_cuda_to_xla(self, cpu_tensor): # Assumes CPU-XLA data movement works. cuda_tensor = cpu_tensor.to("cuda") # Move tensor CUDA -> XLA. - xla_tensor = cuda_tensor.to(xm.xla_device()) + xla_tensor = cuda_tensor.to(torch_xla.device()) # Move the XLA tensor back to CPU, and check that it is the same as # the original CPU tensor. self.assertTrue(torch.equal(cpu_tensor, xla_tensor.cpu())) @@ -2901,7 +2899,7 @@ def test_aten_move_scalar_cuda_to_xla(self): self._test_move_tensor_cuda_to_xla(torch.tensor(42)) def test_unsafe_buffer_pointer(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_tensor_0 = torch.tensor(42).to(xla_device) # `torch_xla.sync()` ensures xtensor->CurrentDataHandle() != nullptr torch_xla.sync() @@ -2909,7 +2907,7 @@ def test_unsafe_buffer_pointer(self): self.assertGreaterEqual(buf_ptr_0, 0) # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr - xla_tensor_1 = torch.tensor(42, device=xm.xla_device()) + xla_tensor_1 = torch.tensor(42, device="xla") buf_ptr_1 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_1) self.assertGreaterEqual(buf_ptr_1, 0) @@ -2918,7 +2916,7 @@ def test_unsafe_buffer_pointer(self): buf_ptr_2 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_2) self.assertGreaterEqual(buf_ptr_2, 0) - xla_tensor_3 = torch.arange(5, device=xm.xla_device()) + xla_tensor_3 = torch.arange(5, device="xla") torch_xla.sync() # Without the `wait_device_ops()`, the pjrt buffer (pjrt_data->buffer) at https://github.com/pytorch/xla/blob/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd/torch_xla/csrc/runtime/pjrt_computation_client.cc#L467 will be nullptr. xm.wait_device_ops() @@ -2946,14 +2944,14 @@ def _test_dlpack_capsule_conversion_helper(self, xla_tensor): @onlyIfPJRTDeviceIsCUDA @parameterized.parameters(*all_types_and(torch.half, torch.bfloat16)) def test_dlpack_roundtrip_tensor(self, dtype): - xla_device = xm.xla_device() + xla_device = torch_xla.device() # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr # xla_tensor_2 uses XLANativeFunctions::_to_copy xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device) self._test_dlpack_capsule_conversion_helper(xla_tensor_2) # xla_tensor_3 uses arange_out IR node. - xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device()) + xla_tensor_3 = torch.arange(5, dtype=dtype, device="xla") torch_xla.sync() self._test_dlpack_capsule_conversion_helper(xla_tensor_3) @@ -2963,7 +2961,7 @@ def test_dlpack_roundtrip_tensor(self, dtype): *all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64)) def test_dlpack_roundtrip_scalar(self, dtype): - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) # `torch_xla.sync()` ensures xtensor->CurrentDataHandle() != nullptr torch_xla.sync() @@ -2976,7 +2974,7 @@ def test_dlpack_roundtrip_scalar(self, dtype): @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_roundtrip_bool(self): - xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device()) + xla_tensor = torch.ones(1, dtype=torch.bool).to(torch_xla.device()) self._test_dlpack_capsule_conversion_helper(xla_tensor) @onlyIfTorchSupportsCUDA @@ -3044,7 +3042,7 @@ def test_dlpack_pytorch_cuda_to_xla_protocol_conversion(self): @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_xla_to_pytorch_cuda(self): - xla_t1 = torch.arange(5).to(xm.xla_device()) + xla_t1 = torch.arange(5).to(torch_xla.device()) dlt1 = xdlpack.to_dlpack(xla_t1) cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1) self.assertEqual(cuda_t1.device.type, 'cuda') @@ -3055,7 +3053,7 @@ def test_dlpack_xla_to_pytorch_cuda(self): @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_xla_to_pytorch_cuda_protocol_conversion(self): - xla_t1 = torch.arange(5).to(xm.xla_device()) + xla_t1 = torch.arange(5).to(torch_xla.device()) cuda_t1 = torch.utils.dlpack.from_dlpack(xla_t1) self.assertEqual(cuda_t1.device.type, 'cuda') self.assertEqual(cuda_t1.device.index, xla_t1.device.index) @@ -3120,7 +3118,7 @@ def forward(self, inp): class TestActivationCheckpoint(test_utils.XlaTestCase): def test_dropout(self): - device = xm.xla_device() + device = torch_xla.device() model = SimpleModelWithDropout().to(device) model = checkpoint_module(model) _input = torch.randn(128, 128, requires_grad=True) @@ -3134,7 +3132,7 @@ def test_dropout(self): f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}") def test_opt_barrier(self): - device = xm.xla_device() + device = torch_xla.device() model = SimpleModelWithDropout().to(device) model = checkpoint_module(model) _input = torch.randn(128, 128, requires_grad=True) @@ -3169,7 +3167,7 @@ def _reference_nms(self, boxes, scores, iou_threshold): def _nms(self, boxes, scores, iou_threshold): import torchvision - device = xm.xla_device() + device = torch_xla.device() return torchvision.ops.nms( boxes.to(device), scores.to(device), iou_threshold).cpu() diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index 25e17b7c265d..3001865b98b9 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -30,15 +30,15 @@ def tearDown(self): super(TestOperationsHlo, self).tearDown() def test_expand(self): - a = torch.rand(1, 5, device=xm.xla_device()) + a = torch.rand(1, 5, device="xla") b = a.expand(5, 5) hlo_text = torch_xla._XLAC._get_xla_tensors_text([b]) assert 'aten::expand' in hlo_text def test_special_scalars_addcdiv_addcmul(self): - a = torch.rand(5, 5).to(xm.xla_device()) - b = torch.rand(5, 5).to(xm.xla_device()) - c = torch.rand(5, 5).to(xm.xla_device()) + a = torch.rand(5, 5).to(torch_xla.device()) + b = torch.rand(5, 5).to(torch_xla.device()) + c = torch.rand(5, 5).to(torch_xla.device()) for op in [torch.addcdiv, torch.addcmul]: out = op(a, b, c, value=1.0) hlo_text = torch_xla._XLAC._get_xla_tensors_text([out]) @@ -52,8 +52,8 @@ def test_special_scalars_addcdiv_addcmul(self): def test_div_by_f64(self): mod = torch.nn.MultiheadAttention(768, 12, batch_first=True) - mod.to(xm.xla_device()) - a = torch.rand(1, 512, 768).to(xm.xla_device()) + mod.to(torch_xla.device()) + a = torch.rand(1, 512, 768).to(torch_xla.device()) b, _ = mod(a, a, a, need_weights=False) b.sum().backward() hlo_text = torch_xla._XLAC._get_xla_tensors_text( @@ -61,8 +61,8 @@ def test_div_by_f64(self): assert 'f64' not in hlo_text def test_dropout_by_u8_mask(self): - mod = torch.nn.Dropout().to(xm.xla_device()) - a = torch.rand(20, 16, dtype=torch.bfloat16).to(xm.xla_device()) + mod = torch.nn.Dropout().to(torch_xla.device()) + a = torch.rand(20, 16, dtype=torch.bfloat16).to(torch_xla.device()) b = mod(a) hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b]) assert 'u8' in hlo_text diff --git a/test/test_persistent_cache.py b/test/test_persistent_cache.py index 75a739e64638..ccc5c8a568c7 100644 --- a/test/test_persistent_cache.py +++ b/test/test_persistent_cache.py @@ -51,14 +51,14 @@ def _mp_test(rank, tmpdir, metrics): xr.initialize_cache(os.path.join(tmpdir, str(rank))) t = torch.randn(16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) _assert_correctness_and_metrics(t, xt, metrics) def _single_device_test(tmpdir, metrics): xr.initialize_cache(tmpdir) t = torch.randn(16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) _assert_correctness_and_metrics(t, xt, metrics) @@ -66,7 +66,7 @@ def _spmd_replicated_test(tmpdir, metrics): xr.initialize_cache(tmpdir) xr.use_spmd() t = torch.randn(16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) _assert_correctness_and_metrics(t, xt, metrics) @@ -74,7 +74,7 @@ def _spmd_explicitly_replicated_test(tmpdir, metrics): xr.initialize_cache(tmpdir) xr.use_spmd() t = torch.randn(16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) n_dev = xr.global_runtime_device_count() mesh = xs.Mesh(range(n_dev), (n_dev,)) @@ -87,7 +87,7 @@ def _spmd_sharded_test(tmpdir, metrics): xr.use_spmd() t = torch.randn(16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) n_dev = xr.global_runtime_device_count() mesh = xs.Mesh(range(n_dev), (n_dev,)) xs.mark_sharding(xt, mesh, (0,)) diff --git a/test/test_profile_mp_mnist.py b/test/test_profile_mp_mnist.py index 266a7bc5634f..e23c2f59c223 100644 --- a/test/test_profile_mp_mnist.py +++ b/test/test_profile_mp_mnist.py @@ -144,7 +144,7 @@ def train_mnist(flags, # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = xm.xla_device() + device = torch_xla.device() model = MNIST().to(device) writer = None if xm.is_master_ordinal(): diff --git a/test/test_python_ops.py b/test/test_python_ops.py index 24e48d9a1664..9dc145947f62 100644 --- a/test/test_python_ops.py +++ b/test/test_python_ops.py @@ -29,8 +29,8 @@ def test_put(self, dtype): raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( str(dtype))) - device = xm.xla_device() - real_device_type = xm.xla_device_hw(str(xm.xla_device())) + device = torch_xla.device() + real_device_type = xm.xla_device_hw(str(torch_xla.device())) if real_device_type == "TPU": raise unittest.SkipTest("TestPut is too slow on TPU. Skipped") @@ -108,7 +108,7 @@ def test_index_copy(self, dtype): raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( str(dtype))) - device = xm.xla_device() + device = torch_xla.device() # We just test for num_copy <= num_dest, as otherwise there are repeated indices # and the behavior is undefined diff --git a/test/test_syncfree_optimizers.py b/test/test_syncfree_optimizers.py index 991bf8e5c936..8807271440c6 100644 --- a/test/test_syncfree_optimizers.py +++ b/test/test_syncfree_optimizers.py @@ -53,7 +53,7 @@ def _test_optimizer(self, syncfree_optim_cls, ref_optim_cls, optim_kwargs={'lr': 1e-2}): - device = xm.xla_device() + device = torch_xla.device() loss_fn = nn.NLLLoss() # syncfree model torch.manual_seed(0) diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index f492fcef3334..b7b37e31e004 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -7,7 +7,7 @@ def _mp_fn(index): - dev = xm.xla_device() + dev = torch_xla.device() if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'): print( 'Default device {} is not a TPU or CUDA device'.format(dev), @@ -19,7 +19,7 @@ def _mp_fn(index): model = FSDP(model) # wrapping the linear module with FSDP - input = torch.rand((2, 1024), device=xm.xla_device()) + input = torch.rand((2, 1024), device="xla") output = model(input) loss = torch.sum(output) diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index bf4573713ec3..a3069a6637ec 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -63,7 +63,7 @@ def test_xla_backend_exists(self): self.assertIsNotNone(pg_xla_creator) def test_allreduce(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' dist.all_reduce(tensor) @@ -72,7 +72,7 @@ def test_allreduce(self): @patch_world(rank=3, size=6) def test_allreduce_with_mesh(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_options = {'xla_pg_options': {'spmd': True}} @@ -89,7 +89,7 @@ def test_allreduce_with_mesh(self): @patch_world(rank=3, size=8) def test_allgather(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)] all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\(' @@ -99,7 +99,7 @@ def test_allgather(self): @patch_world(rank=3, size=8) def test_all_scalar_allgather(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.zeros((), device=device) + 1 + 2 * dist.get_rank() output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)] all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\(' @@ -109,7 +109,7 @@ def test_all_scalar_allgather(self): @patch_world(rank=3, size=8) def test_allgather_coalesced(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() pg_xla = get_process_group_xla(rank=3, size=8) @@ -127,7 +127,7 @@ def test_allgather_coalesced(self): hlo_matches(hlo, all_gather_pattern) def test_broadcast(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' dist.broadcast(tensor, 0) @@ -136,7 +136,7 @@ def test_broadcast(self): # Needed for ZeRO stage 1 def test_reduce_scatter(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() input_list = [tensor] output = torch.zeros_like(tensor) @@ -148,7 +148,7 @@ def test_reduce_scatter(self): @skipIf(xr.device_type() == 'CPU', "UNIMPLEMENTED: ReduceScatter is not implemented on CPU.") def test_reduce_scatter_coalesced(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() input_tensors_list = [[tensor, tensor], [tensor2, tensor2]] @@ -168,7 +168,7 @@ def test_reduce_scatter_coalesced(self): @patch_world(0, 6) def test_send(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() input_list = [tensor] @@ -185,11 +185,11 @@ def test_send(self): hlo_matches(hlo, senddone_pattern) # Don't try to run Send on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) @patch_world(0, 6) def test_recv(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() with mock.patch.object( @@ -205,7 +205,7 @@ def test_recv(self): hlo_matches(hlo, recvdone_pattern) # Don't try to run Recv on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) @patch_world(rank=0, size=12) def test_new_group_no_ranks(self): @@ -365,7 +365,7 @@ def test_barrier(self): 'monitored_barrier', ) def test_unimplemented_op(self, op): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_xla = dist.group.WORLD self.assertIsInstance(pg_xla, diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index bec580c3831e..efb34a2cc3af 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -250,7 +250,7 @@ def train_imagenet(): torch.manual_seed(42) - device = xm.xla_device() + device = torch_xla.device() model = get_model_property('model_fn')().to(device) # Initialization is nondeterministic with multiple threads in PjRt. diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index 0ab5e1fd8007..290857281fd7 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -194,7 +194,7 @@ def train_imagenet(): torch.manual_seed(42) - device = xm.xla_device() + device = torch_xla.device() device_hw = xm.xla_device_hw(device) model = get_model_property('model_fn')().to(device) writer = None @@ -229,7 +229,7 @@ def train_loop_fn(loader, epoch): for step, (data, target) in enumerate(loader): optimizer.zero_grad() if FLAGS.amp: - with autocast(xm.xla_device()): + with autocast(torch_xla.device()): output = model(data) loss = loss_fn(output, target) if scaler: diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py index 8c9be15ac2e2..1d939d8385b3 100644 --- a/test/test_train_mp_imagenet_fsdp.py +++ b/test/test_train_mp_imagenet_fsdp.py @@ -241,7 +241,7 @@ def train_imagenet(): torch.manual_seed(42) - device = xm.xla_device() + device = torch_xla.device() model = get_model_property('model_fn')() # Automatic wrapping sub-modules with inner FSDP auto_wrap_policy = None diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 9e470719f27b..0a5e46fdcd1f 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -130,7 +130,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = xm.xla_device() + device = torch_xla.device() model = MNIST().to(device) # Initialization is nondeterministic with multiple threads in PjRt. diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 3fa8770f1a89..0bd393b21f2e 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -130,7 +130,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = xm.xla_device() + device = torch_xla.device() device_hw = xm.xla_device_hw(device) model = MNIST().to(device) diff --git a/test/test_train_mp_mnist_fsdp_with_ckpt.py b/test/test_train_mp_mnist_fsdp_with_ckpt.py index 169bfe264a3c..833612a2be49 100644 --- a/test/test_train_mp_mnist_fsdp_with_ckpt.py +++ b/test/test_train_mp_mnist_fsdp_with_ckpt.py @@ -164,7 +164,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = xm.xla_device() + device = torch_xla.device() model = MNIST() # Automatic wrapping sub-modules with inner FSDP auto_wrap_policy = None diff --git a/test/test_train_mp_mnist_zero1.py b/test/test_train_mp_mnist_zero1.py index 77e284c98d76..523bf5fc0a19 100644 --- a/test/test_train_mp_mnist_zero1.py +++ b/test/test_train_mp_mnist_zero1.py @@ -114,7 +114,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = xm.xla_device() + device = torch_xla.device() model = MNIST().to(device) writer = None diff --git a/test/test_user_computation_debug_cache.py b/test/test_user_computation_debug_cache.py index 88467de3c51b..f83f856c2cfd 100644 --- a/test/test_user_computation_debug_cache.py +++ b/test/test_user_computation_debug_cache.py @@ -40,7 +40,7 @@ def input_scope_0(tensor): def input_scope_1(tensor): return [torch.sin(tensor), torch.cos(tensor)] - device = xm.xla_device() + device = torch_xla.device() init_tensor = torch.tensor(10).to(device) def create_user_computation(fn): diff --git a/test/test_utils.py b/test/test_utils.py index f238f4c82540..6a913f932e4d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -384,7 +384,7 @@ def compareResults(self, results, xla_results, rel_err=1e-2, abs_err=1e-5): def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5): if device is None: - device = xm.xla_device() + device = torch_xla.device() tensors = xu.as_list(tensors) xla_tensors = [ x.to(device).detach().requires_grad_(x.requires_grad) for x in tensors diff --git a/test/test_while_loop.py b/test/test_while_loop.py index e8ea617b0f96..4dc0a17a96ea 100644 --- a/test/test_while_loop.py +++ b/test/test_while_loop.py @@ -26,7 +26,7 @@ def _fake_while_loop(cond_fn, body_fn, operands): class WhileLoopTest(unittest.TestCase): def test_while_loop_addition(self): - device = xm.xla_device() + device = torch_xla.device() def cond_fn(iteri, x): return iteri > 0 @@ -41,7 +41,7 @@ def body_fn(iteri, x): self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) def test_while_loop_addition_nested(self): - device = xm.xla_device() + device = torch_xla.device() def cond_fn(iteri, x): return iteri > 0 @@ -56,7 +56,7 @@ def body_fn(iteri, x): self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) def test_while_loop_simple_linear_inside_loop(self): - device = xm.xla_device() + device = torch_xla.device() torch.set_grad_enabled(False) class SimpleLinear(torch.nn.Module): @@ -94,7 +94,7 @@ def forward_without_while_loop_op(self, iteri, x): # ====== fori_loop ====== @unittest.skip("Fori_loop is not supported now due to unstable result.") def test_fori_loop_addition(self): - device = xm.xla_device() + device = torch_xla.device() lower = torch.tensor(0, device=device) upper = torch.tensor(50, device=device) diff --git a/test/test_xla_graph_execution.py b/test/test_xla_graph_execution.py index d8aa33f39aee..53e0ff46e36c 100644 --- a/test/test_xla_graph_execution.py +++ b/test/test_xla_graph_execution.py @@ -21,7 +21,7 @@ class TestXlaGraphExecution(test_utils.XlaTestCase): def test_graph_execution_allowed(self): torch_xla._XLAC._set_allow_execution(True) - x = torch.ones(2, device=xm.xla_device()) + x = torch.ones(2, device="xla") self.assertEqual(x[0], 1.0) # This should trigger the checking del x @@ -30,7 +30,7 @@ def test_graph_execution_disallowed_with_error(self): # Trigger runtime error for unexpected graph execution torch_xla._XLAC._set_allow_execution( False) # this flag disallows graph execution - x = torch.ones(2, device=xm.xla_device()) + x = torch.ones(2, device="xla") with self.assertRaises(RuntimeError) as e: self.assertEqual(x[0], 1.0) # This should trigger the checking self.assertIn( diff --git a/test/test_zero1.py b/test/test_zero1.py index e3dc5738f846..8bb2fbc3d822 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -34,7 +34,7 @@ class XlaZeRO1Test(test_utils.XlaTestCase): @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") def test_zero1(self): - device = xm.xla_device() + device = torch_xla.device() model = nn.Linear(32, 32) x = torch.ones((32, 32)) @@ -89,7 +89,7 @@ def test_zero1(self): torch_xla.sync() def test_zero1_load(self): - device = xm.xla_device() + device = torch_xla.device() model = nn.Linear(32, 32) x = torch.ones((32, 32)) @@ -153,7 +153,7 @@ def test_zero1_load(self): def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/torch_distributed/test_ddp.py b/test/torch_distributed/test_ddp.py index d4e3fc77c7f2..1d91f520d5aa 100644 --- a/test/torch_distributed/test_ddp.py +++ b/test/torch_distributed/test_ddp.py @@ -24,7 +24,7 @@ def _ddp_correctness(rank, gradient_as_bucket_view: bool = False): # We cannot run this guard before XMP, # see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing. - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) not in ('TPU', 'CUDA'): print( 'Default device {} is not a TPU device'.format(device), diff --git a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py index 84b950becde6..7c30b211ad49 100644 --- a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py index b354ec3d57a0..2fd71d2ed84e 100644 --- a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() dist.init_process_group('xla', init_method='xla://') diff --git a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py index 3d5736b0ec43..c462f7552800 100644 --- a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_fsdp_meta.py b/test/torch_distributed/test_torch_distributed_fsdp_meta.py index 182a7818ecbb..08db424608d9 100644 --- a/test/torch_distributed/test_torch_distributed_fsdp_meta.py +++ b/test/torch_distributed/test_torch_distributed_fsdp_meta.py @@ -60,7 +60,7 @@ def _init_with_reset_params(module): """ is_meta = any(t.is_meta for t in module.parameters()) if is_meta: - module.to_empty(device=xm.xla_device()) + module.to_empty(device="xla") with torch.no_grad(): module.reset_parameters() @@ -87,7 +87,7 @@ def _compare_fsdp(self, fsdp1, fsdp2): def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): # Create model on meta device and wrap with FSDP. model = meta_module_fn() - inp = torch.randn(10, 2, device=xm.xla_device()) + inp = torch.randn(10, 2, device="xla") fsdp_meta = XlaFullyShardedDataParallel( model, @@ -99,7 +99,7 @@ def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): meta_opt.step() torch_xla.sync() - regular = MyModel(device=xm.xla_device()) + regular = MyModel(device="xla") fsdp_regular = XlaFullyShardedDataParallel( regular, auto_wrap_policy=always_wrap) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) @@ -127,7 +127,7 @@ def meta_module_fn(): def test_simple_model_with_torchdistX_init_fn(self): def meta_module_fn(): - return deferred_init.deferred_init(MyModel, device=xm.xla_device()) + return deferred_init.deferred_init(MyModel, device="xla") self._test_simple_model_with_meta_device( meta_module_fn, init_fn=_init_with_torchdistX) @@ -135,13 +135,13 @@ def meta_module_fn(): def test_simple_model_with_default_torchdistX(self): def meta_module_fn(): - return deferred_init.deferred_init(MyModel, device=xm.xla_device()) + return deferred_init.deferred_init(MyModel, device="xla") self._test_simple_model_with_meta_device(meta_module_fn) def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() # This test fails on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) if xm.xla_device_hw(device) in ('TPU', 'NEURON'): dist.init_process_group('xla', init_method='xla://') diff --git a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py index 8ca45141350e..9089f9d799ff 100644 --- a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py b/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py index 36e6420dce10..006d3fd33a95 100644 --- a/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/utils/train_spmd_linear_model.py b/test/utils/train_spmd_linear_model.py index 53ca0c6cc6dd..e2bcb6124f87 100644 --- a/test/utils/train_spmd_linear_model.py +++ b/test/utils/train_spmd_linear_model.py @@ -69,7 +69,7 @@ def forward(self, x): def train(): - device = xm.xla_device() + device = torch_xla.device() torch.manual_seed(42) model = SimpleLinear().to(device) print('===> Preparing data..') @@ -148,5 +148,5 @@ def train_and_evaluate(): xr.use_spmd(auto=FLAGS.auto_spmd) print('Start training loop...') losses, m = train() - t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device()) + t = torch.randn(10, FLAGS.input_dim).to(torch_xla.device()) return [loss.cpu() for loss in losses], m(t).cpu() diff --git a/test/utils/train_spmd_linear_model_grad_acc.py b/test/utils/train_spmd_linear_model_grad_acc.py index b3c107770ae8..294309d62ed6 100644 --- a/test/utils/train_spmd_linear_model_grad_acc.py +++ b/test/utils/train_spmd_linear_model_grad_acc.py @@ -77,7 +77,7 @@ def forward(self, x): def train(): - device = xm.xla_device() + device = torch_xla.device() num_devices = xr.global_runtime_device_count() print(f'num_devices: {num_devices}') # Define a mesh with all devices along one axis @@ -182,6 +182,6 @@ def train_and_evaluate_grad_acc(): xr.use_spmd(auto=FLAGS.auto_spmd) print('Start training loop...') losses, m = train() - t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device()) + t = torch.randn(10, FLAGS.input_dim).to(torch_xla.device()) m(t).cpu() return [loss.cpu() for loss in losses] diff --git a/torch_xla/_dynamo/dynamo_bridge.py b/torch_xla/_dynamo/dynamo_bridge.py index ac7d9d906ff9..ce11cc07b2cf 100644 --- a/torch_xla/_dynamo/dynamo_bridge.py +++ b/torch_xla/_dynamo/dynamo_bridge.py @@ -495,7 +495,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule, # 2. All of the pending IRs are result of our warm up cache tracing and they # should be removed to avoid extra computation executed and in place updates op # mistakenlly update the input tensors. - torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) vars_to_return = (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, @@ -564,7 +564,7 @@ def optimized_mod(*args: tuple): is_cuda_args = original_device.type == "cuda" if is_cuda_args: - args = _maybe_move_tensors_to_device(args, xm.xla_device()) + args = _maybe_move_tensors_to_device(args, torch_xla.device()) if not config.skip_input_data_check: # `torch_xla.sync()` needs to be blocking since we want to access args's @@ -761,7 +761,7 @@ def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args, # UnsupportedNodesCollector might trigger in place ops, need to clear them here. _clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args) - torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): @@ -805,7 +805,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args): if _args_on_cuda(xla_args): - xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device())) + xla_args = tuple( + _maybe_move_tensors_to_device(xla_args, torch_xla.device())) # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its # value reference before actually computing it. diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 6c1c4c26a392..25a0ee36c36e 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -104,7 +104,7 @@ def initialize_singleprocess(): plugins.default().configure_single_process() elif runtime.device_type() == 'TPU': tpu.configure_one_chip_topology() - xm.set_replication(xm.xla_device(), []) + xm.set_replication(torch_xla.device(), []) def initialize_multiprocess(local_rank: int, local_world_size: int): @@ -119,7 +119,7 @@ def initialize_multiprocess(local_rank: int, local_world_size: int): neuron.initialize_env(local_rank, local_world_size) devices = xm.get_xla_supported_devices() - xm.set_replication(xm.xla_device(), devices) + xm.set_replication(torch_xla.device(), devices) def run_multiprocess(fn: Callable[..., R], diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index f7118dd7b3f3..2f17875be454 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -312,7 +312,7 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str: if xr.is_spmd(): return _spmd_find_master_ip(worker_ips[current_worker_id]) - t = torch.tensor([current_worker_id], device=xm.xla_device()) + t = torch.tensor([current_worker_id], device="xla") xm.collective_broadcast([t]) torch_xla.sync() diff --git a/torch_xla/amp/syncfree/adam.py b/torch_xla/amp/syncfree/adam.py index 4201933ca590..cc37b7013f57 100644 --- a/torch_xla/amp/syncfree/adam.py +++ b/torch_xla/amp/syncfree/adam.py @@ -94,7 +94,7 @@ def step(self, closure=None, found_inf: Tensor = None): p, memory_format=torch.preserve_format) else: state['max_exp_avg_sq'] = torch.empty( - 0, dtype=torch.float, device=xm.xla_device()) + 0, dtype=torch.float, device="xla") exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) diff --git a/torch_xla/amp/syncfree/adamw.py b/torch_xla/amp/syncfree/adamw.py index 83e11d46fad9..237345f52780 100644 --- a/torch_xla/amp/syncfree/adamw.py +++ b/torch_xla/amp/syncfree/adamw.py @@ -92,7 +92,7 @@ def step(self, closure=None, found_inf: Tensor = None): p, memory_format=torch.preserve_format) else: state['max_exp_avg_sq'] = torch.empty( - 0, dtype=torch.float, device=xm.xla_device()) + 0, dtype=torch.float, device="xla") exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 536b9c4115b6..6dcc5afe7bc7 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -135,6 +135,7 @@ def master_print(*args: Any, print(*args, file=fd, flush=flush) +@deprecated("Use torch_xla.device instead") def xla_device(n: Optional[int] = None, devkind: Optional[str] = None) -> torch.device: """Returns a given instance of an XLA device. @@ -142,21 +143,15 @@ def xla_device(n: Optional[int] = None, Args: n (int, optional): The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise - the first device of `devkind` will be returned. + the first device (default 0) will be returned. devkind (string..., optional): If specified, device type such as `TPU`, `CUDA`, `CPU`, or custom PJRT device. Deprecated. Returns: - A `torch.device` with the requested instance. + A `torch.device` with the requested instance of an XLA device. """ - # When SPMD is enabled, we always return `xla:0` to the user, and - # under the hood we use virtual device logic for every xla tensor - if xu.check_env_flag('XLA_USE_SPMD'): - device = 'xla:0' - torch_xla._XLAC._xla_set_default_device(device) - return torch.device(device) - - return runtime.xla_device(n, devkind) + del devkind + return torch_xla.device(n) def _xla_real_device(device: torch.device) -> Any: diff --git a/torch_xla/core/xla_op_registry.py b/torch_xla/core/xla_op_registry.py index aba1c7076c39..62943f4c70c5 100644 --- a/torch_xla/core/xla_op_registry.py +++ b/torch_xla/core/xla_op_registry.py @@ -68,7 +68,7 @@ def slice_and_add(a, b, dimno=0): SLICE_AND_ADD = xor.register('slice_and_add', slice_and_add) def user_computation_test(): - device = xm.xla_device() + device = torch_xla.device() x = torch.randn(2, 2).to(device) y = torch.randn(2, 2).to(device) z = SLICE_AND_ADD(x, y, dimno=0) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index edf96b4538a3..deff0c9b0afc 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1232,7 +1232,7 @@ void BuildLoweringContextSubmodule(py::module* m) { * import torch_xla * import torch_xla.core.xla_model as xm * - * device = xm.xla_device() + * device = torch_xla.device() * example = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device) * * def network(x): diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index af2b1246baf3..16b629728649 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -139,7 +139,7 @@ class XlaFullyShardedDataParallel(nn.Module): module (nn.Module): module to be wrapped with FSDP. If the input module's parameters and buffers are not already on XLA device, they will be cast to - ``xm.xla_device()`` (after sharding) during FSDP initialization. + ``torch_xla.device()`` (after sharding) during FSDP initialization. reshard_after_forward (bool, Optional): if ``True``, reshard parameters after the forward pass. This saves memory but slows training. This is only relevant when resharding @@ -527,7 +527,7 @@ def __init__( List[Parameter], self._fsdp_wrapped_module.flat_params) + non_flatten_params - self.xla_device = xm.xla_device() + self.xla_device = torch_xla.device() # Shard module parameters in place self._shard_parameters_(params_to_shard) # Cast the module buffers to the specified buffer_dtype @@ -1014,7 +1014,7 @@ def _dummy_forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: A dummy forward pass with minimal computation that sums all inputs and full parameters, e.g. to debug parameter memory consumption. """ - outputs = torch.zeros(1, device=xm.xla_device()) + outputs = torch.zeros(1, device="xla") for t in chain(args, kwargs.values(), self.full_params): if isinstance(t, torch.Tensor) and t.dtype == torch.float32: outputs = outputs + t.mean() @@ -1646,7 +1646,7 @@ def _print_r0(self, msg: str, restart: bool = False) -> None: if restart: self._tstart = time.time() if self.rank == 0: - memory_info = xm.get_memory_info(xm.xla_device()) + memory_info = xm.get_memory_info(torch_xla.device()) gb_free = memory_info["kb_free"] / 1024 / 1024 gb_total = memory_info["kb_total"] / 1024 / 1024 logging.info( diff --git a/torch_xla/distributed/spmd/api.py b/torch_xla/distributed/spmd/api.py index 3c6dcff14e05..77ff9e9ac6ee 100644 --- a/torch_xla/distributed/spmd/api.py +++ b/torch_xla/distributed/spmd/api.py @@ -215,10 +215,10 @@ def xla_distribute_module( if partition_fn: if getattr(partition_fn, '__name__', 'unknown') == "auto_policy": # TODO(yeounoh) allow pre-loading to xla device in the future. - assert next(module.parameters()).device != xm.xla_device(), \ + assert next(module.parameters()).device != torch_xla.device(), \ f"Currently requires module to be on cpu, before xla_distribute_module." xr.use_spmd(auto=True) - module = module.to(xm.xla_device()) + module = module.to(torch_xla.device()) else: # apply partition_fun to submodules for name, submod in module.named_modules(): diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index cb61158df903..b7089ffc498e 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -334,7 +334,7 @@ def _get_physical_tpu_mesh(self, devices: np.ndarray) -> np.ndarray: A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On v2 and v3, global_z is instead cores_per_chip (i.e., 2). """ - assert xm.xla_device_hw(xm.xla_device()) == 'TPU' + assert xm.xla_device_hw(torch_xla.device()) == 'TPU' # coords is a 3-dims tuple representing the device in physical mesh device_coords = [self.device_attributes[d]['coords'] for d in devices] dims = tuple(d + 1 for d in max(device_coords)) @@ -595,9 +595,9 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, >>> num_devices = xr.global_runtime_device_count() >>> device_ids = np.array(range(num_devices)) >>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) - >>> input = torch.randn(8, 32).to(xm.xla_device()) + >>> input = torch.randn(8, 32).to(torch_xla.device()) >>> xs.mark_sharding(input, mesh, (0, None)) # 4-way data parallel - >>> linear = nn.Linear(32, 10).to(xm.xla_device()) + >>> linear = nn.Linear(32, 10).to(torch_xla.device()) >>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel """ # We only allow fully specified `partition_spec` to be applicable, as opposed @@ -793,7 +793,7 @@ def can_apply(self, t: torch.Tensor) -> bool: def apply(self, t: torch.Tensor): # TODO(yeounoh) use virtual device interface when available. - assert (t.device == xm.xla_device()) + assert (t.device == torch_xla.device()) mark_sharding(t, self.mesh, self.partition_spec) diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index d699abaebafb..e3b349a4b7fb 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -56,7 +56,7 @@ class MpModelWrapper(object): WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork()) def _mp_fn(index, ...): - device = xm.xla_device() + device = torch_xla.device() model = WRAPPED_MODEL.to(device) ... diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index ca54521c8bc0..ea4c8d54c1a2 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -16,7 +16,7 @@ def fori_loop(lower, upper, body_fun, *input_value): - device = xm.xla_device() + device = torch_xla.device() if (upper < lower): print("ERROR: upper should be a larger number than lower") iteri = upper - lower diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 894ed1baa92d..df7b7fc05515 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -120,7 +120,7 @@ def scan(fn, init, xs): Example: >>> # Example of using `scan` to implement `torch.cumsum`. - >>> import torch_xla.runtime + >>> import torch_xla.core.xla_model as xm >>> import torch >>> from torch_xla.experimental.scan import scan >>> @@ -129,7 +129,7 @@ def scan(fn, init, xs): >>> y = new_carry >>> return new_carry, y >>> - >>> with torch_xla.runtime.xla_device(): + >>> with torch_xla.device(): >>> init = torch.tensor([0.0, 0.0], requires_grad=True) >>> xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], >>> requires_grad=True) @@ -727,7 +727,7 @@ def defeat_device_data(v: torch.Tensor) -> torch.Tensor: seed_tensor = hoisted_vars[seed_parameter_id] assert seed_tensor.dtype == torch.int64 hoisted_vars[seed_parameter_id] = torch.randint( - 0, 2**62, (num_iters,), dtype=torch.int64, device=torch_xla.device()) + 0, 2**62, (num_iters,), dtype=torch.int64, device="xla") # Add hoisted variables as While computation params as well, # including the potentially updated seed tensor. diff --git a/torch_xla/experimental/scan_layers.py b/torch_xla/experimental/scan_layers.py index 140a95312c03..0be37b363909 100644 --- a/torch_xla/experimental/scan_layers.py +++ b/torch_xla/experimental/scan_layers.py @@ -47,11 +47,11 @@ def scan_layers(layers: Iterable[torch.nn.Module], Example: - >>> import torch_xla.runtime + >>> import torch_xla.core.xla_model as xm >>> import torch >>> import torch.nn as nn >>> from torch_xla.experimental.scan_layers import scan_layers - >>> with torch_xla.runtime.xla_device(): + >>> with torch_xla.device(): >>> layers = [nn.Linear(16, 16) for i in range(10)] >>> input = torch.randn(16) >>> output = scan_layers(layers, input) diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 0b198785d76b..e0e7f5ba6027 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -109,7 +109,7 @@ def __init__( # Let's move the module to xla device in case it's not moved # by the caller already. - self._orig_module = module.to(xm.xla_device()) + self._orig_module = module.to(torch_xla.device()) self._mesh = mesh # Only handle params which are not already sharded. This enables diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index b1285c268e82..014994bc3c79 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -97,36 +97,12 @@ def is_bf16_supported(): """Returns whether torch.bfloat16 is supported on this environment. """ try: - torch.tensor([1.], dtype=torch.bfloat16, device=xm.xla_device()) + torch.tensor([1.], dtype=torch.bfloat16, device="xla") return True except Exception as e: return False -def xla_device(n: Optional[int] = None, - devkind: Optional[str] = None) -> torch.device: - """Returns an XLA device. - - Args: - n: Index of XLA device within visibible devices. If not set, use local - ordinal (default 0) to select an addressable device. - devkind: Type of device to return. Should match `device_type()`. - - Returns: - A `torch.device` representing an XLA device. - """ - if n is None: - return torch.device(torch_xla._XLAC._xla_get_default_device()) - - devices = xm.get_xla_supported_devices(devkind=devkind) - if n > len(devices): - raise IndexError('Device index {} out of range in {}'.format(n, devices)) - - device = devices[n] - torch_xla._XLAC._xla_set_default_device(device) - return torch.device(device) - - def local_process_count() -> int: """Returns the number of processes running on this host.""" return xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_COUNT, int, defval=1) @@ -180,7 +156,7 @@ def local_ordinal() -> int: Local ordinal is in range [0, local_device_count).""" local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0) devices_per_process = addressable_device_count() - return local_rank * devices_per_process + xla_device().index + return local_rank * devices_per_process + torch_xla.device().index def process_index() -> int: diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 6b3e25584b4c..b88a8131b2d8 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -341,7 +341,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, assert len(kwargs) == 0, "Export to stablehlo doesnt support kwargs yet." - device = xm.xla_device() + device = torch_xla.device() _flat_input_args = exported_model._graph_module_flat_inputs(args, {}) _flat_input_args = pytree.tree_map_only(torch.Tensor, @@ -352,7 +352,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, torch_xla.sync() xm.wait_device_ops() metrics.clear_counters() - device = xm.xla_device() + device = torch_xla.device() # Run the fx graph tracing using lazy tensor if options.inline_all_constant: diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index e4486a8cd0b5..3b2b327ff5c9 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -19,18 +19,33 @@ def device(index: int = None) -> torch.device: """Returns a given instance of an XLA device. - If SPMD enables, returns a virtual device that wraps all devices available + If SPMD is enabled, returns a virtual device that wraps all devices available to this process. Args: index: index of the XLA device to be returned. Corresponds to index in - `torch_xla.devices()`. + `torch_xla.devices()`. By default, get the first device. Returns: An XLA `torch.device`. """ - - return xm.xla_device(index) + # When SPMD is enabled, we always return `xla:0` to the user, and + # under the hood we use virtual device logic for every xla tensor + if xu.check_env_flag('XLA_USE_SPMD'): + device = 'xla:0' + torch_xla._XLAC._xla_set_default_device(device) + return torch.device(device) + + if n is None: + return torch.device(torch_xla._XLAC._xla_get_default_device()) + + devices = xm.get_xla_supported_devices() + if n > len(devices): + raise IndexError('Device index {} out of range in {}'.format(n, devices)) + + device = devices[n] + torch_xla._XLAC._xla_set_default_device(device) + return torch.device(device) def devices() -> List[torch.device]: