From b729729b44c68b0d940ae26956626a3270444293 Mon Sep 17 00:00:00 2001 From: Simon Schoelly Date: Mon, 1 Sep 2025 08:14:48 +0200 Subject: [PATCH 1/6] Add option to export onnx with float16 precision --- src/lightly_train/_commands/export_task.py | 22 +++++++++++++-- .../task_model.py | 14 ++++++++-- tests/_commands/test_export_task.py | 28 ++++++++++++------- 3 files changed, 49 insertions(+), 15 deletions(-) diff --git a/src/lightly_train/_commands/export_task.py b/src/lightly_train/_commands/export_task.py index 1ff6cc169..6e2985043 100644 --- a/src/lightly_train/_commands/export_task.py +++ b/src/lightly_train/_commands/export_task.py @@ -63,6 +63,7 @@ def export_onnx( num_channels: int = 3, height: int = 224, width: int = 224, + half: bool = False, verify: bool = True, overwrite: bool = False, format_args: dict[str, Any] | None = None, @@ -79,6 +80,7 @@ def _export_task( num_channels: int = 3, height: int = 224, width: int = 224, + half: bool = False, verify: bool = True, overwrite: bool = False, format_args: dict[str, Any] | None = None, @@ -100,6 +102,8 @@ def _export_task( Height of the input tensor. width: Width of the input tensor. + half: + Export the model with float16 precision. verify: Check the exported model for errors. overwrite: @@ -130,12 +134,16 @@ def _export_task_from_config(config: ExportTaskConfig) -> None: checkpoint=checkpoint_path ) task_model.eval() + if config.half: + task_model.half() # Export the model to ONNX format # TODO(Yutong, 07/25): support more formats (may use ONNX as the intermediate format) if config.format == "onnx": # Get the device of the model to ensure dummy input is on the same device model_device = next(task_model.parameters()).device + onnx_dtype = torch.float16 if config.half else torch.float32 + dummy_input = torch.randn( config.batch_size, config.num_channels, @@ -143,6 +151,7 @@ def _export_task_from_config(config: ExportTaskConfig) -> None: config.width, requires_grad=False, device=model_device, + dtype=onnx_dtype, ) input_name = "input" output_names = ["masks", "logits"] @@ -165,13 +174,19 @@ def _export_task_from_config(config: ExportTaskConfig) -> None: onnx.checker.check_model(out_path, full_check=True) - x = torch.rand_like(dummy_input) + # Always run the reference input in float32 and on cpu for consistency + x_model = torch.rand_like(dummy_input, dtype=torch.float32, device="cpu") + x_onnx = x_model.half() if config.half else x_model + session = ort.InferenceSession(out_path) - input_feed = {input_name: x.cpu().numpy()} + input_feed = {input_name: x_onnx.numpy()} outputs_onnx = session.run(output_names=output_names, input_feed=input_feed) outputs_onnx = tuple(torch.from_numpy(y) for y in outputs_onnx) - outputs_model = task_model(x) + task_model = task_model_helpers.load_model_from_checkpoint( + checkpoint=checkpoint_path, device="cpu" + ) + outputs_model = task_model(x_model) if len(outputs_onnx) != len(outputs_model): raise AssertionError( @@ -210,6 +225,7 @@ class ExportTaskConfig(PydanticConfig): num_channels: int = 3 height: int = 224 width: int = 224 + half: bool = False verify: bool = True overwrite: bool = False format_args: dict[str, Any] | None = ( diff --git a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py index 9665bd2f6..b92ccadfa 100644 --- a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py +++ b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py @@ -415,13 +415,23 @@ def untile( ) -> list[Tensor]: logit_sums, logit_counts = [], [] + dtype = crop_logits.dtype + # Initialize the tensors containing the final predictions. for size in image_sizes: logit_sums.append( - torch.zeros((crop_logits.shape[1], *size), device=crop_logits.device) + torch.zeros( + (crop_logits.shape[1], *size), + device=crop_logits.device, + dtype=dtype, + ) ) logit_counts.append( - torch.zeros((crop_logits.shape[1], *size), device=crop_logits.device) + torch.zeros( + (crop_logits.shape[1], *size), + device=crop_logits.device, + dtype=dtype, + ) ) for crop_index, (image_index, start, end, is_tall) in enumerate(origins): diff --git a/tests/_commands/test_export_task.py b/tests/_commands/test_export_task.py index 223108c2a..0750b6e1c 100644 --- a/tests/_commands/test_export_task.py +++ b/tests/_commands/test_export_task.py @@ -64,15 +64,15 @@ def dinov2_vits14_eomt_checkpoint(tmp_path_factory: pytest.TempPathFactory) -> P onnx_export_testset = [ - (1, 42, 154), - (1, 42, 154), - (2, 14, 14), - (3, 140, 280), - (4, 266, 28), + (1, 42, 154, False), + (1, 42, 154, False), + (2, 14, 14, False), + (3, 140, 280, True), + (4, 266, 28, True), ] -@pytest.mark.parametrize("batch_size,height,width", onnx_export_testset) +@pytest.mark.parametrize("batch_size,height,width, half", onnx_export_testset) @pytest.mark.skipif( sys.version_info < (3, 9), reason="Requires Python 3.9 or higher for image preprocessing.", @@ -85,6 +85,7 @@ def test_onnx_export( batch_size: int, height: int, width: int, + half: bool, dinov2_vits14_eomt_checkpoint: Path, tmp_path: Path, ) -> None: @@ -96,12 +97,13 @@ def test_onnx_export( dinov2_vits14_eomt_checkpoint, device="cpu" ) onnx_path = tmp_path / "model.onnx" - validation_input = torch.randn(batch_size, 3, height, width).cpu() + validation_input = torch.randn(batch_size, 3, height, width, device="cpu") expected_outputs = model(validation_input) + expected_output_dtypes = [torch.int64, torch.float16 if half else torch.float32] # We use torch.testing.assert_close to check if the model outputs the same as when we run the exported # onnx file with onnxruntime. Unfortunately the default tolerances are too strict so we specify our own. - rtol = 1e-3 - atol = 1e-5 + rtol = 1e-2 + atol = 1e-4 # act lightly_train.export_onnx( @@ -109,6 +111,7 @@ def test_onnx_export( checkpoint=dinov2_vits14_eomt_checkpoint, height=height, width=width, + half=half, batch_size=batch_size, overwrite=True, ) @@ -118,10 +121,15 @@ def test_onnx_export( onnx.checker.check_model(onnx_path, full_check=True) session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) + if half: + validation_input = validation_input.half() ort_in = {"input": validation_input.numpy()} ort_outputs = session.run(["masks", "logits"], ort_in) ort_outputs = [torch.from_numpy(y).cpu() for y in ort_outputs] + assert [y.dtype for y in ort_outputs] == expected_output_dtypes assert len(ort_outputs) == len(expected_outputs) for ort_y, expected_y in zip(ort_outputs, expected_outputs): - torch.testing.assert_close(ort_y, expected_y, rtol=rtol, atol=atol) + torch.testing.assert_close( + ort_y, expected_y, check_dtype=False, rtol=rtol, atol=atol + ) From 1525d7d10d7eb1485e2e206fc15afeaa045aca2f Mon Sep 17 00:00:00 2001 From: Simon Schoelly Date: Mon, 1 Sep 2025 08:18:23 +0200 Subject: [PATCH 2/6] Document onnx half flag in CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 55a93b036..c490dd29f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Add `half` flat to ONNX export task to export with float16 precision. + ### Deprecated ### Removed From b83e945b6ec2ffd0f4cae22b6b91a7a3d162fe2a Mon Sep 17 00:00:00 2001 From: Simon Schoelly Date: Mon, 1 Sep 2025 08:29:53 +0200 Subject: [PATCH 3/6] Typo --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c490dd29f..f3f85d034 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Add `half` flat to ONNX export task to export with float16 precision. +- Add `half` flag to ONNX export task to export with float16 precision. ### Deprecated From a3085621cae8d96e0ebad126c75e9d1460d6a778 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Sch=C3=B6lly?= Date: Mon, 1 Sep 2025 09:47:21 +0200 Subject: [PATCH 4/6] Update tests/_commands/test_export_task.py Co-authored-by: Lionel Peer --- tests/_commands/test_export_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/_commands/test_export_task.py b/tests/_commands/test_export_task.py index 0750b6e1c..3e0919750 100644 --- a/tests/_commands/test_export_task.py +++ b/tests/_commands/test_export_task.py @@ -72,7 +72,7 @@ def dinov2_vits14_eomt_checkpoint(tmp_path_factory: pytest.TempPathFactory) -> P ] -@pytest.mark.parametrize("batch_size,height,width, half", onnx_export_testset) +@pytest.mark.parametrize("batch_size,height,width,half", onnx_export_testset) @pytest.mark.skipif( sys.version_info < (3, 9), reason="Requires Python 3.9 or higher for image preprocessing.", From 025294f4f50c54a47e1871999995230b855e6ac8 Mon Sep 17 00:00:00 2001 From: Simon Schoelly Date: Mon, 1 Sep 2025 17:58:18 +0200 Subject: [PATCH 5/6] Add a protocl to specify that a model can be exported to ONNX --- src/lightly_train/_commands/export_task.py | 183 ++++++++++++++---- .../task_model.py | 58 +++++- 2 files changed, 203 insertions(+), 38 deletions(-) diff --git a/src/lightly_train/_commands/export_task.py b/src/lightly_train/_commands/export_task.py index 6e2985043..5943633b7 100644 --- a/src/lightly_train/_commands/export_task.py +++ b/src/lightly_train/_commands/export_task.py @@ -10,8 +10,10 @@ import contextlib import contextvars import logging -from collections.abc import Iterator -from typing import Any, Literal +from abc import abstractmethod +from collections.abc import Iterator, Container, Mapping, Sequence +from enum import Enum +from typing import Any, Literal, Protocol, Tuple, runtime_checkable import torch from torch import distributed @@ -24,7 +26,6 @@ logger = logging.getLogger(__name__) - _PRECALCULATE_FOR_ONNX_EXPORT = contextvars.ContextVar( "PRECALCULATE_FOR_ONNX_EXPORT", default=False ) @@ -115,6 +116,87 @@ def _export_task( _export_task_from_config(config=config) +class ONNXPrecision(Enum): + FP16 = "float16" + FP32 = "float32" + + +@runtime_checkable +class ONNXExportable(Protocol): + """ + A protocol to specify that a model can be exported to ONNX. + + Some default implementations are provided for most methods these can only be used if one inherits from the Protocol. + Otherwise, one needs to make a call to `Protocol.somemethod(self, ...)` when implementing the protocol. + """ + + def onnx_opset_versions(self) -> Tuple[int, int | None]: + """ + The range of ONNX opset versions supported by the model. + + Return a tuple where the first element is the lower bound and the second element is the upper bound (inclusive. + The upper bound can also be None to indicate that there is no upper bound. + """ + del self + return (7, None) + + def onnx_precisions(self) -> Container[ONNXPrecision]: + """ + The precisions that the ONNX model can be exported with. + """ + del self + return {ONNXPrecision.FP16, ONNXPrecision.FP32} + + def verify_torch_onnx_export_kwargs(self, **kwargs: dict[str, Any]) -> None: + """ + Verify additional arguments passed to torch.onnx.export. Should raise an exception if some argument + is not supported. + """ + del self + del kwargs + return + + def setup_onnx_model( + self, *, checkpoint: PathLike, precision: ONNXPrecision + ) -> torch.nn.Module: + """ + Set up the exact torch model that should be exported with torch.onnx.export. + """ + del self + model = task_model_helpers.load_model_from_checkpoint(checkpoint=checkpoint) + if precision == ONNXPrecision.FP16: + model = model.half() + return model + + def setup_validation_model(self, *, checkpoint: PathLike) -> torch.nn.Module: + """ + Set up the exact torch model that is used as a reference model to verify the export onnx model. + """ + del self + model = task_model_helpers.load_model_from_checkpoint( + checkpoint=checkpoint, device="cpu" + ) + return model + + @abstractmethod + def make_onnx_export_inputs( + self, *, precision: ONNXPrecision, device: torch.device, **kwargs + ) -> Mapping[str, torch.Tensor]: + """ + Create the dummy input tensors that are used during the ONNX export. + + Should return a mapping from input names to tensors. + """ + ... + + @abstractmethod + def onnx_output_names(self) -> Sequence[str]: + """ + Return the names of the ONNX output tensors. + """ + ... + + def _export_task_from_config(config: ExportTaskConfig) -> None: # Only export on rank 0. if distributed.is_initialized() and distributed.get_rank() > 0: @@ -133,39 +215,55 @@ def _export_task_from_config(config: ExportTaskConfig) -> None: task_model = task_model_helpers.load_model_from_checkpoint( checkpoint=checkpoint_path ) - task_model.eval() - if config.half: - task_model.half() # Export the model to ONNX format # TODO(Yutong, 07/25): support more formats (may use ONNX as the intermediate format) if config.format == "onnx": - # Get the device of the model to ensure dummy input is on the same device - model_device = next(task_model.parameters()).device - onnx_dtype = torch.float16 if config.half else torch.float32 - - dummy_input = torch.randn( - config.batch_size, - config.num_channels, - config.height, - config.width, - requires_grad=False, - device=model_device, - dtype=onnx_dtype, + if not isinstance(task_model, ONNXExportable): + raise ValueError( + f"Model of class {task_model.__class__.__name__} cannot be exported to ONNX." + ) + opset_version = config.format_args.get("opset_version", 18) + opset_lower, opset_upper = task_model.onnx_opset_versions() + # torch.onnx.export requires at least opset version 7 + if opset_version < max(opset_lower, 7): + raise ValueError(f"Opset must be a at least {opset_lower}.") + if opset_upper is not None and opset_version > opset_upper: + raise ValueError(f"Opset can be at most {opset_upper}.") + precision = ONNXPrecision.FP16 if config.half else ONNXPrecision.FP32 + if precision not in task_model.onnx_precisions(): + raise ValueError(f"Precision {precision.value} is not supported.") + + export_model = task_model.setup_onnx_model( + checkpoint=checkpoint_path, precision=precision + ) + export_model_device = next(export_model.parameters()).device + + dummy_inputs = task_model.make_onnx_export_inputs( + precision=precision, + device=export_model_device, + batch_size=config.batch_size, + num_channels=config.num_channels, + height=config.height, + width=config.width, ) - input_name = "input" - output_names = ["masks", "logits"] + + input_names = list(dummy_inputs.keys()) + dummy_inputs = tuple(dummy_inputs.values()) + output_names = task_model.onnx_output_names() + with precalculate_for_onnx_export(): - task_model(dummy_input) + export_model(*dummy_inputs) logger.info(f"Exporting ONNX model to '{out_path}'") torch.onnx.export( - task_model, - (dummy_input,), + export_model, + dummy_inputs, out_path, - input_names=[input_name], + input_names=input_names, output_names=output_names, **config.format_args if config.format_args else {}, ) + del export_model if config.verify: logger.info("Verifying ONNX model") @@ -174,26 +272,37 @@ def _export_task_from_config(config: ExportTaskConfig) -> None: onnx.checker.check_model(out_path, full_check=True) - # Always run the reference input in float32 and on cpu for consistency - x_model = torch.rand_like(dummy_input, dtype=torch.float32, device="cpu") - x_onnx = x_model.half() if config.half else x_model + onnx_inputs = task_model.make_onnx_export_inputs( + precision=precision, + device=torch.device("cpu"), + batch_size=config.batch_size, + num_channels=config.num_channels, + height=config.height, + width=config.width, + ) + # Always run the validation input in float32 and on cpu for consistency + validation_inputs = [v.to(torch.float32) for v in onnx_inputs.values()] + onnx_inputs = {k: v.numpy() for (k, v) in onnx_inputs.items()} - session = ort.InferenceSession(out_path) - input_feed = {input_name: x_onnx.numpy()} - outputs_onnx = session.run(output_names=output_names, input_feed=input_feed) - outputs_onnx = tuple(torch.from_numpy(y) for y in outputs_onnx) + validation_model = task_model.setup_validation_model( + checkpoint=checkpoint_path + ) + del task_model - task_model = task_model_helpers.load_model_from_checkpoint( - checkpoint=checkpoint_path, device="cpu" + session = ort.InferenceSession(out_path) + onnx_outputs = session.run( + output_names=output_names, input_feed=onnx_inputs ) - outputs_model = task_model(x_model) + onnx_outputs = tuple(torch.from_numpy(y) for y in onnx_outputs) + + validation_outputs = validation_model(*validation_inputs) - if len(outputs_onnx) != len(outputs_model): + if len(onnx_outputs) != len(validation_outputs): raise AssertionError( - f"Number of onnx outputs should be {len(outputs_model)} but is {len(outputs_onnx)}" + f"Number of onnx outputs should be {len(validation_outputs)} but is {len(onnx_outputs)}" ) for output_onnx, output_model, output_name in zip( - outputs_onnx, outputs_model, output_names + onnx_outputs, validation_outputs, output_names ): # Absolute and relative tolerances are a bit arbitrary and taken from here: # https://github.com/pytorch/pytorch/blob/main/torch/onnx/_internal/exporter/_core.py#L1611-L1618 diff --git a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py index b92ccadfa..0615f9084 100644 --- a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py +++ b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py @@ -10,7 +10,8 @@ import logging import math import os -from typing import Any +from collections.abc import Sequence, Mapping, Container +from typing import Any, Tuple import torch from PIL.Image import Image as PILImage @@ -18,7 +19,9 @@ from torch.nn import GELU, Embedding, Linear, Sequential from torch.nn import functional as F from torchvision.transforms.v2 import functional as transforms_functional +from typing_extensions import override +from lightly_train._commands.export_task import ONNXPrecision, ONNXExportable from lightly_train._data import file_helpers from lightly_train._models import package_helpers from lightly_train._models.dinov2_vit.dinov2_vit_package import DINOV2_VIT_PACKAGE @@ -587,3 +590,56 @@ def load_train_state_dict(self, state_dict: dict[str, Any]) -> None: name = name[len("model.") :] new_state_dict[name] = param self.load_state_dict(new_state_dict, strict=True) + + @override + def onnx_opset_versions(self) -> Tuple[int, int | None]: + # TODO verify if 12 is really the lower bound here + return (12, None) + + @override + def onnx_precisions(self) -> Container[ONNXPrecision]: + return ONNXExportable.onnx_precisions(self) + + @override + def verify_torch_onnx_export_kwargs(self, **kwargs: dict[str, Any]) -> None: + if kwargs.get("dynamo", True): + raise ValueError( + f"Dynamo is not supported for ONNX export with{self.__class__.__name__} model." + ) + return + + @override + def setup_onnx_model( + self, *, checkpoint: PathLike, precision: ONNXPrecision + ) -> torch.nn.Module: + return ONNXExportable.setup_onnx_model( + self, checkpoint=checkpoint, precision=precision + ) + + @override + def setup_validation_model(self, *, checkpoint: PathLike) -> torch.nn.Module: + return ONNXExportable.setup_validation_model(self, checkpoint=checkpoint) + + @override + def make_onnx_export_inputs( + self, + *, + precision: ONNXPrecision, + device: torch.device, + batch_size: int, + num_channels: int, + height: int, + width: int, + ) -> dict[str, torch.Tensor]: + del self + dtype = torch.float16 if precision == ONNXPrecision.FP16 else torch.float32 + return { + "input": torch.rand( + batch_size, num_channels, height, width, dtype=dtype, device=device + ) + } + + @override + def onnx_output_names(self) -> Sequence[str]: + del self + return ["masks", "logits"] From 4fdd515107beb10c3a5bc584473375aa8fc417e6 Mon Sep 17 00:00:00 2001 From: Simon Schoelly Date: Mon, 1 Sep 2025 18:05:53 +0200 Subject: [PATCH 6/6] Formatting --- src/lightly_train/_commands/export_task.py | 2 +- .../dinov2_eomt_semantic_segmentation/task_model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightly_train/_commands/export_task.py b/src/lightly_train/_commands/export_task.py index 5943633b7..2e20286d9 100644 --- a/src/lightly_train/_commands/export_task.py +++ b/src/lightly_train/_commands/export_task.py @@ -11,7 +11,7 @@ import contextvars import logging from abc import abstractmethod -from collections.abc import Iterator, Container, Mapping, Sequence +from collections.abc import Container, Iterator, Mapping, Sequence from enum import Enum from typing import Any, Literal, Protocol, Tuple, runtime_checkable diff --git a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py index 0615f9084..5c7630391 100644 --- a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py +++ b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py @@ -10,7 +10,7 @@ import logging import math import os -from collections.abc import Sequence, Mapping, Container +from collections.abc import Container, Sequence from typing import Any, Tuple import torch @@ -21,7 +21,7 @@ from torchvision.transforms.v2 import functional as transforms_functional from typing_extensions import override -from lightly_train._commands.export_task import ONNXPrecision, ONNXExportable +from lightly_train._commands.export_task import ONNXExportable, ONNXPrecision from lightly_train._data import file_helpers from lightly_train._models import package_helpers from lightly_train._models.dinov2_vit.dinov2_vit_package import DINOV2_VIT_PACKAGE