diff --git a/docs/index.html b/docs/index.html index 327a69f..175d5cf 100644 --- a/docs/index.html +++ b/docs/index.html @@ -3,8 +3,8 @@ - - Module List – pdoc 14.6.1 + + Module List – pdoc 14.7.0 diff --git a/docs/sony_custom_layers/keras.html b/docs/sony_custom_layers/keras.html index 20e53bd..4170c5f 100644 --- a/docs/sony_custom_layers/keras.html +++ b/docs/sony_custom_layers/keras.html @@ -3,7 +3,7 @@ - + sony_custom_layers.keras API documentation diff --git a/docs/sony_custom_layers/pytorch.html b/docs/sony_custom_layers/pytorch.html index 7bcd0f3..c43389f 100644 --- a/docs/sony_custom_layers/pytorch.html +++ b/docs/sony_custom_layers/pytorch.html @@ -3,7 +3,7 @@ - + sony_custom_layers.pytorch API documentation @@ -98,6 +98,18 @@

API Documentation

+ +
  • + FasterRCNNBoxDecode + +
  • load_custom_ops @@ -146,49 +158,53 @@

    21if TYPE_CHECKING: 22 import onnxruntime as ort 23 -24__all__ = ['multiclass_nms', 'NMSResults', 'multiclass_nms_with_indices', 'NMSWithIndicesResults', 'load_custom_ops'] -25 -26validate_installed_libraries(required_libraries['torch']) -27 -28from .object_detection import multiclass_nms, NMSResults # noqa: E402 -29from .object_detection import multiclass_nms_with_indices, NMSWithIndicesResults # noqa: E402 -30 -31 -32def load_custom_ops(ort_session_ops: Optional['ort.SessionOptions'] = None) -> 'ort.SessionOptions': -33 """ -34 Registers the custom ops implementation for onnxruntime, and sets up the SessionOptions object for onnxruntime -35 session. -36 -37 Args: -38 ort_session_ops: SessionOptions object to register the custom ops library on. If None, creates a new object. +24__all__ = [ +25 'multiclass_nms', 'NMSResults', 'multiclass_nms_with_indices', 'NMSWithIndicesResults', 'FasterRCNNBoxDecode', +26 'load_custom_ops' +27] +28 +29validate_installed_libraries(required_libraries['torch']) +30from sony_custom_layers.pytorch.nms import ( # noqa: E402 +31 multiclass_nms, NMSResults, multiclass_nms_with_indices, NMSWithIndicesResults) +32from sony_custom_layers.pytorch.box_decode import FasterRCNNBoxDecode # noqa: E402 +33 +34 +35def load_custom_ops(ort_session_ops: Optional['ort.SessionOptions'] = None) -> 'ort.SessionOptions': +36 """ +37 Registers the custom ops implementation for onnxruntime, and sets up the SessionOptions object for onnxruntime +38 session. 39 -40 Returns: -41 SessionOptions object with registered custom ops. +40 Args: +41 ort_session_ops: SessionOptions object to register the custom ops library on. If None, creates a new object. 42 -43 Example: -44 ``` -45 import onnxruntime as ort -46 from sony_custom_layers.pytorch import load_custom_ops -47 -48 so = load_custom_ops() -49 session = ort.InferenceSession(model_path, sess_options=so) -50 session.run(...) -51 ``` -52 You can also pass your own SessionOptions object upon which to register the custom ops -53 ``` -54 load_custom_ops(ort_session_options=so) -55 ``` -56 """ -57 validate_installed_libraries(required_libraries['torch_ort']) -58 -59 # trigger onnxruntime op registration -60 from .object_detection import nms_ort +43 Returns: +44 SessionOptions object with registered custom ops. +45 +46 Example: +47 ``` +48 import onnxruntime as ort +49 from sony_custom_layers.pytorch import load_custom_ops +50 +51 so = load_custom_ops() +52 session = ort.InferenceSession(model_path, sess_options=so) +53 session.run(...) +54 ``` +55 You can also pass your own SessionOptions object upon which to register the custom ops +56 ``` +57 load_custom_ops(ort_session_options=so) +58 ``` +59 """ +60 validate_installed_libraries(required_libraries['torch_ort']) 61 -62 from onnxruntime_extensions import get_library_path -63 from onnxruntime import SessionOptions -64 ort_session_ops = ort_session_ops or SessionOptions() -65 ort_session_ops.register_custom_ops_library(get_library_path()) -66 return ort_session_ops +62 # trigger onnxruntime op registration +63 from .nms import nms_ort +64 from .box_decode import box_decode_ort +65 +66 from onnxruntime_extensions import get_library_path +67 from onnxruntime import SessionOptions +68 ort_session_ops = ort_session_ops or SessionOptions() +69 ort_session_ops.register_custom_ops_library(get_library_path()) +70 return ort_session_ops @@ -779,6 +795,175 @@

    Example:
    + + +
    + +
    + + class + FasterRCNNBoxDecode(torch.nn.modules.module.Module): + + + +
    + +
    30class FasterRCNNBoxDecode(nn.Module):
    +31    """
    +32    Box decoding as per Faster R-CNN <https://arxiv.org/abs/1506.01497>.
    +33
    +34    Args:
    +35        anchors: Anchors with a shape of (n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).
    +36        scale_factors: Scaling factors in the format (y, x, height, width).
    +37        clip_window: Clipping window in the format (y_min, x_min, y_max, x_max).
    +38
    +39    Inputs:
    +40        **rel_codes** (Tensor): Relative codes (encoded offsets) with a shape of (batch, n_boxes, 4) in centroid
    +41                                coordinates (y_center, x_center, h, w).
    +42
    +43    Returns:
    +44        Decoded boxes with a shape of (batch, n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).
    +45
    +46    Raises:
    +47        ValueError: If provided with invalid arguments or an input tensor with unexpected shape
    +48
    +49    Example:
    +50        ```
    +51        from sony_custom_layers.pytorch import FasterRCNNBoxDecode
    +52
    +53        box_decode = FasterRCNNBoxDecode(anchors,
    +54                                         scale_factors=(10, 10, 5, 5),
    +55                                         clip_window=(0, 0, 1, 1))
    +56        decoded_boxes = box_decode(rel_codes)
    +57        ```
    +58    """
    +59
    +60    def __init__(self, anchors: torch.Tensor, scale_factors: Sequence[Union[float, int]],
    +61                 clip_window: Sequence[Union[float, int]]):
    +62        super().__init__()
    +63        if not (len(anchors.shape) == 2 and anchors.shape[-1] == 4):
    +64            raise ValueError(f'Invalid anchors shape {anchors.shape}. Expected shape (n_boxes, 4).')
    +65        self.register_buffer('anchors', anchors)
    +66
    +67        if len(scale_factors) != 4:
    +68            raise ValueError(f'Invalid scale factors {scale_factors}. Expected 4 values for (y, x, height, width).')
    +69        self.register_buffer('scale_factors', torch.tensor(scale_factors, dtype=torch.float32, device=anchors.device))
    +70
    +71        if len(clip_window) != 4:
    +72            raise ValueError(f'Invalid clip window {clip_window}. Expected 4 values for (y_min, x_min, y_max, x_max).')
    +73        self.register_buffer('clip_window', torch.tensor(clip_window, dtype=torch.float32, device=anchors.device))
    +74
    +75    def forward(self, rel_codes: torch.Tensor) -> torch.Tensor:
    +76        return torch.ops.sony.faster_rcnn_box_decode(rel_codes, self.anchors, self.scale_factors, self.clip_window)
    +
    + + +

    Box decoding as per Faster R-CNN https://arxiv.org/abs/1506.01497.

    + +
    Arguments:
    + +
      +
    • anchors: Anchors with a shape of (n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).
    • +
    • scale_factors: Scaling factors in the format (y, x, height, width).
    • +
    • clip_window: Clipping window in the format (y_min, x_min, y_max, x_max).
    • +
    + +
    Inputs:
    + +
    +

    rel_codes (Tensor): Relative codes (encoded offsets) with a shape of (batch, n_boxes, 4) in centroid + coordinates (y_center, x_center, h, w).

    +
    + +
    Returns:
    + +
    +

    Decoded boxes with a shape of (batch, n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).

    +
    + +
    Raises:
    + +
      +
    • ValueError: If provided with invalid arguments or an input tensor with unexpected shape
    • +
    + +
    Example:
    + +
    +
    from sony_custom_layers.pytorch import FasterRCNNBoxDecode
    +
    +box_decode = FasterRCNNBoxDecode(anchors,
    +                                 scale_factors=(10, 10, 5, 5),
    +                                 clip_window=(0, 0, 1, 1))
    +decoded_boxes = box_decode(rel_codes)
    +
    +
    +
    + + +
    + +
    + + FasterRCNNBoxDecode( anchors: torch.Tensor, scale_factors: Sequence[Union[float, int]], clip_window: Sequence[Union[float, int]]) + + + +
    + +
    60    def __init__(self, anchors: torch.Tensor, scale_factors: Sequence[Union[float, int]],
    +61                 clip_window: Sequence[Union[float, int]]):
    +62        super().__init__()
    +63        if not (len(anchors.shape) == 2 and anchors.shape[-1] == 4):
    +64            raise ValueError(f'Invalid anchors shape {anchors.shape}. Expected shape (n_boxes, 4).')
    +65        self.register_buffer('anchors', anchors)
    +66
    +67        if len(scale_factors) != 4:
    +68            raise ValueError(f'Invalid scale factors {scale_factors}. Expected 4 values for (y, x, height, width).')
    +69        self.register_buffer('scale_factors', torch.tensor(scale_factors, dtype=torch.float32, device=anchors.device))
    +70
    +71        if len(clip_window) != 4:
    +72            raise ValueError(f'Invalid clip window {clip_window}. Expected 4 values for (y_min, x_min, y_max, x_max).')
    +73        self.register_buffer('clip_window', torch.tensor(clip_window, dtype=torch.float32, device=anchors.device))
    +
    + + +

    Initialize internal Module state, shared by both nn.Module and ScriptModule.

    +
    + + +
    +
    + +
    + + def + forward(self, rel_codes: torch.Tensor) -> torch.Tensor: + + + +
    + +
    75    def forward(self, rel_codes: torch.Tensor) -> torch.Tensor:
    +76        return torch.ops.sony.faster_rcnn_box_decode(rel_codes, self.anchors, self.scale_factors, self.clip_window)
    +
    + + +

    Define the computation performed at every call.

    + +

    Should be overridden by all subclasses.

    + +
    + +

    Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

    + +
    +
    + +
    @@ -792,41 +977,42 @@
    Example:
    -
    33def load_custom_ops(ort_session_ops: Optional['ort.SessionOptions'] = None) -> 'ort.SessionOptions':
    -34    """
    -35    Registers the custom ops implementation for onnxruntime, and sets up the SessionOptions object for onnxruntime
    -36    session.
    -37
    -38    Args:
    -39        ort_session_ops: SessionOptions object to register the custom ops library on. If None, creates a new object.
    +            
    36def load_custom_ops(ort_session_ops: Optional['ort.SessionOptions'] = None) -> 'ort.SessionOptions':
    +37    """
    +38    Registers the custom ops implementation for onnxruntime, and sets up the SessionOptions object for onnxruntime
    +39    session.
     40
    -41    Returns:
    -42        SessionOptions object with registered custom ops.
    +41    Args:
    +42        ort_session_ops: SessionOptions object to register the custom ops library on. If None, creates a new object.
     43
    -44    Example:
    -45        ```
    -46        import onnxruntime as ort
    -47        from sony_custom_layers.pytorch import load_custom_ops
    -48
    -49        so = load_custom_ops()
    -50        session = ort.InferenceSession(model_path, sess_options=so)
    -51        session.run(...)
    -52        ```
    -53        You can also pass your own SessionOptions object upon which to register the custom ops
    -54        ```
    -55        load_custom_ops(ort_session_options=so)
    -56        ```
    -57    """
    -58    validate_installed_libraries(required_libraries['torch_ort'])
    -59
    -60    # trigger onnxruntime op registration
    -61    from .object_detection import nms_ort
    +44    Returns:
    +45        SessionOptions object with registered custom ops.
    +46
    +47    Example:
    +48        ```
    +49        import onnxruntime as ort
    +50        from sony_custom_layers.pytorch import load_custom_ops
    +51
    +52        so = load_custom_ops()
    +53        session = ort.InferenceSession(model_path, sess_options=so)
    +54        session.run(...)
    +55        ```
    +56        You can also pass your own SessionOptions object upon which to register the custom ops
    +57        ```
    +58        load_custom_ops(ort_session_options=so)
    +59        ```
    +60    """
    +61    validate_installed_libraries(required_libraries['torch_ort'])
     62
    -63    from onnxruntime_extensions import get_library_path
    -64    from onnxruntime import SessionOptions
    -65    ort_session_ops = ort_session_ops or SessionOptions()
    -66    ort_session_ops.register_custom_ops_library(get_library_path())
    -67    return ort_session_ops
    +63    # trigger onnxruntime op registration
    +64    from .nms import nms_ort
    +65    from .box_decode import box_decode_ort
    +66
    +67    from onnxruntime_extensions import get_library_path
    +68    from onnxruntime import SessionOptions
    +69    ort_session_ops = ort_session_ops or SessionOptions()
    +70    ort_session_ops.register_custom_ops_library(get_library_path())
    +71    return ort_session_ops
     
    diff --git a/sony_custom_layers/pytorch/tests/object_detection/__init__.py b/sony_custom_layers/common/__init__.py similarity index 100% rename from sony_custom_layers/pytorch/tests/object_detection/__init__.py rename to sony_custom_layers/common/__init__.py diff --git a/sony_custom_layers/keras/object_detection/box_utils.py b/sony_custom_layers/common/box_util.py similarity index 95% rename from sony_custom_layers/keras/object_detection/box_utils.py rename to sony_custom_layers/common/box_util.py index 917c818..18553d0 100644 --- a/sony_custom_layers/keras/object_detection/box_utils.py +++ b/sony_custom_layers/common/box_util.py @@ -1,5 +1,5 @@ # ----------------------------------------------------------------------------- -# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- - from typing import Tuple diff --git a/sony_custom_layers/keras/object_detection/faster_rcnn_box_decode.py b/sony_custom_layers/keras/object_detection/faster_rcnn_box_decode.py index a4fbac0..652f37b 100644 --- a/sony_custom_layers/keras/object_detection/faster_rcnn_box_decode.py +++ b/sony_custom_layers/keras/object_detection/faster_rcnn_box_decode.py @@ -19,8 +19,8 @@ import tensorflow as tf import numpy as np +from sony_custom_layers.common.box_util import corners_to_centroids, centroids_to_corners from sony_custom_layers.keras.base_custom_layer import CustomLayer -from sony_custom_layers.keras.object_detection.box_utils import corners_to_centroids, centroids_to_corners from sony_custom_layers.keras.custom_objects import register_layer diff --git a/sony_custom_layers/pytorch/__init__.py b/sony_custom_layers/pytorch/__init__.py index 0c172f0..56067f3 100644 --- a/sony_custom_layers/pytorch/__init__.py +++ b/sony_custom_layers/pytorch/__init__.py @@ -21,12 +21,15 @@ if TYPE_CHECKING: import onnxruntime as ort -__all__ = ['multiclass_nms', 'NMSResults', 'multiclass_nms_with_indices', 'NMSWithIndicesResults', 'load_custom_ops'] +__all__ = [ + 'multiclass_nms', 'NMSResults', 'multiclass_nms_with_indices', 'NMSWithIndicesResults', 'FasterRCNNBoxDecode', + 'load_custom_ops' +] validate_installed_libraries(required_libraries['torch']) - -from .object_detection import multiclass_nms, NMSResults # noqa: E402 -from .object_detection import multiclass_nms_with_indices, NMSWithIndicesResults # noqa: E402 +from sony_custom_layers.pytorch.nms import ( # noqa: E402 + multiclass_nms, NMSResults, multiclass_nms_with_indices, NMSWithIndicesResults) +from sony_custom_layers.pytorch.box_decode import FasterRCNNBoxDecode # noqa: E402 def load_custom_ops(ort_session_ops: Optional['ort.SessionOptions'] = None) -> 'ort.SessionOptions': @@ -57,7 +60,8 @@ def load_custom_ops(ort_session_ops: Optional['ort.SessionOptions'] = None) -> ' validate_installed_libraries(required_libraries['torch_ort']) # trigger onnxruntime op registration - from .object_detection import nms_ort + from .nms import nms_ort + from .box_decode import box_decode_ort from onnxruntime_extensions import get_library_path from onnxruntime import SessionOptions diff --git a/sony_custom_layers/pytorch/box_decode/__init__.py b/sony_custom_layers/pytorch/box_decode/__init__.py new file mode 100644 index 0000000..df98b28 --- /dev/null +++ b/sony_custom_layers/pytorch/box_decode/__init__.py @@ -0,0 +1,21 @@ +# ----------------------------------------------------------------------------- +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- +from .box_decode import FasterRCNNBoxDecode + +# trigger onnx op registration +from . import box_decode_onnx + +__all__ = ['FasterRCNNBoxDecode'] diff --git a/sony_custom_layers/pytorch/box_decode/box_decode.py b/sony_custom_layers/pytorch/box_decode/box_decode.py new file mode 100644 index 0000000..3246c30 --- /dev/null +++ b/sony_custom_layers/pytorch/box_decode/box_decode.py @@ -0,0 +1,114 @@ +# ----------------------------------------------------------------------------- +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- +from typing import Union, Sequence + +import torch +from torch import nn + +from sony_custom_layers.common.box_util import corners_to_centroids, centroids_to_corners +from sony_custom_layers.pytorch.custom_lib import register_op + +BOX_DECODE_TORCH_OP = 'faster_rcnn_box_decode' + +__all__ = ['FasterRCNNBoxDecode'] + + +class FasterRCNNBoxDecode(nn.Module): + """ + Box decoding as per Faster R-CNN . + + Args: + anchors: Anchors with a shape of (n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max). + scale_factors: Scaling factors in the format (y, x, height, width). + clip_window: Clipping window in the format (y_min, x_min, y_max, x_max). + + Inputs: + **rel_codes** (Tensor): Relative codes (encoded offsets) with a shape of (batch, n_boxes, 4) in centroid + coordinates (y_center, x_center, h, w). + + Returns: + Decoded boxes with a shape of (batch, n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max). + + Raises: + ValueError: If provided with invalid arguments or an input tensor with unexpected shape + + Example: + ``` + from sony_custom_layers.pytorch import FasterRCNNBoxDecode + + box_decode = FasterRCNNBoxDecode(anchors, + scale_factors=(10, 10, 5, 5), + clip_window=(0, 0, 1, 1)) + decoded_boxes = box_decode(rel_codes) + ``` + """ + + def __init__(self, anchors: torch.Tensor, scale_factors: Sequence[Union[float, int]], + clip_window: Sequence[Union[float, int]]): + super().__init__() + if not (len(anchors.shape) == 2 and anchors.shape[-1] == 4): + raise ValueError(f'Invalid anchors shape {anchors.shape}. Expected shape (n_boxes, 4).') + self.register_buffer('anchors', anchors) + + if len(scale_factors) != 4: + raise ValueError(f'Invalid scale factors {scale_factors}. Expected 4 values for (y, x, height, width).') + self.register_buffer('scale_factors', torch.tensor(scale_factors, dtype=torch.float32, device=anchors.device)) + + if len(clip_window) != 4: + raise ValueError(f'Invalid clip window {clip_window}. Expected 4 values for (y_min, x_min, y_max, x_max).') + self.register_buffer('clip_window', torch.tensor(clip_window, dtype=torch.float32, device=anchors.device)) + + def forward(self, rel_codes: torch.Tensor) -> torch.Tensor: + return torch.ops.sony.faster_rcnn_box_decode(rel_codes, self.anchors, self.scale_factors, self.clip_window) + + +###################### +# Register custom op # +###################### + + +def _faster_rcnn_box_decode_impl(rel_codes: torch.Tensor, anchors: torch.Tensor, scale_factors: torch.Tensor, + clip_window: torch.Tensor) -> torch.Tensor: + """ This implementation is intended only to be registered as custom torch and onnxruntime op. """ + if len(rel_codes.shape) != 3 or rel_codes.shape[-1] != 4: + raise ValueError(f'Invalid input tensor shape {rel_codes.shape}. Expected shape (batch, n_boxes, 4).') + + if rel_codes.shape[-2] != anchors.shape[-2]: + raise ValueError(f'Mismatch in the number of boxes between input tensor ({rel_codes.shape[-2]}) ' + f'and anchors ({anchors.shape[-2]})') + + scaled_codes = rel_codes / scale_factors + + a_y_min, a_x_min, a_y_max, a_x_max = torch.unbind(anchors, dim=-1) + a_y_center, a_x_center, a_h, a_w = corners_to_centroids(a_y_min, a_x_min, a_y_max, a_x_max) + + box_y_center = scaled_codes[..., 0] * a_h + a_y_center + box_x_center = scaled_codes[..., 1] * a_w + a_x_center + box_h = torch.exp(scaled_codes[..., 2]) * a_h + box_w = torch.exp(scaled_codes[..., 3]) * a_w + box_y_min, box_x_min, box_y_max, box_x_max = centroids_to_corners(box_y_center, box_x_center, box_h, box_w) + boxes = torch.stack([box_y_min, box_x_min, box_y_max, box_x_max], dim=-1) + + y_low, x_low, y_high, x_high = clip_window + boxes = torch.clip(boxes, torch.tensor([y_low, x_low, y_low, x_low], device=rel_codes.device), + torch.tensor([y_high, x_high, y_high, x_high], device=rel_codes.device)) + return boxes + + +schema = (BOX_DECODE_TORCH_OP + + "(Tensor rel_codes, Tensor anchors, Tensor scale_factors, Tensor clip_window) -> Tensor") + +register_op(BOX_DECODE_TORCH_OP, schema, _faster_rcnn_box_decode_impl) diff --git a/sony_custom_layers/pytorch/box_decode/box_decode_onnx.py b/sony_custom_layers/pytorch/box_decode/box_decode_onnx.py new file mode 100644 index 0000000..dc0f269 --- /dev/null +++ b/sony_custom_layers/pytorch/box_decode/box_decode_onnx.py @@ -0,0 +1,32 @@ +# ----------------------------------------------------------------------------- +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- +import torch + +from sony_custom_layers.pytorch.box_decode.box_decode import BOX_DECODE_TORCH_OP +from sony_custom_layers.pytorch.custom_lib import get_op_qualname + +BOX_DECODE_ONNX_OP = "Sony::FasterRCNNBoxDecode" + + +@torch.onnx.symbolic_helper.parse_args('v', 'v', 'v', 'v') +def box_decode_onnx(g, rel_codes, anchors, scale_factors, clip_window): + outputs = g.op(BOX_DECODE_ONNX_OP, rel_codes, anchors, scale_factors, clip_window, outputs=1) + # Set output tensors shape and dtype + outputs.setType(rel_codes.type()) + return outputs + + +torch.onnx.register_custom_op_symbolic(get_op_qualname(BOX_DECODE_TORCH_OP), box_decode_onnx, opset_version=1) diff --git a/sony_custom_layers/pytorch/box_decode/box_decode_ort.py b/sony_custom_layers/pytorch/box_decode/box_decode_ort.py new file mode 100644 index 0000000..2298531 --- /dev/null +++ b/sony_custom_layers/pytorch/box_decode/box_decode_ort.py @@ -0,0 +1,28 @@ +# ----------------------------------------------------------------------------- +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- +import torch +from onnxruntime_extensions import onnx_op, PyCustomOpDef + +from .box_decode import _faster_rcnn_box_decode_impl +from .box_decode_onnx import BOX_DECODE_ONNX_OP + + +@onnx_op(op_type=BOX_DECODE_ONNX_OP, + inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_float, PyCustomOpDef.dt_float, PyCustomOpDef.dt_float], + outputs=[PyCustomOpDef.dt_float]) +def box_decode_ort(rel_codes, anchors, scale_factors, clip_window): + return _faster_rcnn_box_decode_impl(torch.as_tensor(rel_codes), torch.as_tensor(anchors), + torch.as_tensor(scale_factors), torch.as_tensor(clip_window)) diff --git a/sony_custom_layers/pytorch/object_detection/__init__.py b/sony_custom_layers/pytorch/nms/__init__.py similarity index 99% rename from sony_custom_layers/pytorch/object_detection/__init__.py rename to sony_custom_layers/pytorch/nms/__init__.py index f7af0c5..023dd7a 100644 --- a/sony_custom_layers/pytorch/object_detection/__init__.py +++ b/sony_custom_layers/pytorch/nms/__init__.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- - from .nms import multiclass_nms, NMSResults from .nms_with_indices import multiclass_nms_with_indices, NMSWithIndicesResults diff --git a/sony_custom_layers/pytorch/object_detection/nms.py b/sony_custom_layers/pytorch/nms/nms.py similarity index 98% rename from sony_custom_layers/pytorch/object_detection/nms.py rename to sony_custom_layers/pytorch/nms/nms.py index d4f186a..09cdcec 100644 --- a/sony_custom_layers/pytorch/object_detection/nms.py +++ b/sony_custom_layers/pytorch/nms/nms.py @@ -20,7 +20,7 @@ import torchvision # noqa: F401 # needed for torch.ops.torchvision from sony_custom_layers.pytorch.custom_lib import register_op -from sony_custom_layers.pytorch.object_detection.nms_common import _batch_multiclass_nms, SCORES, LABELS +from .nms_common import _batch_multiclass_nms, SCORES, LABELS MULTICLASS_NMS_TORCH_OP = 'multiclass_nms' diff --git a/sony_custom_layers/pytorch/object_detection/nms_common.py b/sony_custom_layers/pytorch/nms/nms_common.py similarity index 100% rename from sony_custom_layers/pytorch/object_detection/nms_common.py rename to sony_custom_layers/pytorch/nms/nms_common.py diff --git a/sony_custom_layers/pytorch/object_detection/nms_onnx.py b/sony_custom_layers/pytorch/nms/nms_onnx.py similarity index 98% rename from sony_custom_layers/pytorch/object_detection/nms_onnx.py rename to sony_custom_layers/pytorch/nms/nms_onnx.py index 25ad590..b52c5b4 100644 --- a/sony_custom_layers/pytorch/object_detection/nms_onnx.py +++ b/sony_custom_layers/pytorch/nms/nms_onnx.py @@ -15,9 +15,9 @@ # ----------------------------------------------------------------------------- import torch +from sony_custom_layers.pytorch.custom_lib import get_op_qualname from .nms import MULTICLASS_NMS_TORCH_OP from .nms_with_indices import MULTICLASS_NMS_WITH_INDICES_TORCH_OP -from ..custom_lib import get_op_qualname MULTICLASS_NMS_ONNX_OP = "Sony::MultiClassNMS" MULTICLASS_NMS_WITH_INDICES_ONNX_OP = "Sony::MultiClassNMSWithIndices" diff --git a/sony_custom_layers/pytorch/object_detection/nms_ort.py b/sony_custom_layers/pytorch/nms/nms_ort.py similarity index 100% rename from sony_custom_layers/pytorch/object_detection/nms_ort.py rename to sony_custom_layers/pytorch/nms/nms_ort.py diff --git a/sony_custom_layers/pytorch/object_detection/nms_with_indices.py b/sony_custom_layers/pytorch/nms/nms_with_indices.py similarity index 98% rename from sony_custom_layers/pytorch/object_detection/nms_with_indices.py rename to sony_custom_layers/pytorch/nms/nms_with_indices.py index 2b52b3f..9901fef 100644 --- a/sony_custom_layers/pytorch/object_detection/nms_with_indices.py +++ b/sony_custom_layers/pytorch/nms/nms_with_indices.py @@ -19,7 +19,7 @@ from torch import Tensor from sony_custom_layers.pytorch.custom_lib import register_op -from sony_custom_layers.pytorch.object_detection.nms_common import _batch_multiclass_nms, SCORES, LABELS, INDICES +from .nms_common import _batch_multiclass_nms, SCORES, LABELS, INDICES __all__ = ['multiclass_nms_with_indices', 'NMSWithIndicesResults'] diff --git a/sony_custom_layers/pytorch/tests/test_box_decode.py b/sony_custom_layers/pytorch/tests/test_box_decode.py new file mode 100644 index 0000000..4b7b3b3 --- /dev/null +++ b/sony_custom_layers/pytorch/tests/test_box_decode.py @@ -0,0 +1,224 @@ +# ----------------------------------------------------------------------------- +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- +from typing import Optional + +import numpy as np +import onnx.helper +import pytest +import torch +import onnxruntime as ort +from sony_custom_layers.util.test_util import exec_in_clean_process + +from sony_custom_layers.pytorch import FasterRCNNBoxDecode, load_custom_ops +from sony_custom_layers.pytorch.tests.util import load_and_validate_onnx_model + + +class TestBoxDecode: + + def test_zero_offsets(self): + n_boxes = 100 + anchors = self._generate_random_anchors(n_boxes, seed=1) + model = FasterRCNNBoxDecode(anchors=anchors, scale_factors=(1, 2, 3, 4), clip_window=(0, 0, 1, 1)) + out = model(torch.zeros((2, n_boxes, 4), dtype=torch.float32)) + assert torch.allclose(out, anchors) + + @pytest.mark.parametrize('scale_factors', [(1., 1., 1., 1.), (1, 2, 3, 4), (1.1, 2.2, 0.5, 3.3)]) + @pytest.mark.parametrize('cuda', [False, True]) + def test_box_decode(self, scale_factors, cuda): + if cuda and not torch.cuda.is_available(): + pytest.skip('cuda is not available') + + n_boxes = 100 + anchors = self._generate_random_anchors(n_boxes, img_size=(100, 200), seed=1) + + v0, v1, v2, v3 = .5, 1., .2, 1.2 + offsets = torch.empty((2, n_boxes, 4), dtype=torch.float32) + # we define encoded offsets that will yield boxes such that: + # np.log(boxes_h / anchors_h) = v2 + # np.log(boxes_w / anchors_w) = v3 + # (boxes_center_y - anchors_center_y) / anchors_h = v0 + # (boxes_center_x - anchors_center_x) / anchors_w = v1 + offsets[:, :, 0] = v0 * scale_factors[0] + offsets[:, :, 1] = v1 * scale_factors[1] + offsets[:, :, 2] = v2 * scale_factors[2] + offsets[:, :, 3] = v3 * scale_factors[3] + + # disable clipping + model = FasterRCNNBoxDecode(anchors, scale_factors, clip_window=(-1000, -1000, 1000, 1000)) + if cuda: + model = model.cuda() + offsets = offsets.cuda() + boxes = model(offsets) + + boxes = boxes.cpu() + boxes_hw = boxes[..., 2:] - boxes[..., :2] + anchors_hw = anchors[..., 2:] - anchors[..., :2] + assert torch.allclose(boxes_hw[..., 0] / anchors_hw[..., 0], torch.exp(torch.as_tensor(v2))) + assert torch.allclose(boxes_hw[..., 1] / anchors_hw[..., 1], torch.exp(torch.as_tensor(v3))) + boxes_center = boxes[..., :2] + 0.5 * boxes_hw + anchors_center = anchors[..., :2] + 0.5 * anchors_hw + t = (boxes_center - anchors_center) / anchors_hw + assert torch.allclose(t[..., 0], torch.as_tensor(v0), atol=1e-5) + assert torch.allclose(t[..., 1], torch.as_tensor(v1), atol=1e-5) + + @pytest.mark.parametrize('clip_window, normalize', [((-4, 1, 90, 110), False), ((-.04, .01, .9, 1.1), True)]) + def test_clipping(self, clip_window, normalize): + scale_factors = (1, 2, 3, 4) + n_boxes = 3 + anchors = self._generate_random_anchors(n_anchors=n_boxes, seed=1) + mul = 0.01 if normalize else 1 + # (2, n_boxes, 4) + boxes = mul * torch.as_tensor( + [ + [ + [-5, 5, 1, 12], # clip y_min + [85, -4, 90, 2], # clip x_min + [85, 95, 95, 100] # clip y_max + ], + [ + [0, 85, 2, 115], # clip x_max + [-10, 115, -5, 120], # y_min, y_max < 0, x_min, x_max > x_size + [95, -10, 100, -5] # y_min, y_max > y_size, x_min, x_max < 0 + ] + ], + dtype=torch.float32) + + rel_codes = self._encode_offsets(boxes, anchors, scale_factors=scale_factors) + model = FasterRCNNBoxDecode(anchors, scale_factors=scale_factors, clip_window=clip_window) + out = model(rel_codes) + exp_boxes = mul * torch.as_tensor( + [ + [ + [-4, 5, 1, 12], # clip y_min + [85, 1, 90, 2], # clip x_min + [85, 95, 90, 100] # clip y_max + ], + [ + [0, 85, 2, 110], # clip x_max + [-4, 110, -4, 110], # y_min, y_max < 0, x_min, x_max > x_size + [90, 1, 90, 1] # y_min, y_max > y_size, x_min, x_max < 0 + ] + ], + dtype=torch.float32) + assert torch.allclose(out, exp_boxes, atol=1e-6) + + @pytest.mark.parametrize('dynamic_batch', [True, False]) + @pytest.mark.parametrize('scale_factors, clip_window', [ + [(1, 2, 3, 4), (0.1, 0.2, 0.9, 1.2)], + [(1.1, 2.2, 3.3, 0.5), (10, 20, 30, 40)], + ]) + def test_onnx_export(self, dynamic_batch, scale_factors, clip_window, tmp_path): + n_boxes = 1000 + anchors = self._generate_random_anchors(n_anchors=n_boxes, seed=1) + model = FasterRCNNBoxDecode(anchors, scale_factors=scale_factors, clip_window=clip_window) + path = str(tmp_path / 'box_decode.onnx') + self._export_onnx(model, n_boxes, path, dynamic_batch=dynamic_batch) + + onnx_model = load_and_validate_onnx_model(path, exp_opset=1) + + [box_decode_node] = list(onnx_model.graph.node) + assert box_decode_node.domain == 'Sony' + assert box_decode_node.op_type == 'FasterRCNNBoxDecode' + assert len(box_decode_node.input) == 4 + assert len(box_decode_node.output) == 1 + # sanity check that we extracted the input nodes correctly + anchors_input, scale_factors_input, clip_window_input = list(onnx_model.graph.initializer) + assert box_decode_node.input[1] == anchors_input.name + assert box_decode_node.input[2] == scale_factors_input.name + assert box_decode_node.input[3] == clip_window_input.name + + def check_input(t, exp_tensor): + assert tuple(t.dims) == exp_tensor.shape + assert np.allclose(onnx.numpy_helper.to_array(t), exp_tensor) + + check_input(anchors_input, anchors) + check_input(scale_factors_input, np.array(scale_factors)) + check_input(clip_window_input, np.array(clip_window)) + + @pytest.mark.parametrize('dynamic_batch', [True, False]) + @pytest.mark.parametrize('scale_factors, clip_window', [ + [(1, 2, 3, 4), (0.1, 0.2, 0.9, 1.2)], + [(1.1, 2.2, 3.3, 0.5), (10, 20, 30, 40)], + ]) + def test_ort(self, dynamic_batch, scale_factors, clip_window, tmp_path): + n_boxes = 1000 + anchors = self._generate_random_anchors(n_anchors=n_boxes, seed=1) + model = FasterRCNNBoxDecode(anchors, scale_factors=scale_factors, clip_window=clip_window) + path = str(tmp_path / 'box_decode.onnx') + self._export_onnx(model, n_boxes, path, dynamic_batch=dynamic_batch) + + batch = 5 if dynamic_batch else 1 + boxes = self._generate_random_boxes(batch, n_boxes, seed=1) + rel_codes = self._encode_offsets(boxes, anchors, scale_factors) + + torch_res = model(rel_codes) + so = load_custom_ops() + + session = ort.InferenceSession(path, sess_options=so) + ort_res = session.run(output_names=None, input_feed={'rel_codes': rel_codes.numpy()}) + assert np.allclose(torch_res, ort_res[0]) + + # run in a new process + code = f""" +import onnxruntime as ort +import numpy as np +from sony_custom_layers.pytorch import load_custom_ops +so = ort.SessionOptions() +so = load_custom_ops(so) +session = ort.InferenceSession('{path}', so) +rel_codes = np.random.rand({batch}, {n_boxes}, 4).astype(np.float32) +ort_res = session.run(output_names=None, input_feed={{'rel_codes': rel_codes}}) +assert ort_res[0].max() and ort_res[0].max() > ort_res[0].min() + """ + exec_in_clean_process(code, check=True) + + @staticmethod + def _generate_random_boxes(n_batches, n_boxes, seed=None): + if seed: + np.random.seed(seed) + boxes = np.empty((n_batches, n_boxes, 4)) + boxes[..., :2] = np.random.uniform(low=0, high=.9, size=(n_batches, n_boxes, 2)) + boxes[..., 2:] = np.random.uniform(low=boxes[..., :2], high=1., size=(n_batches, n_boxes, 2)) + return torch.as_tensor(boxes, dtype=torch.float32) + + @classmethod + def _generate_random_anchors(cls, n_anchors, img_size: Optional[tuple] = None, seed=None): + anchors = cls._generate_random_boxes(1, n_anchors, seed)[0] + if img_size: + anchors = anchors * torch.tensor(img_size + img_size, dtype=torch.float32) + return anchors + + @staticmethod + def _encode_offsets(boxes, anchors, scale_factors): + anchors_hw = anchors[..., 2:] - anchors[..., :2] + boxes_hw = boxes[..., 2:] - boxes[..., :2] + boxes_center = boxes[..., :2] + boxes_hw / 2 + anchors_center = anchors[..., :2] + anchors_hw / 2 + thw = torch.log(boxes_hw / anchors_hw) + tyx = (boxes_center - anchors_center) / anchors_hw + t = torch.concat([tyx, thw], dim=-1) + return t * torch.as_tensor(scale_factors) + + def _export_onnx(self, model, n_boxes, path, dynamic_batch: bool): + input_names = ['rel_codes'] + output_names = ['decoded'] + kwargs = {'dynamic_axes': {k: {0: 'batch'} for k in input_names + output_names}} if dynamic_batch else {} + torch.onnx.export(model, + args=(torch.ones((1, n_boxes, 4))), + f=path, + input_names=input_names, + output_names=output_names, + **kwargs) diff --git a/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py b/sony_custom_layers/pytorch/tests/test_multiclass_nms.py similarity index 88% rename from sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py rename to sony_custom_layers/pytorch/tests/test_multiclass_nms.py index 0f02979..dbab493 100644 --- a/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py +++ b/sony_custom_layers/pytorch/tests/test_multiclass_nms.py @@ -18,13 +18,13 @@ import pytest import numpy as np import torch -import onnx import onnxruntime as ort from sony_custom_layers.pytorch import multiclass_nms, multiclass_nms_with_indices, NMSResults, NMSWithIndicesResults from sony_custom_layers.pytorch import load_custom_ops -from sony_custom_layers.pytorch.object_detection.nms_common import LABELS, INDICES, SCORES -from sony_custom_layers.pytorch.tests.object_detection.test_nms_common import generate_random_inputs +from sony_custom_layers.pytorch.nms.nms_common import LABELS, INDICES, SCORES +from sony_custom_layers.pytorch.tests.test_nms_common import generate_random_inputs +from sony_custom_layers.pytorch.tests.util import load_and_validate_onnx_model, check_tensor from sony_custom_layers.util.test_util import exec_in_clean_process @@ -58,7 +58,7 @@ def _batch_multiclass_nms_mock(self, batch, n_dets, n_classes=20): @pytest.mark.parametrize('op, patch_pkg', [(torch.ops.sony.multiclass_nms, 'nms'), (torch.ops.sony.multiclass_nms_with_indices, 'nms_with_indices')]) def test_torch_op(self, mocker, op, patch_pkg): - mock = mocker.patch(f'sony_custom_layers.pytorch.object_detection.{patch_pkg}._batch_multiclass_nms', + mock = mocker.patch(f'sony_custom_layers.pytorch.nms.{patch_pkg}._batch_multiclass_nms', self._batch_multiclass_nms_mock(batch=3, n_dets=5)) boxes, scores = generate_random_inputs(batch=3, n_boxes=10, n_classes=5) ret = op(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5) @@ -89,7 +89,7 @@ def test_torch_op(self, mocker, op, patch_pkg): (multiclass_nms_with_indices, NMSWithIndicesResults, torch.ops.sony.multiclass_nms_with_indices, 'nms_with_indices')]) def test_torch_op_wrapper(self, mocker, op, res_cls, torch_op, patch_pkg): - mock = mocker.patch(f'sony_custom_layers.pytorch.object_detection.{patch_pkg}._batch_multiclass_nms', + mock = mocker.patch(f'sony_custom_layers.pytorch.nms.{patch_pkg}._batch_multiclass_nms', self._batch_multiclass_nms_mock(batch=3, n_dets=5)) boxes, scores = generate_random_inputs(batch=3, n_boxes=20, n_classes=10) ret = op(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5) @@ -114,6 +114,14 @@ def test_torch_op_wrapper(self, mocker, op, res_cls, torch_op, patch_pkg): assert torch.equal(ret.n_valid, ref_ret[4]) assert ret.n_valid.dtype == torch.int64 + @pytest.mark.parametrize('op', [multiclass_nms, multiclass_nms_with_indices]) + @pytest.mark.parametrize('cuda', [True, False]) + def test_full_op_sanity(self, op, cuda): + if cuda and not torch.cuda.is_available(): + pytest.skip('cuda is not available') + boxes, scores = generate_random_inputs(batch=3, n_boxes=20, n_classes=10) + op(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5) + @pytest.mark.parametrize('op', [multiclass_nms, multiclass_nms_with_indices]) def test_empty_tensors(self, op): # empty inputs @@ -137,10 +145,7 @@ def test_onnx_export(self, dynamic_batch, tmpdir_factory, with_indices): path = str(tmpdir_factory.mktemp('nms').join(f'nms{with_indices}.onnx')) self._export_onnx(onnx_model, n_boxes, n_classes, path, dynamic_batch=dynamic_batch, with_indices=with_indices) - onnx_model = onnx.load(path) - onnx.checker.check_model(onnx_model, full_check=True) - opset_info = list(onnx_model.opset_import)[1] - assert opset_info.domain == 'Sony' and opset_info.version == 1 + onnx_model = load_and_validate_onnx_model(path, exp_opset=1) nms_node = list(onnx_model.graph.node)[0] assert nms_node.domain == 'Sony' @@ -155,24 +160,17 @@ def test_onnx_export(self, dynamic_batch, tmpdir_factory, with_indices): assert len(nms_node.input) == 2 assert len(nms_node.output) == 4 + int(with_indices) - def check_tensor(onnx_tensor, exp_shape, exp_type): - tensor_type = onnx_tensor.type.tensor_type - shape = [d.dim_value if d.dim_value else d.dim_param for d in tensor_type.shape.dim] - exp_shape = ['batch' if dynamic_batch else 1] + exp_shape - assert shape == exp_shape - assert tensor_type.elem_type == exp_type - - check_tensor(onnx_model.graph.input[0], [10, 4], torch.onnx.TensorProtoDataType.FLOAT) - check_tensor(onnx_model.graph.input[1], [10, 5], torch.onnx.TensorProtoDataType.FLOAT) + check_tensor(onnx_model.graph.input[0], [10, 4], torch.onnx.TensorProtoDataType.FLOAT, dynamic_batch) + check_tensor(onnx_model.graph.input[1], [10, 5], torch.onnx.TensorProtoDataType.FLOAT, dynamic_batch) # test shape inference that is defined as part of onnx op - check_tensor(onnx_model.graph.output[0], [max_dets, 4], torch.onnx.TensorProtoDataType.FLOAT) - check_tensor(onnx_model.graph.output[1], [max_dets], torch.onnx.TensorProtoDataType.FLOAT) - check_tensor(onnx_model.graph.output[2], [max_dets], torch.onnx.TensorProtoDataType.INT32) + check_tensor(onnx_model.graph.output[0], [max_dets, 4], torch.onnx.TensorProtoDataType.FLOAT, dynamic_batch) + check_tensor(onnx_model.graph.output[1], [max_dets], torch.onnx.TensorProtoDataType.FLOAT, dynamic_batch) + check_tensor(onnx_model.graph.output[2], [max_dets], torch.onnx.TensorProtoDataType.INT32, dynamic_batch) if with_indices: - check_tensor(onnx_model.graph.output[3], [max_dets], torch.onnx.TensorProtoDataType.INT32) - check_tensor(onnx_model.graph.output[4], [1], torch.onnx.TensorProtoDataType.INT32) + check_tensor(onnx_model.graph.output[3], [max_dets], torch.onnx.TensorProtoDataType.INT32, dynamic_batch) + check_tensor(onnx_model.graph.output[4], [1], torch.onnx.TensorProtoDataType.INT32, dynamic_batch) else: - check_tensor(onnx_model.graph.output[3], [1], torch.onnx.TensorProtoDataType.INT32) + check_tensor(onnx_model.graph.output[3], [1], torch.onnx.TensorProtoDataType.INT32, dynamic_batch) @pytest.mark.parametrize('dynamic_batch', [True, False]) @pytest.mark.parametrize('with_indices', [True, False]) diff --git a/sony_custom_layers/pytorch/tests/object_detection/test_nms_common.py b/sony_custom_layers/pytorch/tests/test_nms_common.py similarity index 97% rename from sony_custom_layers/pytorch/tests/object_detection/test_nms_common.py rename to sony_custom_layers/pytorch/tests/test_nms_common.py index bf02da4..566a3df 100644 --- a/sony_custom_layers/pytorch/tests/object_detection/test_nms_common.py +++ b/sony_custom_layers/pytorch/tests/test_nms_common.py @@ -20,7 +20,7 @@ import torch from torch import Tensor -from sony_custom_layers.pytorch.object_detection import nms_common +from sony_custom_layers.pytorch.nms import nms_common def generate_random_inputs(batch: Optional[int], n_boxes, n_classes, seed=None): @@ -84,7 +84,7 @@ def test_image_multiclass_nms(self, mocker, max_detections, mock_tv_op): score_threshold = 0.11 iou_threshold = 0.61 if mock_tv_op: - nms_mock = mocker.patch('sony_custom_layers.pytorch.object_detection.nms_common._nms_with_class_offsets', + nms_mock = mocker.patch('sony_custom_layers.pytorch.nms.nms_common._nms_with_class_offsets', Mock(return_value=Tensor([4, 5, 1, 0, 2, 3]).to(torch.int64))) ret, ret_valid_dets = nms_common._image_multiclass_nms(boxes, scores, @@ -159,7 +159,7 @@ def test_batch_multiclass_nms(self, mocker): ret_valid_dets = Tensor([[5], [4], [3]]) # each time the function is called, next value in the list returned images_ret = [(img_nms_ret[i], ret_valid_dets[i]) for i in range(3)] - mock = mocker.patch('sony_custom_layers.pytorch.object_detection.nms_common._image_multiclass_nms', + mock = mocker.patch('sony_custom_layers.pytorch.nms.nms_common._image_multiclass_nms', Mock(side_effect=lambda *args, **kwargs: images_ret.pop(0))) res, n_valid = nms_common._batch_multiclass_nms(input_boxes, diff --git a/sony_custom_layers/pytorch/tests/util.py b/sony_custom_layers/pytorch/tests/util.py new file mode 100644 index 0000000..fc19a8b --- /dev/null +++ b/sony_custom_layers/pytorch/tests/util.py @@ -0,0 +1,32 @@ +# ----------------------------------------------------------------------------- +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------- +import onnx + + +def load_and_validate_onnx_model(path, exp_opset): + onnx_model = onnx.load(path) + onnx.checker.check_model(onnx_model, full_check=True) + opset_info = list(onnx_model.opset_import)[1] + assert opset_info.domain == 'Sony' and opset_info.version == exp_opset + return onnx_model + + +def check_tensor(onnx_tensor, exp_shape, exp_type, dynamic_batch: bool): + tensor_type = onnx_tensor.type.tensor_type + shape = [d.dim_value if d.dim_value else d.dim_param for d in tensor_type.shape.dim] + exp_shape = ['batch' if dynamic_batch else 1] + exp_shape + assert shape == exp_shape + assert tensor_type.elem_type == exp_type