Skip to content

Commit a31d539

Browse files
Simplify MLM masking
1 parent 31e4df2 commit a31d539

File tree

4 files changed

+196
-148
lines changed

4 files changed

+196
-148
lines changed

glm_experiments/data/components/mlm_collator.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

glm_experiments/data/dna_datamodule.py

Lines changed: 102 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,79 @@
11
"""Generic DataModule for DNA masked language modeling."""
22

3-
from typing import Optional
4-
53
import numpy as np
64
import torch
75
from Bio.Seq import Seq
86
from datasets import load_dataset
97
from datasets.distributed import split_dataset_by_node
108
from lightning import LightningDataModule
11-
from torch.utils.data import DataLoader
9+
from torch.utils.data import DataLoader, default_collate
1210
from transformers import AutoTokenizer
1311

14-
from glm_experiments.data.components.mlm_collator import (
15-
DataCollatorForLanguageModelingSimplified,
16-
)
12+
13+
def apply_reverse_complement(sequences: list[str]) -> list[str]:
14+
"""Apply random reverse complement augmentation to sequences.
15+
16+
Each sequence is independently randomly assigned to forward or reverse
17+
complement strand with equal probability.
18+
19+
Args:
20+
sequences: List of DNA sequences
21+
22+
Returns:
23+
List of sequences, each randomly on forward or reverse complement strand
24+
"""
25+
n = len(sequences)
26+
strand = np.random.choice(["+", "-"], n)
27+
return [
28+
seq if strand[i] == "+" else str(Seq(seq).reverse_complement())
29+
for i, seq in enumerate(sequences)
30+
]
31+
32+
33+
def apply_mlm_masking(
34+
input_ids: torch.Tensor,
35+
mask_token_id: int,
36+
vocab_size: int,
37+
mlm_probability: float = 0.15,
38+
) -> tuple[torch.Tensor, torch.Tensor]:
39+
"""Apply masked language modeling to input tokens.
40+
41+
Uses standard BERT masking strategy:
42+
- 15% of tokens are selected for masking
43+
- Of those: 80% replaced with [MASK], 10% random token, 10% unchanged
44+
45+
Args:
46+
input_ids: Token IDs of shape (batch_size, seq_len)
47+
mask_token_id: Token ID for [MASK]
48+
vocab_size: Vocabulary size for random replacement
49+
mlm_probability: Probability of selecting a token for masking
50+
51+
Returns:
52+
Tuple of (masked_input_ids, labels) both as int8.
53+
Labels has -100 for non-masked positions (standard PyTorch ignore_index).
54+
"""
55+
input_ids = input_ids.clone().to(torch.int8)
56+
labels = input_ids.clone()
57+
58+
# Select tokens for masking
59+
probability_matrix = torch.full(labels.shape, mlm_probability)
60+
masked_indices = torch.bernoulli(probability_matrix).bool()
61+
labels[~masked_indices] = -100 # standard PyTorch ignore_index
62+
63+
# 80% -> [MASK]
64+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
65+
input_ids[indices_replaced] = mask_token_id
66+
67+
# 10% -> random token (0.5 of remaining 20%)
68+
indices_random = (
69+
torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
70+
)
71+
random_words = torch.randint(vocab_size, labels.shape, dtype=torch.int8)
72+
input_ids[indices_random] = random_words[indices_random]
73+
74+
# 10% -> unchanged (implicit)
75+
76+
return input_ids, labels
1777

1878

1979
class DNADataModule(LightningDataModule):
@@ -60,7 +120,6 @@ def __init__(
60120
self.tokenizer = None
61121
self.data_train = None
62122
self.data_val = None
63-
self.data_collator = None
64123

65124
def prepare_data(self) -> None:
66125
"""Download data and tokenizer (runs on single GPU/process)."""
@@ -86,86 +145,82 @@ def setup(self, stage: str | None = None) -> None:
86145
# Load tokenizer
87146
self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.tokenizer_name) # nosec B615
88147

89-
# Create data collator for MLM
90-
self.data_collator = DataCollatorForLanguageModelingSimplified(
91-
tokenizer=self.tokenizer,
92-
mlm=True,
93-
mlm_probability=self.hparams.mlm_probability,
94-
)
95-
96-
# Load raw dataset with streaming
97-
raw_datasets = load_dataset(self.hparams.dataset_name, streaming=True) # nosec B615
148+
def tokenize(seq: list[str]) -> list[list[int]]:
149+
"""Tokenize sequences to input_ids only."""
150+
return self.tokenizer(
151+
seq,
152+
padding=False,
153+
truncation=False,
154+
return_token_type_ids=False,
155+
return_attention_mask=False,
156+
return_special_tokens_mask=False,
157+
)["input_ids"]
98158

99-
# Tokenization function
100-
def tokenize_function(examples, soft_masked_weight, data_aug=False):
101-
"""Tokenize sequences with optional reverse complement augmentation.
159+
def transform_batch(examples: dict, soft_masked_weight: float, data_aug: bool) -> dict:
160+
"""Transform a batch of examples.
102161
103162
Args:
104163
examples: Batch of examples with 'seq' field
105164
soft_masked_weight: Loss weight for lowercase nucleotides
106165
data_aug: Whether to apply reverse complement augmentation
107166
108167
Returns:
109-
Dictionary with input_ids (torch.uint8), special_tokens_mask, and
110-
loss_weight (torch.float16)
168+
Dictionary with input_ids, labels, and loss_weight (all tensors)
111169
"""
112170
seq = examples["seq"]
113171

114172
# Apply reverse complement augmentation
115173
if data_aug:
116-
n = len(seq)
117-
strand = np.random.choice(["+", "-"], n)
118-
seq = [
119-
seq[i] if strand[i] == "+" else str(Seq(seq[i]).reverse_complement())
120-
for i in range(n)
121-
]
122-
123-
# Tokenize (returns dict with 'input_ids' as list of lists)
124-
tokenized = self.tokenizer(
125-
seq,
126-
return_special_tokens_mask=True,
127-
padding=False,
128-
truncation=False,
129-
)
174+
seq = apply_reverse_complement(seq)
130175

131-
# Convert to tensors
132-
input_ids = torch.tensor(tokenized["input_ids"], dtype=torch.uint8)
133-
special_tokens_mask = torch.tensor(tokenized["special_tokens_mask"], dtype=torch.uint8)
176+
# Tokenize
177+
input_ids = torch.tensor(tokenize(seq), dtype=torch.int8)
134178

135179
# Create loss weights (lower weight for soft-masked lowercase regions)
136-
loss_weight = torch.ones_like(input_ids, dtype=torch.float16)
180+
loss_weight = torch.ones(input_ids.shape, dtype=torch.float16)
137181
for i, s in enumerate(seq):
138182
lowercase_mask = np.array([c.islower() for c in s])
139183
loss_weight[i][lowercase_mask] = soft_masked_weight
140184

185+
# Apply MLM masking
186+
input_ids, labels = apply_mlm_masking(
187+
input_ids,
188+
mask_token_id=self.tokenizer.mask_token_id,
189+
vocab_size=self.tokenizer.vocab_size,
190+
mlm_probability=self.hparams.mlm_probability,
191+
)
192+
141193
return {
142194
"input_ids": input_ids,
143-
"special_tokens_mask": special_tokens_mask,
195+
"labels": labels,
144196
"loss_weight": loss_weight,
145197
}
146198

199+
# Load raw dataset with streaming
200+
raw_datasets = load_dataset(self.hparams.dataset_name, streaming=True) # nosec B615
201+
147202
# Process splits (train and val only)
148203
if stage == "fit" or stage is None:
149204
# Training dataset with augmentation and shuffling
150205
train_dataset = raw_datasets["train"].shuffle(seed=self.hparams.seed)
151206
train_dataset = train_dataset.map(
152-
lambda ex: tokenize_function(
207+
lambda ex: transform_batch(
153208
ex,
154-
self.hparams.soft_masked_loss_weight_train,
209+
soft_masked_weight=self.hparams.soft_masked_loss_weight_train,
155210
data_aug=self.hparams.data_augmentation,
156211
),
157212
batched=True,
158213
remove_columns=list(list(raw_datasets["train"].take(1))[0].keys()),
159214
# drop_last_batch needed for torch.compile to avoid variable batch sizes
160215
drop_last_batch=True,
161-
batch_size=self.hparams.batch_size,
216+
batch_size=self.batch_size_per_device,
162217
)
163218

164219
# Validation dataset (no augmentation, no shuffling)
165220
val_dataset = raw_datasets["validation"].map(
166-
lambda ex: tokenize_function(
221+
lambda ex: transform_batch(
167222
ex,
168-
self.hparams.soft_masked_loss_weight_eval,
223+
soft_masked_weight=self.hparams.soft_masked_loss_weight_eval,
169224
data_aug=False,
170225
),
171226
batched=True,
@@ -196,7 +251,7 @@ def train_dataloader(self) -> DataLoader:
196251
num_workers=self.hparams.num_workers,
197252
pin_memory=self.hparams.pin_memory,
198253
shuffle=False, # Shuffling handled by dataset
199-
collate_fn=self.data_collator,
254+
collate_fn=default_collate,
200255
)
201256

202257
def val_dataloader(self) -> DataLoader:
@@ -207,5 +262,5 @@ def val_dataloader(self) -> DataLoader:
207262
num_workers=self.hparams.num_workers,
208263
pin_memory=self.hparams.pin_memory,
209264
shuffle=False,
210-
collate_fn=self.data_collator,
265+
collate_fn=default_collate,
211266
)

glm_experiments/models/components/bert.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55
import torch.nn.functional as F
66

77

8-
def loss_fn(logits: torch.Tensor, labels: torch.Tensor, loss_weight: torch.Tensor) -> torch.Tensor:
8+
def loss_fn(
9+
logits: torch.Tensor,
10+
labels: torch.Tensor,
11+
loss_weight: torch.Tensor,
12+
) -> torch.Tensor:
913
"""Compute weighted cross-entropy loss.
1014
1115
Args:
1216
logits: Logits of shape (batch, seq_len, vocab_size)
13-
labels: Target labels of shape (batch, seq_len)
17+
labels: Target labels of shape (batch, seq_len), -100 for ignored positions
1418
loss_weight: Loss weights of shape (batch, seq_len)
1519
1620
Returns:
1721
Scalar loss value
1822
"""
1923
logits = logits.view(-1, logits.size(-1))
20-
labels = labels.view(-1)
24+
labels = labels.view(-1).long()
2125
loss_weight = loss_weight.view(-1)
2226
# Subset to positions where labels != -100 (ignore index)
2327
mask = labels != -100
@@ -63,13 +67,16 @@ def forward(
6367
"""Forward pass with loss calculation.
6468
6569
Args:
66-
input_ids: Input token IDs (with masks) of shape (batch, seq_len)
70+
input_ids: Input token IDs (with masks) of shape (batch, seq_len), int8 or long
6771
labels: True token IDs of shape (batch, seq_len), -100 for non-masked
6872
loss_weight: Per-token loss weights of shape (batch, seq_len)
6973
7074
Returns:
7175
Weighted cross-entropy loss (scalar)
7276
"""
77+
# Convert int8 to long for embedding lookup
78+
input_ids = input_ids.long()
79+
7380
# Embed
7481
x = self.embedder(input_ids) # (batch, seq_len, hidden_dim)
7582

0 commit comments

Comments
 (0)