Skip to content

Commit 382c534

Browse files
committed
Optional embedding sim term
1 parent fa4215a commit 382c534

2 files changed

Lines changed: 69 additions & 47 deletions

File tree

caduceus_distill/distill.py

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@
3737
def _filter_non_specific_nucleotides_and_batch(
3838
student_logits: torch.Tensor,
3939
teacher_logits: torch.Tensor,
40-
student_emb: torch.Tensor,
41-
teacher_emb: torch.Tensor,
40+
student_emb: torch.Tensor | None,
41+
teacher_emb: torch.Tensor | None,
4242
input_ids: torch.Tensor,
43-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
43+
) -> tuple[
44+
torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor
45+
]:
4446
"""
4547
This function filters out the examples where the label is the non-specific nucleotide (PAD token).
4648
"""
@@ -49,13 +51,13 @@ def _filter_non_specific_nucleotides_and_batch(
4951
student_logits = student_logits[mask]
5052
teacher_logits = teacher_logits[mask]
5153
input_ids = input_ids[mask]
52-
student_emb = student_emb[mask]
53-
teacher_emb = teacher_emb[mask]
54+
student_emb = student_emb[mask] if student_emb is not None else None
55+
teacher_emb = teacher_emb[mask] if teacher_emb is not None else None
5456

5557
assert student_logits.ndim == 2
5658
assert teacher_logits.ndim == 2
57-
assert student_emb.ndim == 2
58-
assert teacher_emb.ndim == 2
59+
assert student_emb.ndim == 2 if student_emb is not None else True
60+
assert teacher_emb.ndim == 2 if teacher_emb is not None else True
5961
assert input_ids.ndim == 1
6062

6163
return student_logits, teacher_logits, student_emb, teacher_emb, input_ids
@@ -65,8 +67,8 @@ def distillation_loss(
6567
*,
6668
student_logits: torch.Tensor,
6769
teacher_logits: torch.Tensor,
68-
student_emb: torch.Tensor,
69-
teacher_emb: torch.Tensor,
70+
student_emb: torch.Tensor | None,
71+
teacher_emb: torch.Tensor | None,
7072
input_ids: torch.Tensor,
7173
temperature: float,
7274
alpha_soft: float,
@@ -101,12 +103,17 @@ def distillation_loss(
101103
assert input_ids.ndim == 2, f"Expected input_ids to be 2D, got {input_ids.ndim}D"
102104

103105
# Expect B, T, D
104-
assert (
105-
student_emb.ndim == 3
106-
), f"Expected student_emb to be 3D, got {student_emb.ndim}D"
107-
assert (
108-
teacher_emb.ndim == 3
109-
), f"Expected teacher_emb to be 3D, got {teacher_emb.ndim}D"
106+
assert (student_emb is None) == (
107+
teacher_emb is None
108+
), "Both student_emb and teacher_emb must be either None or not None"
109+
if student_emb is not None:
110+
assert teacher_emb is not None
111+
assert (
112+
student_emb.ndim == 3
113+
), f"Expected student_emb to be 3D, got {student_emb.ndim}D"
114+
assert (
115+
teacher_emb.ndim == 3
116+
), f"Expected teacher_emb to be 3D, got {teacher_emb.ndim}D"
110117

111118
student_logits, teacher_logits, student_emb, teacher_emb, input_ids = (
112119
_filter_non_specific_nucleotides_and_batch(
@@ -160,16 +167,20 @@ def distillation_loss(
160167
soft_loss_contrib = alpha_soft * soft_loss
161168
hard_loss_contrib = alpha_hard * hard_loss
162169

163-
hidden_state_sim = F.cosine_embedding_loss(
164-
student_emb,
165-
teacher_emb,
166-
torch.ones(
167-
student_emb.size(0),
168-
device=student_emb.device,
169-
),
170-
reduction="mean",
171-
)
172-
hidden_state_sim = alpha_sim * hidden_state_sim
170+
if alpha_sim > 0:
171+
assert student_emb is not None and teacher_emb is not None
172+
hidden_state_sim = F.cosine_embedding_loss(
173+
student_emb,
174+
teacher_emb,
175+
torch.ones(
176+
student_emb.size(0),
177+
device=student_emb.device,
178+
),
179+
reduction="mean",
180+
)
181+
hidden_state_sim = alpha_sim * hidden_state_sim
182+
else:
183+
hidden_state_sim = torch.tensor(0.0, device=input_ids.device, dtype=torch.float)
173184

174185
total_loss = soft_loss_contrib + hard_loss_contrib + hidden_state_sim
175186
return total_loss, soft_loss_contrib, hard_loss_contrib, hidden_state_sim
@@ -412,7 +423,7 @@ def __init__(
412423
self.d_model <= self.teacher_d_model
413424
), f"Expected student d_model <= {self.teacher_d_model=}, got {self.d_model=}"
414425

415-
if self.d_model < self.teacher_d_model:
426+
if alpha_sim > 0 and self.d_model < self.teacher_d_model:
416427
self.d_model_proj: torch.nn.Module = torch.nn.Linear(
417428
# NOTE: we double the d_model bi-directionally of the teacher model
418429
in_features=self.d_model * 2,
@@ -447,37 +458,46 @@ def forward(
447458
output_hidden_states: bool | None = None,
448459
**kwargs: Any,
449460
) -> MaskedLMOutput:
450-
outputs = self.student.caduceus(
451-
input_ids=input_ids,
452-
output_hidden_states=output_hidden_states,
453-
)
461+
if self.alpha_sim > 0:
462+
outputs = self.student.caduceus(
463+
input_ids=input_ids,
464+
output_hidden_states=output_hidden_states,
465+
)
454466

455-
hidden_states = outputs[0]
456-
B, S, D = hidden_states.shape
457-
hidden_states = self.d_model_proj(hidden_states)
458-
if output_hidden_states:
459-
outputs.hidden_states.append(hidden_states)
460-
assert hidden_states.shape == (B, S, self.teacher_d_model * 2)
461-
logits = self.student.lm_head(hidden_states)
462-
logits = logits.float()
463-
464-
return MaskedLMOutput(
465-
logits=logits,
466-
hidden_states=outputs.hidden_states,
467-
)
467+
hidden_states = outputs[0]
468+
B, S, _ = hidden_states.shape
469+
hidden_states = self.d_model_proj(hidden_states)
470+
if output_hidden_states:
471+
outputs.hidden_states.append(hidden_states)
472+
assert hidden_states.shape == (B, S, self.teacher_d_model * 2)
473+
logits = self.student.lm_head(hidden_states)
474+
logits = logits.float()
475+
476+
return MaskedLMOutput(
477+
logits=logits,
478+
hidden_states=outputs.hidden_states,
479+
)
480+
else:
481+
return self.student(
482+
input_ids=input_ids,
483+
output_hidden_states=output_hidden_states,
484+
**kwargs,
485+
)
468486

469487
def training_step(self, batch: HG38_EXAMPLE_T, batch_idx: int) -> torch.Tensor:
470488
input_ids, chr_names, starts, ends = batch
471489
logger.debug(f"Train {batch_idx=}, {chr_names=}, {starts=}, {ends=}")
472490

473491
with torch.no_grad():
474-
outputs = self.teacher(input_ids.to(self.device), output_hidden_states=True)
492+
outputs = self.teacher(
493+
input_ids.to(self.device), output_hidden_states=self.alpha_sim > 0
494+
)
475495
teacher_logits = outputs.logits
476-
teacher_emb = outputs.hidden_states[-1]
496+
teacher_emb = outputs.hidden_states[-1] if self.alpha_sim > 0 else None
477497

478-
outputs = self(input_ids, output_hidden_states=True)
498+
outputs = self(input_ids, output_hidden_states=self.alpha_sim > 0)
479499
student_logits = outputs.logits
480-
student_emb = outputs.hidden_states[-1]
500+
student_emb = outputs.hidden_states[-1] if self.alpha_sim > 0 else None
481501

482502
# Calculate combined loss
483503
loss, soft_loss, hard_loss, hidden_state_sim_loss = distillation_loss(

caduceus_distill/tests/caduceus_distillation_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ def test_filter_non_specific_nucleotides_filter_all(basic_inputs):
202202
)
203203
)
204204

205+
assert student_emb is not None
206+
assert teacher_emb is not None
205207
assert targets.size(0) == 0
206208
assert student_logits.size(0) == 0
207209
assert teacher_logits.size(0) == 0

0 commit comments

Comments
 (0)