-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
110 lines (102 loc) · 4.44 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os, math
os.environ["KERAS_BACKEND"] = "torch"
import keras, torch, json
from tqdm.auto import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def log_normal_diag(x, mu, log_var):
log_p = -0.5 * keras.ops.log(2. * math.pi) - 0.5 * log_var - 0.5 * keras.ops.exp(-log_var) * (x - mu)**2.
return log_p
# The only function of the code that requires backend-specific ops
def train_step(x_t, x_tplus1, forward_t, forward_tplus1, prior, posterior, decoder, opt):
# Move to gpu
x_t = x_t.to(DEVICE)
x_tplus1 = x_tplus1.to(DEVICE)
# Forward pass
h_t = forward_t(x_t)
h_tplus1 = forward_tplus1(x_tplus1)
z, mu, logvar = posterior(h_t, h_tplus1)
x_tplus1_hat = decoder(z, h_t)
_, *mu_logvar = prior(h_t)
kl_nll = keras.ops.mean(
log_normal_diag(z, mu, logvar) - log_normal_diag(z, *mu_logvar),
axis=[1,2,3] # mean reduction
)
rec_nl = keras.ops.mean( # MSE == \mathcal{N}(\eps|0,1)
keras.ops.square(x_tplus1 - x_tplus1_hat),
axis=[1,2,3] # mean reduction
)
loss = keras.ops.mean(rec_nl + kl_nll) # mean reduction
# Prepare backward pass
forward_t.zero_grad()
forward_tplus1.zero_grad()
prior.zero_grad()
posterior.zero_grad()
decoder.zero_grad()
# Backward pass
loss.backward()
trainable_weights = forward_t.trainable_weights + forward_tplus1.trainable_weights \
+ prior.trainable_weights + posterior.trainable_weights + decoder.trainable_weights
gradients = [t.value.grad for t in trainable_weights]
with torch.no_grad():
opt.apply_gradients(zip(gradients, trainable_weights))
# Return loss interpretably
return x_tplus1_hat, keras.ops.mean(kl_nll).item(), keras.ops.mean(rec_nl).item()
def val_step(x_t, x_tplus1, forward_t, prior, decoder):
# Move to gpu
x_t = x_t.to(DEVICE)
x_tplus1 = x_tplus1.to(DEVICE)
# Forward pass
h_t = forward_t(x_t)
z, *_ = prior(h_t)
x_tplus1_hat = decoder(z, h_t)
# Return reconstruction loss
rec_nl = keras.ops.mean(keras.ops.square(x_tplus1 - x_tplus1_hat)) # mean (squared) reduction
return rec_nl.item()
def run(train_loader, val_loader, forward_t, forward_tplus1, prior, posterior, decoder, optimizer, save_dir, max_epochs, max_patience=5):
# Loop over epochs
patience = 0
loss_history = {"kl_loss": [], "rec_loss": [], "val_rec_loss": []}
os.makedirs(save_dir, exist_ok=True)
for i in tqdm(range(max_epochs)):
# Initialize losses
loss_history["kl_loss"].append(0)
loss_history["rec_loss"].append(0)
# Loop over batches
for j, (x_t, x_tplus1) in enumerate(train_loader, 1):
# Prepare pushforward training
if j==1:
x_t_hat = x_t
else:
mask = keras.ops.reshape(train_loader.dataset.mask, (-1,1,1,1))
x_t_hat = keras.ops.where(mask, x_t_hat.detach(), x_t[...,:-2]) # detach used in favor of retain_graph
x_t_hat = keras.ops.concatenate([x_t_hat, x_t[...,-2:]], axis=-1)
# Train
x_t_hat, kl_loss, rec_loss = train_step(x_t_hat, x_tplus1, forward_t, forward_tplus1, prior, posterior, decoder, optimizer)
# Keep track of losses
loss_history["kl_loss"][i] += kl_loss
loss_history["rec_loss"][i] += rec_loss
# Normalize losses
loss_history["kl_loss"][i] /= j
loss_history["rec_loss"][i] /= j
# Validation
val_loss = 0
for k, (x_t, x_tplus1) in enumerate(val_loader, 1):
val_loss += val_step(x_t, x_tplus1, forward_t, prior, decoder)
loss_history["val_rec_loss"].append(val_loss/k)
# Early stopping with patience
if (i>0) and ((val_loss/k)>min(loss_history["val_rec_loss"])):
patience += 1
if patience>max_patience:
break
else:
# Reset patience
patience = 0
# Save models (keras.saving.save_model causes inconsisteny)
forward_t.save_weights(f"{save_dir}/forward_t.weights.h5")
forward_tplus1.save_weights(f"{save_dir}/forward_tplus1.weights.h5")
prior.save_weights(f"{save_dir}/prior.weights.h5")
posterior.save_weights(f"{save_dir}/posterior.weights.h5")
decoder.save_weights(f"{save_dir}/decoder.weights.h5")
# Save history
with open(f"{save_dir}/history.json", "w") as f:
json.dump(loss_history, f)