Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"numpy",
"omegaconf",
"pre-commit",
"pydantic<2.7", # silence pydantic v2 warning
"pydantic", # silence pydantic v2 warning
"pytest",
"pycortex",
"scikit-learn",
Expand Down
Empty file added src/__init__.py
Empty file.
119 changes: 119 additions & 0 deletions src/simclr/config/hcp_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Name of the run. For output directory base name and wandb.
name: pretrain_simclr_vit_base

# Description of the run. Goes in wandb notes.
notes: "SimCLR pre-training with a ViT-Base backbone on the HCP dataset."

# Root output directory
output_dir: ./checkpoints

# How often to print logs to the console during training/evaluation.
print_freq: 100

# --- Data Config ---
# All parameters related to data shape and transformations.
data:
# The 2D spatial size of the input images [height, width].
# The model gets the time dimension from `num_frames`.
img_size: [224, 560]
in_chans: 1
patch_size: 16
num_frames: 16
t_patch_size: 16
clip_vmax: 3.0
normalize: frame
random_crop: false
crop_kwargs:
scale: [0.9, 1.0]
ratio: [2.5, 2.5]
interpolation: 3

# --- Model Config ---
model:
contrastive_mode: simclr
backbone_name: mae_vit_base
mask_ratio: 0.9
temperature: 0.1

# Arguments passed to the backbone model constructor.
# Architectural details like embed_dim are set by the `backbone_name` preset.
backbone_kwargs:
pos_embed: sep
class_token: true
drop_path_rate: 0.0

# Arguments passed to the projection/prediction head constructors.
head_kwargs:
hidden_dim: 2048
out_dim: 128

# --- Datasets ---
# Replace the placeholder paths with the actual locations of your datasets.
datasets:
hcp-train:
type: flat-wds
url: "path/to/your/hcp-flat/hcp-flat_{0000..1799}.tar"
clipping: random
clipping_kwargs: {oversample: 4.0}
shuffle: true
buffer_size: 1000
samples_per_epoch: 200000

hcp-train-subset:
type: flat-clips
root: "path/to/your/flat-clips/hcp-train-clips-16t"
shuffle: false

hcp-val:
type: flat-clips
root: "path/to/your/flat-clips/hcp-val-clips-16t"
shuffle: false

nsd-val:
type: flat-clips
root: "path/to/your/flat-clips/nsd-subj01-clips-16t"
shuffle: false

# Which datasets to use for training and evaluation.
train_dataset: hcp-train
eval_datasets:
- hcp-val
- nsd-val

# --- Data Loader ---
num_workers: 8

# --- Optimization ---
optim:
epochs: 100
batch_size: 32
accum_iter: 1
base_lr: 1e-3
min_lr: 0.0
warmup_epochs: 5
start_warmup_lr: 1e-6
weight_decay: 0.05
betas: [0.9, 0.95]
clip_grad: 1.0

# --- Training Settings ---
amp: true
amp_dtype: float16

# --- Checkpointing ---
ckpt: null
resume: true
auto_resume: true
start_epoch: 0
max_checkpoints: 1
checkpoint_period: 1

# --- Misc ---
device: cuda
seed: 7338
debug: false

# --- Logging ---
wandb: true
wandb_entity: null
wandb_project: fMRI-foundation-model
24 changes: 24 additions & 0 deletions src/simclr/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
from torch.utils.data import default_collate

class SimCLRTransform:

def __init__(self, base_transform):

self.base_transform = base_transform

def __call__(self, raw_sample):

view_1 = self.base_transform(raw_sample)
view_2 = self.base_transform(raw_sample)
return view_1, view_2

def simclr_collate(batch):

views_1 = [item[0] for item in batch]
views_2 = [item[1] for item in batch]

collated_view_1 = default_collate(views_1)
collated_view_2 = default_collate(views_2)

return collated_view_1, collated_view_2
90 changes: 90 additions & 0 deletions src/simclr/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# SimSiam: https://github.com/facebookresearch/simsiam
# --------------------------------------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F
import os

class SimCLRProjectionHead(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int = 2048, out_dim: int = 128):
super().__init__()
self.head = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim),
)

def forward(self, x):
return self.head(x)

class SimSiamProjectionHead(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int = 2048, out_dim: int = 2048):
super().__init__()

self.head = nn.Sequential(
nn.Linear(in_dim, hidden_dim, bias=False),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim, bias=False),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim, bias=False),
nn.BatchNorm1d(out_dim, affine=False),
)

def forward(self, x):
return self.head(x)

class SimSiamPredictionHead(nn.Module):
def __init__(self, in_dim: int = 2048, hidden_dim: int = 512, out_dim: int = 2048):
super().__init__()
self.head = nn.Sequential(
nn.Linear(in_dim, hidden_dim, bias=False),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim),
)

def forward(self, x):
return self.head(x)


def nt_xent_loss(z1: torch.Tensor, z2: torch.Tensor, temperature: float = 0.5, distributed: bool = False):

z1 = F.normalize(z1, dim=-1)
z2 = F.normalize(z2, dim=-1)

all_z = torch.cat([z1, z2], dim=0)

if distributed:
all_z_dist = torch.cat(torch.distributed.nn.all_gather(all_z), dim=0)
rank = int(os.getenv("LOCAL_RANK", 0))
else:
all_z_dist = all_z
rank = 0

logits = torch.matmul(all_z, all_z_dist.T) / temperature

batch_size = z1.shape[0]
labels_v1 = torch.arange(batch_size, device=z1.device) + batch_size
labels_v2 = torch.arange(batch_size, device=z1.device)
labels = torch.cat([labels_v1, labels_v2], dim=0)

labels = labels + (rank * 2 * batch_size)

return F.cross_entropy(logits, labels)

def simsiam_loss(p1: torch.Tensor, z2: torch.Tensor, p2: torch.Tensor, z1: torch.Tensor):

z1 = z1.detach()
z2 = z2.detach()

loss1 = -F.cosine_similarity(p1, z2, dim=-1).mean()
loss2 = -F.cosine_similarity(p2, z1, dim=-1).mean()

return (loss1 + loss2) / 2
Loading