-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathcheckpoint.py
108 lines (87 loc) · 3.34 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import flax
import dill as pickle
import os
import builtins
from jax._src.lib import xla_client
import tensorflow as tf
import logging
logger = logging.getLogger(__name__)
# Hack: this is the module reported by this object.
# https://github.com/google/jax/issues/8505
builtins.bfloat16 = xla_client.bfloat16
def pickle_dump(obj, filename):
""" Wrapper to dump an object to a file."""
with tf.io.gfile.GFile(filename, "wb") as f:
f.write(pickle.dumps(obj))
def pickle_load(filename):
""" Wrapper to load an object from a file."""
with tf.io.gfile.GFile(filename, 'rb') as f:
pickled = pickle.loads(f.read())
return pickled
def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep_best=2, is_best=False):
"""
Saves checkpoint.
Args:
ckpt_dir (str): Path to the directory, where checkpoints are saved.
state_G (train_state.TrainState): Generator state.
state_D (train_state.TrainState): Discriminator state.
params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
pl_mean (array): Moving average of the path length (generator regularization).
config (argparse.Namespace): Configuration.
step (int): Current step.
epoch (int): Current epoch.
fid_score (float): FID score corresponding to the checkpoint.
keep_best (int): Number of best checkpoints to keep.
is_best (bool): Whether this is a new best model
"""
state_dict = {'state_G': flax.jax_utils.unreplicate(state_G),
'state_D': flax.jax_utils.unreplicate(state_D),
'params_ema_G': params_ema_G,
'pl_mean': flax.jax_utils.unreplicate(pl_mean),
'config': config,
'fid_score': fid_score,
'step': step,
'epoch': epoch}
if is_best:
f_name = f'ckpt_{step}_best.pickle'
else:
f_name = f'ckpt_{step}.pickle'
f_path = os.path.join(ckpt_dir, f_name)
logger.info(f'Saving checkpoint for step {step:,} to {f_path}')
pickle_dump(state_dict, f_path)
if is_best:
ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*_best.pickle'))
if len(ckpts) > keep_best:
modified_times = {}
for ckpt in ckpts:
stats = tf.io.gfile.stat(ckpt)
modified_times[ckpt] = stats.mtime_nsec
oldest_ckpt = sorted(modified_times, key=modified_times.get)[0]
tf.io.gfile.remove(oldest_ckpt)
def load_checkpoint(filename):
"""
Loads checkpoints.
Args:
filename (str): Path to the checkpoint file.
Returns:
(dict): Checkpoint.
"""
state_dict = pickle_load(filename)
return state_dict
def get_latest_checkpoint(ckpt_dir):
"""
Returns the path of the latest checkpoint.
Args:
ckpt_dir (str): Path to the directory, where checkpoints are saved.
Returns:
(str): Path to latest checkpoint (if it exists).
"""
ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
if len(ckpts) == 0:
return None
modified_times = {}
for ckpt in ckpts:
stats = tf.io.gfile.stat(ckpt)
modified_times[ckpt] = stats.mtime_nsec
latest_ckpt = sorted(modified_times, key=modified_times.get)[-1]
return latest_ckpt