-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
31 lines (24 loc) · 985 Bytes
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch.nn as nn
import torch
from transformers import HubertConfig, HubertModel
class Hubert(HubertModel):
def __init__(self, config):
super().__init__(config,)
self.s_layer = 8 #layer of student
self.hubert = HubertModel(config)
self.post_init()
self.classifier_t = nn.Linear(config.hidden_size, 29)
self.classifier_s = nn.Linear(config.hidden_size, 29)
def freeze(self):
for param in self.hubert.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.hubert.parameters():
param.requires_grad = True
def forward(self, audio_input):
out = self.hubert(audio_input,
attention_mask=None,
output_hidden_states=True,).hidden_states
i_logits = self.classifier_s(out[self.s_layer])
logits = self.classifier_t(out[-1])
return i_logits, logits