-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_transformer.py
151 lines (114 loc) · 3.93 KB
/
train_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import json
import torch
import models as M
import config_v7 as config
from time import process_time
import data_processor as dp
datapath = os.path.join(config.DATADIR, "chairilanwar.txt")
chproc = dp.CharProcessor(datapath)
data = torch.tensor(chproc.encode(chproc.text), dtype=torch.long)
print(data.shape, data.dtype)
# Split the data into train and validation
# n = int(0.8 * len(data))
# train_data = data[:n]
# val_data = data[n:]
n = len(data)
train_data = data[:n]
val_data = None
# Print pairs of input and target
x = train_data[:config.BLOCK_SIZE]
y = train_data[1:config.BLOCK_SIZE+1]
for t in range(config.BLOCK_SIZE):
context = x[:t+1]
target = y[t]
print(f"when input is {context}, the target is {target}")
torch.manual_seed(1337)
@torch.no_grad()
def estimate_loss(
model,
data,
eval_iters=10
):
model.eval()
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
xb, yb = dp.get_batch(data, batch_size=config.BATCH_SIZE, block_size=config.BLOCK_SIZE)
xb = xb.to(config.DEVICE)
yb = yb.to(config.DEVICE)
logits, loss = model(xb, yb)
losses[k] = loss.item()
avg_loss = losses.mean()
model.train()
return avg_loss
xb, yb = dp.get_batch(
train_data,
batch_size=config.BATCH_SIZE,
block_size=config.BLOCK_SIZE
)
xb = xb.to(config.DEVICE)
yb = yb.to(config.DEVICE)
model = M.Transformer(
chproc.vocab_size,
config.BLOCK_SIZE,
config.N_EMBED,
config.N_LAYER,
config.NUM_HEADS,
device=config.DEVICE
)
model = model.to(config.DEVICE)
logits, loss = model(xb, yb)
print(f"Count parameters: {model.count_parameters()}")
idx = torch.zeros((1, 1), dtype=torch.long)
pred_idx = model.generate(idx, 1000)
pred_str = chproc.decode(pred_idx[0].tolist())
print(f"pred_idx: {pred_idx}")
print(f"pred_str: {pred_str}")
# Create a Pytorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
val_loss = 99999
train_loss = 99999
print(f"Using device: {config.DEVICE}")
history = {
'train_losses': [],
'train_times': [],
}
elapsed_times = []
for step in range(config.MAX_ITERS):
# Sample a batch of data
xb, yb = dp.get_batch(train_data, batch_size=config.BATCH_SIZE, block_size=config.BLOCK_SIZE)
xb = xb.to(config.DEVICE)
yb = yb.to(config.DEVICE)
start_t = process_time()
# Evaluate the loss
logits, loss = model(xb, yb)
# Backprop
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
elapsed_t = process_time() - start_t
elapsed_times.append(elapsed_t)
history['train_times'] = elapsed_times
if step % config.EVAL_INTERVAL == 0:
start_t = process_time()
train_loss = estimate_loss(model, train_data, eval_iters=config.EVAL_ITERS)
if val_data is not None:
val_loss = estimate_loss(model, val_data, eval_iters=config.EVAL_ITERS)
history['train_losses'].append(train_loss.item())
val_elapsed_t = process_time() - start_t
print(f"Step-{step+1}/{config.MAX_ITERS} [elapsed time: {elapsed_t:.5f}secs (train), {val_elapsed_t:.5f}secs (val)]: train loss={train_loss:.4f}, validation loss={val_loss:.4f}")
# Save the model
model.save(config.CHECKPOINT_DIR, config.MODEL_NAME)
# Save training history
history_path = os.path.join(config.CHECKPOINT_DIR, f"{config.MODEL_NAME}_hist.json")
json_object = json.dumps(history) # serializing json
with open(history_path, "w") as outfile:
outfile.write(json_object) # write to json file
pred_idx = model.generate(idx, 100)
pred_str = chproc.decode(pred_idx[0].tolist())
print(f"\npred_str: {pred_str}\n")
pred_idx = model.generate(idx, 1000)
pred_str = chproc.decode(pred_idx[0].tolist())
print(f"pred_idx: {pred_idx}")
print(f"pred_str: {pred_str}")
print(f"Total elapsed time: {sum(elapsed_times):.5f} secs")