-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_softmax2triplet.py
60 lines (52 loc) · 1.79 KB
/
train_softmax2triplet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #supress tensorflow info except error
import sys
import math
import pathlib
import datetime
import tensorflow as tf
from model.parse_params import parse_params
from model.input_fn import dataset_pipeline_balance_label
from model.triplet_model_fn import transfer_model_fn
gpuNum = 1
if __name__ == "__main__":
# read params path
path = sys.argv[1]
params = parse_params(path)
params_path = pathlib.Path(path).parents[0]
with tf.device(f'/device:GPU:{gpuNum}'):
# dataset
train_ds, train_count = dataset_pipeline_balance_label(True, **params)
model = transfer_model_fn(True, **params)
model.summary()
log_dir = os.path.join(params_path, "logs/",
datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir,
update_freq='batch',
profile_batch=0,
histogram_freq=1
)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(params_path, "model"),
monitor='loss',
mode='min',
save_best_only=True,
save_weights_only=True,
verbose=1)
es_callback = tf.keras.callbacks.EarlyStopping(
monitor='loss',
patience=params['early_stopping']
)
# start training
model.compile(optimizer="adam")
model.fit(
train_ds,
steps_per_epoch=math.ceil(
train_count/(params['n_class_per_batch']*params['n_per_class'])
),
epochs=params['n_epochs'],
callbacks=[tensorboard_callback, cp_callback, es_callback]
)