1
+ import torch
2
+ import torch ._C
3
+ from torch .utils import _pytree as pytree
4
+
5
+ def call_with_next_key (op , args , kwargs ):
6
+ return op (* args , ** kwargs )
7
+
8
+ target_precision = torch .bfloat16
9
+
10
+ def lower_precision_fp (op ):
11
+ def inner (* args , ** kwargs ):
12
+ target_precision = torch .get_autocast_dtype ('privateuseone' )
13
+ autocast_keyset = torch ._C .DispatchKeySet (torch ._C .DispatchKey .AutocastPrivateUse1 )
14
+ with torch ._C ._ExcludeDispatchKeyGuard (autocast_keyset ):
15
+ is_float_tensor = lambda a : isinstance (a , torch .Tensor ) and a .is_floating_point ()
16
+ args , kwargs = pytree .tree_map_only (
17
+ is_float_tensor ,
18
+ lambda x : x .to (target_precision ),
19
+ (args , kwargs ))
20
+ return op (* args , ** kwargs )
21
+ return inner
22
+
23
+
24
+ lib = torch .library .Library ('aten' , 'FRAGMENT' )
25
+ my_lib = torch .library .Library ('_' , 'IMPL' , 'AutocastPrivateUse1' )
26
+ my_lib .fallback (torch .library .fallthrough_kernel )
27
+
28
+
29
+ for op in [torch .ops .aten .conv1d .default ,
30
+ torch .ops .aten .conv1d .padding ,
31
+ torch .ops .aten .conv2d .default ,
32
+ torch .ops .aten .conv2d .padding ,
33
+ torch .ops .aten .conv3d .default ,
34
+ torch .ops .aten .bmm .default ,
35
+ torch .ops .aten .mm .default ,
36
+ torch .ops .aten .baddbmm .default ,
37
+ torch .ops .aten .addmm .default ,
38
+ torch .ops .aten .addbmm .default ,
39
+ torch .ops .aten .linear .default ,
40
+ torch .ops .aten .matmul .default ,
41
+ torch .ops .aten .conv_tbc .default ,
42
+ torch .ops .aten .conv_transpose1d .default ,
43
+ torch .ops .aten .conv_transpose2d .input ,
44
+ torch .ops .aten .conv_transpose3d .input ,
45
+ torch .ops .aten .prelu .default ,
46
+ torch .ops .aten .relu .default ,
47
+ torch .ops .aten .max_pool2d .default ,
48
+ torch .ops .aten .einsum .default ,
49
+ ]:
50
+ lib .impl (op .name (), lower_precision_fp (op ), "AutocastPrivateUse1" , with_keyset = False )
51
+
52
+ # https://github.com/pytorch/xla/blob/20899c7258680a36cd3bec1c820e8a52c16a4bbf/torch_xla/csrc/autocast_mode.cpp#L29
53
+ # enum class CastPolicy : uint8_t {
54
+ # lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
55
+ # // running the op. Currently, lower_precision_fp is
56
+ # // fp16 for AutocastCUDA, and is defined by user
57
+ # // (default bf16) for AutocastCPU or other device.
58
+ # fp32, // Cast all inputs to at::kFloat before running the op.
59
+ # fp32_set_opt_dtype, // Treats functions (like softmax) that
60
+ # // 1. we'd like to run in fp32 and
61
+ # // 2. have a std::optional<ScalarType> arg that controls
62
+ # // the output type.
63
+ # // fp32_set_opt_dtype wrappers' policy is: if the output
64
+ # // type is already set, don't touch it, otherwise, set
65
+ # // it to at::kFloat.
66
+ # fp32_append_dtype, // Treats functions (like norm) that
67
+ # // 1. we'd like to run in fp32 and
68
+ # // 2. have some overloads that accept an output type and
69
+ # // other overloads that don't.
70
+ # // fp32_append_dtype wrappers wrap the overloads that don't
71
+ # // have an output dtype.
72
+ # // The wrapper policy is: append at::kFloat to the args,
73
+ # // and redispatch to the type-aware overload.
74
+ # promote, // Run in the widest dtype among several args.
75
+ # };
76
+ # TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) {
77
+ # // lower_precision_fp cast policy
78
+ # KERNEL_XLA(conv1d, lower_precision_fp)
79
+ # KERNEL_XLA2(conv1d, padding, lower_precision_fp)
80
+ # KERNEL_XLA(conv2d, lower_precision_fp)
81
+ # KERNEL_XLA2(conv2d, padding, lower_precision_fp)
82
+ # KERNEL_XLA(conv3d, lower_precision_fp)
83
+ # KERNEL_XLA2(conv3d, padding, lower_precision_fp)
84
+ # KERNEL_XLA(bmm, lower_precision_fp)
85
+ # KERNEL_XLA(mm, lower_precision_fp)
86
+ # KERNEL_XLA(baddbmm, lower_precision_fp)
87
+ # KERNEL_XLA(addmm, lower_precision_fp)
88
+ # KERNEL_XLA(addbmm, lower_precision_fp)
89
+ # KERNEL_XLA(linear, lower_precision_fp)
90
+ # KERNEL_XLA(matmul, lower_precision_fp)
91
+ # KERNEL_XLA(conv_tbc, lower_precision_fp)
92
+ # KERNEL_XLA(conv_transpose1d, lower_precision_fp)
93
+ # KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp)
94
+ # KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp)
95
+ # KERNEL_XLA(prelu, lower_precision_fp)
96
+ # KERNEL_XLA(relu, lower_precision_fp)
97
+ # KERNEL_XLA(max_pool2d, lower_precision_fp)
98
+ # KERNEL_XLA(einsum, lower_precision_fp)
99
+ # // Disable `scaled_dot_product_attention` for now since it causes
100
+ # // undefined symbol with official torch whl.
101
+ # // KERNEL_XLA(scaled_dot_product_attention, lower_precision_fp)
102
+
103
+ # // fp32 cast policy
104
+ # // Commented out ops are included in the AutoCastCPU Policy,
105
+ # // but not lowered. Enable if op is lowered.
106
+ # KERNEL_XLA(batch_norm, fp32)
107
+ # KERNEL_XLA(_softmax, fp32)
108
+ # KERNEL_XLA2(softmax, int, fp32)
109
+ # KERNEL_XLA2(softmax, Dimname, fp32)
110
+ # KERNEL_XLA2(log_softmax, int, fp32)
111
+ # KERNEL_XLA2(log_softmax, Dimname, fp32)
112
+ # KERNEL_XLA(binary_cross_entropy, fp32)
113
+ # // KERNEL_XLA(grid_sampler, fp32)
114
+ # // KERNEL_XLA(polar, fp32)
115
+ # KERNEL_XLA2(pow, Tensor_Scalar, fp32)
116
+ # KERNEL_XLA(prod, fp32)
117
+ # KERNEL_XLA2(prod, dim_int, fp32)
118
+ # KERNEL_XLA2(prod, dim_Dimname, fp32)
119
+ # // KERNEL_XLA(quantile, fp32)
120
+ # // KERNEL_XLA2(quantile, scalar, fp32)
121
+ # // KERNEL_XLA(nanquantile, fp32)
122
+ # // KERNEL_XLA2(nanquantile, scalar, fp32)
123
+ # // KERNEL_XLA(stft, fp32)
124
+ # // KERNEL_XLA2(stft, center, fp32)
125
+ # KERNEL_XLA(cdist, fp32)
126
+ # // KERNEL_XLA(grid_sampler_2d, fp32)
127
+ # // KERNEL_XLA(grid_sampler_3d, fp32)
128
+ # KERNEL_XLA(trace, fp32)
129
+ # // KERNEL_XLA(view_as_complex, fp32)
130
+ # KERNEL_XLA(cholesky, fp32)
131
+ # KERNEL_XLA(cholesky_inverse, fp32)
132
+ # KERNEL_XLA(cholesky_solve, fp32)
133
+ # KERNEL_XLA(inverse, fp32)
134
+ # // KERNEL_XLA(lu_solve, fp32)
135
+ # // KERNEL_XLA(orgqr, fp32)
136
+ # // KERNEL_XLA(ormqr, fp32)
137
+ # // KERNEL_XLA(pinverse, fp32)
138
+ # KERNEL_XLA(reflection_pad1d, fp32)
139
+ # KERNEL_XLA(reflection_pad2d, fp32)
140
+ # KERNEL_XLA(replication_pad1d, fp32)
141
+ # KERNEL_XLA(replication_pad2d, fp32)
142
+ # KERNEL_XLA(replication_pad3d, fp32)
143
+ # KERNEL_XLA(mse_loss, fp32)
144
+ # KERNEL_XLA(cosine_embedding_loss, fp32)
145
+ # KERNEL_XLA(nll_loss, fp32)
146
+ # KERNEL_XLA(nll_loss2d, fp32)
147
+ # KERNEL_XLA(hinge_embedding_loss, fp32)
148
+ # // KERNEL_XLA(poisson_nll_loss, fp32)
149
+ # KERNEL_XLA(smooth_l1_loss, fp32)
150
+ # KERNEL_XLA(cross_entropy_loss, fp32)
151
+ # KERNEL_XLA(l1_loss, fp32)
152
+ # // KERNEL_XLA(huber_loss, fp32)
153
+ # KERNEL_XLA(margin_ranking_loss, fp32)
154
+ # KERNEL_XLA(soft_margin_loss, fp32)
155
+ # KERNEL_XLA(triplet_margin_loss, fp32)
156
+ # KERNEL_XLA(multi_margin_loss, fp32)
157
+ # KERNEL_XLA2(ctc_loss, IntList, fp32)
158
+ # KERNEL_XLA2(ctc_loss, Tensor, fp32)
159
+ # KERNEL_XLA(kl_div, fp32)
160
+ # KERNEL_XLA(multilabel_margin_loss, fp32)
161
+ # KERNEL_XLA(binary_cross_entropy_with_logits, fp32)
162
+ # // KERNEL_XLA(fft_fft, fp32)
163
+ # // KERNEL_XLA(fft_ifft, fp32)
164
+ # // KERNEL_XLA(fft_fft2, fp32)
165
+ # // KERNEL_XLA(fft_ifft2, fp32)
166
+ # // KERNEL_XLA(fft_fftn, fp32)
167
+ # // KERNEL_XLA(fft_ifftn, fp32)
168
+ # // KERNEL_XLA(fft_rfft, fp32)
169
+ # // KERNEL_XLA(fft_irfft, fp32)
170
+ # // KERNEL_XLA(fft_rfft2, fp32)
171
+ # // KERNEL_XLA(fft_irfft2, fp32)
172
+ # // KERNEL_XLA(fft_rfftn, fp32)
173
+ # // KERNEL_XLA(fft_irfftn, fp32)
174
+ # // KERNEL_XLA(fft_hfft, fp32)
175
+ # // KERNEL_XLA(fft_ihfft, fp32)
176
+ # // KERNEL_XLA(linalg_cond, fp32)
177
+ # // KERNEL_XLA2(linalg_cond, p_str, fp32)
178
+ # // KERNEL_XLA(linalg_matrix_rank, fp32)
179
+ # // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32)
180
+ # // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32)
181
+ # // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32)
182
+ # // KERNEL_XLA(linalg_solve, fp32)
183
+ # // KERNEL_XLA(linalg_cholesky, fp32)
184
+ # // KERNEL_XLA(linalg_svdvals, fp32)
185
+ # // KERNEL_XLA(linalg_eigvals, fp32)
186
+ # // KERNEL_XLA(linalg_eigvalsh, fp32)
187
+ # // KERNEL_XLA(linalg_inv, fp32)
188
+ # // KERNEL_XLA(linalg_householder_product, fp32)
189
+ # // KERNEL_XLA(linalg_tensorinv, fp32)
190
+ # // KERNEL_XLA(linalg_tensorsolve, fp32)
191
+ # // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32)
192
+ # // KERNEL_XLA(geqrf, fp32)
193
+ # // KERNEL_XLA(_lu_with_info, fp32)
194
+ # KERNEL_XLA(qr, fp32)
195
+ # KERNEL_XLA(svd, fp32)
196
+ # KERNEL_XLA(triangular_solve, fp32)
197
+ # KERNEL_XLA(multilabel_margin_loss_forward, fp32)
198
+ # // KERNEL_XLA(linalg_qr, fp32)
199
+ # // KERNEL_XLA(linalg_cholesky_ex, fp32)
200
+ # KERNEL_XLA(linalg_svd, fp32)
201
+ # // KERNEL_XLA(linalg_eig, fp32)
202
+ # // KERNEL_XLA(linalg_eigh, fp32)
203
+ # // KERNEL_XLA(linalg_lstsq, fp32)
204
+ # KERNEL_XLA(linalg_inv_ex, fp32)
205
+
206
+ # // promote
207
+ # KERNEL_XLA(stack, promote)
208
+ # KERNEL_XLA(cat, promote)
209
+ # KERNEL_XLA(index_copy, promote)
210
+ # KERNEL_XLA2(index_copy, dimname, promote)
0 commit comments