diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c6390db514..2681b31d466 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: - id: pyupgrade args: [--py36-plus] - repo: https://github.com/PyCQA/docformatter - rev: v1.5.1 + rev: v1.5.0 hooks: - id: docformatter args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120] diff --git a/ludwig/encoders/image/__init__.py b/ludwig/encoders/image/__init__.py index c5e597a3f66..3e4c5e5d730 100644 --- a/ludwig/encoders/image/__init__.py +++ b/ludwig/encoders/image/__init__.py @@ -1,2 +1,3 @@ import ludwig.encoders.image.base import ludwig.encoders.image.torchvision # noqa +import ludwig.encoders.image.metaformer # noqa diff --git a/ludwig/encoders/image/base.py b/ludwig/encoders/image/base.py index 1e3bc5ca580..66667a8287f 100644 --- a/ludwig/encoders/image/base.py +++ b/ludwig/encoders/image/base.py @@ -54,6 +54,8 @@ def __init__( conv_layers: Optional[List[Dict]] = None, num_conv_layers: Optional[int] = None, num_channels: int = None, + # Optional custom CAFormer backbone (e.g. caformer_s18, caformer_s36, caformer_m36, caformer_b36) + custom_model: Optional[str] = None, out_channels: int = 32, kernel_size: Union[int, Tuple[int]] = 3, stride: Union[int, Tuple[int]] = 1, @@ -87,6 +89,35 @@ def __init__( super().__init__() self.config = encoder_config + self._use_caformer = False + self._output_shape_override: Optional[torch.Size] = None + if custom_model and isinstance(custom_model, str) and custom_model.startswith("caformer_"): + try: + from caformer_setup_backup.caformer_stacked_cnn import CAFormerStackedCNN + # Instantiate CAFormer encoder (it internally handles resizing / channel adapting) + self.caformer_encoder = CAFormerStackedCNN( + height=height if height is not None else 224, + width=width if width is not None else 224, + num_channels=num_channels if num_channels is not None else 3, + output_size=output_size, + custom_model=custom_model, + use_pretrained=True, + trainable=True, + ) + self._use_caformer = True + # Override forward dynamically + self.forward = self._forward_caformer # type: ignore + # Store output shape + if hasattr(self.caformer_encoder, "output_shape"): + # CAFormerStackedCNN.output_shape returns a list + shape_list = self.caformer_encoder.output_shape + if isinstance(shape_list, (list, tuple)): + self._output_shape_override = torch.Size(shape_list) + logger.info(f"Using CAFormer backbone '{custom_model}' in place of stacked_cnn.") + except Exception as e: + logger.error(f"Failed to initialize CAFormer encoder '{custom_model}': {e}") + raise + logger.debug(f" {self.name}") # map parameter input feature config names to internal names @@ -144,6 +175,16 @@ def __init__( default_dropout=fc_dropout, ) + def _forward_caformer(self, inputs: torch.Tensor) -> EncoderOutputDict: + """ + Forward pass when a CAFormer backbone is used. + CAFormerStackedCNN.forward returns a dict with key 'encoder_output'. + """ + if not hasattr(self, "caformer_encoder"): + raise RuntimeError("CAFormer encoder not initialized despite _use_caformer=True.") + out = self.caformer_encoder(inputs) + return {ENCODER_OUTPUT: out.get("encoder_output")} + def forward(self, inputs: torch.Tensor) -> EncoderOutputDict: """ :param inputs: The inputs fed into the encoder. @@ -162,6 +203,8 @@ def get_schema_cls() -> Type[ImageEncoderConfig]: @property def output_shape(self) -> torch.Size: + if self._output_shape_override is not None: + return self._output_shape_override return self.fc_stack.output_shape @property diff --git a/ludwig/encoders/image/metaformer.py b/ludwig/encoders/image/metaformer.py new file mode 100644 index 00000000000..2d823d60232 --- /dev/null +++ b/ludwig/encoders/image/metaformer.py @@ -0,0 +1,143 @@ +#! /usr/bin/env python +# Copyright (c) 2025 +# +# New MetaFormer / CAFormer style image encoder for Ludwig. +# +# This integrates ConvFormer / CAFormer family backbones as a first-class +# Ludwig encoder, avoiding runtime monkey patching of existing encoders. +# +# The implementation wraps the existing CAFormerStackedCNN (renamed conceptually +# to MetaFormerStackedCNN) logic currently living in caformer_setup_backup. +# +# TODO (follow-up in PR): +# - Move / refactor caformer_setup_backup/ code into ludwig/modules or a +# dedicated ludwig/encoders/image/metaformer_backbones.py file. +# - Add a proper schema definition (see get_schema_cls TODO below). +# - Add unit tests under tests/ludwig/encoders/test_metaformer_encoder.py +# - Add release note and documentation snippet. + +from typing import Any, Dict, Optional, Type +import logging + +import torch +import torch.nn as nn + +from ludwig.api_annotations import DeveloperAPI +from ludwig.constants import ENCODER_OUTPUT, IMAGE +from ludwig.encoders.base import Encoder +from ludwig.encoders.registry import register_encoder +from ludwig.encoders.types import EncoderOutputDict + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +@register_encoder("metaformer", IMAGE) +class MetaFormerEncoder(Encoder): + """MetaFormerEncoder + Provides access to MetaFormer / CAFormer style convolution-attention hybrids + (e.g., caformer_s18 / m36 / b36 etc.) as a Ludwig image encoder. + + Configuration (proposed): + type: metaformer + model_name: caformer_s18 # required + use_pretrained: true # optional (default True) + trainable: true # optional + output_size: 128 # dimensionality after projection head + + Behavior: + - Loads the specified backbone. + - Adapts input spatial size & channels as needed (handled by backbone wrapper). + - Emits a dense representation of shape (output_size,). + """ + + def __init__( + self, + height: int, + width: int, + num_channels: int = 3, + model_name: Optional[str] = None, + use_pretrained: bool = True, + trainable: bool = True, + output_size: int = 128, + encoder_config=None, + **kwargs: Any, + ): + super().__init__() + self.config = encoder_config + self.model_name = model_name or "caformer_s18" + self.use_pretrained = use_pretrained + self.trainable = trainable + self.output_size = output_size + + # Import existing implementation (currently in backup namespace). + # In a polished PR this code should be relocated inside core tree. + try: + # Updated to use integrated metaformer implementation + from metaformer_integration.metaformer_stacked_cnn import MetaFormerStackedCNN as _BackboneWrapper + except Exception as e: # pragma: no cover + raise ImportError( + "Failed to import CAFormer / MetaFormer backbone code. " + "Ensure integration code is migrated from caformer_setup_backup." + ) from e + + logger.info( + "Initializing MetaFormerEncoder backbone=%s pretrained=%s trainable=%s output_size=%d", + self.model_name, + self.use_pretrained, + self.trainable, + self.output_size, + ) + + self.backbone_wrapper: nn.Module = _BackboneWrapper( + height=height if height is not None else 224, + width=width if width is not None else 224, + num_channels=num_channels if num_channels is not None else 3, + output_size=output_size, + custom_model=self.model_name, + use_pretrained=self.use_pretrained, + trainable=self.trainable, + ) + + # Expose shapes + self._input_shape = (num_channels, height, width) + # Backbone wrapper exposes output_shape as list -> convert to torch.Size + raw_out_shape = getattr(self.backbone_wrapper, "output_shape", [self.output_size]) + if isinstance(raw_out_shape, (list, tuple)): + self._output_shape = torch.Size(raw_out_shape) + else: + self._output_shape = torch.Size([self.output_size]) + + # Freeze if not trainable + if not self.trainable: + for p in self.backbone_wrapper.parameters(): + p.requires_grad = False + + def forward(self, inputs: torch.Tensor) -> EncoderOutputDict: + # Expect shape: [B, C, H, W] + if not isinstance(inputs, torch.Tensor): + raise TypeError("MetaFormerEncoder forward expects a torch.Tensor input.") + out_dict = self.backbone_wrapper(inputs) + if isinstance(out_dict, dict) and "encoder_output" in out_dict: + rep = out_dict["encoder_output"] + else: + # Fallback: treat raw module output as representation + rep = out_dict + return {ENCODER_OUTPUT: rep} + + @property + def input_shape(self) -> torch.Size: + return torch.Size(self._input_shape) + + @property + def output_shape(self) -> torch.Size: + return self._output_shape + + @staticmethod + def get_schema_cls() -> Type[Any]: + # Return dedicated MetaFormerConfig (added in schema tree) + try: + from ludwig.schema.encoders.image.metaformer import MetaFormerConfig + return MetaFormerConfig # type: ignore + except Exception as e: # pragma: no cover + raise RuntimeError("MetaFormerConfig schema import failed.") from e diff --git a/ludwig/schema/encoders/image/__init__.py b/ludwig/schema/encoders/image/__init__.py index e1767333ea0..1f4f9c03311 100644 --- a/ludwig/schema/encoders/image/__init__.py +++ b/ludwig/schema/encoders/image/__init__.py @@ -1,2 +1,3 @@ import ludwig.schema.encoders.image.base import ludwig.schema.encoders.image.torchvision # noqa +import ludwig.schema.encoders.image.metaformer # noqa diff --git a/ludwig/schema/encoders/image/base.py b/ludwig/schema/encoders/image/base.py index c0feeecb3b8..522f78e244c 100644 --- a/ludwig/schema/encoders/image/base.py +++ b/ludwig/schema/encoders/image/base.py @@ -33,6 +33,15 @@ def module_name(): description=ENCODER_METADATA["Stacked2DCNN"]["type"].long_description, ) + # Added to enable CAFormer integration via stacked_cnn encoder. + custom_model: Optional[str] = schema_utils.String( + default=None, + allow_none=True, + description="If set to one of ['caformer_s18','caformer_s36','caformer_m36','caformer_b36'], " + "the standard stacked_cnn encoder will be dynamically replaced at runtime by a CAFormer backbone " + "defined in caformer_setup_backup. Requires calling the patch utility prior to model construction.", + ) + conv_dropout: Optional[int] = schema_utils.FloatRange( default=0.0, min=0, diff --git a/ludwig/schema/encoders/image/metaformer.py b/ludwig/schema/encoders/image/metaformer.py new file mode 100644 index 00000000000..51a674f9318 --- /dev/null +++ b/ludwig/schema/encoders/image/metaformer.py @@ -0,0 +1,76 @@ +from typing import Optional + +from ludwig.api_annotations import DeveloperAPI +from ludwig.constants import IMAGE +from ludwig.schema import utils as schema_utils +from ludwig.schema.encoders.image.base import ImageEncoderConfig +from ludwig.schema.encoders.utils import register_encoder_config +from ludwig.schema.utils import ludwig_dataclass + +@DeveloperAPI +@register_encoder_config("metaformer", IMAGE) +@ludwig_dataclass +class MetaFormerConfig(ImageEncoderConfig): + """Configuration for the MetaFormer / CAFormer style image encoder. + + This schema intentionally avoids referencing ENCODER_METADATA (not yet extended) + to keep the initial integration minimal and self-contained. + """ + + @staticmethod + def module_name(): + return "MetaFormerEncoder" + + type: str = schema_utils.ProtectedString( + "metaformer", + description="MetaFormer / CAFormer image encoder integrating ConvFormer / CAFormer style backbones.", + ) + + model_name: str = schema_utils.String( + default="caformer_s18", + allow_none=False, + description="Backbone model name (e.g. caformer_s18, caformer_s36, caformer_m36, caformer_b36, etc.).", + ) + + use_pretrained: bool = schema_utils.Boolean( + default=True, + description="If true, load pretrained backbone weights (if available).", + ) + + trainable: bool = schema_utils.Boolean( + default=True, + description="If false, freezes backbone parameters.", + ) + + output_size: int = schema_utils.PositiveInteger( + default=128, + description="Projection head output dimensionality.", + ) + + height: Optional[int] = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="Input image height (optional; if None, provided by feature preprocessing).", + ) + + width: Optional[int] = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="Input image width (optional; if None, provided by feature preprocessing).", + ) + + num_channels: Optional[int] = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="Number of input image channels (e.g. 1 for grayscale, 3 for RGB).", + ) + + def set_fixed_preprocessing_params(self, model_type: str, preprocessing: "ImagePreprocessingConfig"): + # Allow variable sizes; internal wrapper adapts / pools to model expected size. + preprocessing.requires_equal_dimensions = False + # Leave height/width unset to allow dataset-driven or on-the-fly resizing. + # Channel adaptation is handled dynamically if needed. + if self.height is not None: + preprocessing.height = self.height + if self.width is not None: + preprocessing.width = self.width diff --git a/ludwig/utils/tokenizers.py b/ludwig/utils/tokenizers.py index 6fb2d2117a4..ae8357fef00 100644 --- a/ludwig/utils/tokenizers.py +++ b/ludwig/utils/tokenizers.py @@ -18,7 +18,23 @@ from typing import Any, Dict, List, Optional, Union import torch -import torchtext +# Make torchtext optional: guard import to allow pure vision use-cases (e.g., CAFormer image encoder) without binary wheels. +try: + import torchtext # type: ignore + TORCHTEXT_AVAILABLE = True +except Exception as e: + TORCHTEXT_AVAILABLE = False + logger = logging.getLogger(__name__) + logger.warning(f"torchtext import failed, disabling torchtext-based tokenizers: {e}") + class _TorchTextStub: # Minimal stub so attribute access does not immediately crash unless actually used. + __version__ = "0.0.0" + class transforms: # empty namespace + pass + class utils: + @staticmethod + def get_asset_local_path(path): + raise RuntimeError("torchtext unavailable (stubbed).") + torchtext = _TorchTextStub() # type: ignore from ludwig.constants import PADDING_SYMBOL, UNKNOWN_SYMBOL from ludwig.utils.data_utils import load_json diff --git a/metaformer_integration/__init__.py b/metaformer_integration/__init__.py new file mode 100644 index 00000000000..fd4813c43fd --- /dev/null +++ b/metaformer_integration/__init__.py @@ -0,0 +1,32 @@ +# MetaFormer integration package init +from .metaformer_models import get_registered_model, default_cfgs # noqa: F401 +from .metaformer_stacked_cnn import ( + MetaFormerStackedCNN, + CAFormerStackedCNN, + patch_ludwig_comprehensive, + patch_ludwig_direct, + patch_ludwig_robust, + list_metaformer_models, + get_metaformer_backbone_names, + metaformer_model_exists, + describe_metaformer_model, +) # noqa: F401 + +def patch(): + """Convenience one-call patch entrypoint (alias of patch_ludwig_comprehensive).""" + return patch_ludwig_comprehensive() + +__all__ = [ + "MetaFormerStackedCNN", + "CAFormerStackedCNN", + "patch_ludwig_comprehensive", + "patch_ludwig_direct", + "patch_ludwig_robust", + "list_metaformer_models", + "get_metaformer_backbone_names", + "metaformer_model_exists", + "describe_metaformer_model", + "get_registered_model", + "default_cfgs", + "patch", +] diff --git a/metaformer_integration/metaformer_models.py b/metaformer_integration/metaformer_models.py new file mode 100644 index 00000000000..189e2ce87e4 --- /dev/null +++ b/metaformer_integration/metaformer_models.py @@ -0,0 +1,911 @@ +# MetaFormer unified models module (migrated from caformer_setup_backup/caformer_models.py) +# Provides IdentityFormer, RandFormer, PoolFormerV2, ConvFormer, CAFormer minimal builders. + +from functools import partial +import math +from typing import Dict, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ----------------------------------------------------------------------------- +# Utility +# ----------------------------------------------------------------------------- +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + def norm_cdf(x): + return (1. + math.erf(x / math.sqrt(2.))) / 2. + with torch.no_grad(): + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + +def to_2tuple(x): + if isinstance(x, (list, tuple)): + return x + return (x, x) + +class DropPath(nn.Module): + def __init__(self, drop_prob: float = 0.0): + super().__init__() + self.drop_prob = drop_prob + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() + return x.div(keep_prob) * random_tensor + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +# ----------------------------------------------------------------------------- +# Registry +# ----------------------------------------------------------------------------- +_model_registry: Dict[str, Callable] = {} + +def register_model(fn): + _model_registry[fn.__name__] = fn + return fn + +def get_registered_model(name: str, pretrained=False, **kwargs): + if name not in _model_registry: + raise ValueError(f"Model '{name}' not found. Available: {list(_model_registry.keys())}") + return _model_registry[name](pretrained=pretrained, **kwargs) + +def _cfg(url: str = "", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": None, + "crop_pct": 1.0, + "interpolation": "bicubic", + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "classifier": "head", + **kwargs, + } + +default_cfgs = { + # IdentityFormer + "identityformer_s12": _cfg(url="https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth"), + "identityformer_s24": _cfg(url="https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth"), + "identityformer_s36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth"), + "identityformer_m36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth"), + "identityformer_m48": _cfg(url="https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth"), + # RandFormer + "randformer_s12": _cfg(url="https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth"), + "randformer_s24": _cfg(url="https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth"), + "randformer_s36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth"), + "randformer_m36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth"), + "randformer_m48": _cfg(url="https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth"), + # PoolFormerV2 + "poolformerv2_s12": _cfg(url="https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth"), + "poolformerv2_s24": _cfg(url="https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth"), + "poolformerv2_s36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth"), + "poolformerv2_m36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth"), + "poolformerv2_m48": _cfg(url="https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth"), + # ConvFormer + "convformer_s18": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth"), + "convformer_s18_384": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth", input_size=(3, 384, 384)), + "convformer_s18_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth"), + "convformer_s18_384_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth", input_size=(3, 384, 384)), + "convformer_s18_in21k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth", num_classes=21841), + "convformer_s36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth"), + "convformer_s36_384": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth", input_size=(3, 384, 384)), + "convformer_s36_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth"), + "convformer_s36_384_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth", input_size=(3, 384, 384)), + "convformer_s36_in21k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth", num_classes=21841), + "convformer_m36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth"), + "convformer_m36_384": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth", input_size=(3, 384, 384)), + "convformer_m36_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth"), + "convformer_m36_384_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth", input_size=(3, 384, 384)), + "convformer_m36_in21k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth", num_classes=21841), + "convformer_b36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth"), + "convformer_b36_384": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth", input_size=(3, 384, 384)), + "convformer_b36_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth"), + "convformer_b36_384_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth", input_size=(3, 384, 384)), + "convformer_b36_in21k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth", num_classes=21841), + # CAFormer + "caformer_s18": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth"), + "caformer_s18_384": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth", input_size=(3, 384, 384)), + "caformer_s18_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth"), + "caformer_s18_384_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth", input_size=(3, 384, 384)), + "caformer_s18_in21k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth", num_classes=21841), + "caformer_s36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth"), + "caformer_s36_384": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth", input_size=(3, 384, 384)), + "caformer_s36_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth"), + "caformer_s36_384_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth", input_size=(3, 384, 384)), + "caformer_s36_in21k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth", num_classes=21841), + "caformer_m36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth"), + "caformer_m36_384": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth", input_size=(3, 384, 384)), + "caformer_m36_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth"), + "caformer_m36_384_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth", input_size=(3, 384, 384)), + "caformer_m36_in21k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth", num_classes=21841), + "caformer_b36": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth"), + "caformer_b36_384": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth", input_size=(3, 384, 384)), + "caformer_b36_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth"), + "caformer_b36_384_in21ft1k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth", input_size=(3, 384, 384)), + "caformer_b36_in21k": _cfg(url="https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth", num_classes=21841), +} + +# ----------------------------------------------------------------------------- +# Primitives +# ----------------------------------------------------------------------------- +class Scale(nn.Module): + def __init__(self, dim, init_value=1.0, trainable=True): + super().__init__() + self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) + def forward(self, x): + if x.dim() == 4: + return x * self.scale.view(1, -1, 1, 1) + return x * self.scale + +class SquaredReLU(nn.Module): + def __init__(self, inplace=False): + super().__init__() + self.relu = nn.ReLU(inplace=inplace) + def forward(self, x): + return torch.square(self.relu(x)) + +class StarReLU(nn.Module): + def __init__(self, scale_value=1.0, bias_value=0.0, + scale_learnable=True, bias_learnable=True, inplace=False): + super().__init__() + self.relu = nn.ReLU(inplace=inplace) + self.scale = nn.Parameter(scale_value * torch.ones(1), requires_grad=scale_learnable) + self.bias = nn.Parameter(bias_value * torch.ones(1), requires_grad=bias_learnable) + def forward(self, x): + return self.scale * self.relu(x) ** 2 + self.bias + +class Attention(nn.Module): + def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False, + attn_drop=0., proj_drop=0., proj_bias=False, **kwargs): + super().__init__() + self.head_dim = head_dim + self.num_heads = num_heads if num_heads else max(1, dim // head_dim) + self.attention_dim = self.num_heads * self.head_dim + self.scale = self.head_dim ** -0.5 + self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + def forward(self, x): + B, H, W, C = x.shape + N = H * W + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = self.attn_drop(attn.softmax(dim=-1)) + x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class RandomMixing(nn.Module): + def __init__(self, num_tokens=196, **kwargs): + super().__init__() + self.register_buffer("random_matrix", + torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), + persistent=False) + def forward(self, x): + B, H, W, C = x.shape + N = H * W + if self.random_matrix.shape[0] != N: + rm = torch.eye(N, device=x.device, dtype=x.dtype) + else: + rm = self.random_matrix + x = x.reshape(B, N, C) + x = torch.einsum("mn,bnc->bmc", rm, x) + x = x.reshape(B, H, W, C) + return x + +class Pooling(nn.Module): + def __init__(self, pool_size=3, **kwargs): + super().__init__() + self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) + def forward(self, x): + y = x.permute(0, 3, 1, 2) + y = self.pool(y) + y = y.permute(0, 2, 3, 1) + return y - x + +class LayerNormGeneral(nn.Module): + def __init__(self, affine_shape=None, normalized_dim=(-1,), scale=True, bias=True, eps=1e-5): + super().__init__() + self.normalized_dim = normalized_dim + self.use_scale = scale + self.use_bias = bias + self.eps = eps + if scale and affine_shape is not None: + self.weight = nn.Parameter(torch.ones(affine_shape)) + else: + self.weight = None + if bias and affine_shape is not None: + self.bias = nn.Parameter(torch.zeros(affine_shape)) + else: + self.bias = None + def forward(self, x): + c = x - x.mean(self.normalized_dim, keepdim=True) + v = (c ** 2).mean(self.normalized_dim, keepdim=True) + x = c / torch.sqrt(v + self.eps) + if self.use_scale and self.weight is not None: + w = self.weight + if x.dim() == 4 and w.dim() == 1: + w = w.view(1, -1, 1, 1) + x = x * w + if self.use_bias and self.bias is not None: + b = self.bias + if x.dim() == 4 and b.dim() == 1: + b = b.view(1, -1, 1, 1) + x = x + b + return x + +class LayerNormWithoutBias(nn.Module): + def __init__(self, normalized_shape, eps=1e-5, **kwargs): + super().__init__() + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.eps = eps + self.normalized_shape = normalized_shape + def forward(self, x): + if x.dim() == 4: + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x, self.normalized_shape, self.weight, None, self.eps) + x = x.permute(0, 3, 1, 2) + else: + x = F.layer_norm(x, self.normalized_shape, self.weight, None, self.eps) + return x + +class SepConv(nn.Module): + def __init__(self, dim, expansion_ratio=2, act1_layer=StarReLU, act2_layer=nn.Identity, + bias=False, kernel_size=7, padding=3, **kwargs): + super().__init__() + hidden = int(expansion_ratio * dim) + self.pw1 = nn.Linear(dim, hidden, bias=bias) + self.act1 = act1_layer() + self.dw = nn.Conv2d(hidden, hidden, kernel_size=kernel_size, padding=padding, groups=hidden, bias=bias) + self.act2 = act2_layer() + self.pw2 = nn.Linear(hidden, dim, bias=bias) + def forward(self, x): + identity = x + x = self.pw1(x) + x = self.act1(x) + x = x.permute(0, 3, 1, 2) + x = self.dw(x) + x = x.permute(0, 2, 3, 1) + x = self.act2(x) + x = self.pw2(x) + if x.shape == identity.shape: + x = x + identity + return x + +class Mlp(nn.Module): + def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs): + super().__init__() + out_features = out_features or dim + hidden = int(mlp_ratio * dim) + drops = to_2tuple(drop) + self.fc1 = nn.Linear(dim, hidden, bias=bias) + self.act = act_layer() + self.drop1 = nn.Dropout(drops[0]) + self.fc2 = nn.Linear(hidden, out_features, bias=bias) + self.drop2 = nn.Dropout(drops[1]) + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + +class MlpHead(nn.Module): + def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU, + norm_layer=nn.LayerNorm, head_dropout=0., bias=True): + super().__init__() + hidden = int(mlp_ratio * dim) + self.fc1 = nn.Linear(dim, hidden, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden) + self.drop = nn.Dropout(head_dropout) + self.fc2 = nn.Linear(hidden, num_classes, bias=bias) + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.drop(x) + x = self.fc2(x) + return x + +class MetaFormerBlock(nn.Module): + def __init__(self, dim, token_mixer=nn.Identity, mlp=Mlp, + norm_layer=nn.LayerNorm, drop=0.0, drop_path=0.0, + layer_scale_init_value=None, res_scale_init_value=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.token_mixer = token_mixer(dim=dim, drop=drop) if token_mixer not in (nn.Identity, None) else token_mixer() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.layer_scale1 = Scale(dim, init_value=layer_scale_init_value) if layer_scale_init_value else nn.Identity() + self.res_scale1 = Scale(dim, init_value=res_scale_init_value) if res_scale_init_value else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = mlp(dim=dim, drop=drop) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.layer_scale2 = Scale(dim, init_value=layer_scale_init_value) if layer_scale_init_value else nn.Identity() + self.res_scale2 = Scale(dim, init_value=res_scale_init_value) if res_scale_init_value else nn.Identity() + def forward(self, x): + shortcut = x + x1 = self.norm1(x) + x1p = x1.permute(0, 2, 3, 1) + tm = self.token_mixer(x1p) if not isinstance(self.token_mixer, nn.Identity) else x1p + tm = tm.permute(0, 3, 1, 2) + x = shortcut + self.drop_path1(self.layer_scale1(tm)) + x = self.res_scale1(x) + shortcut2 = x + x2 = self.norm2(x) + x2p = x2.permute(0, 2, 3, 1) + mlp_out = self.mlp(x2p).permute(0, 3, 1, 2) + x = shortcut2 + self.drop_path2(self.layer_scale2(mlp_out)) + x = self.res_scale2(x) + return x + +class MetaFormer(nn.Module): + def __init__(self, in_chans=3, num_classes=1000, + depths=(2, 2, 6, 2), dims=(64, 128, 320, 512), + token_mixers=nn.Identity, mlps=Mlp, + norm_layers=partial(LayerNormWithoutBias, eps=1e-6), + drop_path_rate=0.0, head_dropout=0.0, + layer_scale_init_values=None, + res_scale_init_values=(None, None, 1.0, 1.0), + output_norm=partial(nn.LayerNorm, eps=1e-6), + head_fn=nn.Linear, **kwargs): + super().__init__() + self.num_classes = num_classes + self.num_features = dims[-1] + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + norm_layers(dims[0]) if not isinstance(norm_layers, (list, tuple)) else norm_layers[0](dims[0]) + ) + self.downsample_layers.append(stem) + for i in range(3): + norm_mod = norm_layers if not isinstance(norm_layers, (list, tuple)) else norm_layers[i + 1] + layer = nn.Sequential( + nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), + norm_mod(dims[i + 1]) + ) + self.downsample_layers.append(layer) + if not isinstance(token_mixers, (list, tuple)): + token_mixers = [token_mixers] * len(depths) + if not isinstance(mlps, (list, tuple)): + mlps = [mlps] * len(depths) + if not isinstance(norm_layers, (list, tuple)): + norm_layers = [norm_layers] * len(depths) + if not isinstance(layer_scale_init_values, (list, tuple)): + layer_scale_init_values = [layer_scale_init_values] * len(depths) + if not isinstance(res_scale_init_values, (list, tuple)): + res_scale_init_values = [res_scale_init_values] * len(depths) + dp_rates = torch.linspace(0, drop_path_rate, sum(depths)).tolist() + cur = 0 + self.stages = nn.ModuleList() + for i, depth in enumerate(depths): + blocks = [] + for j in range(depth): + blocks.append( + MetaFormerBlock( + dim=dims[i], + token_mixer=token_mixers[i], + mlp=mlps[i], + norm_layer=norm_layers[i] if not isinstance(norm_layers[i], partial) else norm_layers[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_values[i], + res_scale_init_value=res_scale_init_values[i], + ) + ) + cur += depth + self.stages.append(nn.Sequential(*blocks)) + self.norm = output_norm(dims[-1]) if not isinstance(output_norm, (list, tuple)) else output_norm[-1] + if head_dropout > 0.0: + self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout) + else: + self.head = head_fn(dims[-1], num_classes) + self.apply(self._init_weights) + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=0.02) + if getattr(m, "bias", None) is not None: + nn.init.constant_(m.bias, 0) + def forward_features(self, x): + for i, down in enumerate(self.downsample_layers): + x = down(x) + x = self.stages[i](x) + x = x.mean(dim=[2, 3]) + return x + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + +# ----------------------------------------------------------------------------- +# Model builders +# ----------------------------------------------------------------------------- +@register_model +def identityformer_s12(pretrained=False, **kwargs): + model = MetaFormer(depths=(2, 2, 6, 2), dims=(64,128,320,512), + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["identityformer_s12"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def identityformer_s24(pretrained=False, **kwargs): + model = MetaFormer(depths=(4,4,12,4), dims=(64,128,320,512), + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["identityformer_s24"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def identityformer_s36(pretrained=False, **kwargs): + model = MetaFormer(depths=(6,6,18,6), dims=(64,128,320,512), + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["identityformer_s36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def identityformer_m36(pretrained=False, **kwargs): + model = MetaFormer(depths=(6,6,18,6), dims=(96,192,384,768), + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["identityformer_m36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def identityformer_m48(pretrained=False, **kwargs): + model = MetaFormer(depths=(8,8,24,8), dims=(96,192,384,768), + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["identityformer_m48"] + if pretrained: _load_pretrained(model) + return model + +def _rand_token_mixers(depths, tokens_last_stage=49): + return [nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=tokens_last_stage)] + +@register_model +def randformer_s12(pretrained=False, **kwargs): + model = MetaFormer(depths=(2,2,6,2), dims=(64,128,320,512), + token_mixers=_rand_token_mixers((2,2,6,2)), + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["randformer_s12"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def randformer_s24(pretrained=False, **kwargs): + model = MetaFormer(depths=(4,4,12,4), dims=(64,128,320,512), + token_mixers=_rand_token_mixers((4,4,12,4)), + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["randformer_s24"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def randformer_s36(pretrained=False, **kwargs): + model = MetaFormer(depths=(6,6,18,6), dims=(64,128,320,512), + token_mixers=_rand_token_mixers((6,6,18,6)), + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["randformer_s36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def randformer_m36(pretrained=False, **kwargs): + model = MetaFormer(depths=(6,6,18,6), dims=(96,192,384,768), + token_mixers=_rand_token_mixers((6,6,18,6)), + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["randformer_m36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def randformer_m48(pretrained=False, **kwargs): + model = MetaFormer(depths=(8,8,24,8), dims=(96,192,384,768), + token_mixers=_rand_token_mixers((8,8,24,8)), + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["randformer_m48"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def poolformerv2_s12(pretrained=False, **kwargs): + model = MetaFormer(depths=(2,2,6,2), dims=(64,128,320,512), + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["poolformerv2_s12"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def poolformerv2_s24(pretrained=False, **kwargs): + model = MetaFormer(depths=(4,4,12,4), dims=(64,128,320,512), + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["poolformerv2_s24"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def poolformerv2_s36(pretrained=False, **kwargs): + model = MetaFormer(depths=(6,6,18,6), dims=(64,128,320,512), + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["poolformerv2_s36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def poolformerv2_m36(pretrained=False, **kwargs): + model = MetaFormer(depths=(6,6,18,6), dims=(96,192,384,768), + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["poolformerv2_m36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def poolformerv2_m48(pretrained=False, **kwargs): + model = MetaFormer(depths=(8,8,24,8), dims=(96,192,384,768), + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1,2,3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs["poolformerv2_m48"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_s18(pretrained=False, **kwargs): + model = MetaFormer(depths=(3,3,9,3), dims=(64,128,320,512), + token_mixers=SepConv, head_fn=MlpHead, **kwargs) + model.default_cfg = default_cfgs["convformer_s18"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_s18_384(pretrained=False, **kwargs): + model = convformer_s18(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_s18_384"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_s18_in21ft1k(pretrained=False, **kwargs): + model = convformer_s18(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_s18_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_s18_384_in21ft1k(pretrained=False, **kwargs): + model = convformer_s18(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_s18_384_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_s18_in21k(pretrained=False, **kwargs): + model = convformer_s18(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_s18_in21k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_s36(pretrained=False, **kwargs): + model = MetaFormer(depths=(3,12,18,3), dims=(64,128,320,512), + token_mixers=SepConv, head_fn=MlpHead, **kwargs) + model.default_cfg = default_cfgs["convformer_s36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_s36_384(pretrained=False, **kwargs): + model = convformer_s36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_s36_384"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_s36_in21ft1k(pretrained=False, **kwargs): + model = convformer_s36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_s36_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_s36_384_in21ft1k(pretrained=False, **kwargs): + model = convformer_s36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_s36_384_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_s36_in21k(pretrained=False, **kwargs): + model = convformer_s36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_s36_in21k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_m36(pretrained=False, **kwargs): + model = MetaFormer(depths=(3,12,18,3), dims=(96,192,384,576), + token_mixers=SepConv, head_fn=MlpHead, **kwargs) + model.default_cfg = default_cfgs["convformer_m36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_m36_384(pretrained=False, **kwargs): + model = convformer_m36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_m36_384"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_m36_in21ft1k(pretrained=False, **kwargs): + model = convformer_m36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_m36_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_m36_384_in21ft1k(pretrained=False, **kwargs): + model = convformer_m36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_m36_384_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_m36_in21k(pretrained=False, **kwargs): + model = convformer_m36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_m36_in21k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_b36(pretrained=False, **kwargs): + model = MetaFormer(depths=(3,12,18,3), dims=(128,256,512,768), + token_mixers=SepConv, head_fn=MlpHead, **kwargs) + model.default_cfg = default_cfgs["convformer_b36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_b36_384(pretrained=False, **kwargs): + model = convformer_b36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_b36_384"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_b36_in21ft1k(pretrained=False, **kwargs): + model = convformer_b36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_b36_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_b36_384_in21ft1k(pretrained=False, **kwargs): + model = convformer_b36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_b36_384_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def convformer_b36_in21k(pretrained=False, **kwargs): + model = convformer_b36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["convformer_b36_in21k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_s18(pretrained=False, **kwargs): + model = MetaFormer(depths=(3,3,9,3), dims=(64,128,320,512), + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, **kwargs) + model.default_cfg = default_cfgs["caformer_s18"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_s18_384(pretrained=False, **kwargs): + model = caformer_s18(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_s18_384"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_s18_in21ft1k(pretrained=False, **kwargs): + model = caformer_s18(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_s18_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_s18_384_in21ft1k(pretrained=False, **kwargs): + model = caformer_s18(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_s18_384_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_s18_in21k(pretrained=False, **kwargs): + model = caformer_s18(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_s18_in21k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_s36(pretrained=False, **kwargs): + model = MetaFormer(depths=(3,12,18,3), dims=(64,128,320,512), + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, **kwargs) + model.default_cfg = default_cfgs["caformer_s36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_s36_384(pretrained=False, **kwargs): + model = caformer_s36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_s36_384"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_s36_in21ft1k(pretrained=False, **kwargs): + model = caformer_s36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_s36_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_s36_384_in21ft1k(pretrained=False, **kwargs): + model = caformer_s36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_s36_384_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_s36_in21k(pretrained=False, **kwargs): + model = caformer_s36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_s36_in21k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_m36(pretrained=False, **kwargs): + model = MetaFormer(depths=(3,12,18,3), dims=(96,192,384,576), + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, **kwargs) + model.default_cfg = default_cfgs["caformer_m36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_m36_384(pretrained=False, **kwargs): + model = caformer_m36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_m36_384"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_m36_in21ft1k(pretrained=False, **kwargs): + model = caformer_m36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_m36_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_m36_384_in21ft1k(pretrained=False, **kwargs): + model = caformer_m36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_m36_384_in21ft1k"] + if pretrained: _load_pretrained(model) + return model # (If needed you can correct unused variant naming later.) + +@register_model +def caformer_m36_in21k(pretrained=False, **kwargs): + model = caformer_m36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_m36_in21k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_b36(pretrained=False, **kwargs): + model = MetaFormer(depths=(3,12,18,3), dims=(128,256,512,768), + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, **kwargs) + model.default_cfg = default_cfgs["caformer_b36"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_b36_384(pretrained=False, **kwargs): + model = caformer_b36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_b36_384"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_b36_in21ft1k(pretrained=False, **kwargs): + model = caformer_b36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_b36_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_b36_384_in21ft1k(pretrained=False, **kwargs): + model = caformer_b36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_b36_384_in21ft1k"] + if pretrained: _load_pretrained(model) + return model + +@register_model +def caformer_b36_in21k(pretrained=False, **kwargs): + model = caformer_b36(pretrained=False, **kwargs) + model.default_cfg = default_cfgs["caformer_b36_in21k"] + if pretrained: _load_pretrained(model) + return model + +def _load_pretrained(model: nn.Module): + url = getattr(model, "default_cfg", {}).get("url", "") + if not url: + return + try: + state_dict = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + if missing: + print(f"[MetaFormer] Missing keys (ignored): {missing[:5]} ...") + if unexpected: + print(f"[MetaFormer] Unexpected keys (ignored): {unexpected[:5]} ...") + except Exception as e: + print(f"[MetaFormer] Failed to load pretrained weights from {url}: {e}") + +def _quick_test(): + names = ["identityformer_s12", "randformer_s12", "poolformerv2_s12", "convformer_s18", "caformer_s18"] + for n in names: + try: + m = get_registered_model(n, pretrained=False, num_classes=10) + x = torch.randn(1, 3, 224, 224) + y = m(x) + print(f"{n}: output {y.shape}") + except Exception as ex: + print(f"{n}: FAILED ({ex})") + +if __name__ == "__main__": + _quick_test() diff --git a/metaformer_integration/metaformer_stacked_cnn.py b/metaformer_integration/metaformer_stacked_cnn.py new file mode 100644 index 00000000000..f2d29c3aa68 --- /dev/null +++ b/metaformer_integration/metaformer_stacked_cnn.py @@ -0,0 +1,434 @@ +import logging +import torch +import torch.nn as nn +from typing import Dict, Any, Optional, List +import sys +import os + +# Ensure local package path +sys.path.insert(0, os.path.dirname(__file__)) + +logger = logging.getLogger(__name__) + +try: + from metaformer_integration.metaformer_models import get_registered_model, default_cfgs + METAFORMER_AVAILABLE = True + logger.info(" MetaFormer family models imported successfully") +except ImportError as e: + logger.warning(f" MetaFormer models not available: {e}") + METAFORMER_AVAILABLE = False + +META_PREFIXES = ( + "caformer_", + "convformer_", + "identityformer_", + "randformer_", + "poolformerv2_", +) + +_PATCHED_LUDWIG_META = False + +class MetaFormerStackedCNN(nn.Module): + """Generic MetaFormer family encoder wrapper (ConvFormer, CAFormer, IdentityFormer, RandFormer, PoolFormerV2). + Backward compatible alias retained for legacy but prefer MetaFormerStackedCNN. + """ + def __init__(self, + height: int = 224, + width: int = 224, + num_channels: int = 3, + output_size: int = 128, + custom_model: Optional[str] = None, + use_pretrained: bool = True, + trainable: bool = True, + conv_layers: Optional[List[Dict]] = None, + num_conv_layers: Optional[int] = None, + conv_activation: str = "relu", + conv_dropout: float = 0.0, + conv_norm: Optional[str] = None, + conv_use_bias: bool = True, + fc_layers: Optional[List[Dict]] = None, + num_fc_layers: int = 1, + fc_activation: str = "relu", + fc_dropout: float = 0.0, + fc_norm: Optional[str] = None, + fc_use_bias: bool = True, + **kwargs): + print(f" MetaFormerStackedCNN encoder instantiated! ") + print(f" Using MetaFormer model: {custom_model} ") + super().__init__() + + self.height = height + self.width = width + self.num_channels = num_channels + self.output_size = output_size + self.custom_model = custom_model + if self.custom_model is None: + self.custom_model = sorted(default_cfgs.keys())[0] + self.use_pretrained = use_pretrained + self.trainable = trainable + + env_flag = os.getenv("METAFORMER_PRETRAINED") + if env_flag is not None and env_flag.lower() in ("0", "false", "no", "off"): + self.use_pretrained = False + + cfg_input = default_cfgs.get(self.custom_model, {}).get("input_size", (3, 224, 224)) + self.target_height, self.target_width = cfg_input[1], cfg_input[2] + logger.info(f"Target backbone input size: {self.target_height}x{self.target_width}") + + logger.info(f"Initializing MetaFormerStackedCNN with model: {self.custom_model}") + logger.info(f"Input: {num_channels}x{height}x{width} -> Output: {output_size}") + + self.channel_adapter = None + if num_channels != 3: + self.channel_adapter = nn.Conv2d(num_channels, 3, kernel_size=1, stride=1, padding=0) + logger.info(f"Added channel adapter: {num_channels} -> 3 channels") + + self.size_adapter = None + if height != self.target_height or width != self.target_width: + self.size_adapter = nn.AdaptiveAvgPool2d((self.target_height, self.target_width)) + logger.info(f"Added size adapter: {height}x{width} -> {self.target_height}x{self.target_width}") + + self.backbone = self._load_metaformer_backbone() + self.feature_dim = self._get_feature_dim() + + self.fc_layers = self._create_fc_layers( + input_dim=self.feature_dim, + output_dim=output_size, + num_layers=num_fc_layers, + activation=fc_activation, + dropout=fc_dropout, + norm=fc_norm, + use_bias=fc_use_bias, + fc_layers_config=fc_layers + ) + + if not trainable: + for param in self.backbone.parameters(): + param.requires_grad = False + logger.info("MetaFormer backbone frozen (trainable=False)") + + logger.info(f"MetaFormerStackedCNN initialized successfully") + + def _load_metaformer_backbone(self): + print(f" Loading MetaFormer backbone: {self.custom_model} ") + if not METAFORMER_AVAILABLE: + raise ImportError("MetaFormer models are not available") + if self.custom_model not in default_cfgs: + raise ValueError(f"Unknown MetaFormer model: {self.custom_model}. Available: {list(default_cfgs.keys())[:10]} ...") + model = get_registered_model(self.custom_model, pretrained=self.use_pretrained) + print(f"Successfully loaded weights (if requested) for {self.custom_model}") + logger.info(f"Loaded MetaFormer backbone: {self.custom_model} (pretrained={self.use_pretrained})") + return model + + def _get_feature_dim(self): + with torch.no_grad(): + dummy_input = torch.randn(1, 3, self.target_height, self.target_width) + features = self.backbone.forward_features(dummy_input) + feature_dim = features.shape[-1] + logger.info(f"MetaFormer feature dimension: {feature_dim}") + return feature_dim + + def _create_fc_layers(self, input_dim, output_dim, num_layers, activation, dropout, norm, use_bias, fc_layers_config): + layers = [] + if fc_layers_config: + current_dim = input_dim + for i, layer_config in enumerate(fc_layers_config): + layer_output_dim = layer_config.get('output_size', output_dim if i == len(fc_layers_config) - 1 else current_dim) + layers.append(nn.Linear(current_dim, layer_output_dim, bias=use_bias)) + if i < len(fc_layers_config) - 1: + if activation == "relu": + layers.append(nn.ReLU()) + elif activation == "tanh": + layers.append(nn.Tanh()) + elif activation == "sigmoid": + layers.append(nn.Sigmoid()) + elif activation == "leaky_relu": + layers.append(nn.LeakyReLU()) + if dropout > 0: + layers.append(nn.Dropout(dropout)) + if norm == "batch": + layers.append(nn.BatchNorm1d(layer_output_dim)) + elif norm == "layer": + layers.append(nn.LayerNorm(layer_output_dim)) + current_dim = layer_output_dim + else: + if num_layers == 1: + layers.append(nn.Linear(input_dim, output_dim, bias=use_bias)) + else: + intermediate_dims = [input_dim] + for i in range(num_layers - 1): + intermediate_dim = int(input_dim * (0.5 ** (i + 1))) + intermediate_dim = max(intermediate_dim, output_dim) + intermediate_dims.append(intermediate_dim) + intermediate_dims.append(output_dim) + for i in range(num_layers): + layers.append(nn.Linear(intermediate_dims[i], intermediate_dims[i+1], bias=use_bias)) + if i < num_layers - 1: + if activation == "relu": + layers.append(nn.ReLU()) + elif activation == "tanh": + layers.append(nn.Tanh()) + elif activation == "sigmoid": + layers.append(nn.Sigmoid()) + elif activation == "leaky_relu": + layers.append(nn.LeakyReLU()) + if dropout > 0: + layers.append(nn.Dropout(dropout)) + if norm == "batch": + layers.append(nn.BatchNorm1d(intermediate_dims[i+1])) + elif norm == "layer": + layers.append(nn.LayerNorm(intermediate_dims[i+1])) + return nn.Sequential(*layers) + + def forward(self, x): + if x.shape[1] != 3: + if self.channel_adapter is None: + self.channel_adapter = nn.Conv2d(x.shape[1], 3, kernel_size=1, stride=1, padding=0).to(x.device) + logger.info(f"Created dynamic channel adapter: {x.shape[1]} -> 3 channels") + x = self.channel_adapter(x) + if x.shape[2] != self.target_height or x.shape[3] != self.target_width: + if self.size_adapter is None: + self.size_adapter = nn.AdaptiveAvgPool2d((self.target_height, self.target_width)).to(x.device) + logger.info(f"Created dynamic size adapter: {x.shape[2]}x{x.shape[3]} -> {self.target_height}x{self.target_width}") + x = self.size_adapter(x) + features = self.backbone.forward_features(x) + output = self.fc_layers(features) + return {'encoder_output': output} + + @property + def output_shape(self): + return [self.output_size] + + # Minimal schema hook so Ludwig's schema-based encoder initialization does not fail. + # Ludwig expects: encoder_cls.get_schema_cls().Schema() to return a marshmallow schema instance. + # We provide a lightweight stand-in that simply echoes the config (no validation). + @classmethod + def get_schema_cls(cls): + class _MetaFormerStackedCNNSchemaContainer: + class Schema: + def dump(self, obj, *args, **kwargs): + # Return a plain dict representation (passthrough) + if hasattr(obj, "to_dict"): + return obj.to_dict() + if isinstance(obj, dict): + return obj + return {} + def load(self, data, **kwargs): + # Passthrough load (no validation) + if hasattr(data, "to_dict"): + return data.to_dict() + return data + return _MetaFormerStackedCNNSchemaContainer + +# Legacy alias (retain but deprecated) +CAFormerStackedCNN = MetaFormerStackedCNN + +def create_metaformer_stacked_cnn(model_name: str, **kwargs) -> MetaFormerStackedCNN: + print(f" CREATE_METAFORMER_STACKED_CNN called with model_name: {model_name} ") + print(f"Creating MetaFormer stacked CNN encoder: {model_name}") + if 'custom_model' in kwargs: + kwargs.pop('custom_model', None) + encoder = MetaFormerStackedCNN(custom_model=model_name, **kwargs) + print(f" MetaFormer encoder created successfully: {type(encoder)} ") + return encoder + +def create_caformer_stacked_cnn(model_name: str, **kwargs): + # Backward compatibility + return create_metaformer_stacked_cnn(model_name, **kwargs) + +def list_metaformer_models(): + return sorted(default_cfgs.keys()) + +def get_metaformer_backbone_names(prefix: Optional[str] = None): + names = list_metaformer_models() + if prefix: + names = [n for n in names if n.startswith(prefix)] + return names + +def metaformer_model_exists(name: str) -> bool: + return name in default_cfgs + +def describe_metaformer_model(name: str) -> Dict[str, Any]: + cfg = default_cfgs.get(name, {}).copy() + cfg["exists"] = name in default_cfgs + return cfg + +def patch_ludwig_stacked_cnn(): + return patch_ludwig_direct() + +def patch_ludwig_robust(): + try: + from ludwig.encoders.registry import get_encoder_cls + original_get_encoder_cls = get_encoder_cls + def patched_get_encoder_cls(*args, **kwargs): + # Support both legacy signature get_encoder_cls(encoder_type) + # and current signature get_encoder_cls(feature_type, encoder_type). + if args: + encoder_type = args[-1] + else: + encoder_type = kwargs.get("encoder_type") + if encoder_type == "stacked_cnn": + return MetaFormerStackedCNN + return original_get_encoder_cls(*args, **kwargs) + import ludwig.encoders.registry + ludwig.encoders.registry.get_encoder_cls = patched_get_encoder_cls + from ludwig.encoders.image.base import Stacked2DCNN + original_stacked_cnn_init = Stacked2DCNN.__init__ + def patched_stacked_cnn_init(self, *args, **kwargs): + custom_model = None + if 'custom_model' in kwargs: + custom_model = kwargs['custom_model'] + elif 'encoder_config' in kwargs: + enc_cfg = kwargs['encoder_config'] + if hasattr(enc_cfg, 'to_dict'): + enc_cfg = enc_cfg.to_dict() + if isinstance(enc_cfg, dict): + custom_model = enc_cfg.get('custom_model', None) + if custom_model and any(str(custom_model).startswith(p) for p in META_PREFIXES): + original_stacked_cnn_init(self, *args, **kwargs) + print(f"DETECTED MetaFormer model: {custom_model}") + print(f"MetaFormer encoder is being loaded and used (robust patch).") + build_kwargs = dict(kwargs) + build_kwargs.pop('custom_model', None) + meta_encoder = create_metaformer_stacked_cnn(custom_model, **build_kwargs) + if hasattr(meta_encoder, 'backbone'): + self.backbone = meta_encoder.backbone + if hasattr(meta_encoder, 'fc_layers'): + self.fc_layers = meta_encoder.fc_layers + if hasattr(meta_encoder, 'feature_dim'): + self.feature_dim = meta_encoder.feature_dim + if hasattr(meta_encoder, 'output_size'): + self.output_size = meta_encoder.output_size + self.forward = meta_encoder.forward + if hasattr(meta_encoder, 'output_shape'): + shape_val = meta_encoder.output_shape + if isinstance(shape_val, torch.Size): + self._output_shape_override = shape_val + elif isinstance(shape_val, (list, tuple)): + self._output_shape_override = torch.Size(shape_val) + return + original_stacked_cnn_init(self, *args, **kwargs) + Stacked2DCNN.__init__ = patched_stacked_cnn_init + try: + from ludwig.features.image_feature import ImageInputFeature + original_image_feature_init = ImageInputFeature.__init__ + def patched_image_feature_init(self, *args, **kwargs): + # Call original init + original_image_feature_init(self, *args, **kwargs) + # If ImageInputFeature lacks an input_shape attribute (property without setter) + # we cannot assign it directly. Instead, override create_sample_input used by the + # batch size tuner to synthesize a random tensor of appropriate shape. + if not hasattr(self, "input_shape"): + ch = getattr(getattr(self, "encoder_obj", None), "num_channels", 3) + h = getattr(getattr(self, "encoder_obj", None), "height", 224) + w = getattr(getattr(self, "encoder_obj", None), "width", 224) + def _mf_create_sample_input(batch_size=2, sequence_length=None): + import torch + return torch.rand([batch_size, ch, h, w]) + # Monkey patch only if framework did not already supply one + self.create_sample_input = _mf_create_sample_input + try: + logger.info(f"[MetaFormer Patch] Injected fallback create_sample_input with shape=({ch},{h},{w})") + except Exception: + pass + ImageInputFeature.__init__ = patched_image_feature_init + except Exception: + pass + return True + except Exception as e: + logger.error(f"Failed to apply robust patch: {e}") + return False + +def patch_ludwig_direct(): + try: + from ludwig.encoders.registry import get_encoder_cls + original_get_encoder_cls = get_encoder_cls + def patched_get_encoder_cls(*args, **kwargs): + if args: + encoder_type = args[-1] + else: + encoder_type = kwargs.get("encoder_type") + if encoder_type == "stacked_cnn": + return MetaFormerStackedCNN + return original_get_encoder_cls(*args, **kwargs) + import ludwig.encoders.registry + ludwig.encoders.registry.get_encoder_cls = patched_get_encoder_cls + from ludwig.encoders.image.base import Stacked2DCNN + original_stacked_cnn_init = Stacked2DCNN.__init__ + def patched_stacked_cnn_init(self, *args, **kwargs): + custom_model = kwargs.get('custom_model', None) + if custom_model is None: + custom_model = sorted(default_cfgs.keys())[0] + if any(custom_model.startswith(p) for p in META_PREFIXES): + print(f"DETECTED MetaFormer model: {custom_model}") + print(f"MetaFormer encoder is being loaded and used.") + original_stacked_cnn_init(self, *args, **kwargs) + meta_encoder = create_metaformer_stacked_cnn(custom_model, **kwargs) + self.forward = meta_encoder.forward + if hasattr(meta_encoder, 'backbone'): + self.backbone = meta_encoder.backbone + if hasattr(meta_encoder, 'fc_layers'): + self.fc_layers = meta_encoder.fc_layers + if hasattr(meta_encoder, 'custom_model'): + self.custom_model = meta_encoder.custom_model + else: + original_stacked_cnn_init(self, *args, **kwargs) + Stacked2DCNN.__init__ = patched_stacked_cnn_init + return True + except Exception as e: + logger.error(f"Failed to apply direct patch: {e}") + return False + +def patch_ludwig_schema_validation(): + print(f" PATCH_LUDWIG_SCHEMA_VALIDATION function called ") + try: + from ludwig.schema.features.image import ImageInputFeatureConfig + original_validate = ImageInputFeatureConfig.validate + def patched_validate(self, data, **kwargs): + print(f" PATCHED SCHEMA VALIDATION called ") + print(f" data: {data}") + if 'encoder' in data and 'custom_model' in data['encoder']: + custom_model = data['encoder']['custom_model'] + print(f" DETECTED custom_model in schema validation: {custom_model} ") + return original_validate(self, data, **kwargs) + ImageInputFeatureConfig.validate = patched_validate + print(f" Successfully patched schema validation ") + return True + except Exception as e: + print(f" Could not patch schema validation: {e} ") + return False + +def patch_ludwig_comprehensive(): + global _PATCHED_LUDWIG_META + if _PATCHED_LUDWIG_META: + print(" PATCH_LUDWIG_COMPREHENSIVE already applied (skipping)") + return True + print(" PATCH_LUDWIG_COMPREHENSIVE function called ") + patch_robust = patch_ludwig_robust() + patch_schema = patch_ludwig_schema_validation() + _PATCHED_LUDWIG_META = patch_robust or patch_schema + print(f" Patch results: robust={patch_robust}, schema={patch_schema} ") + return _PATCHED_LUDWIG_META + +def _quick_metaformer_test(): + if not METAFORMER_AVAILABLE: + print("MetaFormer models not available, skipping test") + return + try: + encoder = MetaFormerStackedCNN( + custom_model=sorted(default_cfgs.keys())[0], + height=224, + width=224, + num_channels=3, + output_size=128, + use_pretrained=False, + trainable=False, + ) + dummy = torch.randn(1, 3, 224, 224) + out = encoder(dummy) + print("MetaFormer quick test OK:", out['encoder_output'].shape) + except Exception as e: + print(f"MetaFormer quick test failed: {e}") + +if __name__ == "__main__": + _quick_metaformer_test() diff --git a/test_data/mnist_subset.zip b/test_data/mnist_subset.zip new file mode 100644 index 00000000000..f6785598cd4 Binary files /dev/null and b/test_data/mnist_subset.zip differ diff --git a/test_data/test/0/833.jpg b/test_data/test/0/833.jpg new file mode 100644 index 00000000000..9fb72977edb Binary files /dev/null and b/test_data/test/0/833.jpg differ diff --git a/test_data/test/0/855.jpg b/test_data/test/0/855.jpg new file mode 100644 index 00000000000..94dd1112958 Binary files /dev/null and b/test_data/test/0/855.jpg differ diff --git a/test_data/test/1/1110.jpg b/test_data/test/1/1110.jpg new file mode 100644 index 00000000000..dd06cd54c99 Binary files /dev/null and b/test_data/test/1/1110.jpg differ diff --git a/test_data/test/1/969.jpg b/test_data/test/1/969.jpg new file mode 100644 index 00000000000..bcefd655258 Binary files /dev/null and b/test_data/test/1/969.jpg differ diff --git a/test_data/test/2/961.jpg b/test_data/test/2/961.jpg new file mode 100644 index 00000000000..8b6d0bc055c Binary files /dev/null and b/test_data/test/2/961.jpg differ diff --git a/test_data/test/2/971.jpg b/test_data/test/2/971.jpg new file mode 100644 index 00000000000..45ed8cad66c Binary files /dev/null and b/test_data/test/2/971.jpg differ diff --git a/test_data/test/3/1005.jpg b/test_data/test/3/1005.jpg new file mode 100644 index 00000000000..f58cc9b2726 Binary files /dev/null and b/test_data/test/3/1005.jpg differ diff --git a/test_data/test/3/895.jpg b/test_data/test/3/895.jpg new file mode 100644 index 00000000000..f3eb9dda0cc Binary files /dev/null and b/test_data/test/3/895.jpg differ diff --git a/test_data/test/4/940.jpg b/test_data/test/4/940.jpg new file mode 100644 index 00000000000..850a58d46f8 Binary files /dev/null and b/test_data/test/4/940.jpg differ diff --git a/test_data/test/4/975.jpg b/test_data/test/4/975.jpg new file mode 100644 index 00000000000..00268108b43 Binary files /dev/null and b/test_data/test/4/975.jpg differ diff --git a/test_data/test/5/780.jpg b/test_data/test/5/780.jpg new file mode 100644 index 00000000000..8468cfe3fa1 Binary files /dev/null and b/test_data/test/5/780.jpg differ diff --git a/test_data/test/5/834.jpg b/test_data/test/5/834.jpg new file mode 100644 index 00000000000..4a04ff5ae46 Binary files /dev/null and b/test_data/test/5/834.jpg differ diff --git a/test_data/test/6/796.jpg b/test_data/test/6/796.jpg new file mode 100644 index 00000000000..3011fdfd140 Binary files /dev/null and b/test_data/test/6/796.jpg differ diff --git a/test_data/test/6/932.jpg b/test_data/test/6/932.jpg new file mode 100644 index 00000000000..8fd0acc1948 Binary files /dev/null and b/test_data/test/6/932.jpg differ diff --git a/test_data/test/7/835.jpg b/test_data/test/7/835.jpg new file mode 100644 index 00000000000..7de85815b47 Binary files /dev/null and b/test_data/test/7/835.jpg differ diff --git a/test_data/test/7/863.jpg b/test_data/test/7/863.jpg new file mode 100644 index 00000000000..1d62ab452c8 Binary files /dev/null and b/test_data/test/7/863.jpg differ diff --git a/test_data/test/8/898.jpg b/test_data/test/8/898.jpg new file mode 100644 index 00000000000..a8c57aa1bf2 Binary files /dev/null and b/test_data/test/8/898.jpg differ diff --git a/test_data/test/8/899.jpg b/test_data/test/8/899.jpg new file mode 100644 index 00000000000..acd50f7edba Binary files /dev/null and b/test_data/test/8/899.jpg differ diff --git a/test_data/test/9/1007.jpg b/test_data/test/9/1007.jpg new file mode 100644 index 00000000000..3db75f382cb Binary files /dev/null and b/test_data/test/9/1007.jpg differ diff --git a/test_data/test/9/954.jpg b/test_data/test/9/954.jpg new file mode 100644 index 00000000000..2292b2f1b84 Binary files /dev/null and b/test_data/test/9/954.jpg differ diff --git a/test_data/training/0/5003.jpg b/test_data/training/0/5003.jpg new file mode 100644 index 00000000000..5fe44508d13 Binary files /dev/null and b/test_data/training/0/5003.jpg differ diff --git a/test_data/training/0/5010.jpg b/test_data/training/0/5010.jpg new file mode 100644 index 00000000000..663d44aab3f Binary files /dev/null and b/test_data/training/0/5010.jpg differ diff --git a/test_data/training/0/5359.jpg b/test_data/training/0/5359.jpg new file mode 100644 index 00000000000..ba740e30aac Binary files /dev/null and b/test_data/training/0/5359.jpg differ diff --git a/test_data/training/0/5405.jpg b/test_data/training/0/5405.jpg new file mode 100644 index 00000000000..b9e188b4854 Binary files /dev/null and b/test_data/training/0/5405.jpg differ diff --git a/test_data/training/0/5452.jpg b/test_data/training/0/5452.jpg new file mode 100644 index 00000000000..a4f92ba2477 Binary files /dev/null and b/test_data/training/0/5452.jpg differ diff --git a/test_data/training/0/5524.jpg b/test_data/training/0/5524.jpg new file mode 100644 index 00000000000..2c6e6c15f13 Binary files /dev/null and b/test_data/training/0/5524.jpg differ diff --git a/test_data/training/0/5527.jpg b/test_data/training/0/5527.jpg new file mode 100644 index 00000000000..484b1e41325 Binary files /dev/null and b/test_data/training/0/5527.jpg differ diff --git a/test_data/training/0/5680.jpg b/test_data/training/0/5680.jpg new file mode 100644 index 00000000000..122f02e0b15 Binary files /dev/null and b/test_data/training/0/5680.jpg differ diff --git a/test_data/training/0/5699.jpg b/test_data/training/0/5699.jpg new file mode 100644 index 00000000000..87b4234b360 Binary files /dev/null and b/test_data/training/0/5699.jpg differ diff --git a/test_data/training/0/5766.jpg b/test_data/training/0/5766.jpg new file mode 100644 index 00000000000..a0d2e51f753 Binary files /dev/null and b/test_data/training/0/5766.jpg differ diff --git a/test_data/training/1/5754.jpg b/test_data/training/1/5754.jpg new file mode 100644 index 00000000000..212d0dc933d Binary files /dev/null and b/test_data/training/1/5754.jpg differ diff --git a/test_data/training/1/6015.jpg b/test_data/training/1/6015.jpg new file mode 100644 index 00000000000..2446bdc336d Binary files /dev/null and b/test_data/training/1/6015.jpg differ diff --git a/test_data/training/1/6100.jpg b/test_data/training/1/6100.jpg new file mode 100644 index 00000000000..4c591378f19 Binary files /dev/null and b/test_data/training/1/6100.jpg differ diff --git a/test_data/training/1/6129.jpg b/test_data/training/1/6129.jpg new file mode 100644 index 00000000000..2b7aa9c4d63 Binary files /dev/null and b/test_data/training/1/6129.jpg differ diff --git a/test_data/training/1/6247.jpg b/test_data/training/1/6247.jpg new file mode 100644 index 00000000000..9995d17e8e6 Binary files /dev/null and b/test_data/training/1/6247.jpg differ diff --git a/test_data/training/1/6275.jpg b/test_data/training/1/6275.jpg new file mode 100644 index 00000000000..2da61b20476 Binary files /dev/null and b/test_data/training/1/6275.jpg differ diff --git a/test_data/training/1/6552.jpg b/test_data/training/1/6552.jpg new file mode 100644 index 00000000000..838daea59bc Binary files /dev/null and b/test_data/training/1/6552.jpg differ diff --git a/test_data/training/1/6590.jpg b/test_data/training/1/6590.jpg new file mode 100644 index 00000000000..5bac6101af6 Binary files /dev/null and b/test_data/training/1/6590.jpg differ diff --git a/test_data/training/1/6727.jpg b/test_data/training/1/6727.jpg new file mode 100644 index 00000000000..483953e989a Binary files /dev/null and b/test_data/training/1/6727.jpg differ diff --git a/test_data/training/1/6733.jpg b/test_data/training/1/6733.jpg new file mode 100644 index 00000000000..316493d7f6c Binary files /dev/null and b/test_data/training/1/6733.jpg differ diff --git a/test_data/training/2/4984.jpg b/test_data/training/2/4984.jpg new file mode 100644 index 00000000000..651be9485b4 Binary files /dev/null and b/test_data/training/2/4984.jpg differ diff --git a/test_data/training/2/4992.jpg b/test_data/training/2/4992.jpg new file mode 100644 index 00000000000..d3e0d88748c Binary files /dev/null and b/test_data/training/2/4992.jpg differ diff --git a/test_data/training/2/5008.jpg b/test_data/training/2/5008.jpg new file mode 100644 index 00000000000..75ecc06e859 Binary files /dev/null and b/test_data/training/2/5008.jpg differ diff --git a/test_data/training/2/5323.jpg b/test_data/training/2/5323.jpg new file mode 100644 index 00000000000..09969fca718 Binary files /dev/null and b/test_data/training/2/5323.jpg differ diff --git a/test_data/training/2/5325.jpg b/test_data/training/2/5325.jpg new file mode 100644 index 00000000000..38fc0c58554 Binary files /dev/null and b/test_data/training/2/5325.jpg differ diff --git a/test_data/training/2/5407.jpg b/test_data/training/2/5407.jpg new file mode 100644 index 00000000000..c4dddd5951e Binary files /dev/null and b/test_data/training/2/5407.jpg differ diff --git a/test_data/training/2/5438.jpg b/test_data/training/2/5438.jpg new file mode 100644 index 00000000000..937a10abcf0 Binary files /dev/null and b/test_data/training/2/5438.jpg differ diff --git a/test_data/training/2/5585.jpg b/test_data/training/2/5585.jpg new file mode 100644 index 00000000000..5a703177e21 Binary files /dev/null and b/test_data/training/2/5585.jpg differ diff --git a/test_data/training/2/5807.jpg b/test_data/training/2/5807.jpg new file mode 100644 index 00000000000..31a7ae8447a Binary files /dev/null and b/test_data/training/2/5807.jpg differ diff --git a/test_data/training/2/5865.jpg b/test_data/training/2/5865.jpg new file mode 100644 index 00000000000..c4b7b7d7b5e Binary files /dev/null and b/test_data/training/2/5865.jpg differ diff --git a/test_data/training/3/5333.jpg b/test_data/training/3/5333.jpg new file mode 100644 index 00000000000..270b9594583 Binary files /dev/null and b/test_data/training/3/5333.jpg differ diff --git a/test_data/training/3/5410.jpg b/test_data/training/3/5410.jpg new file mode 100644 index 00000000000..3b5fbd7e387 Binary files /dev/null and b/test_data/training/3/5410.jpg differ diff --git a/test_data/training/3/5519.jpg b/test_data/training/3/5519.jpg new file mode 100644 index 00000000000..c9afdc621ec Binary files /dev/null and b/test_data/training/3/5519.jpg differ diff --git a/test_data/training/3/5577.jpg b/test_data/training/3/5577.jpg new file mode 100644 index 00000000000..3b0f2194d39 Binary files /dev/null and b/test_data/training/3/5577.jpg differ diff --git a/test_data/training/3/5586.jpg b/test_data/training/3/5586.jpg new file mode 100644 index 00000000000..9fae24b740c Binary files /dev/null and b/test_data/training/3/5586.jpg differ diff --git a/test_data/training/3/5710.jpg b/test_data/training/3/5710.jpg new file mode 100644 index 00000000000..1c6186204e6 Binary files /dev/null and b/test_data/training/3/5710.jpg differ diff --git a/test_data/training/3/5714.jpg b/test_data/training/3/5714.jpg new file mode 100644 index 00000000000..40e3aac76b6 Binary files /dev/null and b/test_data/training/3/5714.jpg differ diff --git a/test_data/training/3/5813.jpg b/test_data/training/3/5813.jpg new file mode 100644 index 00000000000..077b66fad09 Binary files /dev/null and b/test_data/training/3/5813.jpg differ diff --git a/test_data/training/3/5869.jpg b/test_data/training/3/5869.jpg new file mode 100644 index 00000000000..38cd430c899 Binary files /dev/null and b/test_data/training/3/5869.jpg differ diff --git a/test_data/training/3/6093.jpg b/test_data/training/3/6093.jpg new file mode 100644 index 00000000000..891f99ee766 Binary files /dev/null and b/test_data/training/3/6093.jpg differ diff --git a/test_data/training/4/4887.jpg b/test_data/training/4/4887.jpg new file mode 100644 index 00000000000..c57cf664dda Binary files /dev/null and b/test_data/training/4/4887.jpg differ diff --git a/test_data/training/4/4972.jpg b/test_data/training/4/4972.jpg new file mode 100644 index 00000000000..6c918b7e56e Binary files /dev/null and b/test_data/training/4/4972.jpg differ diff --git a/test_data/training/4/5052.jpg b/test_data/training/4/5052.jpg new file mode 100644 index 00000000000..0fcac89fbdc Binary files /dev/null and b/test_data/training/4/5052.jpg differ diff --git a/test_data/training/4/5092.jpg b/test_data/training/4/5092.jpg new file mode 100644 index 00000000000..7a3d239589b Binary files /dev/null and b/test_data/training/4/5092.jpg differ diff --git a/test_data/training/4/5123.jpg b/test_data/training/4/5123.jpg new file mode 100644 index 00000000000..9ef0650fca3 Binary files /dev/null and b/test_data/training/4/5123.jpg differ diff --git a/test_data/training/4/5348.jpg b/test_data/training/4/5348.jpg new file mode 100644 index 00000000000..01628970e2f Binary files /dev/null and b/test_data/training/4/5348.jpg differ diff --git a/test_data/training/4/5368.jpg b/test_data/training/4/5368.jpg new file mode 100644 index 00000000000..ea821ebd169 Binary files /dev/null and b/test_data/training/4/5368.jpg differ diff --git a/test_data/training/4/5610.jpg b/test_data/training/4/5610.jpg new file mode 100644 index 00000000000..147bc02a013 Binary files /dev/null and b/test_data/training/4/5610.jpg differ diff --git a/test_data/training/4/5685.jpg b/test_data/training/4/5685.jpg new file mode 100644 index 00000000000..df7457c07f0 Binary files /dev/null and b/test_data/training/4/5685.jpg differ diff --git a/test_data/training/4/5793.jpg b/test_data/training/4/5793.jpg new file mode 100644 index 00000000000..06f361a9f70 Binary files /dev/null and b/test_data/training/4/5793.jpg differ diff --git a/test_data/training/5/4442.jpg b/test_data/training/5/4442.jpg new file mode 100644 index 00000000000..f20120de83d Binary files /dev/null and b/test_data/training/5/4442.jpg differ diff --git a/test_data/training/5/4506.jpg b/test_data/training/5/4506.jpg new file mode 100644 index 00000000000..244033c84c6 Binary files /dev/null and b/test_data/training/5/4506.jpg differ diff --git a/test_data/training/5/4592.jpg b/test_data/training/5/4592.jpg new file mode 100644 index 00000000000..5e8e412204d Binary files /dev/null and b/test_data/training/5/4592.jpg differ diff --git a/test_data/training/5/4707.jpg b/test_data/training/5/4707.jpg new file mode 100644 index 00000000000..aa220cfd40c Binary files /dev/null and b/test_data/training/5/4707.jpg differ diff --git a/test_data/training/5/4745.jpg b/test_data/training/5/4745.jpg new file mode 100644 index 00000000000..0be35ce1a89 Binary files /dev/null and b/test_data/training/5/4745.jpg differ diff --git a/test_data/training/5/4888.jpg b/test_data/training/5/4888.jpg new file mode 100644 index 00000000000..690974b9243 Binary files /dev/null and b/test_data/training/5/4888.jpg differ diff --git a/test_data/training/5/5100.jpg b/test_data/training/5/5100.jpg new file mode 100644 index 00000000000..66e9791448f Binary files /dev/null and b/test_data/training/5/5100.jpg differ diff --git a/test_data/training/5/5118.jpg b/test_data/training/5/5118.jpg new file mode 100644 index 00000000000..fba8fc0bbd8 Binary files /dev/null and b/test_data/training/5/5118.jpg differ diff --git a/test_data/training/5/5282.jpg b/test_data/training/5/5282.jpg new file mode 100644 index 00000000000..d6f5fecdf38 Binary files /dev/null and b/test_data/training/5/5282.jpg differ diff --git a/test_data/training/5/5305.jpg b/test_data/training/5/5305.jpg new file mode 100644 index 00000000000..b45f20e979a Binary files /dev/null and b/test_data/training/5/5305.jpg differ diff --git a/test_data/training/6/5076.jpg b/test_data/training/6/5076.jpg new file mode 100644 index 00000000000..cd25c25e34f Binary files /dev/null and b/test_data/training/6/5076.jpg differ diff --git a/test_data/training/6/5231.jpg b/test_data/training/6/5231.jpg new file mode 100644 index 00000000000..b6bb0822bbe Binary files /dev/null and b/test_data/training/6/5231.jpg differ diff --git a/test_data/training/6/5260.jpg b/test_data/training/6/5260.jpg new file mode 100644 index 00000000000..7ea23d8309e Binary files /dev/null and b/test_data/training/6/5260.jpg differ diff --git a/test_data/training/6/5435.jpg b/test_data/training/6/5435.jpg new file mode 100644 index 00000000000..bde414eb346 Binary files /dev/null and b/test_data/training/6/5435.jpg differ diff --git a/test_data/training/6/5553.jpg b/test_data/training/6/5553.jpg new file mode 100644 index 00000000000..be201dd32c9 Binary files /dev/null and b/test_data/training/6/5553.jpg differ diff --git a/test_data/training/6/5567.jpg b/test_data/training/6/5567.jpg new file mode 100644 index 00000000000..d4c8f08a394 Binary files /dev/null and b/test_data/training/6/5567.jpg differ diff --git a/test_data/training/6/5743.jpg b/test_data/training/6/5743.jpg new file mode 100644 index 00000000000..9568520953c Binary files /dev/null and b/test_data/training/6/5743.jpg differ diff --git a/test_data/training/6/5823.jpg b/test_data/training/6/5823.jpg new file mode 100644 index 00000000000..38f3616a2bb Binary files /dev/null and b/test_data/training/6/5823.jpg differ diff --git a/test_data/training/6/5849.jpg b/test_data/training/6/5849.jpg new file mode 100644 index 00000000000..4ed4d40124e Binary files /dev/null and b/test_data/training/6/5849.jpg differ diff --git a/test_data/training/6/5899.jpg b/test_data/training/6/5899.jpg new file mode 100644 index 00000000000..c155e7fee5f Binary files /dev/null and b/test_data/training/6/5899.jpg differ diff --git a/test_data/training/7/5481.jpg b/test_data/training/7/5481.jpg new file mode 100644 index 00000000000..580a2232115 Binary files /dev/null and b/test_data/training/7/5481.jpg differ diff --git a/test_data/training/7/5488.jpg b/test_data/training/7/5488.jpg new file mode 100644 index 00000000000..8741818889e Binary files /dev/null and b/test_data/training/7/5488.jpg differ diff --git a/test_data/training/7/5506.jpg b/test_data/training/7/5506.jpg new file mode 100644 index 00000000000..0d369463fce Binary files /dev/null and b/test_data/training/7/5506.jpg differ diff --git a/test_data/training/7/5634.jpg b/test_data/training/7/5634.jpg new file mode 100644 index 00000000000..7220cd5bd6c Binary files /dev/null and b/test_data/training/7/5634.jpg differ diff --git a/test_data/training/7/5721.jpg b/test_data/training/7/5721.jpg new file mode 100644 index 00000000000..71754249322 Binary files /dev/null and b/test_data/training/7/5721.jpg differ diff --git a/test_data/training/7/5834.jpg b/test_data/training/7/5834.jpg new file mode 100644 index 00000000000..030b175e1e0 Binary files /dev/null and b/test_data/training/7/5834.jpg differ diff --git a/test_data/training/7/5934.jpg b/test_data/training/7/5934.jpg new file mode 100644 index 00000000000..c612ddb33cc Binary files /dev/null and b/test_data/training/7/5934.jpg differ diff --git a/test_data/training/7/6036.jpg b/test_data/training/7/6036.jpg new file mode 100644 index 00000000000..b2efe0f1166 Binary files /dev/null and b/test_data/training/7/6036.jpg differ diff --git a/test_data/training/7/6194.jpg b/test_data/training/7/6194.jpg new file mode 100644 index 00000000000..fe10cb952b2 Binary files /dev/null and b/test_data/training/7/6194.jpg differ diff --git a/test_data/training/7/6204.jpg b/test_data/training/7/6204.jpg new file mode 100644 index 00000000000..3c55aa80e27 Binary files /dev/null and b/test_data/training/7/6204.jpg differ diff --git a/test_data/training/8/4880.jpg b/test_data/training/8/4880.jpg new file mode 100644 index 00000000000..b7f6b1f021f Binary files /dev/null and b/test_data/training/8/4880.jpg differ diff --git a/test_data/training/8/4933.jpg b/test_data/training/8/4933.jpg new file mode 100644 index 00000000000..6c31f6428a0 Binary files /dev/null and b/test_data/training/8/4933.jpg differ diff --git a/test_data/training/8/4938.jpg b/test_data/training/8/4938.jpg new file mode 100644 index 00000000000..6afd3646baa Binary files /dev/null and b/test_data/training/8/4938.jpg differ diff --git a/test_data/training/8/5001.jpg b/test_data/training/8/5001.jpg new file mode 100644 index 00000000000..03ae28b3809 Binary files /dev/null and b/test_data/training/8/5001.jpg differ diff --git a/test_data/training/8/5039.jpg b/test_data/training/8/5039.jpg new file mode 100644 index 00000000000..e9047a66eab Binary files /dev/null and b/test_data/training/8/5039.jpg differ diff --git a/test_data/training/8/5057.jpg b/test_data/training/8/5057.jpg new file mode 100644 index 00000000000..a5596ae7a55 Binary files /dev/null and b/test_data/training/8/5057.jpg differ diff --git a/test_data/training/8/5341.jpg b/test_data/training/8/5341.jpg new file mode 100644 index 00000000000..8d3724846d9 Binary files /dev/null and b/test_data/training/8/5341.jpg differ diff --git a/test_data/training/8/5462.jpg b/test_data/training/8/5462.jpg new file mode 100644 index 00000000000..19e4e198517 Binary files /dev/null and b/test_data/training/8/5462.jpg differ diff --git a/test_data/training/8/5785.jpg b/test_data/training/8/5785.jpg new file mode 100644 index 00000000000..13168a113e5 Binary files /dev/null and b/test_data/training/8/5785.jpg differ diff --git a/test_data/training/8/5844.jpg b/test_data/training/8/5844.jpg new file mode 100644 index 00000000000..8eebfd4ecb1 Binary files /dev/null and b/test_data/training/8/5844.jpg differ diff --git a/test_data/training/9/5186.jpg b/test_data/training/9/5186.jpg new file mode 100644 index 00000000000..1946ddc421c Binary files /dev/null and b/test_data/training/9/5186.jpg differ diff --git a/test_data/training/9/5193.jpg b/test_data/training/9/5193.jpg new file mode 100644 index 00000000000..7d33ebc1545 Binary files /dev/null and b/test_data/training/9/5193.jpg differ diff --git a/test_data/training/9/5444.jpg b/test_data/training/9/5444.jpg new file mode 100644 index 00000000000..2ba38aa056e Binary files /dev/null and b/test_data/training/9/5444.jpg differ diff --git a/test_data/training/9/5541.jpg b/test_data/training/9/5541.jpg new file mode 100644 index 00000000000..b1b80775ccb Binary files /dev/null and b/test_data/training/9/5541.jpg differ diff --git a/test_data/training/9/5579.jpg b/test_data/training/9/5579.jpg new file mode 100644 index 00000000000..2e2fd466741 Binary files /dev/null and b/test_data/training/9/5579.jpg differ diff --git a/test_data/training/9/5688.jpg b/test_data/training/9/5688.jpg new file mode 100644 index 00000000000..2f1ca502a83 Binary files /dev/null and b/test_data/training/9/5688.jpg differ diff --git a/test_data/training/9/5756.jpg b/test_data/training/9/5756.jpg new file mode 100644 index 00000000000..5cf86212ab7 Binary files /dev/null and b/test_data/training/9/5756.jpg differ diff --git a/test_data/training/9/5786.jpg b/test_data/training/9/5786.jpg new file mode 100644 index 00000000000..9a2e09e730d Binary files /dev/null and b/test_data/training/9/5786.jpg differ diff --git a/test_data/training/9/5870.jpg b/test_data/training/9/5870.jpg new file mode 100644 index 00000000000..c26ff904c57 Binary files /dev/null and b/test_data/training/9/5870.jpg differ diff --git a/test_data/training/9/5931.jpg b/test_data/training/9/5931.jpg new file mode 100644 index 00000000000..a60c7183796 Binary files /dev/null and b/test_data/training/9/5931.jpg differ diff --git a/tests/ludwig/encoders/test_metaformer_encoder.py b/tests/ludwig/encoders/test_metaformer_encoder.py new file mode 100644 index 00000000000..5835d3d23d2 --- /dev/null +++ b/tests/ludwig/encoders/test_metaformer_encoder.py @@ -0,0 +1,68 @@ +import os +import pytest +import torch + +from ludwig.constants import ENCODER_OUTPUT +from ludwig.utils.misc_utils import set_random_seed + +# Ensure we do not attempt to download pretrained weights during tests. +os.environ.setdefault("METAFORMER_PRETRAINED", "0") + +try: + # Verify backbone registry availability early; skip if integration not present. + from metaformer_integration.metaformer_models import default_cfgs as _mf_cfgs # noqa: F401 + META_INTEGRATION_AVAILABLE = True +except Exception: + META_INTEGRATION_AVAILABLE = False + +pytestmark = pytest.mark.skipif( + not META_INTEGRATION_AVAILABLE, + reason="MetaFormer integration not available (metaformer_integration.metaformer_models import failed).", +) + +def test_metaformer_encoder_basic_forward_and_shape(): + from ludwig.encoders.image.metaformer import MetaFormerEncoder + + set_random_seed(1234) + + # Use small dimensions (will be internally adapted to model expected size). + encoder = MetaFormerEncoder( + height=28, + width=28, + num_channels=1, + model_name="caformer_s18", + use_pretrained=False, + trainable=True, + output_size=64, + ) + + batch_size = 2 + x = torch.rand(batch_size, 1, 28, 28) + out = encoder(x) + assert ENCODER_OUTPUT in out, "Encoder output key missing." + rep = out[ENCODER_OUTPUT] + assert rep.shape[0] == batch_size, "Batch dimension mismatch." + assert tuple(rep.shape[1:]) == tuple(encoder.output_shape), "Representation shape mismatch." + +def test_metaformer_encoder_parameter_updates(): + from ludwig.encoders.image.metaformer import MetaFormerEncoder + from tests.integration_tests.parameter_update_utils import check_module_parameters_updated + + set_random_seed(5678) + + encoder = MetaFormerEncoder( + height=32, + width=32, + num_channels=3, + model_name="caformer_s18", + use_pretrained=False, + trainable=True, + output_size=32, + ) + + inputs = torch.rand(2, 3, 32, 32) + outputs = encoder(inputs) + target = torch.randn_like(outputs[ENCODER_OUTPUT]) + + fpc, tpc, upc, not_updated = check_module_parameters_updated(encoder, (inputs,), target) + assert tpc == upc, f"Some parameters did not update: {not_updated}"