Skip to content

Commit af2d71f

Browse files
Create test.py
1 parent 9c78982 commit af2d71f

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

parsing/dml_csr/test.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
"""
4+
@Author : Qingping Zheng
5+
6+
@File : datasets.py
7+
@Time : 10/01/21 00:00 PM
8+
@Desc :
9+
@License : Licensed under the Apache License, Version 2.0 (the "License");
10+
@Copyright : Copyright 2015 The Authors. All Rights Reserved.
11+
"""
12+
from __future__ import absolute_import
13+
from __future__ import division
14+
from __future__ import print_function
15+
16+
import argparse
17+
import cv2
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
import os
21+
import torch
22+
import torch.backends.cudnn as cudnn
23+
import torchvision.transforms as transforms
24+
25+
from copy import deepcopy
26+
from inplace_abn import InPlaceABN
27+
from dataset import datasets
28+
from networks import dml_csr
29+
from utils import miou
30+
31+
torch.multiprocessing.set_start_method("spawn", force=True)
32+
33+
DATA_DIRECTORY = './datasets/Helen'
34+
IGNORE_LABEL = 255
35+
NUM_CLASSES = 20
36+
SNAPSHOT_DIR = './snapshots/'
37+
INPUT_SIZE = (473,473)
38+
39+
40+
def get_arguments():
41+
"""Parse all the arguments provided from the CLI.
42+
43+
Returns:
44+
A list of parsed arguments.
45+
"""
46+
parser = argparse.ArgumentParser(description="DML_CSR Network")
47+
parser.add_argument("--batch-size", type=int, default=1,
48+
help="Number of images sent to the network in one step.")
49+
parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY,
50+
help="Path to the directory containing the PASCAL VOC dataset.")
51+
parser.add_argument("--out-dir", type=str, default=DATA_DIRECTORY,
52+
help="Path to the directory containing the PASCAL VOC dataset.")
53+
parser.add_argument("--dataset", type=str, default='val',
54+
help="Path to the file listing the images in the dataset.")
55+
parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,
56+
help="The index of the label to ignore during the training.")
57+
parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
58+
help="Number of classes to predict (including background).")
59+
parser.add_argument("--restore-from", type=str,
60+
help="Where restore model parameters from.")
61+
parser.add_argument("--gpu", type=str, default='7',
62+
help="choose gpu device.")
63+
parser.add_argument("--input-size", type=str, default=INPUT_SIZE,
64+
help="Comma-separated string with height and width of images.")
65+
parser.add_argument("--local_rank", type=int, default=0,
66+
help="choose gpu numbers")
67+
parser.add_argument('--dist-backend', default='nccl', type=str,
68+
help='distributed backend')
69+
parser.add_argument("--model_type", type=int, default=0,
70+
help="choose model type")
71+
return parser.parse_args()
72+
73+
74+
def valid(model, valloader, input_size, num_samples, dir=None, dir_edge=None, dir_img=None):
75+
76+
height = input_size[0]
77+
width = input_size[1]
78+
with torch.autograd.profiler.profile(enabled=True, use_cuda=True, \
79+
record_shapes=False, profile_memory=False) as prof:
80+
model.eval()
81+
parsing_preds = np.zeros((num_samples, height, width), dtype=np.uint8)
82+
scales = np.zeros((num_samples, 2), dtype=np.float32)
83+
centers = np.zeros((num_samples, 2), dtype=np.int32)
84+
85+
idx = 0
86+
interp = torch.nn.Upsample(size=(height, width), mode='bilinear', align_corners=True)
87+
88+
with torch.no_grad():
89+
for index, batch in enumerate(valloader):
90+
image, meta = batch
91+
num_images = image.size(0)
92+
if index % 10 == 0:
93+
print('%d processd' % (index * num_images))
94+
95+
c = meta['center'].numpy()
96+
s = meta['scale'].numpy()
97+
scales[idx:idx + num_images, :] = s[:, :]
98+
centers[idx:idx + num_images, :] = c[:, :]
99+
100+
results = model(image.cuda())
101+
outputs = results
102+
103+
if isinstance(results, list):
104+
outputs = results[0]
105+
106+
if isinstance(outputs, list):
107+
for k, output in enumerate(outputs):
108+
parsing = output
109+
nums = len(parsing)
110+
parsing = interp(parsing).data.cpu().numpy()
111+
parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC
112+
parsing_preds[idx:idx + nums, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8)
113+
idx += nums
114+
else:
115+
parsing = outputs
116+
parsing = interp(parsing).data.cpu().numpy()
117+
parsing = parsing.transpose(0, 2, 3, 1) # NCHW NHWC
118+
parsing_preds[idx:idx + num_images, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8)
119+
120+
if dir is not None:
121+
for i in range(len(meta['name'])):
122+
cv2.imwrite(os.path.join(dir, meta['name'][i] + '.png'), np.asarray(np.argmax(parsing, axis=3))[i])
123+
idx += num_images
124+
parsing_preds = parsing_preds[:num_samples, :, :]
125+
126+
return parsing_preds, scales, centers
127+
128+
129+
def main():
130+
"""Create the model and start the evaluation process."""
131+
132+
args = get_arguments()
133+
134+
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
135+
gpus = [int(i) for i in args.gpu.split(',')]
136+
137+
print(args.gpu)
138+
139+
h, w = map(int, args.input_size.split(','))
140+
141+
input_size = (h, w)
142+
143+
cudnn.benchmark = True
144+
cudnn.enabled = True
145+
146+
model = dml_csr.DML_CSR(args.num_classes, InPlaceABN, False)
147+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
148+
std=[0.229, 0.224, 0.225])
149+
150+
transform = transforms.Compose([
151+
transforms.ToTensor(),
152+
normalize,
153+
])
154+
155+
dataset = datasets.FaceDataSet(args.data_dir, args.dataset, \
156+
crop_size=input_size, transform=transform)
157+
num_samples = len(dataset)
158+
159+
valloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, \
160+
shuffle=False, pin_memory=True)
161+
162+
restore_from = args.restore_from
163+
print(restore_from)
164+
state_dict = torch.load(restore_from,map_location='cuda:0')
165+
model.load_state_dict(state_dict)
166+
167+
model.cuda()
168+
model.eval()
169+
170+
save_path = os.path.join(args.out_dir, args.dataset, 'parsing')
171+
if not os.path.exists(save_path):
172+
os.makedirs(save_path)
173+
174+
parsing_preds, scales, centers = valid(model, valloader, input_size, num_samples, save_path)
175+
mIoU, f1 = miou.compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, \
176+
input_size, args.dataset, reverse=True)
177+
178+
print(mIoU)
179+
print(f1)
180+
181+
if __name__ == '__main__':
182+
main()

0 commit comments

Comments
 (0)