-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrainer.py
154 lines (143 loc) · 5.77 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/python
# -*- coding: utf-8 -*-
from argument_BERT.preprocessing import preparator
from argument_BERT.preprocessing import model_builder
from argument_BERT.preprocessing import data_builder, data_loader
from argument_BERT.utils import metrics
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split as sk_train_test_split
from bert_embedding import BertEmbedding
from keras.callbacks import EarlyStopping
from datetime import datetime
def load_data(directory,
rst_files=True,
attack=True,
bidirect='abs',
balance_ratio=0.5,
ADU=False,
):
"""Load data files for specific type of training
Input:
directory: stored annotation file directory
rst_files: True if RST files stored and used
attack: True for support-attack, False for related
bidirect: label type for relation direction
abs - absolute value,
zero - negatives zero,
skip - keeping original
balance_ratio: non-related vs others balance
ADU: True for proposition type classification
"""
train_data = data_loader.load_from_directory(directory, rst_files, ADU)
if not ADU:
train_data = preparator.change_labels(train_data,
attack=attack,
bidirect=bidirect)
train_data = preparator.balance_dataset(train_data, balance_ratio)
return train_data
def trainer(directory,
ADU=False,
save=True,
save_dir='/content/models/',
train_generable=False,
support_attack=True,
rst_files=True,
verbose=0):
"""Train a classification model using files from directory
Input:
directory: location of the annotation files
ADU: True for proposition type classification,
false for relation detection
save: True to save the model
save_dir: location to save the file (name: model_[datetime].h5
train_generable: True to elliminate non-generable features (e.g. RST)
support_attack: True for support-attack, false for related class
rst_files: True for using existing RST files
verbose: more than 0 for additional texts
Output:
Test results printed
Model stored if save is true"""
bert_embedding = BertEmbedding(model='bert_12_768_12',
dataset_name='book_corpus_wiki_en_cased',
max_seq_length=35)
train_data = load_data(directory, ADU=ADU, attack=support_attack,
rst_files=rst_files)
train_data = data_builder.add_features(train_data, has_2=not ADU,
bert_emb=bert_embedding)
if train_generable:
train_data = data_builder.remove_nongenerable_features(train_data,
bert_embedding,
ADU)
if verbose > 0:
print('Feature list:')
print(list(train_data.keys()))
(train_data, test_data) = sk_train_test_split(train_data,
test_size=0.10)
if verbose > 0:
print('Train-test split:' +
str(len(train_data)) + ' ' + str(len(test_data)))
(x_data, y_data) = preparator.input_output_split(train_data)
(x_test, y_test) = preparator.input_output_split(test_data)
if verbose > 0:
print('X-Y data ready: ' + str(len(x_data)) + ' ' + str(len(x_test)))
es = EarlyStopping('val_loss', patience=150,
restore_best_weights=True)
if not ADU:
features = model_builder.select_FFNN_features(x_data,
shared_feature_list=None,
original_bert=True)
model = model_builder.build_FFNN(features[-1].shape[1:],
y_data.shape[1],
1, 300, 1, 600, 0.4, 0.10, True, True,
1, 300, optimizer='rmsprop',
activation='sigmoid',)
history = model.fit(
features,
y_data,
validation_split=0.05,
epochs=5000,
batch_size=5000,
verbose=0,
callbacks=[es],
)
test_features = model_builder.select_FFNN_features(
x_test, shared_feature_list=None, original_bert=True)
metrics.related_unrelated_report(model, test_features, y_test)
else:
features = model_builder.select_FFNN_features(
x_data, shared_feature_list=None, original_bert=True, has_2=False)
model = model_builder.build_FFNN(
features[-1].shape[1:],
y_data.shape[1],
1,
300,
1,
600,
0.4,
0.10,
True,
True,
1,
300,
optimizer='rmsprop',
activation='sigmoid',
has_2=False,
)
history = model.fit(
features,
y_data,
validation_split=0.05,
epochs=5000,
batch_size=5000,
verbose=0,
callbacks=[es],
)
test_features = model_builder.select_FFNN_features(
x_test, shared_feature_list=None, original_bert=True, has_2=False)
metrics.adu_report(model, test_features, y_test)
if save:
filename = 'model_' \
+ datetime.now().strftime('%Y-%m-%d-%H:%M:%S') + '.h5'
path = save_dir + filename
model.save(path)