|
| 1 | +# Copyright (c) 2017, Udacity |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# Redistribution and use in source and binary forms, with or without |
| 5 | +# modification, are permitted provided that the following conditions are met: |
| 6 | +# |
| 7 | +# 1. Redistributions of source code must retain the above copyright notice, this |
| 8 | +# list of conditions and the following disclaimer. |
| 9 | +# 2. Redistributions in binary form must reproduce the above copyright notice, |
| 10 | +# this list of conditions and the following disclaimer in the documentation |
| 11 | +# and/or other materials provided with the distribution. |
| 12 | +# |
| 13 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND |
| 14 | +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED |
| 15 | +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| 16 | +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR |
| 17 | +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES |
| 18 | +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; |
| 19 | +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
| 20 | +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 21 | +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS |
| 22 | +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 23 | +# |
| 24 | +# The views and conclusions contained in the software and documentation are those |
| 25 | +# of the authors and should not be interpreted as representing official policies, |
| 26 | +# either expressed or implied, of the FreeBSD Project |
| 27 | + |
| 28 | +# Author: Devin Anzelmo |
| 29 | + |
| 30 | +import os |
| 31 | +import glob |
| 32 | +import numpy as np |
| 33 | +import matplotlib.patches as mpatches |
| 34 | +import matplotlib.pyplot as plt |
| 35 | +from tensorflow.contrib.keras.python import keras |
| 36 | +from scipy import misc |
| 37 | + |
| 38 | +def make_dir_if_not_exist(path): |
| 39 | + if not os.path.exists(path): |
| 40 | + os.makedirs(path) |
| 41 | + |
| 42 | +def show(im, x=5, y=5): |
| 43 | + plt.figure(figsize=(x,y)) |
| 44 | + plt.imshow(im) |
| 45 | + plt.show() |
| 46 | + |
| 47 | +def show_images(maybe_ims, x=4, y=4): |
| 48 | + if isinstance(maybe_ims, (list, tuple)): |
| 49 | + new_im = np.concatenate(maybe_ims, axis=1) |
| 50 | + show(new_im, len(maybe_ims)*x, y) |
| 51 | + else: |
| 52 | + show(maybe_ims) |
| 53 | + |
| 54 | +# helpers for loading a few images from the grading data |
| 55 | +def get_im_files(path, subset_name): |
| 56 | + return sorted(glob.glob(os.path.join(path, subset_name, 'images', '*.jpeg'))) |
| 57 | + |
| 58 | +def get_mask_files(path, subset_name): |
| 59 | + return sorted(glob.glob(os.path.join(path, subset_name, 'masks', '*.png'))) |
| 60 | + |
| 61 | +def get_pred_files(subset_name): |
| 62 | + return sorted(glob.glob(os.path.join('..','data', 'runs', subset_name, '*.png'))) |
| 63 | + |
| 64 | +def get_im_file_sample(grading_data_dir_name, subset_name, pred_dir_suffix=None, n_file_names=10): |
| 65 | + path = os.path.join('..', 'data', grading_data_dir_name) |
| 66 | + ims = np.array(get_im_files(path, subset_name)) |
| 67 | + masks = np.array(get_mask_files(path, subset_name)) |
| 68 | + |
| 69 | + shuffed_inds = np.random.permutation(np.arange(masks.shape[0])) |
| 70 | + ims_subset = ims[shuffed_inds[:n_file_names]] |
| 71 | + masks_subset = masks[shuffed_inds[:n_file_names]] |
| 72 | + if not pred_dir_suffix: |
| 73 | + return list(zip(ims_subset, masks_subset)) |
| 74 | + else: |
| 75 | + preds = np.array(get_pred_files(subset_name+'_'+pred_dir_suffix)) |
| 76 | + preds_subset = preds[shuffed_inds[:n_file_names]] |
| 77 | + return list(zip(ims_subset, masks_subset, preds_subset)) |
| 78 | + |
| 79 | +def load_images(file_tuple): |
| 80 | + im = misc.imread(file_tuple[0]) |
| 81 | + mask = misc.imread(file_tuple[1]) |
| 82 | + if len(file_tuple) == 2: |
| 83 | + return im, mask |
| 84 | + else: |
| 85 | + pred = misc.imread(file_tuple[2]) |
| 86 | + if pred.shape[0] != im.shape[0]: |
| 87 | + mask = misc.imresize(mask, pred.shape) |
| 88 | + im = misc.imresize(im, pred.shape) |
| 89 | + return im, mask, pred |
| 90 | + |
| 91 | + |
| 92 | +def plot_keras_model(model, fig_name): |
| 93 | + base_path = os.path.join('..', 'data', 'figures') |
| 94 | + make_dir_if_not_exist(base_path) |
| 95 | + keras.utils.vis_utils.plot_model(model, os.path.join(base_path, fig_name)) |
| 96 | + keras.utils.vis_utils.plot_model(model, os.path.join(base_path, fig_name +'_with_shapes'), show_shapes=True) |
| 97 | + |
| 98 | + |
| 99 | +def train_val_curve(train_loss, val_loss=None): |
| 100 | + train_line = plt.plot(train_loss, label='train_loss') |
| 101 | + train_patch = mpatches.Patch(color='blue',label='train_loss') |
| 102 | + handles = [train_patch] |
| 103 | + if val_loss: |
| 104 | + val_line = plt.plot(val_loss, label='val_loss') |
| 105 | + val_patch = mpatches.Patch(color='orange',label='val_loss') |
| 106 | + handles.append(val_patch) |
| 107 | + |
| 108 | + plt.legend(handles=handles, loc=2) |
| 109 | + plt.title('training curves') |
| 110 | + plt.ylabel('loss') |
| 111 | + plt.xlabel('epochs') |
| 112 | + plt.show() |
| 113 | + |
| 114 | +# modified from the BaseLogger in file linked below |
| 115 | +# https://github.com/fchollet/keras/blob/master/keras/callbacks.py |
| 116 | +class LoggerPlotter(keras.callbacks.Callback): |
| 117 | + """Callback that accumulates epoch averages of metrics. |
| 118 | + and plots train and validation curves on end of epoch |
| 119 | + """ |
| 120 | + def __init__(self): |
| 121 | + self.hist_dict = {'loss':[], 'val_loss':[]} |
| 122 | + |
| 123 | + def on_epoch_begin(self, epoch, logs=None): |
| 124 | + self.seen = 0 |
| 125 | + self.totals = {} |
| 126 | + |
| 127 | + def on_batch_end(self, batch, logs=None): |
| 128 | + logs = logs or {} |
| 129 | + batch_size = logs.get('size', 0) |
| 130 | + self.seen += batch_size |
| 131 | + |
| 132 | + for k, v in logs.items(): |
| 133 | + if k in self.totals: |
| 134 | + self.totals[k] += v * batch_size |
| 135 | + else: |
| 136 | + self.totals[k] = v * batch_size |
| 137 | + |
| 138 | + |
| 139 | + def on_epoch_end(self, epoch, logs=None): |
| 140 | + if logs is not None: |
| 141 | + for k in self.params['metrics']: |
| 142 | + if k in self.totals: |
| 143 | + # Make value available to next callbacks. |
| 144 | + logs[k] = self.totals[k] / self.seen |
| 145 | + |
| 146 | + self.hist_dict['loss'].append(logs['loss']) |
| 147 | + if 'val_loss' in self.params['metrics']: |
| 148 | + self.hist_dict['val_loss'].append(logs['val_loss']) |
| 149 | + train_val_curve(self.hist_dict['loss'], self.hist_dict['val_loss']) |
| 150 | + else: |
| 151 | + train_val_curve(self.hist_dict['loss']) |
0 commit comments