-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
28 lines (22 loc) · 991 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from re import X
import torch
from transformers import AutoModel
from transformers import AutoTokenizer
from transformers import pipeline
class TextBackbone(torch.nn.Module):
def __init__(self, pretrained_model_name='chinese-roberta-wwm-ext', num_classes=2):
super(TextBackbone, self).__init__()
self.extractor = AutoModel.from_pretrained(pretrained_model_name)
self.fc = torch.nn.Linear(768, num_classes)
def forward(self, x):
inputs = {'input_ids':x['input_ids'],'attention_mask':x['attention_mask'],'token_type_ids':x['token_type_ids']}
outputs = self.extractor(**inputs)
out = self.fc(outputs.pooler_output)
return out
if __name__ == '__main__':
model = TextBackbone()
x = ['我爱学习', '我不爱学习']
tokenizer = AutoTokenizer.from_pretrained('chinese-roberta-wwm-ext')
x = tokenizer.encode_plus(x, max_length=512, padding='max_length', return_tensors='pt')
y = model(x)
print(y.shape)