-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGlossingModel.py
148 lines (131 loc) · 7.23 KB
/
GlossingModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import torch.nn as nn
import torch.nn.functional as F
from Encoder import TransformerCharEncoder
from MorphemeSegmenter import MorphemeSegmenter
from GlossingDecoder import GlossingDecoder
from Utilities import aggregate_segments # Import aggregate_segments from Utilities
import pytorch_lightning as pl
#########################################
# 6. Integrated Glossing Pipeline as a LightningModule
#########################################
class GlossingPipeline(pl.LightningModule):
"""
An integrated glossing pipeline that combines:
- A Transformer-based character encoder,
- An improved morpheme segmentation module with adaptive thresholding,
- A translation encoder,
- A glossing decoder with cross-attention over aggregated segment representations.
This module is a PyTorch LightningModule so that training, validation,
and optimizer configuration are integrated.
"""
def __init__(self, char_vocab_size: int, gloss_vocab_size: int, trans_vocab_size: int,
embed_dim: int = 256, num_heads: int = 8, ff_dim: int = 512,
num_layers: int = 6, dropout: float = 0.1, use_gumbel: bool = False,
learning_rate: float = 0.001, gloss_pad_idx: int = None,
use_relative: bool = True, max_relative_position: int = 64):
super(GlossingPipeline, self).__init__()
self.save_hyperparameters(ignore=["gloss_pad_idx"])
# Build model components.
self.encoder = TransformerCharEncoder(
input_size=char_vocab_size,
embed_dim=embed_dim,
num_layers=num_layers,
num_heads=num_heads,
dropout=dropout,
projection_dim=None,
use_relative=use_relative,
max_relative_position=max_relative_position
)
# For Track 1 (unsupervised segmentation), gold segmentation is not available so pass None.
self.segmentation = MorphemeSegmenter(embed_dim, use_gumbel=use_gumbel)
self.decoder = GlossingDecoder(
gloss_vocab_size=gloss_vocab_size,
embed_dim=embed_dim,
num_heads=num_heads,
ff_dim=ff_dim,
num_layers=num_layers,
dropout=dropout
)
self.translation_encoder = nn.Embedding(trans_vocab_size, embed_dim)
# Loss function.
self.criterion = nn.CrossEntropyLoss(ignore_index=gloss_pad_idx)
self.learning_rate = learning_rate
def forward(self, src_features, src_lengths, tgt, trans, learn_segmentation: bool = True, num_morphemes=None):
"""
Forward pass through the glossing pipeline.
Args:
src_features: Source character features (batch_size, src_seq_len, char_vocab_size) as one-hot.
src_lengths: Valid lengths of source sequences (batch_size,).
tgt: Target gloss token indices (batch_size, tgt_seq_len).
trans: Translation token indices (batch_size, trans_seq_len).
learn_segmentation: Whether to learn segmentation (True for Track 1 data).
num_morphemes: If available (Track 2), the target number of morphemes per word;
set to None for unsupervised segmentation (Track 1).
Returns:
logits, morpheme_count, tau, seg_probs.
"""
# Encode source characters.
encoder_outputs = self.encoder(src_features, src_lengths)
assert trans.max().item() < self.translation_encoder.num_embeddings, \
f"Found token index {trans.max().item()} which exceeds vocab size {self.translation_encoder.num_embeddings}"
# Compute segmentation.
# For Track 1 (unsupervised), num_morphemes is None.
segmentation_mask, morpheme_count, tau, seg_probs = self.segmentation(
encoder_outputs, src_lengths, num_morphemes, training=learn_segmentation
)
# Aggregate encoder outputs into morpheme representations.
seg_tensor = aggregate_segments(encoder_outputs, segmentation_mask)
# Encode translation and get representation.
trans_embedded = self.translation_encoder(trans) # (batch_size, trans_seq_len, embed_dim)
trans_repr = trans_embedded.mean(dim=1, keepdim=True) # (batch_size, 1, embed_dim)
# Prepend translation representation to the segment memory.
memory = torch.cat([trans_repr, seg_tensor], dim=1)
# Decode gloss tokens using cross-attention.
logits = self.decoder(tgt, memory)
return logits, morpheme_count, tau, seg_probs
def training_step(self, batch, batch_idx):
src_batch, src_len_batch, tgt_batch, trans_batch = batch
# Convert source indices into one-hot vectors.
src_features = F.one_hot(src_batch, num_classes=self.encoder.input_size).float()
# For Track 1, no gold segmentation is available, so pass None for num_morphemes.
logits, morpheme_count, tau, seg_probs = self(src_features, src_len_batch, tgt_batch, trans_batch,
learn_segmentation=True, num_morphemes=None)
# logits: (batch_size, tgt_seq_len, gloss_vocab_size)
batch_size, tgt_seq_len, gloss_vocab_size = logits.size()
logits = logits.view(-1, gloss_vocab_size)
tgt_flat = tgt_batch.view(-1)
loss = self.criterion(logits, tgt_flat)
self.log("train_loss", loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
src_batch, src_len_batch, tgt_batch, trans_batch = batch
src_features = F.one_hot(src_batch, num_classes=self.encoder.input_size).float()
logits, _, _, _ = self(src_features, src_len_batch, tgt_batch, trans_batch,
learn_segmentation=True, num_morphemes=None)
batch_size, tgt_seq_len, gloss_vocab_size = logits.size()
logits = logits.view(-1, gloss_vocab_size)
tgt_flat = tgt_batch.view(-1)
loss = self.criterion(logits, tgt_flat)
self.log("val_loss", loss, on_step=False, on_epoch=True)
# Manually store the loss for epoch end averaging.
if not hasattr(self, "val_outputs"):
self.val_outputs = []
self.val_outputs.append(loss)
return loss
def on_validation_epoch_end(self):
if hasattr(self, "val_outputs") and self.val_outputs:
avg_loss = torch.stack(self.val_outputs).mean()
self.log("val_loss_epoch", avg_loss)
self.val_outputs = [] # Clear for next epoch
def predict_step(self, batch, batch_idx, dataloader_idx=0):
src_batch, src_len_batch, tgt_batch, trans_batch = batch
# Convert source indices into one-hot vectors.
src_features = F.one_hot(src_batch, num_classes=self.encoder.input_size).float()
# Run forward pass in inference mode (set learn_segmentation=False to use soft segmentation or to avoid noise).
logits, _, _, _ = self(src_features, src_len_batch, tgt_batch, trans_batch, learn_segmentation=False)
# Obtain predictions by taking the argmax over the token probabilities.
predictions = torch.argmax(logits, dim=-1) # shape: (batch_size, tgt_seq_len)
return predictions
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)