Skip to content

[torchax] Added support for bicubic and billinear resampling #9196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions torchax/test/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import unittest
from typing import Tuple
import itertools
from functools import partial
import jax
import torch

import torchax
import torchax.interop

def to_xla_tensor(tensorstree):
return torchax.interop.torch_view(torchax.tensor.t2j(tensorstree))

def to_torch_tensor(tensorstree):
return torchax.tensor.j2t(torchax.interop.jax_view(tensorstree))



@partial(jax.jit, static_argnums=(1, 2, 3, 4))
def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool, antialias: bool, method: str):
tensor = torchax.interop.torch_view(tensor)
tensor = torch.nn.functional.interpolate(tensor, size=output_size, mode=method, align_corners=align_corners, antialias=antialias)
return torchax.interop.jax_view(tensor)


def test_upsampling(align_corners: bool, antialias: bool, method: str):

if method == 'bilinear':
if align_corners:
return # bilinear upsampling does not support align_corners

input_tensor = torch.rand((1, 1, 256, 512), dtype=torch.float32)
output_size = (128, 64)

upsampled_tensor = torch.nn.functional.interpolate(input_tensor, size=output_size, mode=method, align_corners=align_corners, antialias=antialias)

with torchax.default_env():
input_tensor_xla = to_xla_tensor(input_tensor)
input_tensor_xla = torchax.interop.jax_view(input_tensor_xla)
upsampled_tensor_xla = upsample_jit(input_tensor_xla, output_size, align_corners, antialias=antialias, method=method)

upsampled_tensor_xla = to_torch_tensor(upsampled_tensor_xla)
abs_err = torch.abs(upsampled_tensor - upsampled_tensor_xla)

assert torch.allclose(upsampled_tensor, upsampled_tensor_xla, atol=1e-4, rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}"

class TestResampling(unittest.TestCase):
def test_resampling_combinations(self):
methods = [
'bicubic',
'bilinear',
]
antialias_options = [
True,
False,
]

aligncorners_options = [
False,
True,
]

for method, antialias, align_corners in itertools.product(methods, antialias_options, aligncorners_options):
with self.subTest(method=method, antialias=antialias, align_corners=align_corners):
test_upsampling(align_corners=align_corners, antialias=antialias, method=method)



if __name__ == '__main__':
unittest.main()
35 changes: 26 additions & 9 deletions torchax/torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -5181,17 +5181,16 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
return output


@op(torch.ops.aten._upsample_bilinear2d_aa)
def _aten_upsample_bilinear2d_aa(input,
output_size,
align_corners,
scale_factors=None,
scales_h=None,
scales_w=None):
def _aten_upsample(input,
output_size,
align_corners,
antialias,
method,
scale_factors=None,
scales_h=None,
scales_w=None):
# input: is of type jaxlib.xla_extension.ArrayImpl
image = input
method = "bilinear"
antialias = True # ignored for upsampling

# https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
# Resize does not distinguish batch, channel size.
Expand Down Expand Up @@ -5253,6 +5252,24 @@ def _aten_upsample_bilinear2d_aa(input,
)


@op(torch.ops.aten._upsample_bilinear2d_aa)
def _aten_upsample_billinear_aa(input,
output_size,
align_corners,
scale_factors=None,
scales_h=None,
scales_w=None):
return _aten_upsample(
input,
output_size,
align_corners,
True, # antialias
"bilinear", # method
scale_factors,
scales_h,
scales_w
)

@op(torch.ops.aten.polar)
def _aten_polar(abs, angle, *, out=None):
return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle))
Expand Down
96 changes: 96 additions & 0 deletions torchax/torchax/ops/jimage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import jax
import jax.numpy as jnp

def cubic_kernel(x, a=-0.75):
"""Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
absx = jnp.abs(x)
x2 = absx * absx
x3 = x2 * absx
cond1 = (absx <= 1)
cond2 = (absx > 1) & (absx < 2)
f1 = (a + 2) * x3 - (a + 3) * x2 + 1
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))


def compute_contribs(in_size, out_size, scale, support=2.0, align_corners=False):
if align_corners:
if out_size == 1:
in_coords = jnp.zeros((1,))
else:
in_coords = jnp.linspace(0, in_size - 1, out_size)
else:
out_coords = jnp.arange(out_size) + 0.5
in_coords = out_coords / scale - 0.5

left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
idxs = left_idx[:, None] + jnp.arange(4)

dx = in_coords[:, None] - idxs

weights = cubic_kernel(dx)

weights = weights / jnp.sum(weights, axis=1, keepdims=True)
return idxs, weights

def gather_weights(img, idxs, axis):
"""Safely gather with boundary handling"""
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
return jnp.take(img, idxs, axis=axis)

def interpolate_along_axis_bchw(img, idxs, weights, axis):
"""
Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
idxs: (out_size, 4) int32 indices
weights: (out_size, 4) float32 weights
"""
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
out_size = idxs.shape[0]
k = idxs.shape[1] # Typically 4 for cubic

# Clip to input bounds
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)

def gather_and_weight(i):
idx = idxs[i] # (4,)
w = weights[i] # (4,)

def gather_one(offset):
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)

gathered = jnp.stack([gather_one(o) for o in range(k)], axis=0) # (4, B, C, H, W)
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
return weighted

out = jax.vmap(gather_and_weight)(jnp.arange(out_size)) # (out_size, B, C, H, W)

# Move the interpolated axis back into place
if axis == 2: # interpolated over H
return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
else: # axis == 3, interpolated over W
return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)



def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
h, w = img.shape[-2:]
if align_corners and out_h > 1:
scale_y = (h - 1) / (out_h - 1)
else:
scale_y = out_h / h

if align_corners and out_w > 1:
scale_x = (w - 1) / (out_w - 1)
else:
scale_x = out_w / w

idxs_y, weights_y = compute_contribs(
h, out_h, scale_y, align_corners=align_corners,
)
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)

idxs_x, weights_x = compute_contribs(
w, out_w, scale_x, align_corners=align_corners,
)
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
return out
52 changes: 50 additions & 2 deletions torchax/torchax/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import collections.abc
import functools
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple
import numpy as np

import jax
Expand All @@ -13,7 +13,7 @@

import torch
from torchax.ops.ops_registry import register_torch_function_op
from torchax.ops import op_base, mappings, jaten
from torchax.ops import op_base, mappings, jaten, jimage
import torchax.tensor
from torchax.view import View, NarrowInfo
import torch.utils._pytree as pytree
Expand Down Expand Up @@ -512,3 +512,51 @@ def functional_linear(self, weights, bias=None):
if bias is not None:
res += bias
return res



@register_function(torch.nn.functional.interpolate)
def functional_interpolate(
input,
size: Tuple[int, int],
scale_factor: Optional[float],
mode: str,
align_corners: bool,
recompute_scale_factor: bool,
antialias: bool,
):
supported_methods = (
"nearest",
"linear",
"bilinear",
"trilinear",
"cubic",
"bicubic",
"tricubic",
"lanczos3",
"lanczos5",
)
is_jax_supported = mode in supported_methods
if not is_jax_supported:
raise torchax.tensor.OperatorNotFound(
f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}"
)
# None check
antialias = antialias or False
align_corners = align_corners or False

if mode in ('cubic', 'bicubic', 'tricubic') and not antialias:
return jimage.interpolate_bicubic_no_aa(
input,
size[0],
size[1],
align_corners,
)
return jaten._aten_upsample(
input,
size,
align_corners,
antialias,
mode,
scale_factor,
)
Loading