From fa99f20fab77ff691093afd7ef56acccec76b261 Mon Sep 17 00:00:00 2001 From: zmelumian Date: Tue, 20 May 2025 08:55:50 +0300 Subject: [PATCH] [torchax] Added support for bicubic and billinear resampling --- torchax/test/test_image.py | 70 +++++++++++++++++++++++++ torchax/torchax/ops/jaten.py | 35 +++++++++---- torchax/torchax/ops/jimage.py | 96 +++++++++++++++++++++++++++++++++++ torchax/torchax/ops/jtorch.py | 52 ++++++++++++++++++- 4 files changed, 242 insertions(+), 11 deletions(-) create mode 100644 torchax/test/test_image.py create mode 100644 torchax/torchax/ops/jimage.py diff --git a/torchax/test/test_image.py b/torchax/test/test_image.py new file mode 100644 index 00000000000..845505a0790 --- /dev/null +++ b/torchax/test/test_image.py @@ -0,0 +1,70 @@ +import unittest +from typing import Tuple +import itertools +from functools import partial +import jax +import torch + +import torchax +import torchax.interop + +def to_xla_tensor(tensorstree): + return torchax.interop.torch_view(torchax.tensor.t2j(tensorstree)) + +def to_torch_tensor(tensorstree): + return torchax.tensor.j2t(torchax.interop.jax_view(tensorstree)) + + + +@partial(jax.jit, static_argnums=(1, 2, 3, 4)) +def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool, antialias: bool, method: str): + tensor = torchax.interop.torch_view(tensor) + tensor = torch.nn.functional.interpolate(tensor, size=output_size, mode=method, align_corners=align_corners, antialias=antialias) + return torchax.interop.jax_view(tensor) + + +def test_upsampling(align_corners: bool, antialias: bool, method: str): + + if method == 'bilinear': + if align_corners: + return # bilinear upsampling does not support align_corners + + input_tensor = torch.rand((1, 1, 256, 512), dtype=torch.float32) + output_size = (128, 64) + + upsampled_tensor = torch.nn.functional.interpolate(input_tensor, size=output_size, mode=method, align_corners=align_corners, antialias=antialias) + + with torchax.default_env(): + input_tensor_xla = to_xla_tensor(input_tensor) + input_tensor_xla = torchax.interop.jax_view(input_tensor_xla) + upsampled_tensor_xla = upsample_jit(input_tensor_xla, output_size, align_corners, antialias=antialias, method=method) + + upsampled_tensor_xla = to_torch_tensor(upsampled_tensor_xla) + abs_err = torch.abs(upsampled_tensor - upsampled_tensor_xla) + + assert torch.allclose(upsampled_tensor, upsampled_tensor_xla, atol=1e-4, rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}" + +class TestResampling(unittest.TestCase): + def test_resampling_combinations(self): + methods = [ + 'bicubic', + 'bilinear', + ] + antialias_options = [ + True, + False, + ] + + aligncorners_options = [ + False, + True, + ] + + for method, antialias, align_corners in itertools.product(methods, antialias_options, aligncorners_options): + with self.subTest(method=method, antialias=antialias, align_corners=align_corners): + test_upsampling(align_corners=align_corners, antialias=antialias, method=method) + + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index fc8dcc71e46..0605ae3a66a 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -5181,17 +5181,16 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): return output -@op(torch.ops.aten._upsample_bilinear2d_aa) -def _aten_upsample_bilinear2d_aa(input, - output_size, - align_corners, - scale_factors=None, - scales_h=None, - scales_w=None): +def _aten_upsample(input, + output_size, + align_corners, + antialias, + method, + scale_factors=None, + scales_h=None, + scales_w=None): # input: is of type jaxlib.xla_extension.ArrayImpl image = input - method = "bilinear" - antialias = True # ignored for upsampling # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html # Resize does not distinguish batch, channel size. @@ -5253,6 +5252,24 @@ def _aten_upsample_bilinear2d_aa(input, ) +@op(torch.ops.aten._upsample_bilinear2d_aa) +def _aten_upsample_billinear_aa(input, + output_size, + align_corners, + scale_factors=None, + scales_h=None, + scales_w=None): + return _aten_upsample( + input, + output_size, + align_corners, + True, # antialias + "bilinear", # method + scale_factors, + scales_h, + scales_w + ) + @op(torch.ops.aten.polar) def _aten_polar(abs, angle, *, out=None): return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle)) diff --git a/torchax/torchax/ops/jimage.py b/torchax/torchax/ops/jimage.py new file mode 100644 index 00000000000..248b43cf2a2 --- /dev/null +++ b/torchax/torchax/ops/jimage.py @@ -0,0 +1,96 @@ +import jax +import jax.numpy as jnp + +def cubic_kernel(x, a=-0.75): + """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)""" + absx = jnp.abs(x) + x2 = absx * absx + x3 = x2 * absx + cond1 = (absx <= 1) + cond2 = (absx > 1) & (absx < 2) + f1 = (a + 2) * x3 - (a + 3) * x2 + 1 + f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a + return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0)) + + +def compute_contribs(in_size, out_size, scale, support=2.0, align_corners=False): + if align_corners: + if out_size == 1: + in_coords = jnp.zeros((1,)) + else: + in_coords = jnp.linspace(0, in_size - 1, out_size) + else: + out_coords = jnp.arange(out_size) + 0.5 + in_coords = out_coords / scale - 0.5 + + left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1 + idxs = left_idx[:, None] + jnp.arange(4) + + dx = in_coords[:, None] - idxs + + weights = cubic_kernel(dx) + + weights = weights / jnp.sum(weights, axis=1, keepdims=True) + return idxs, weights + +def gather_weights(img, idxs, axis): + """Safely gather with boundary handling""" + idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) + return jnp.take(img, idxs, axis=axis) + +def interpolate_along_axis_bchw(img, idxs, weights, axis): + """ + Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W). + idxs: (out_size, 4) int32 indices + weights: (out_size, 4) float32 weights + """ + assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)" + out_size = idxs.shape[0] + k = idxs.shape[1] # Typically 4 for cubic + + # Clip to input bounds + idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4) + + def gather_and_weight(i): + idx = idxs[i] # (4,) + w = weights[i] # (4,) + + def gather_one(offset): + return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W) + + gathered = jnp.stack([gather_one(o) for o in range(k)], axis=0) # (4, B, C, H, W) + weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W) + return weighted + + out = jax.vmap(gather_and_weight)(jnp.arange(out_size)) # (out_size, B, C, H, W) + + # Move the interpolated axis back into place + if axis == 2: # interpolated over H + return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W) + else: # axis == 3, interpolated over W + return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W) + + + +def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False): + h, w = img.shape[-2:] + if align_corners and out_h > 1: + scale_y = (h - 1) / (out_h - 1) + else: + scale_y = out_h / h + + if align_corners and out_w > 1: + scale_x = (w - 1) / (out_w - 1) + else: + scale_x = out_w / w + + idxs_y, weights_y = compute_contribs( + h, out_h, scale_y, align_corners=align_corners, + ) + tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2) + + idxs_x, weights_x = compute_contribs( + w, out_w, scale_x, align_corners=align_corners, + ) + out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3) + return out diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index f03e5cbf7a0..5b79a089535 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -3,7 +3,7 @@ import math import collections.abc import functools -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple import numpy as np import jax @@ -13,7 +13,7 @@ import torch from torchax.ops.ops_registry import register_torch_function_op -from torchax.ops import op_base, mappings, jaten +from torchax.ops import op_base, mappings, jaten, jimage import torchax.tensor from torchax.view import View, NarrowInfo import torch.utils._pytree as pytree @@ -512,3 +512,51 @@ def functional_linear(self, weights, bias=None): if bias is not None: res += bias return res + + + +@register_function(torch.nn.functional.interpolate) +def functional_interpolate( + input, + size: Tuple[int, int], + scale_factor: Optional[float], + mode: str, + align_corners: bool, + recompute_scale_factor: bool, + antialias: bool, +): + supported_methods = ( + "nearest", + "linear", + "bilinear", + "trilinear", + "cubic", + "bicubic", + "tricubic", + "lanczos3", + "lanczos5", + ) + is_jax_supported = mode in supported_methods + if not is_jax_supported: + raise torchax.tensor.OperatorNotFound( + f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" + ) + # None check + antialias = antialias or False + align_corners = align_corners or False + + if mode in ('cubic', 'bicubic', 'tricubic') and not antialias: + return jimage.interpolate_bicubic_no_aa( + input, + size[0], + size[1], + align_corners, + ) + return jaten._aten_upsample( + input, + size, + align_corners, + antialias, + mode, + scale_factor, + ) \ No newline at end of file