Skip to content

Commit fa99f20

Browse files
committed
[torchax] Added support for bicubic and billinear resampling
1 parent edc1a88 commit fa99f20

File tree

4 files changed

+242
-11
lines changed

4 files changed

+242
-11
lines changed

torchax/test/test_image.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import unittest
2+
from typing import Tuple
3+
import itertools
4+
from functools import partial
5+
import jax
6+
import torch
7+
8+
import torchax
9+
import torchax.interop
10+
11+
def to_xla_tensor(tensorstree):
12+
return torchax.interop.torch_view(torchax.tensor.t2j(tensorstree))
13+
14+
def to_torch_tensor(tensorstree):
15+
return torchax.tensor.j2t(torchax.interop.jax_view(tensorstree))
16+
17+
18+
19+
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
20+
def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool, antialias: bool, method: str):
21+
tensor = torchax.interop.torch_view(tensor)
22+
tensor = torch.nn.functional.interpolate(tensor, size=output_size, mode=method, align_corners=align_corners, antialias=antialias)
23+
return torchax.interop.jax_view(tensor)
24+
25+
26+
def test_upsampling(align_corners: bool, antialias: bool, method: str):
27+
28+
if method == 'bilinear':
29+
if align_corners:
30+
return # bilinear upsampling does not support align_corners
31+
32+
input_tensor = torch.rand((1, 1, 256, 512), dtype=torch.float32)
33+
output_size = (128, 64)
34+
35+
upsampled_tensor = torch.nn.functional.interpolate(input_tensor, size=output_size, mode=method, align_corners=align_corners, antialias=antialias)
36+
37+
with torchax.default_env():
38+
input_tensor_xla = to_xla_tensor(input_tensor)
39+
input_tensor_xla = torchax.interop.jax_view(input_tensor_xla)
40+
upsampled_tensor_xla = upsample_jit(input_tensor_xla, output_size, align_corners, antialias=antialias, method=method)
41+
42+
upsampled_tensor_xla = to_torch_tensor(upsampled_tensor_xla)
43+
abs_err = torch.abs(upsampled_tensor - upsampled_tensor_xla)
44+
45+
assert torch.allclose(upsampled_tensor, upsampled_tensor_xla, atol=1e-4, rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}"
46+
47+
class TestResampling(unittest.TestCase):
48+
def test_resampling_combinations(self):
49+
methods = [
50+
'bicubic',
51+
'bilinear',
52+
]
53+
antialias_options = [
54+
True,
55+
False,
56+
]
57+
58+
aligncorners_options = [
59+
False,
60+
True,
61+
]
62+
63+
for method, antialias, align_corners in itertools.product(methods, antialias_options, aligncorners_options):
64+
with self.subTest(method=method, antialias=antialias, align_corners=align_corners):
65+
test_upsampling(align_corners=align_corners, antialias=antialias, method=method)
66+
67+
68+
69+
if __name__ == '__main__':
70+
unittest.main()

torchax/torchax/ops/jaten.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5181,17 +5181,16 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
51815181
return output
51825182

51835183

5184-
@op(torch.ops.aten._upsample_bilinear2d_aa)
5185-
def _aten_upsample_bilinear2d_aa(input,
5186-
output_size,
5187-
align_corners,
5188-
scale_factors=None,
5189-
scales_h=None,
5190-
scales_w=None):
5184+
def _aten_upsample(input,
5185+
output_size,
5186+
align_corners,
5187+
antialias,
5188+
method,
5189+
scale_factors=None,
5190+
scales_h=None,
5191+
scales_w=None):
51915192
# input: is of type jaxlib.xla_extension.ArrayImpl
51925193
image = input
5193-
method = "bilinear"
5194-
antialias = True # ignored for upsampling
51955194

51965195
# https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
51975196
# Resize does not distinguish batch, channel size.
@@ -5253,6 +5252,24 @@ def _aten_upsample_bilinear2d_aa(input,
52535252
)
52545253

52555254

5255+
@op(torch.ops.aten._upsample_bilinear2d_aa)
5256+
def _aten_upsample_billinear_aa(input,
5257+
output_size,
5258+
align_corners,
5259+
scale_factors=None,
5260+
scales_h=None,
5261+
scales_w=None):
5262+
return _aten_upsample(
5263+
input,
5264+
output_size,
5265+
align_corners,
5266+
True, # antialias
5267+
"bilinear", # method
5268+
scale_factors,
5269+
scales_h,
5270+
scales_w
5271+
)
5272+
52565273
@op(torch.ops.aten.polar)
52575274
def _aten_polar(abs, angle, *, out=None):
52585275
return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))

torchax/torchax/ops/jimage.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
def cubic_kernel(x, a=-0.75):
5+
"""Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
6+
absx = jnp.abs(x)
7+
x2 = absx * absx
8+
x3 = x2 * absx
9+
cond1 = (absx <= 1)
10+
cond2 = (absx > 1) & (absx < 2)
11+
f1 = (a + 2) * x3 - (a + 3) * x2 + 1
12+
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
13+
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
14+
15+
16+
def compute_contribs(in_size, out_size, scale, support=2.0, align_corners=False):
17+
if align_corners:
18+
if out_size == 1:
19+
in_coords = jnp.zeros((1,))
20+
else:
21+
in_coords = jnp.linspace(0, in_size - 1, out_size)
22+
else:
23+
out_coords = jnp.arange(out_size) + 0.5
24+
in_coords = out_coords / scale - 0.5
25+
26+
left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
27+
idxs = left_idx[:, None] + jnp.arange(4)
28+
29+
dx = in_coords[:, None] - idxs
30+
31+
weights = cubic_kernel(dx)
32+
33+
weights = weights / jnp.sum(weights, axis=1, keepdims=True)
34+
return idxs, weights
35+
36+
def gather_weights(img, idxs, axis):
37+
"""Safely gather with boundary handling"""
38+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
39+
return jnp.take(img, idxs, axis=axis)
40+
41+
def interpolate_along_axis_bchw(img, idxs, weights, axis):
42+
"""
43+
Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
44+
idxs: (out_size, 4) int32 indices
45+
weights: (out_size, 4) float32 weights
46+
"""
47+
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
48+
out_size = idxs.shape[0]
49+
k = idxs.shape[1] # Typically 4 for cubic
50+
51+
# Clip to input bounds
52+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)
53+
54+
def gather_and_weight(i):
55+
idx = idxs[i] # (4,)
56+
w = weights[i] # (4,)
57+
58+
def gather_one(offset):
59+
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
60+
61+
gathered = jnp.stack([gather_one(o) for o in range(k)], axis=0) # (4, B, C, H, W)
62+
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
63+
return weighted
64+
65+
out = jax.vmap(gather_and_weight)(jnp.arange(out_size)) # (out_size, B, C, H, W)
66+
67+
# Move the interpolated axis back into place
68+
if axis == 2: # interpolated over H
69+
return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
70+
else: # axis == 3, interpolated over W
71+
return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)
72+
73+
74+
75+
def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
76+
h, w = img.shape[-2:]
77+
if align_corners and out_h > 1:
78+
scale_y = (h - 1) / (out_h - 1)
79+
else:
80+
scale_y = out_h / h
81+
82+
if align_corners and out_w > 1:
83+
scale_x = (w - 1) / (out_w - 1)
84+
else:
85+
scale_x = out_w / w
86+
87+
idxs_y, weights_y = compute_contribs(
88+
h, out_h, scale_y, align_corners=align_corners,
89+
)
90+
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
91+
92+
idxs_x, weights_x = compute_contribs(
93+
w, out_w, scale_x, align_corners=align_corners,
94+
)
95+
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
96+
return out

torchax/torchax/ops/jtorch.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import math
44
import collections.abc
55
import functools
6-
from typing import Optional, Sequence
6+
from typing import Optional, Sequence, Tuple
77
import numpy as np
88

99
import jax
@@ -13,7 +13,7 @@
1313

1414
import torch
1515
from torchax.ops.ops_registry import register_torch_function_op
16-
from torchax.ops import op_base, mappings, jaten
16+
from torchax.ops import op_base, mappings, jaten, jimage
1717
import torchax.tensor
1818
from torchax.view import View, NarrowInfo
1919
import torch.utils._pytree as pytree
@@ -512,3 +512,51 @@ def functional_linear(self, weights, bias=None):
512512
if bias is not None:
513513
res += bias
514514
return res
515+
516+
517+
518+
@register_function(torch.nn.functional.interpolate)
519+
def functional_interpolate(
520+
input,
521+
size: Tuple[int, int],
522+
scale_factor: Optional[float],
523+
mode: str,
524+
align_corners: bool,
525+
recompute_scale_factor: bool,
526+
antialias: bool,
527+
):
528+
supported_methods = (
529+
"nearest",
530+
"linear",
531+
"bilinear",
532+
"trilinear",
533+
"cubic",
534+
"bicubic",
535+
"tricubic",
536+
"lanczos3",
537+
"lanczos5",
538+
)
539+
is_jax_supported = mode in supported_methods
540+
if not is_jax_supported:
541+
raise torchax.tensor.OperatorNotFound(
542+
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
543+
)
544+
# None check
545+
antialias = antialias or False
546+
align_corners = align_corners or False
547+
548+
if mode in ('cubic', 'bicubic', 'tricubic') and not antialias:
549+
return jimage.interpolate_bicubic_no_aa(
550+
input,
551+
size[0],
552+
size[1],
553+
align_corners,
554+
)
555+
return jaten._aten_upsample(
556+
input,
557+
size,
558+
align_corners,
559+
antialias,
560+
mode,
561+
scale_factor,
562+
)

0 commit comments

Comments
 (0)