|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- encoding: utf-8 -*- |
| 3 | +""" |
| 4 | +@Author : Qingping Zheng |
| 5 | + |
| 6 | +@File : datasets.py |
| 7 | +@Time : 10/01/21 00:00 PM |
| 8 | +@Desc : |
| 9 | +@License : Licensed under the Apache License, Version 2.0 (the "License"); |
| 10 | +@Copyright : Copyright 2015 The Authors. All Rights Reserved. |
| 11 | +""" |
| 12 | +from __future__ import absolute_import |
| 13 | +from __future__ import division |
| 14 | +from __future__ import print_function |
| 15 | + |
| 16 | +import argparse |
| 17 | +import cv2 |
| 18 | +import numpy as np |
| 19 | +import matplotlib.pyplot as plt |
| 20 | +import os |
| 21 | +import torch |
| 22 | +import torch.backends.cudnn as cudnn |
| 23 | +import torchvision.transforms as transforms |
| 24 | + |
| 25 | +from copy import deepcopy |
| 26 | +from inplace_abn import InPlaceABN |
| 27 | +from dataset import datasets |
| 28 | +from networks import dml_csr |
| 29 | +from utils import miou |
| 30 | + |
| 31 | +torch.multiprocessing.set_start_method("spawn", force=True) |
| 32 | + |
| 33 | +DATA_DIRECTORY = './datasets/Helen' |
| 34 | +IGNORE_LABEL = 255 |
| 35 | +NUM_CLASSES = 20 |
| 36 | +SNAPSHOT_DIR = './snapshots/' |
| 37 | +INPUT_SIZE = (473,473) |
| 38 | + |
| 39 | + |
| 40 | +def get_arguments(): |
| 41 | + """Parse all the arguments provided from the CLI. |
| 42 | + |
| 43 | + Returns: |
| 44 | + A list of parsed arguments. |
| 45 | + """ |
| 46 | + parser = argparse.ArgumentParser(description="DML_CSR Network") |
| 47 | + parser.add_argument("--batch-size", type=int, default=1, |
| 48 | + help="Number of images sent to the network in one step.") |
| 49 | + parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, |
| 50 | + help="Path to the directory containing the PASCAL VOC dataset.") |
| 51 | + parser.add_argument("--out-dir", type=str, default=DATA_DIRECTORY, |
| 52 | + help="Path to the directory containing the PASCAL VOC dataset.") |
| 53 | + parser.add_argument("--dataset", type=str, default='val', |
| 54 | + help="Path to the file listing the images in the dataset.") |
| 55 | + parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, |
| 56 | + help="The index of the label to ignore during the training.") |
| 57 | + parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, |
| 58 | + help="Number of classes to predict (including background).") |
| 59 | + parser.add_argument("--restore-from", type=str, |
| 60 | + help="Where restore model parameters from.") |
| 61 | + parser.add_argument("--gpu", type=str, default='7', |
| 62 | + help="choose gpu device.") |
| 63 | + parser.add_argument("--input-size", type=str, default=INPUT_SIZE, |
| 64 | + help="Comma-separated string with height and width of images.") |
| 65 | + parser.add_argument("--local_rank", type=int, default=0, |
| 66 | + help="choose gpu numbers") |
| 67 | + parser.add_argument('--dist-backend', default='nccl', type=str, |
| 68 | + help='distributed backend') |
| 69 | + parser.add_argument("--model_type", type=int, default=0, |
| 70 | + help="choose model type") |
| 71 | + return parser.parse_args() |
| 72 | + |
| 73 | + |
| 74 | +def valid(model, valloader, input_size, num_samples, dir=None, dir_edge=None, dir_img=None): |
| 75 | + |
| 76 | + height = input_size[0] |
| 77 | + width = input_size[1] |
| 78 | + with torch.autograd.profiler.profile(enabled=True, use_cuda=True, \ |
| 79 | + record_shapes=False, profile_memory=False) as prof: |
| 80 | + model.eval() |
| 81 | + parsing_preds = np.zeros((num_samples, height, width), dtype=np.uint8) |
| 82 | + scales = np.zeros((num_samples, 2), dtype=np.float32) |
| 83 | + centers = np.zeros((num_samples, 2), dtype=np.int32) |
| 84 | + |
| 85 | + idx = 0 |
| 86 | + interp = torch.nn.Upsample(size=(height, width), mode='bilinear', align_corners=True) |
| 87 | + |
| 88 | + with torch.no_grad(): |
| 89 | + for index, batch in enumerate(valloader): |
| 90 | + image, meta = batch |
| 91 | + num_images = image.size(0) |
| 92 | + if index % 10 == 0: |
| 93 | + print('%d processd' % (index * num_images)) |
| 94 | + |
| 95 | + c = meta['center'].numpy() |
| 96 | + s = meta['scale'].numpy() |
| 97 | + scales[idx:idx + num_images, :] = s[:, :] |
| 98 | + centers[idx:idx + num_images, :] = c[:, :] |
| 99 | + |
| 100 | + results = model(image.cuda()) |
| 101 | + outputs = results |
| 102 | + |
| 103 | + if isinstance(results, list): |
| 104 | + outputs = results[0] |
| 105 | + |
| 106 | + if isinstance(outputs, list): |
| 107 | + for k, output in enumerate(outputs): |
| 108 | + parsing = output |
| 109 | + nums = len(parsing) |
| 110 | + parsing = interp(parsing).data.cpu().numpy() |
| 111 | + parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC |
| 112 | + parsing_preds[idx:idx + nums, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8) |
| 113 | + idx += nums |
| 114 | + else: |
| 115 | + parsing = outputs |
| 116 | + parsing = interp(parsing).data.cpu().numpy() |
| 117 | + parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC |
| 118 | + parsing_preds[idx:idx + num_images, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8) |
| 119 | + |
| 120 | + if dir is not None: |
| 121 | + for i in range(len(meta['name'])): |
| 122 | + cv2.imwrite(os.path.join(dir, meta['name'][i] + '.png'), np.asarray(np.argmax(parsing, axis=3))[i]) |
| 123 | + idx += num_images |
| 124 | + parsing_preds = parsing_preds[:num_samples, :, :] |
| 125 | + |
| 126 | + return parsing_preds, scales, centers |
| 127 | + |
| 128 | + |
| 129 | +def main(): |
| 130 | + """Create the model and start the evaluation process.""" |
| 131 | + |
| 132 | + args = get_arguments() |
| 133 | + |
| 134 | + os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu |
| 135 | + gpus = [int(i) for i in args.gpu.split(',')] |
| 136 | + |
| 137 | + print(args.gpu) |
| 138 | + |
| 139 | + h, w = map(int, args.input_size.split(',')) |
| 140 | + |
| 141 | + input_size = (h, w) |
| 142 | + |
| 143 | + cudnn.benchmark = True |
| 144 | + cudnn.enabled = True |
| 145 | + |
| 146 | + model = dml_csr.DML_CSR(args.num_classes, InPlaceABN, False) |
| 147 | + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| 148 | + std=[0.229, 0.224, 0.225]) |
| 149 | + |
| 150 | + transform = transforms.Compose([ |
| 151 | + transforms.ToTensor(), |
| 152 | + normalize, |
| 153 | + ]) |
| 154 | + |
| 155 | + dataset = datasets.FaceDataSet(args.data_dir, args.dataset, \ |
| 156 | + crop_size=input_size, transform=transform) |
| 157 | + num_samples = len(dataset) |
| 158 | + |
| 159 | + valloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, \ |
| 160 | + shuffle=False, pin_memory=True) |
| 161 | + |
| 162 | + restore_from = args.restore_from |
| 163 | + print(restore_from) |
| 164 | + state_dict = torch.load(restore_from,map_location='cuda:0') |
| 165 | + model.load_state_dict(state_dict) |
| 166 | + |
| 167 | + model.cuda() |
| 168 | + model.eval() |
| 169 | + |
| 170 | + save_path = os.path.join(args.out_dir, args.dataset, 'parsing') |
| 171 | + if not os.path.exists(save_path): |
| 172 | + os.makedirs(save_path) |
| 173 | + |
| 174 | + parsing_preds, scales, centers = valid(model, valloader, input_size, num_samples, save_path) |
| 175 | + mIoU, f1 = miou.compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, \ |
| 176 | + input_size, args.dataset, reverse=True) |
| 177 | + |
| 178 | + print(mIoU) |
| 179 | + print(f1) |
| 180 | + |
| 181 | +if __name__ == '__main__': |
| 182 | + main() |
0 commit comments