Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.blocks.pos_embed_utils import build_fourier_position_embedding, build_sincos_position_embedding
from monai.networks.layers import Conv, trunc_normal_
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils.module import look_up_option

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"}
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"}


class PatchEmbeddingBlock(nn.Module):
Expand All @@ -53,6 +54,7 @@ def __init__(
pos_embed_type: str = "learnable",
dropout_rate: float = 0.0,
spatial_dims: int = 3,
pos_embed_kwargs: Optional[dict] = None,
) -> None:
"""
Args:
Expand All @@ -65,6 +67,8 @@ def __init__(
pos_embed_type: position embedding layer type.
dropout_rate: fraction of the input units to drop.
spatial_dims: number of spatial dimensions.
pos_embed_kwargs: additional arguments for position embedding. For `sincos`, it can contain
`temperature` and for fourier it can contain `scales`.
"""

super().__init__()
Expand Down Expand Up @@ -105,6 +109,8 @@ def __init__(
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)

pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs

if self.pos_embed_type == "none":
pass
elif self.pos_embed_type == "learnable":
Expand All @@ -114,7 +120,17 @@ def __init__(
for in_size, pa_size in zip(img_size, patch_size):
grid_size.append(in_size // pa_size)

self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
self.position_embeddings = build_sincos_position_embedding(
grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
)
elif self.pos_embed_type == "fourier":
grid_size = []
for in_size, pa_size in zip(img_size, patch_size):
grid_size.append(in_size // pa_size)

self.position_embeddings = build_fourier_position_embedding(
grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
)
else:
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")

Expand Down
57 changes: 56 additions & 1 deletion monai/networks/blocks/pos_embed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn

__all__ = ["build_sincos_position_embedding"]
__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"]


# From PyTorch internals
Expand All @@ -32,6 +32,61 @@ def parse(x):
return parse


def build_fourier_position_embedding(
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0
) -> torch.nn.Parameter:
"""
Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension,
spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant
points more distinguishable.
Reference: https://arxiv.org/abs/2509.02488

Args:
grid_size (List[int]): The size of the grid in each spatial dimension.
embed_dim (int): The dimension of the embedding.
spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
scales (List[float]): The scale for every spatial dimension. If a single float is provided,
the same scale is used for all dimensions.

Returns:
pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter.
"""

to_tuple = _ntuple(spatial_dims)
grid_size_t = to_tuple(grid_size)
if len(grid_size_t) != spatial_dims:
raise ValueError(f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.")

if embed_dim % (2 * spatial_dims) != 0:
raise AssertionError(
f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding"
)

# Ensure scales is a tensor of shape (spatial_dims,)
if isinstance(scales, float):
scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float)
elif isinstance(scales, (list, tuple)):
if len(scales) != spatial_dims:
raise ValueError(f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}")
scales_tensor = torch.tensor(scales, dtype=torch.float)
else:
raise TypeError(f"scales must be float or list of floats, got {type(scales)}")

gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims))
gaussians = gaussians * scales_tensor

position_indeces = [torch.linspace(0, 1, x) for x in grid_size_t]
positions = torch.stack(torch.meshgrid(*position_indeces, indexing="ij"), dim=-1)
positions = positions.flatten(end_dim=-2)

x_proj = (2.0 * torch.pi * positions) @ gaussians.T

pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
pos_emb = nn.Parameter(pos_emb[None, :, :], requires_grad=False)

return pos_emb


def build_sincos_position_embedding(
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0
) -> torch.nn.Parameter:
Expand Down
39 changes: 39 additions & 0 deletions tests/networks/blocks/test_patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ def test_sincos_pos_embed(self):

self.assertEqual(net.position_embeddings.requires_grad, False)

def test_fourier_pos_embed(self):
net = PatchEmbeddingBlock(
in_channels=1,
img_size=(32, 32, 32),
patch_size=(8, 8, 8),
hidden_size=96,
num_heads=8,
pos_embed_type="fourier",
dropout_rate=0.5,
)

self.assertEqual(net.position_embeddings.requires_grad, False)

def test_learnable_pos_embed(self):
net = PatchEmbeddingBlock(
in_channels=1,
Expand All @@ -101,6 +114,32 @@ def test_learnable_pos_embed(self):
self.assertEqual(net.position_embeddings.requires_grad, True)

def test_ill_arg(self):
with self.assertRaises(ValueError):
PatchEmbeddingBlock(
in_channels=1,
img_size=(128, 128, 128),
patch_size=(16, 16, 16),
hidden_size=128,
num_heads=12,
proj_type="conv",
dropout_rate=0.1,
pos_embed_type="fourier",
pos_embed_kwargs=dict(scales=[1.0, 1.0]),
)

with self.assertRaises(ValueError):
PatchEmbeddingBlock(
in_channels=1,
img_size=(128, 128),
patch_size=(16, 16),
hidden_size=128,
num_heads=12,
proj_type="conv",
dropout_rate=0.1,
pos_embed_type="fourier",
pos_embed_kwargs=dict(scales=[1.0, 1.0, 1.0]),
)

with self.assertRaises(ValueError):
PatchEmbeddingBlock(
in_channels=1,
Expand Down
Loading