Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ repos:
- id: pyupgrade
args: [--py36-plus]
- repo: https://github.com/PyCQA/docformatter
rev: v1.5.1
rev: v1.5.0
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
Expand Down
1 change: 1 addition & 0 deletions ludwig/encoders/image/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import ludwig.encoders.image.base
import ludwig.encoders.image.torchvision # noqa
import ludwig.encoders.image.metaformer # noqa
43 changes: 43 additions & 0 deletions ludwig/encoders/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(
conv_layers: Optional[List[Dict]] = None,
num_conv_layers: Optional[int] = None,
num_channels: int = None,
# Optional custom CAFormer backbone (e.g. caformer_s18, caformer_s36, caformer_m36, caformer_b36)
custom_model: Optional[str] = None,
out_channels: int = 32,
kernel_size: Union[int, Tuple[int]] = 3,
stride: Union[int, Tuple[int]] = 1,
Expand Down Expand Up @@ -87,6 +89,35 @@ def __init__(
super().__init__()
self.config = encoder_config

self._use_caformer = False
self._output_shape_override: Optional[torch.Size] = None
if custom_model and isinstance(custom_model, str) and custom_model.startswith("caformer_"):
try:
from caformer_setup_backup.caformer_stacked_cnn import CAFormerStackedCNN
# Instantiate CAFormer encoder (it internally handles resizing / channel adapting)
self.caformer_encoder = CAFormerStackedCNN(
height=height if height is not None else 224,
width=width if width is not None else 224,
num_channels=num_channels if num_channels is not None else 3,
output_size=output_size,
custom_model=custom_model,
use_pretrained=True,
trainable=True,
)
self._use_caformer = True
# Override forward dynamically
self.forward = self._forward_caformer # type: ignore
# Store output shape
if hasattr(self.caformer_encoder, "output_shape"):
# CAFormerStackedCNN.output_shape returns a list
shape_list = self.caformer_encoder.output_shape
if isinstance(shape_list, (list, tuple)):
self._output_shape_override = torch.Size(shape_list)
logger.info(f"Using CAFormer backbone '{custom_model}' in place of stacked_cnn.")
except Exception as e:
logger.error(f"Failed to initialize CAFormer encoder '{custom_model}': {e}")
raise

logger.debug(f" {self.name}")

# map parameter input feature config names to internal names
Expand Down Expand Up @@ -144,6 +175,16 @@ def __init__(
default_dropout=fc_dropout,
)

def _forward_caformer(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
Forward pass when a CAFormer backbone is used.
CAFormerStackedCNN.forward returns a dict with key 'encoder_output'.
"""
if not hasattr(self, "caformer_encoder"):
raise RuntimeError("CAFormer encoder not initialized despite _use_caformer=True.")
out = self.caformer_encoder(inputs)
return {ENCODER_OUTPUT: out.get("encoder_output")}

def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
"""
:param inputs: The inputs fed into the encoder.
Expand All @@ -162,6 +203,8 @@ def get_schema_cls() -> Type[ImageEncoderConfig]:

@property
def output_shape(self) -> torch.Size:
if self._output_shape_override is not None:
return self._output_shape_override
return self.fc_stack.output_shape

@property
Expand Down
143 changes: 143 additions & 0 deletions ludwig/encoders/image/metaformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#! /usr/bin/env python
# Copyright (c) 2025
#
# New MetaFormer / CAFormer style image encoder for Ludwig.
#
# This integrates ConvFormer / CAFormer family backbones as a first-class
# Ludwig encoder, avoiding runtime monkey patching of existing encoders.
#
# The implementation wraps the existing CAFormerStackedCNN (renamed conceptually
# to MetaFormerStackedCNN) logic currently living in caformer_setup_backup.
#
# TODO (follow-up in PR):
# - Move / refactor caformer_setup_backup/ code into ludwig/modules or a
# dedicated ludwig/encoders/image/metaformer_backbones.py file.
# - Add a proper schema definition (see get_schema_cls TODO below).
# - Add unit tests under tests/ludwig/encoders/test_metaformer_encoder.py
# - Add release note and documentation snippet.

from typing import Any, Dict, Optional, Type
import logging

import torch
import torch.nn as nn

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import ENCODER_OUTPUT, IMAGE
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder
from ludwig.encoders.types import EncoderOutputDict

logger = logging.getLogger(__name__)


@DeveloperAPI
@register_encoder("metaformer", IMAGE)
class MetaFormerEncoder(Encoder):
"""MetaFormerEncoder
Provides access to MetaFormer / CAFormer style convolution-attention hybrids
(e.g., caformer_s18 / m36 / b36 etc.) as a Ludwig image encoder.

Configuration (proposed):
type: metaformer
model_name: caformer_s18 # required
use_pretrained: true # optional (default True)
trainable: true # optional
output_size: 128 # dimensionality after projection head

Behavior:
- Loads the specified backbone.
- Adapts input spatial size & channels as needed (handled by backbone wrapper).
- Emits a dense representation of shape (output_size,).
"""

def __init__(
self,
height: int,
width: int,
num_channels: int = 3,
model_name: Optional[str] = None,
use_pretrained: bool = True,
trainable: bool = True,
output_size: int = 128,
encoder_config=None,
**kwargs: Any,
):
super().__init__()
self.config = encoder_config
self.model_name = model_name or "caformer_s18"
self.use_pretrained = use_pretrained
self.trainable = trainable
self.output_size = output_size

# Import existing implementation (currently in backup namespace).
# In a polished PR this code should be relocated inside core tree.
try:
# Updated to use integrated metaformer implementation
from metaformer_integration.metaformer_stacked_cnn import MetaFormerStackedCNN as _BackboneWrapper
except Exception as e: # pragma: no cover
raise ImportError(
"Failed to import CAFormer / MetaFormer backbone code. "
"Ensure integration code is migrated from caformer_setup_backup."
) from e

logger.info(
"Initializing MetaFormerEncoder backbone=%s pretrained=%s trainable=%s output_size=%d",
self.model_name,
self.use_pretrained,
self.trainable,
self.output_size,
)

self.backbone_wrapper: nn.Module = _BackboneWrapper(
height=height if height is not None else 224,
width=width if width is not None else 224,
num_channels=num_channels if num_channels is not None else 3,
output_size=output_size,
custom_model=self.model_name,
use_pretrained=self.use_pretrained,
trainable=self.trainable,
)

# Expose shapes
self._input_shape = (num_channels, height, width)
# Backbone wrapper exposes output_shape as list -> convert to torch.Size
raw_out_shape = getattr(self.backbone_wrapper, "output_shape", [self.output_size])
if isinstance(raw_out_shape, (list, tuple)):
self._output_shape = torch.Size(raw_out_shape)
else:
self._output_shape = torch.Size([self.output_size])

# Freeze if not trainable
if not self.trainable:
for p in self.backbone_wrapper.parameters():
p.requires_grad = False

def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
# Expect shape: [B, C, H, W]
if not isinstance(inputs, torch.Tensor):
raise TypeError("MetaFormerEncoder forward expects a torch.Tensor input.")
out_dict = self.backbone_wrapper(inputs)
if isinstance(out_dict, dict) and "encoder_output" in out_dict:
rep = out_dict["encoder_output"]
else:
# Fallback: treat raw module output as representation
rep = out_dict
return {ENCODER_OUTPUT: rep}

@property
def input_shape(self) -> torch.Size:
return torch.Size(self._input_shape)

@property
def output_shape(self) -> torch.Size:
return self._output_shape

@staticmethod
def get_schema_cls() -> Type[Any]:
# Return dedicated MetaFormerConfig (added in schema tree)
try:
from ludwig.schema.encoders.image.metaformer import MetaFormerConfig
return MetaFormerConfig # type: ignore
except Exception as e: # pragma: no cover
raise RuntimeError("MetaFormerConfig schema import failed.") from e
1 change: 1 addition & 0 deletions ludwig/schema/encoders/image/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import ludwig.schema.encoders.image.base
import ludwig.schema.encoders.image.torchvision # noqa
import ludwig.schema.encoders.image.metaformer # noqa
9 changes: 9 additions & 0 deletions ludwig/schema/encoders/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def module_name():
description=ENCODER_METADATA["Stacked2DCNN"]["type"].long_description,
)

# Added to enable CAFormer integration via stacked_cnn encoder.
custom_model: Optional[str] = schema_utils.String(
default=None,
allow_none=True,
description="If set to one of ['caformer_s18','caformer_s36','caformer_m36','caformer_b36'], "
"the standard stacked_cnn encoder will be dynamically replaced at runtime by a CAFormer backbone "
"defined in caformer_setup_backup. Requires calling the patch utility prior to model construction.",
)

conv_dropout: Optional[int] = schema_utils.FloatRange(
default=0.0,
min=0,
Expand Down
76 changes: 76 additions & 0 deletions ludwig/schema/encoders/image/metaformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Optional

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import IMAGE
from ludwig.schema import utils as schema_utils
from ludwig.schema.encoders.image.base import ImageEncoderConfig
from ludwig.schema.encoders.utils import register_encoder_config
from ludwig.schema.utils import ludwig_dataclass

@DeveloperAPI
@register_encoder_config("metaformer", IMAGE)
@ludwig_dataclass
class MetaFormerConfig(ImageEncoderConfig):
"""Configuration for the MetaFormer / CAFormer style image encoder.

This schema intentionally avoids referencing ENCODER_METADATA (not yet extended)
to keep the initial integration minimal and self-contained.
"""

@staticmethod
def module_name():
return "MetaFormerEncoder"

type: str = schema_utils.ProtectedString(
"metaformer",
description="MetaFormer / CAFormer image encoder integrating ConvFormer / CAFormer style backbones.",
)

model_name: str = schema_utils.String(
default="caformer_s18",
allow_none=False,
description="Backbone model name (e.g. caformer_s18, caformer_s36, caformer_m36, caformer_b36, etc.).",
)

use_pretrained: bool = schema_utils.Boolean(
default=True,
description="If true, load pretrained backbone weights (if available).",
)

trainable: bool = schema_utils.Boolean(
default=True,
description="If false, freezes backbone parameters.",
)

output_size: int = schema_utils.PositiveInteger(
default=128,
description="Projection head output dimensionality.",
)

height: Optional[int] = schema_utils.NonNegativeInteger(
default=None,
allow_none=True,
description="Input image height (optional; if None, provided by feature preprocessing).",
)

width: Optional[int] = schema_utils.NonNegativeInteger(
default=None,
allow_none=True,
description="Input image width (optional; if None, provided by feature preprocessing).",
)

num_channels: Optional[int] = schema_utils.NonNegativeInteger(
default=None,
allow_none=True,
description="Number of input image channels (e.g. 1 for grayscale, 3 for RGB).",
)

def set_fixed_preprocessing_params(self, model_type: str, preprocessing: "ImagePreprocessingConfig"):
# Allow variable sizes; internal wrapper adapts / pools to model expected size.
preprocessing.requires_equal_dimensions = False
# Leave height/width unset to allow dataset-driven or on-the-fly resizing.
# Channel adaptation is handled dynamically if needed.
if self.height is not None:
preprocessing.height = self.height
if self.width is not None:
preprocessing.width = self.width
18 changes: 17 additions & 1 deletion ludwig/utils/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,23 @@
from typing import Any, Dict, List, Optional, Union

import torch
import torchtext
# Make torchtext optional: guard import to allow pure vision use-cases (e.g., CAFormer image encoder) without binary wheels.
try:
import torchtext # type: ignore
TORCHTEXT_AVAILABLE = True
except Exception as e:
TORCHTEXT_AVAILABLE = False
logger = logging.getLogger(__name__)
logger.warning(f"torchtext import failed, disabling torchtext-based tokenizers: {e}")
class _TorchTextStub: # Minimal stub so attribute access does not immediately crash unless actually used.
__version__ = "0.0.0"
class transforms: # empty namespace
pass
class utils:
@staticmethod
def get_asset_local_path(path):
raise RuntimeError("torchtext unavailable (stubbed).")
torchtext = _TorchTextStub() # type: ignore

from ludwig.constants import PADDING_SYMBOL, UNKNOWN_SYMBOL
from ludwig.utils.data_utils import load_json
Expand Down
32 changes: 32 additions & 0 deletions metaformer_integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# MetaFormer integration package init
from .metaformer_models import get_registered_model, default_cfgs # noqa: F401
from .metaformer_stacked_cnn import (
MetaFormerStackedCNN,
CAFormerStackedCNN,
patch_ludwig_comprehensive,
patch_ludwig_direct,
patch_ludwig_robust,
list_metaformer_models,
get_metaformer_backbone_names,
metaformer_model_exists,
describe_metaformer_model,
) # noqa: F401

def patch():
"""Convenience one-call patch entrypoint (alias of patch_ludwig_comprehensive)."""
return patch_ludwig_comprehensive()

__all__ = [
"MetaFormerStackedCNN",
"CAFormerStackedCNN",
"patch_ludwig_comprehensive",
"patch_ludwig_direct",
"patch_ludwig_robust",
"list_metaformer_models",
"get_metaformer_backbone_names",
"metaformer_model_exists",
"describe_metaformer_model",
"get_registered_model",
"default_cfgs",
"patch",
]
Loading
Loading