3737def _filter_non_specific_nucleotides_and_batch (
3838 student_logits : torch .Tensor ,
3939 teacher_logits : torch .Tensor ,
40- student_emb : torch .Tensor ,
41- teacher_emb : torch .Tensor ,
40+ student_emb : torch .Tensor | None ,
41+ teacher_emb : torch .Tensor | None ,
4242 input_ids : torch .Tensor ,
43- ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
43+ ) -> tuple [
44+ torch .Tensor , torch .Tensor , torch .Tensor | None , torch .Tensor | None , torch .Tensor
45+ ]:
4446 """
4547 This function filters out the examples where the label is the non-specific nucleotide (PAD token).
4648 """
@@ -49,13 +51,13 @@ def _filter_non_specific_nucleotides_and_batch(
4951 student_logits = student_logits [mask ]
5052 teacher_logits = teacher_logits [mask ]
5153 input_ids = input_ids [mask ]
52- student_emb = student_emb [mask ]
53- teacher_emb = teacher_emb [mask ]
54+ student_emb = student_emb [mask ] if student_emb is not None else None
55+ teacher_emb = teacher_emb [mask ] if teacher_emb is not None else None
5456
5557 assert student_logits .ndim == 2
5658 assert teacher_logits .ndim == 2
57- assert student_emb .ndim == 2
58- assert teacher_emb .ndim == 2
59+ assert student_emb .ndim == 2 if student_emb is not None else True
60+ assert teacher_emb .ndim == 2 if teacher_emb is not None else True
5961 assert input_ids .ndim == 1
6062
6163 return student_logits , teacher_logits , student_emb , teacher_emb , input_ids
@@ -65,8 +67,8 @@ def distillation_loss(
6567 * ,
6668 student_logits : torch .Tensor ,
6769 teacher_logits : torch .Tensor ,
68- student_emb : torch .Tensor ,
69- teacher_emb : torch .Tensor ,
70+ student_emb : torch .Tensor | None ,
71+ teacher_emb : torch .Tensor | None ,
7072 input_ids : torch .Tensor ,
7173 temperature : float ,
7274 alpha_soft : float ,
@@ -101,12 +103,17 @@ def distillation_loss(
101103 assert input_ids .ndim == 2 , f"Expected input_ids to be 2D, got { input_ids .ndim } D"
102104
103105 # Expect B, T, D
104- assert (
105- student_emb .ndim == 3
106- ), f"Expected student_emb to be 3D, got { student_emb .ndim } D"
107- assert (
108- teacher_emb .ndim == 3
109- ), f"Expected teacher_emb to be 3D, got { teacher_emb .ndim } D"
106+ assert (student_emb is None ) == (
107+ teacher_emb is None
108+ ), "Both student_emb and teacher_emb must be either None or not None"
109+ if student_emb is not None :
110+ assert teacher_emb is not None
111+ assert (
112+ student_emb .ndim == 3
113+ ), f"Expected student_emb to be 3D, got { student_emb .ndim } D"
114+ assert (
115+ teacher_emb .ndim == 3
116+ ), f"Expected teacher_emb to be 3D, got { teacher_emb .ndim } D"
110117
111118 student_logits , teacher_logits , student_emb , teacher_emb , input_ids = (
112119 _filter_non_specific_nucleotides_and_batch (
@@ -160,16 +167,20 @@ def distillation_loss(
160167 soft_loss_contrib = alpha_soft * soft_loss
161168 hard_loss_contrib = alpha_hard * hard_loss
162169
163- hidden_state_sim = F .cosine_embedding_loss (
164- student_emb ,
165- teacher_emb ,
166- torch .ones (
167- student_emb .size (0 ),
168- device = student_emb .device ,
169- ),
170- reduction = "mean" ,
171- )
172- hidden_state_sim = alpha_sim * hidden_state_sim
170+ if alpha_sim > 0 :
171+ assert student_emb is not None and teacher_emb is not None
172+ hidden_state_sim = F .cosine_embedding_loss (
173+ student_emb ,
174+ teacher_emb ,
175+ torch .ones (
176+ student_emb .size (0 ),
177+ device = student_emb .device ,
178+ ),
179+ reduction = "mean" ,
180+ )
181+ hidden_state_sim = alpha_sim * hidden_state_sim
182+ else :
183+ hidden_state_sim = torch .tensor (0.0 , device = input_ids .device , dtype = torch .float )
173184
174185 total_loss = soft_loss_contrib + hard_loss_contrib + hidden_state_sim
175186 return total_loss , soft_loss_contrib , hard_loss_contrib , hidden_state_sim
@@ -412,7 +423,7 @@ def __init__(
412423 self .d_model <= self .teacher_d_model
413424 ), f"Expected student d_model <= { self .teacher_d_model = } , got { self .d_model = } "
414425
415- if self .d_model < self .teacher_d_model :
426+ if alpha_sim > 0 and self .d_model < self .teacher_d_model :
416427 self .d_model_proj : torch .nn .Module = torch .nn .Linear (
417428 # NOTE: we double the d_model bi-directionally of the teacher model
418429 in_features = self .d_model * 2 ,
@@ -447,37 +458,46 @@ def forward(
447458 output_hidden_states : bool | None = None ,
448459 ** kwargs : Any ,
449460 ) -> MaskedLMOutput :
450- outputs = self .student .caduceus (
451- input_ids = input_ids ,
452- output_hidden_states = output_hidden_states ,
453- )
461+ if self .alpha_sim > 0 :
462+ outputs = self .student .caduceus (
463+ input_ids = input_ids ,
464+ output_hidden_states = output_hidden_states ,
465+ )
454466
455- hidden_states = outputs [0 ]
456- B , S , D = hidden_states .shape
457- hidden_states = self .d_model_proj (hidden_states )
458- if output_hidden_states :
459- outputs .hidden_states .append (hidden_states )
460- assert hidden_states .shape == (B , S , self .teacher_d_model * 2 )
461- logits = self .student .lm_head (hidden_states )
462- logits = logits .float ()
463-
464- return MaskedLMOutput (
465- logits = logits ,
466- hidden_states = outputs .hidden_states ,
467- )
467+ hidden_states = outputs [0 ]
468+ B , S , _ = hidden_states .shape
469+ hidden_states = self .d_model_proj (hidden_states )
470+ if output_hidden_states :
471+ outputs .hidden_states .append (hidden_states )
472+ assert hidden_states .shape == (B , S , self .teacher_d_model * 2 )
473+ logits = self .student .lm_head (hidden_states )
474+ logits = logits .float ()
475+
476+ return MaskedLMOutput (
477+ logits = logits ,
478+ hidden_states = outputs .hidden_states ,
479+ )
480+ else :
481+ return self .student (
482+ input_ids = input_ids ,
483+ output_hidden_states = output_hidden_states ,
484+ ** kwargs ,
485+ )
468486
469487 def training_step (self , batch : HG38_EXAMPLE_T , batch_idx : int ) -> torch .Tensor :
470488 input_ids , chr_names , starts , ends = batch
471489 logger .debug (f"Train { batch_idx = } , { chr_names = } , { starts = } , { ends = } " )
472490
473491 with torch .no_grad ():
474- outputs = self .teacher (input_ids .to (self .device ), output_hidden_states = True )
492+ outputs = self .teacher (
493+ input_ids .to (self .device ), output_hidden_states = self .alpha_sim > 0
494+ )
475495 teacher_logits = outputs .logits
476- teacher_emb = outputs .hidden_states [- 1 ]
496+ teacher_emb = outputs .hidden_states [- 1 ] if self . alpha_sim > 0 else None
477497
478- outputs = self (input_ids , output_hidden_states = True )
498+ outputs = self (input_ids , output_hidden_states = self . alpha_sim > 0 )
479499 student_logits = outputs .logits
480- student_emb = outputs .hidden_states [- 1 ]
500+ student_emb = outputs .hidden_states [- 1 ] if self . alpha_sim > 0 else None
481501
482502 # Calculate combined loss
483503 loss , soft_loss , hard_loss , hidden_state_sim_loss = distillation_loss (
0 commit comments