Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Add `half` flag to ONNX export task to export with float16 precision.

### Deprecated

### Removed
Expand Down
179 changes: 152 additions & 27 deletions src/lightly_train/_commands/export_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import contextlib
import contextvars
import logging
from collections.abc import Iterator
from typing import Any, Literal
from abc import abstractmethod
from collections.abc import Container, Iterator, Mapping, Sequence
from enum import Enum
from typing import Any, Literal, Protocol, Tuple, runtime_checkable

import torch
from torch import distributed
Expand All @@ -24,7 +26,6 @@

logger = logging.getLogger(__name__)


_PRECALCULATE_FOR_ONNX_EXPORT = contextvars.ContextVar(
"PRECALCULATE_FOR_ONNX_EXPORT", default=False
)
Expand Down Expand Up @@ -63,6 +64,7 @@ def export_onnx(
num_channels: int = 3,
height: int = 224,
width: int = 224,
half: bool = False,
verify: bool = True,
overwrite: bool = False,
format_args: dict[str, Any] | None = None,
Expand All @@ -79,6 +81,7 @@ def _export_task(
num_channels: int = 3,
height: int = 224,
width: int = 224,
half: bool = False,
verify: bool = True,
overwrite: bool = False,
format_args: dict[str, Any] | None = None,
Expand All @@ -100,6 +103,8 @@ def _export_task(
Height of the input tensor.
width:
Width of the input tensor.
half:
Export the model with float16 precision.
verify:
Check the exported model for errors.
overwrite:
Expand All @@ -111,6 +116,87 @@ def _export_task(
_export_task_from_config(config=config)


class ONNXPrecision(Enum):
FP16 = "float16"
FP32 = "float32"


@runtime_checkable
class ONNXExportable(Protocol):
"""
A protocol to specify that a model can be exported to ONNX.

Some default implementations are provided for most methods these can only be used if one inherits from the Protocol.
Otherwise, one needs to make a call to `Protocol.somemethod(self, ...)` when implementing the protocol.
"""

def onnx_opset_versions(self) -> Tuple[int, int | None]:
"""
The range of ONNX opset versions supported by the model.

Return a tuple where the first element is the lower bound and the second element is the upper bound (inclusive.
The upper bound can also be None to indicate that there is no upper bound.
"""
del self
return (7, None)

def onnx_precisions(self) -> Container[ONNXPrecision]:
"""
The precisions that the ONNX model can be exported with.
"""
del self
return {ONNXPrecision.FP16, ONNXPrecision.FP32}

def verify_torch_onnx_export_kwargs(self, **kwargs: dict[str, Any]) -> None:
"""
Verify additional arguments passed to torch.onnx.export. Should raise an exception if some argument
is not supported.
"""
del self
del kwargs
return

def setup_onnx_model(
self, *, checkpoint: PathLike, precision: ONNXPrecision
) -> torch.nn.Module:
"""
Set up the exact torch model that should be exported with torch.onnx.export.
"""
del self
model = task_model_helpers.load_model_from_checkpoint(checkpoint=checkpoint)
if precision == ONNXPrecision.FP16:
model = model.half()
return model

def setup_validation_model(self, *, checkpoint: PathLike) -> torch.nn.Module:
"""
Set up the exact torch model that is used as a reference model to verify the export onnx model.
"""
del self
model = task_model_helpers.load_model_from_checkpoint(
checkpoint=checkpoint, device="cpu"
)
return model

@abstractmethod
def make_onnx_export_inputs(
self, *, precision: ONNXPrecision, device: torch.device, **kwargs
) -> Mapping[str, torch.Tensor]:
"""
Create the dummy input tensors that are used during the ONNX export.

Should return a mapping from input names to tensors.
"""
...

@abstractmethod
def onnx_output_names(self) -> Sequence[str]:
"""
Return the names of the ONNX output tensors.
"""
...


def _export_task_from_config(config: ExportTaskConfig) -> None:
# Only export on rank 0.
if distributed.is_initialized() and distributed.get_rank() > 0:
Expand All @@ -129,34 +215,55 @@ def _export_task_from_config(config: ExportTaskConfig) -> None:
task_model = task_model_helpers.load_model_from_checkpoint(
checkpoint=checkpoint_path
)
task_model.eval()

# Export the model to ONNX format
# TODO(Yutong, 07/25): support more formats (may use ONNX as the intermediate format)
if config.format == "onnx":
# Get the device of the model to ensure dummy input is on the same device
model_device = next(task_model.parameters()).device
dummy_input = torch.randn(
config.batch_size,
config.num_channels,
config.height,
config.width,
requires_grad=False,
device=model_device,
if not isinstance(task_model, ONNXExportable):
raise ValueError(
f"Model of class {task_model.__class__.__name__} cannot be exported to ONNX."
)
opset_version = config.format_args.get("opset_version", 18)
opset_lower, opset_upper = task_model.onnx_opset_versions()
# torch.onnx.export requires at least opset version 7
if opset_version < max(opset_lower, 7):
raise ValueError(f"Opset must be a at least {opset_lower}.")
if opset_upper is not None and opset_version > opset_upper:
raise ValueError(f"Opset can be at most {opset_upper}.")
precision = ONNXPrecision.FP16 if config.half else ONNXPrecision.FP32
if precision not in task_model.onnx_precisions():
raise ValueError(f"Precision {precision.value} is not supported.")

export_model = task_model.setup_onnx_model(
checkpoint=checkpoint_path, precision=precision
)
input_name = "input"
output_names = ["masks", "logits"]
export_model_device = next(export_model.parameters()).device

dummy_inputs = task_model.make_onnx_export_inputs(
precision=precision,
device=export_model_device,
batch_size=config.batch_size,
num_channels=config.num_channels,
height=config.height,
width=config.width,
)

input_names = list(dummy_inputs.keys())
dummy_inputs = tuple(dummy_inputs.values())
output_names = task_model.onnx_output_names()

with precalculate_for_onnx_export():
task_model(dummy_input)
export_model(*dummy_inputs)
logger.info(f"Exporting ONNX model to '{out_path}'")
torch.onnx.export(
task_model,
(dummy_input,),
export_model,
dummy_inputs,
out_path,
input_names=[input_name],
input_names=input_names,
output_names=output_names,
**config.format_args if config.format_args else {},
)
del export_model

if config.verify:
logger.info("Verifying ONNX model")
Expand All @@ -165,20 +272,37 @@ def _export_task_from_config(config: ExportTaskConfig) -> None:

onnx.checker.check_model(out_path, full_check=True)

x = torch.rand_like(dummy_input)
onnx_inputs = task_model.make_onnx_export_inputs(
precision=precision,
device=torch.device("cpu"),
batch_size=config.batch_size,
num_channels=config.num_channels,
height=config.height,
width=config.width,
)
# Always run the validation input in float32 and on cpu for consistency
validation_inputs = [v.to(torch.float32) for v in onnx_inputs.values()]
onnx_inputs = {k: v.numpy() for (k, v) in onnx_inputs.items()}

validation_model = task_model.setup_validation_model(
checkpoint=checkpoint_path
)
del task_model

session = ort.InferenceSession(out_path)
input_feed = {input_name: x.cpu().numpy()}
outputs_onnx = session.run(output_names=output_names, input_feed=input_feed)
outputs_onnx = tuple(torch.from_numpy(y) for y in outputs_onnx)
onnx_outputs = session.run(
output_names=output_names, input_feed=onnx_inputs
)
onnx_outputs = tuple(torch.from_numpy(y) for y in onnx_outputs)

outputs_model = task_model(x)
validation_outputs = validation_model(*validation_inputs)

if len(outputs_onnx) != len(outputs_model):
if len(onnx_outputs) != len(validation_outputs):
raise AssertionError(
f"Number of onnx outputs should be {len(outputs_model)} but is {len(outputs_onnx)}"
f"Number of onnx outputs should be {len(validation_outputs)} but is {len(onnx_outputs)}"
)
for output_onnx, output_model, output_name in zip(
outputs_onnx, outputs_model, output_names
onnx_outputs, validation_outputs, output_names
):
# Absolute and relative tolerances are a bit arbitrary and taken from here:
# https://github.com/pytorch/pytorch/blob/main/torch/onnx/_internal/exporter/_core.py#L1611-L1618
Expand Down Expand Up @@ -210,6 +334,7 @@ class ExportTaskConfig(PydanticConfig):
num_channels: int = 3
height: int = 224
width: int = 224
half: bool = False
verify: bool = True
overwrite: bool = False
format_args: dict[str, Any] | None = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
import logging
import math
import os
from typing import Any
from collections.abc import Container, Sequence
from typing import Any, Tuple

import torch
from PIL.Image import Image as PILImage
from torch import Tensor
from torch.nn import GELU, Embedding, Linear, Sequential
from torch.nn import functional as F
from torchvision.transforms.v2 import functional as transforms_functional
from typing_extensions import override

from lightly_train._commands.export_task import ONNXExportable, ONNXPrecision
from lightly_train._data import file_helpers
from lightly_train._models import package_helpers
from lightly_train._models.dinov2_vit.dinov2_vit_package import DINOV2_VIT_PACKAGE
Expand Down Expand Up @@ -415,13 +418,23 @@ def untile(
) -> list[Tensor]:
logit_sums, logit_counts = [], []

dtype = crop_logits.dtype

# Initialize the tensors containing the final predictions.
for size in image_sizes:
logit_sums.append(
torch.zeros((crop_logits.shape[1], *size), device=crop_logits.device)
torch.zeros(
(crop_logits.shape[1], *size),
device=crop_logits.device,
dtype=dtype,
)
)
logit_counts.append(
torch.zeros((crop_logits.shape[1], *size), device=crop_logits.device)
torch.zeros(
(crop_logits.shape[1], *size),
device=crop_logits.device,
dtype=dtype,
)
)

for crop_index, (image_index, start, end, is_tall) in enumerate(origins):
Expand Down Expand Up @@ -577,3 +590,56 @@ def load_train_state_dict(self, state_dict: dict[str, Any]) -> None:
name = name[len("model.") :]
new_state_dict[name] = param
self.load_state_dict(new_state_dict, strict=True)

@override
def onnx_opset_versions(self) -> Tuple[int, int | None]:
# TODO verify if 12 is really the lower bound here
return (12, None)

@override
def onnx_precisions(self) -> Container[ONNXPrecision]:
return ONNXExportable.onnx_precisions(self)

@override
def verify_torch_onnx_export_kwargs(self, **kwargs: dict[str, Any]) -> None:
if kwargs.get("dynamo", True):
raise ValueError(
f"Dynamo is not supported for ONNX export with{self.__class__.__name__} model."
)
return

@override
def setup_onnx_model(
self, *, checkpoint: PathLike, precision: ONNXPrecision
) -> torch.nn.Module:
return ONNXExportable.setup_onnx_model(
self, checkpoint=checkpoint, precision=precision
)

@override
def setup_validation_model(self, *, checkpoint: PathLike) -> torch.nn.Module:
return ONNXExportable.setup_validation_model(self, checkpoint=checkpoint)

@override
def make_onnx_export_inputs(
self,
*,
precision: ONNXPrecision,
device: torch.device,
batch_size: int,
num_channels: int,
height: int,
width: int,
) -> dict[str, torch.Tensor]:
del self
dtype = torch.float16 if precision == ONNXPrecision.FP16 else torch.float32
return {
"input": torch.rand(
batch_size, num_channels, height, width, dtype=dtype, device=device
)
}

@override
def onnx_output_names(self) -> Sequence[str]:
del self
return ["masks", "logits"]
Loading
Loading