diff --git a/exec.py b/exec.py index 49e6d4f..a996055 100644 --- a/exec.py +++ b/exec.py @@ -100,7 +100,7 @@ def train(logger): model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) # update monitoring and prediction plots - TrainingPlot.update_and_save(monitor_metrics, epoch) + TrainingPlot.update_and_save(monitor_metrics, starting_epoch, epoch) epoch_time = time.time() - start_time logger.info('trained epoch {}: took {} sec. ({} train / {} val)'.format( epoch, epoch_time, train_time, epoch_time-train_time)) diff --git a/plotting.py b/plotting.py index 4e15c74..bf7408b 100644 --- a/plotting.py +++ b/plotting.py @@ -180,17 +180,17 @@ def __init__(self, cf): self.figure_list[0].ax1.set_ylim(0, 1.5) self.color_palette = ['b', 'c', 'r', 'purple', 'm', 'y', 'k', 'tab:gray'] - def update_and_save(self, metrics, epoch): + def update_and_save(self, metrics, starting_epoch, epoch): for figure_ix in range(len(self.figure_list)): fig = self.figure_list[figure_ix] - detection_monitoring_plot(fig.ax1, metrics, self.exp_name, self.color_palette, epoch, figure_ix, + detection_monitoring_plot(fig.ax1, metrics, starting_epoch, self.exp_name, self.color_palette, epoch, figure_ix, self.separate_values_dict, self.do_validation) fig.savefig(self.file_name + '_{}'.format(figure_ix)) -def detection_monitoring_plot(ax1, metrics, exp_name, color_palette, epoch, figure_ix, separate_values_dict, do_validation): +def detection_monitoring_plot(ax1, metrics, starting_epoch, exp_name, color_palette, epoch, figure_ix, separate_values_dict, do_validation): monitor_values_keys = metrics['train']['monitor_values'][1][0].keys() separate_values = [v for fig_ix in separate_values_dict.values() for v in fig_ix] @@ -217,7 +217,7 @@ def detection_monitoring_plot(ax1, metrics, exp_name, color_palette, epoch, figu if do_validation: ax1.plot(x, y_val, label='val_{}'.format(pk), linestyle='-', color=color_palette[kix]) - if epoch == 1: + if epoch == starting_epoch: box = ax1.get_position() ax1.set_position([box.x0, box.y0, box.width * 0.8, box.height]) ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5))