11"""Generic DataModule for DNA masked language modeling."""
22
3- from typing import Optional
4-
53import numpy as np
64import torch
75from Bio .Seq import Seq
86from datasets import load_dataset
97from datasets .distributed import split_dataset_by_node
108from lightning import LightningDataModule
11- from torch .utils .data import DataLoader
9+ from torch .utils .data import DataLoader , default_collate
1210from transformers import AutoTokenizer
1311
14- from glm_experiments .data .components .mlm_collator import (
15- DataCollatorForLanguageModelingSimplified ,
16- )
12+
13+ def apply_reverse_complement (sequences : list [str ]) -> list [str ]:
14+ """Apply random reverse complement augmentation to sequences.
15+
16+ Each sequence is independently randomly assigned to forward or reverse
17+ complement strand with equal probability.
18+
19+ Args:
20+ sequences: List of DNA sequences
21+
22+ Returns:
23+ List of sequences, each randomly on forward or reverse complement strand
24+ """
25+ n = len (sequences )
26+ strand = np .random .choice (["+" , "-" ], n )
27+ return [
28+ seq if strand [i ] == "+" else str (Seq (seq ).reverse_complement ())
29+ for i , seq in enumerate (sequences )
30+ ]
31+
32+
33+ def apply_mlm_masking (
34+ input_ids : torch .Tensor ,
35+ mask_token_id : int ,
36+ vocab_size : int ,
37+ mlm_probability : float = 0.15 ,
38+ ) -> tuple [torch .Tensor , torch .Tensor ]:
39+ """Apply masked language modeling to input tokens.
40+
41+ Uses standard BERT masking strategy:
42+ - 15% of tokens are selected for masking
43+ - Of those: 80% replaced with [MASK], 10% random token, 10% unchanged
44+
45+ Args:
46+ input_ids: Token IDs of shape (batch_size, seq_len)
47+ mask_token_id: Token ID for [MASK]
48+ vocab_size: Vocabulary size for random replacement
49+ mlm_probability: Probability of selecting a token for masking
50+
51+ Returns:
52+ Tuple of (masked_input_ids, labels) both as int8.
53+ Labels has -100 for non-masked positions (standard PyTorch ignore_index).
54+ """
55+ input_ids = input_ids .clone ().to (torch .int8 )
56+ labels = input_ids .clone ()
57+
58+ # Select tokens for masking
59+ probability_matrix = torch .full (labels .shape , mlm_probability )
60+ masked_indices = torch .bernoulli (probability_matrix ).bool ()
61+ labels [~ masked_indices ] = - 100 # standard PyTorch ignore_index
62+
63+ # 80% -> [MASK]
64+ indices_replaced = torch .bernoulli (torch .full (labels .shape , 0.8 )).bool () & masked_indices
65+ input_ids [indices_replaced ] = mask_token_id
66+
67+ # 10% -> random token (0.5 of remaining 20%)
68+ indices_random = (
69+ torch .bernoulli (torch .full (labels .shape , 0.5 )).bool () & masked_indices & ~ indices_replaced
70+ )
71+ random_words = torch .randint (vocab_size , labels .shape , dtype = torch .int8 )
72+ input_ids [indices_random ] = random_words [indices_random ]
73+
74+ # 10% -> unchanged (implicit)
75+
76+ return input_ids , labels
1777
1878
1979class DNADataModule (LightningDataModule ):
@@ -60,7 +120,6 @@ def __init__(
60120 self .tokenizer = None
61121 self .data_train = None
62122 self .data_val = None
63- self .data_collator = None
64123
65124 def prepare_data (self ) -> None :
66125 """Download data and tokenizer (runs on single GPU/process)."""
@@ -86,86 +145,82 @@ def setup(self, stage: str | None = None) -> None:
86145 # Load tokenizer
87146 self .tokenizer = AutoTokenizer .from_pretrained (self .hparams .tokenizer_name ) # nosec B615
88147
89- # Create data collator for MLM
90- self .data_collator = DataCollatorForLanguageModelingSimplified (
91- tokenizer = self .tokenizer ,
92- mlm = True ,
93- mlm_probability = self .hparams .mlm_probability ,
94- )
95-
96- # Load raw dataset with streaming
97- raw_datasets = load_dataset (self .hparams .dataset_name , streaming = True ) # nosec B615
148+ def tokenize (seq : list [str ]) -> list [list [int ]]:
149+ """Tokenize sequences to input_ids only."""
150+ return self .tokenizer (
151+ seq ,
152+ padding = False ,
153+ truncation = False ,
154+ return_token_type_ids = False ,
155+ return_attention_mask = False ,
156+ return_special_tokens_mask = False ,
157+ )["input_ids" ]
98158
99- # Tokenization function
100- def tokenize_function (examples , soft_masked_weight , data_aug = False ):
101- """Tokenize sequences with optional reverse complement augmentation.
159+ def transform_batch (examples : dict , soft_masked_weight : float , data_aug : bool ) -> dict :
160+ """Transform a batch of examples.
102161
103162 Args:
104163 examples: Batch of examples with 'seq' field
105164 soft_masked_weight: Loss weight for lowercase nucleotides
106165 data_aug: Whether to apply reverse complement augmentation
107166
108167 Returns:
109- Dictionary with input_ids (torch.uint8), special_tokens_mask, and
110- loss_weight (torch.float16)
168+ Dictionary with input_ids, labels, and loss_weight (all tensors)
111169 """
112170 seq = examples ["seq" ]
113171
114172 # Apply reverse complement augmentation
115173 if data_aug :
116- n = len (seq )
117- strand = np .random .choice (["+" , "-" ], n )
118- seq = [
119- seq [i ] if strand [i ] == "+" else str (Seq (seq [i ]).reverse_complement ())
120- for i in range (n )
121- ]
122-
123- # Tokenize (returns dict with 'input_ids' as list of lists)
124- tokenized = self .tokenizer (
125- seq ,
126- return_special_tokens_mask = True ,
127- padding = False ,
128- truncation = False ,
129- )
174+ seq = apply_reverse_complement (seq )
130175
131- # Convert to tensors
132- input_ids = torch .tensor (tokenized ["input_ids" ], dtype = torch .uint8 )
133- special_tokens_mask = torch .tensor (tokenized ["special_tokens_mask" ], dtype = torch .uint8 )
176+ # Tokenize
177+ input_ids = torch .tensor (tokenize (seq ), dtype = torch .int8 )
134178
135179 # Create loss weights (lower weight for soft-masked lowercase regions)
136- loss_weight = torch .ones_like (input_ids , dtype = torch .float16 )
180+ loss_weight = torch .ones (input_ids . shape , dtype = torch .float16 )
137181 for i , s in enumerate (seq ):
138182 lowercase_mask = np .array ([c .islower () for c in s ])
139183 loss_weight [i ][lowercase_mask ] = soft_masked_weight
140184
185+ # Apply MLM masking
186+ input_ids , labels = apply_mlm_masking (
187+ input_ids ,
188+ mask_token_id = self .tokenizer .mask_token_id ,
189+ vocab_size = self .tokenizer .vocab_size ,
190+ mlm_probability = self .hparams .mlm_probability ,
191+ )
192+
141193 return {
142194 "input_ids" : input_ids ,
143- "special_tokens_mask " : special_tokens_mask ,
195+ "labels " : labels ,
144196 "loss_weight" : loss_weight ,
145197 }
146198
199+ # Load raw dataset with streaming
200+ raw_datasets = load_dataset (self .hparams .dataset_name , streaming = True ) # nosec B615
201+
147202 # Process splits (train and val only)
148203 if stage == "fit" or stage is None :
149204 # Training dataset with augmentation and shuffling
150205 train_dataset = raw_datasets ["train" ].shuffle (seed = self .hparams .seed )
151206 train_dataset = train_dataset .map (
152- lambda ex : tokenize_function (
207+ lambda ex : transform_batch (
153208 ex ,
154- self .hparams .soft_masked_loss_weight_train ,
209+ soft_masked_weight = self .hparams .soft_masked_loss_weight_train ,
155210 data_aug = self .hparams .data_augmentation ,
156211 ),
157212 batched = True ,
158213 remove_columns = list (list (raw_datasets ["train" ].take (1 ))[0 ].keys ()),
159214 # drop_last_batch needed for torch.compile to avoid variable batch sizes
160215 drop_last_batch = True ,
161- batch_size = self .hparams . batch_size ,
216+ batch_size = self .batch_size_per_device ,
162217 )
163218
164219 # Validation dataset (no augmentation, no shuffling)
165220 val_dataset = raw_datasets ["validation" ].map (
166- lambda ex : tokenize_function (
221+ lambda ex : transform_batch (
167222 ex ,
168- self .hparams .soft_masked_loss_weight_eval ,
223+ soft_masked_weight = self .hparams .soft_masked_loss_weight_eval ,
169224 data_aug = False ,
170225 ),
171226 batched = True ,
@@ -196,7 +251,7 @@ def train_dataloader(self) -> DataLoader:
196251 num_workers = self .hparams .num_workers ,
197252 pin_memory = self .hparams .pin_memory ,
198253 shuffle = False , # Shuffling handled by dataset
199- collate_fn = self . data_collator ,
254+ collate_fn = default_collate ,
200255 )
201256
202257 def val_dataloader (self ) -> DataLoader :
@@ -207,5 +262,5 @@ def val_dataloader(self) -> DataLoader:
207262 num_workers = self .hparams .num_workers ,
208263 pin_memory = self .hparams .pin_memory ,
209264 shuffle = False ,
210- collate_fn = self . data_collator ,
265+ collate_fn = default_collate ,
211266 )
0 commit comments