Skip to content

Commit 0342e14

Browse files
committed
Implement autocast for torchax
1 parent 55a7540 commit 0342e14

File tree

5 files changed

+264
-6
lines changed

5 files changed

+264
-6
lines changed

torchax/test/test_autocast.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import unittest
2+
import jax
3+
import jax.numpy as jnp
4+
import torchax
5+
from torchax import interop
6+
import torch
7+
8+
9+
10+
class AutocastTest(unittest.TestCase):
11+
12+
def setUp(self):
13+
self.env = torchax.default_env()
14+
15+
16+
def test_auto_cast_ir(self):
17+
with self.env:
18+
with torch.autocast('jax', dtype=torch.bfloat16):
19+
a = jax.ShapeDtypeStruct((2,2), jnp.float32)
20+
b = jax.ShapeDtypeStruct((2,2), jnp.float32)
21+
ir_text = jax.jit(interop.jax_view(torch.matmul)).lower(a, b).as_text()
22+
self.assertIn('tensor<2x2xbf16>', ir_text)
23+
24+
def test_auto_cast_matmul(self):
25+
with self.env:
26+
a = torch.randn(2, 2, device='jax')
27+
b = torch.randn(2, 2, device='jax')
28+
with torch.autocast('jax', dtype=torch.bfloat16):
29+
c = a @ b
30+
31+
self.assertEqual(c.dtype, torch.bfloat16)
32+
33+
with torch.autocast('cpu', dtype=torch.bfloat16):
34+
c_cpu = a.cpu() @ b.cpu()
35+
36+
self.assertTrue(torch.allclose(c.cpu(), c_cpu))
37+
38+
39+
40+
if __name__ == '__main__':
41+
unittest.main()
42+
43+

torchax/torchax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def disable_temporarily():
8989
import jax
9090
import torchax.device_module
9191

92-
torch._register_device_module('jax', torchax.device_module)
92+
torch._register_device_module('privateuseone', torchax.device_module)
9393

9494

9595
def enable_accuracy_mode():

torchax/torchax/device_module.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,8 @@ def is_available():
2424

2525
def current_device():
2626
return 0
27+
28+
29+
import torch
30+
def get_amp_supported_dtype():
31+
return [torch.float16, torch.bfloat16]
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import torch
2+
import torch._C
3+
from torch.utils import _pytree as pytree
4+
5+
def call_with_next_key(op, args, kwargs):
6+
return op(*args, **kwargs)
7+
8+
target_precision = torch.bfloat16
9+
10+
def lower_precision_fp(op):
11+
def inner(*args, **kwargs):
12+
target_precision = torch.get_autocast_dtype('privateuseone')
13+
autocast_keyset = torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastPrivateUse1)
14+
with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
15+
is_float_tensor = lambda a: isinstance(a, torch.Tensor) and a.is_floating_point()
16+
args, kwargs = pytree.tree_map_only(
17+
is_float_tensor,
18+
lambda x: x.to(target_precision),
19+
(args, kwargs))
20+
return op(*args, **kwargs)
21+
return inner
22+
23+
24+
lib = torch.library.Library('aten', 'FRAGMENT')
25+
my_lib = torch.library.Library('_', 'IMPL', 'AutocastPrivateUse1')
26+
my_lib.fallback(torch.library.fallthrough_kernel)
27+
28+
29+
for op in [torch.ops.aten.conv1d.default,
30+
torch.ops.aten.conv1d.padding,
31+
torch.ops.aten.conv2d.default,
32+
torch.ops.aten.conv2d.padding,
33+
torch.ops.aten.conv3d.default,
34+
torch.ops.aten.bmm.default,
35+
torch.ops.aten.mm.default,
36+
torch.ops.aten.baddbmm.default,
37+
torch.ops.aten.addmm.default,
38+
torch.ops.aten.addbmm.default,
39+
torch.ops.aten.linear.default,
40+
torch.ops.aten.matmul.default,
41+
torch.ops.aten.conv_tbc.default,
42+
torch.ops.aten.conv_transpose1d.default,
43+
torch.ops.aten.conv_transpose2d.input,
44+
torch.ops.aten.conv_transpose3d.input,
45+
torch.ops.aten.prelu.default,
46+
torch.ops.aten.relu.default,
47+
torch.ops.aten.max_pool2d.default,
48+
torch.ops.aten.einsum.default,
49+
]:
50+
lib.impl(op.name(), lower_precision_fp(op), "AutocastPrivateUse1", with_keyset=False)
51+
52+
# https://github.com/pytorch/xla/blob/20899c7258680a36cd3bec1c820e8a52c16a4bbf/torch_xla/csrc/autocast_mode.cpp#L29
53+
# enum class CastPolicy : uint8_t {
54+
# lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
55+
# // running the op. Currently, lower_precision_fp is
56+
# // fp16 for AutocastCUDA, and is defined by user
57+
# // (default bf16) for AutocastCPU or other device.
58+
# fp32, // Cast all inputs to at::kFloat before running the op.
59+
# fp32_set_opt_dtype, // Treats functions (like softmax) that
60+
# // 1. we'd like to run in fp32 and
61+
# // 2. have a std::optional<ScalarType> arg that controls
62+
# // the output type.
63+
# // fp32_set_opt_dtype wrappers' policy is: if the output
64+
# // type is already set, don't touch it, otherwise, set
65+
# // it to at::kFloat.
66+
# fp32_append_dtype, // Treats functions (like norm) that
67+
# // 1. we'd like to run in fp32 and
68+
# // 2. have some overloads that accept an output type and
69+
# // other overloads that don't.
70+
# // fp32_append_dtype wrappers wrap the overloads that don't
71+
# // have an output dtype.
72+
# // The wrapper policy is: append at::kFloat to the args,
73+
# // and redispatch to the type-aware overload.
74+
# promote, // Run in the widest dtype among several args.
75+
# };
76+
# TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
77+
# // lower_precision_fp cast policy
78+
# KERNEL_XLA(conv1d, lower_precision_fp)
79+
# KERNEL_XLA2(conv1d, padding, lower_precision_fp)
80+
# KERNEL_XLA(conv2d, lower_precision_fp)
81+
# KERNEL_XLA2(conv2d, padding, lower_precision_fp)
82+
# KERNEL_XLA(conv3d, lower_precision_fp)
83+
# KERNEL_XLA2(conv3d, padding, lower_precision_fp)
84+
# KERNEL_XLA(bmm, lower_precision_fp)
85+
# KERNEL_XLA(mm, lower_precision_fp)
86+
# KERNEL_XLA(baddbmm, lower_precision_fp)
87+
# KERNEL_XLA(addmm, lower_precision_fp)
88+
# KERNEL_XLA(addbmm, lower_precision_fp)
89+
# KERNEL_XLA(linear, lower_precision_fp)
90+
# KERNEL_XLA(matmul, lower_precision_fp)
91+
# KERNEL_XLA(conv_tbc, lower_precision_fp)
92+
# KERNEL_XLA(conv_transpose1d, lower_precision_fp)
93+
# KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp)
94+
# KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp)
95+
# KERNEL_XLA(prelu, lower_precision_fp)
96+
# KERNEL_XLA(relu, lower_precision_fp)
97+
# KERNEL_XLA(max_pool2d, lower_precision_fp)
98+
# KERNEL_XLA(einsum, lower_precision_fp)
99+
# // Disable `scaled_dot_product_attention` for now since it causes
100+
# // undefined symbol with official torch whl.
101+
# // KERNEL_XLA(scaled_dot_product_attention, lower_precision_fp)
102+
103+
# // fp32 cast policy
104+
# // Commented out ops are included in the AutoCastCPU Policy,
105+
# // but not lowered. Enable if op is lowered.
106+
# KERNEL_XLA(batch_norm, fp32)
107+
# KERNEL_XLA(_softmax, fp32)
108+
# KERNEL_XLA2(softmax, int, fp32)
109+
# KERNEL_XLA2(softmax, Dimname, fp32)
110+
# KERNEL_XLA2(log_softmax, int, fp32)
111+
# KERNEL_XLA2(log_softmax, Dimname, fp32)
112+
# KERNEL_XLA(binary_cross_entropy, fp32)
113+
# // KERNEL_XLA(grid_sampler, fp32)
114+
# // KERNEL_XLA(polar, fp32)
115+
# KERNEL_XLA2(pow, Tensor_Scalar, fp32)
116+
# KERNEL_XLA(prod, fp32)
117+
# KERNEL_XLA2(prod, dim_int, fp32)
118+
# KERNEL_XLA2(prod, dim_Dimname, fp32)
119+
# // KERNEL_XLA(quantile, fp32)
120+
# // KERNEL_XLA2(quantile, scalar, fp32)
121+
# // KERNEL_XLA(nanquantile, fp32)
122+
# // KERNEL_XLA2(nanquantile, scalar, fp32)
123+
# // KERNEL_XLA(stft, fp32)
124+
# // KERNEL_XLA2(stft, center, fp32)
125+
# KERNEL_XLA(cdist, fp32)
126+
# // KERNEL_XLA(grid_sampler_2d, fp32)
127+
# // KERNEL_XLA(grid_sampler_3d, fp32)
128+
# KERNEL_XLA(trace, fp32)
129+
# // KERNEL_XLA(view_as_complex, fp32)
130+
# KERNEL_XLA(cholesky, fp32)
131+
# KERNEL_XLA(cholesky_inverse, fp32)
132+
# KERNEL_XLA(cholesky_solve, fp32)
133+
# KERNEL_XLA(inverse, fp32)
134+
# // KERNEL_XLA(lu_solve, fp32)
135+
# // KERNEL_XLA(orgqr, fp32)
136+
# // KERNEL_XLA(ormqr, fp32)
137+
# // KERNEL_XLA(pinverse, fp32)
138+
# KERNEL_XLA(reflection_pad1d, fp32)
139+
# KERNEL_XLA(reflection_pad2d, fp32)
140+
# KERNEL_XLA(replication_pad1d, fp32)
141+
# KERNEL_XLA(replication_pad2d, fp32)
142+
# KERNEL_XLA(replication_pad3d, fp32)
143+
# KERNEL_XLA(mse_loss, fp32)
144+
# KERNEL_XLA(cosine_embedding_loss, fp32)
145+
# KERNEL_XLA(nll_loss, fp32)
146+
# KERNEL_XLA(nll_loss2d, fp32)
147+
# KERNEL_XLA(hinge_embedding_loss, fp32)
148+
# // KERNEL_XLA(poisson_nll_loss, fp32)
149+
# KERNEL_XLA(smooth_l1_loss, fp32)
150+
# KERNEL_XLA(cross_entropy_loss, fp32)
151+
# KERNEL_XLA(l1_loss, fp32)
152+
# // KERNEL_XLA(huber_loss, fp32)
153+
# KERNEL_XLA(margin_ranking_loss, fp32)
154+
# KERNEL_XLA(soft_margin_loss, fp32)
155+
# KERNEL_XLA(triplet_margin_loss, fp32)
156+
# KERNEL_XLA(multi_margin_loss, fp32)
157+
# KERNEL_XLA2(ctc_loss, IntList, fp32)
158+
# KERNEL_XLA2(ctc_loss, Tensor, fp32)
159+
# KERNEL_XLA(kl_div, fp32)
160+
# KERNEL_XLA(multilabel_margin_loss, fp32)
161+
# KERNEL_XLA(binary_cross_entropy_with_logits, fp32)
162+
# // KERNEL_XLA(fft_fft, fp32)
163+
# // KERNEL_XLA(fft_ifft, fp32)
164+
# // KERNEL_XLA(fft_fft2, fp32)
165+
# // KERNEL_XLA(fft_ifft2, fp32)
166+
# // KERNEL_XLA(fft_fftn, fp32)
167+
# // KERNEL_XLA(fft_ifftn, fp32)
168+
# // KERNEL_XLA(fft_rfft, fp32)
169+
# // KERNEL_XLA(fft_irfft, fp32)
170+
# // KERNEL_XLA(fft_rfft2, fp32)
171+
# // KERNEL_XLA(fft_irfft2, fp32)
172+
# // KERNEL_XLA(fft_rfftn, fp32)
173+
# // KERNEL_XLA(fft_irfftn, fp32)
174+
# // KERNEL_XLA(fft_hfft, fp32)
175+
# // KERNEL_XLA(fft_ihfft, fp32)
176+
# // KERNEL_XLA(linalg_cond, fp32)
177+
# // KERNEL_XLA2(linalg_cond, p_str, fp32)
178+
# // KERNEL_XLA(linalg_matrix_rank, fp32)
179+
# // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32)
180+
# // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32)
181+
# // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32)
182+
# // KERNEL_XLA(linalg_solve, fp32)
183+
# // KERNEL_XLA(linalg_cholesky, fp32)
184+
# // KERNEL_XLA(linalg_svdvals, fp32)
185+
# // KERNEL_XLA(linalg_eigvals, fp32)
186+
# // KERNEL_XLA(linalg_eigvalsh, fp32)
187+
# // KERNEL_XLA(linalg_inv, fp32)
188+
# // KERNEL_XLA(linalg_householder_product, fp32)
189+
# // KERNEL_XLA(linalg_tensorinv, fp32)
190+
# // KERNEL_XLA(linalg_tensorsolve, fp32)
191+
# // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32)
192+
# // KERNEL_XLA(geqrf, fp32)
193+
# // KERNEL_XLA(_lu_with_info, fp32)
194+
# KERNEL_XLA(qr, fp32)
195+
# KERNEL_XLA(svd, fp32)
196+
# KERNEL_XLA(triangular_solve, fp32)
197+
# KERNEL_XLA(multilabel_margin_loss_forward, fp32)
198+
# // KERNEL_XLA(linalg_qr, fp32)
199+
# // KERNEL_XLA(linalg_cholesky_ex, fp32)
200+
# KERNEL_XLA(linalg_svd, fp32)
201+
# // KERNEL_XLA(linalg_eig, fp32)
202+
# // KERNEL_XLA(linalg_eigh, fp32)
203+
# // KERNEL_XLA(linalg_lstsq, fp32)
204+
# KERNEL_XLA(linalg_inv_ex, fp32)
205+
206+
# // promote
207+
# KERNEL_XLA(stack, promote)
208+
# KERNEL_XLA(cat, promote)
209+
# KERNEL_XLA(index_copy, promote)
210+
# KERNEL_XLA2(index_copy, dimname, promote)

torchax/torchax/tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __new__(cls, elem, env):
5959
cls,
6060
shape,
6161
dtype=dtype,
62-
device="meta",
62+
device="privateuseone:0",
6363
requires_grad=False,
6464
)
6565

@@ -134,9 +134,9 @@ def dtype(self):
134134
def dim(self):
135135
return self.ndim
136136

137-
@property
138-
def device(self):
139-
return torch.device("jax:0")
137+
# @property
138+
# def device(self):
139+
# return torch.device("jax:0")
140140

141141
@property
142142
def jax_device(self):
@@ -347,7 +347,7 @@ def get_as_jax_device(self, device: Any):
347347
return None # fallback to torch
348348

349349
def load_ops(self):
350-
from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
350+
from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms, autocast_policy
351351

352352
for k, v in itertools.chain(ops_registry.all_aten_ops.items(),
353353
ops_registry.all_torch_functions.items()):

0 commit comments

Comments
 (0)