forked from eragonruan/refinenet-image-segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 5363308
Showing
52 changed files
with
3,284 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
*.ckpt | ||
*.tfrecords | ||
checkpoints/ | ||
logs/ | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import re | ||
import cv2 | ||
import time | ||
import os,shutil | ||
import numpy as np | ||
import tensorflow as tf | ||
slim = tf.contrib.slim | ||
|
||
import model as model | ||
from matplotlib import pyplot as plt | ||
from utils.pascal_voc import pascal_segmentation_lut | ||
from utils.visualization import visualize_segmentation_adaptive | ||
|
||
tf.app.flags.DEFINE_string('test_data_path', 'demo', '') | ||
tf.app.flags.DEFINE_string('gpu_list', '0', '') | ||
tf.app.flags.DEFINE_integer('num_classes', 21, '') | ||
tf.app.flags.DEFINE_string('checkpoint_path', 'checkpoints/', '') | ||
tf.app.flags.DEFINE_string('result_path', 'result/', '') | ||
|
||
FLAGS = tf.app.flags.FLAGS | ||
|
||
|
||
def get_images(): | ||
files = [] | ||
exts = ['jpg', 'png', 'jpeg', 'JPG'] | ||
for parent, dirnames, filenames in os.walk(FLAGS.test_data_path): | ||
for filename in filenames: | ||
for ext in exts: | ||
if filename.endswith(ext): | ||
files.append(os.path.join(parent, filename)) | ||
break | ||
print 'Find {} images'.format(len(files)) | ||
return files | ||
|
||
def resize_image(im, size=32, max_side_len=2400): | ||
h, w, _ = im.shape | ||
resize_w = w | ||
resize_h = h | ||
if max(resize_h, resize_w) > max_side_len: | ||
ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w | ||
else: | ||
ratio = 1. | ||
resize_h = int(resize_h * ratio) | ||
resize_w = int(resize_w * ratio) | ||
resize_h = resize_h if resize_h % size == 0 else (resize_h // size) * size | ||
resize_w = resize_w if resize_w % size == 0 else (resize_w // size) * size | ||
im = cv2.resize(im, (int(resize_w), int(resize_h))) | ||
ratio_h = resize_h / float(h) | ||
ratio_w = resize_w / float(w) | ||
return im, (ratio_h, ratio_w) | ||
|
||
def main(argv=None): | ||
import os | ||
if os.path.exists(FLAGS.result_path): | ||
shutil.rmtree(FLAGS.result_path) | ||
os.makedirs(FLAGS.result_path) | ||
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list | ||
pascal_voc_lut = pascal_segmentation_lut() | ||
|
||
with tf.get_default_graph().as_default(): | ||
input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') | ||
global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) | ||
|
||
logits = model.model(input_images, is_training=False) | ||
pred = tf.argmax(logits, dimension=3) | ||
|
||
variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) | ||
saver = tf.train.Saver(variable_averages.variables_to_restore()) | ||
|
||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: | ||
ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path) | ||
model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path)) | ||
print('Restore from {}'.format(model_path)) | ||
saver.restore(sess, model_path) | ||
|
||
im_fn_list = get_images() | ||
for im_fn in im_fn_list: | ||
im = cv2.imread(im_fn)[:, :, ::-1] | ||
im_resized, (ratio_h, ratio_w) = resize_image(im, size=32) | ||
|
||
start = time.time() | ||
pred_re = sess.run([pred], feed_dict={input_images: [im_resized]}) | ||
pred_re = np.array(np.squeeze(pred_re)) | ||
|
||
img=visualize_segmentation_adaptive(pred_re, pascal_voc_lut) | ||
_diff_time = time.time() - start | ||
cv2.imwrite(os.path.join(FLAGS.result_path, os.path.basename(im_fn)), img) | ||
|
||
print('{}: cost {:.0f}ms').format(im_fn, _diff_time * 1000) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
tf.app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import tensorflow as tf | ||
from tensorflow.contrib import slim | ||
from nets import resnet_v1 | ||
from utils.training import get_valid_logits_and_labels | ||
FLAGS = tf.app.flags.FLAGS | ||
|
||
def unpool(inputs,scale): | ||
return tf.image.resize_bilinear(inputs, size=[tf.shape(inputs)[1]*scale, tf.shape(inputs)[2]*scale]) | ||
|
||
|
||
def ResidualConvUnit(inputs,features=256,kernel_size=3): | ||
net=tf.nn.relu(inputs) | ||
net=slim.conv2d(net, features, kernel_size) | ||
net=tf.nn.relu(net) | ||
net=slim.conv2d(net,features,kernel_size) | ||
net=tf.add(net,inputs) | ||
|
||
return net | ||
|
||
def MultiResolutionFusion(high_inputs=None,low_inputs=None,up0=2,up1=1,n_i=256): | ||
|
||
g0 = unpool(slim.conv2d(high_inputs, n_i, 3), scale=up0) | ||
|
||
if low_inputs is None: | ||
return g0 | ||
|
||
g1=unpool(slim.conv2d(low_inputs,n_i,3),scale=up1) | ||
return tf.add(g0,g1) | ||
|
||
def ChainedResidualPooling(inputs,n_i=256): | ||
net_relu=tf.nn.relu(inputs) | ||
net=slim.max_pool2d(net_relu, [5, 5],stride=1,padding='SAME') | ||
net=slim.conv2d(net,n_i,3) | ||
return tf.add(net,net_relu) | ||
|
||
def RefineBlock(high_inputs=None,low_inputs=None): | ||
if low_inputs is not None: | ||
print(high_inputs.shape) | ||
rcu_high=ResidualConvUnit(high_inputs,features=256) | ||
rcu_low=ResidualConvUnit(low_inputs,features=256) | ||
fuse=MultiResolutionFusion(rcu_high,rcu_low,up0=2,up1=1,n_i=256) | ||
fuse_pooling=ChainedResidualPooling(fuse,n_i=256) | ||
output=ResidualConvUnit(fuse_pooling,features=256) | ||
return output | ||
else: | ||
rcu_high = ResidualConvUnit(high_inputs, features=256) | ||
fuse = MultiResolutionFusion(rcu_high, low_inputs=None, up0=1, n_i=256) | ||
fuse_pooling = ChainedResidualPooling(fuse, n_i=256) | ||
output = ResidualConvUnit(fuse_pooling, features=256) | ||
return output | ||
|
||
|
||
def model(images, weight_decay=1e-5, is_training=True): | ||
images = mean_image_subtraction(images) | ||
|
||
with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)): | ||
logits, end_points = resnet_v1.resnet_v1_101(images, is_training=is_training, scope='resnet_v1_101') | ||
|
||
with tf.variable_scope('feature_fusion', values=[end_points.values]): | ||
batch_norm_params = { | ||
'decay': 0.997, | ||
'epsilon': 1e-5, | ||
'scale': True, | ||
'is_training': is_training | ||
} | ||
with slim.arg_scope([slim.conv2d], | ||
activation_fn=tf.nn.relu, | ||
normalizer_fn=slim.batch_norm, | ||
normalizer_params=batch_norm_params, | ||
weights_regularizer=slim.l2_regularizer(weight_decay)): | ||
f = [end_points['pool5'], end_points['pool4'], | ||
end_points['pool3'], end_points['pool2']] | ||
for i in range(4): | ||
print('Shape of f_{} {}'.format(i, f[i].shape)) | ||
|
||
g = [None, None, None, None] | ||
h = [None, None, None, None] | ||
|
||
for i in range(4): | ||
h[i]=slim.conv2d(f[i], 256, 1) | ||
for i in range(4): | ||
print('Shape of h_{} {}'.format(i, h[i].shape)) | ||
|
||
g[0]=RefineBlock(h[0]) | ||
g[1]=RefineBlock(g[0],h[1]) | ||
g[2]=RefineBlock(g[1],h[2]) | ||
g[3]=RefineBlock(g[2],h[3]) | ||
g[3]=unpool(g[3],scale=4) | ||
F_score = slim.conv2d(g[3], 21, 1, activation_fn=tf.nn.relu, normalizer_fn=None) | ||
|
||
return F_score | ||
|
||
|
||
def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]): | ||
num_channels = images.get_shape().as_list()[-1] | ||
if len(means) != num_channels: | ||
raise ValueError('len(means) must match the number of channels') | ||
channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images) | ||
for i in range(num_channels): | ||
channels[i] -= means[i] | ||
return tf.concat(axis=3, values=channels) | ||
|
||
def loss(annotation_batch,upsampled_logits_batch,class_labels): | ||
valid_labels_batch_tensor, valid_logits_batch_tensor = get_valid_logits_and_labels( | ||
annotation_batch_tensor=annotation_batch, | ||
logits_batch_tensor=upsampled_logits_batch, | ||
class_labels=class_labels) | ||
|
||
cross_entropies = tf.nn.softmax_cross_entropy_with_logits(logits=valid_logits_batch_tensor, | ||
labels=valid_labels_batch_tensor) | ||
|
||
cross_entropy_sum = tf.reduce_mean(cross_entropies) | ||
tf.summary.scalar('cross_entropy_loss', cross_entropy_sum) | ||
|
||
return cross_entropy_sum | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import tensorflow as tf | ||
from tensorflow.contrib import slim | ||
from nets import resnet_v1 | ||
from utils.training import get_valid_logits_and_labels | ||
FLAGS = tf.app.flags.FLAGS | ||
|
||
def unpool(inputs,scale): | ||
return tf.image.resize_bilinear(inputs, size=[tf.shape(inputs)[1]*scale, tf.shape(inputs)[2]*scale]) | ||
|
||
|
||
def ResidualConvUnit(inputs,features=256,kernel_size=3): | ||
net=tf.nn.relu(inputs) | ||
net=slim.conv2d(net, features, kernel_size) | ||
net=tf.nn.relu(net) | ||
net=slim.conv2d(net,features,kernel_size) | ||
net=tf.add(net,inputs) | ||
|
||
return net | ||
|
||
def MultiResolutionFusion(high_inputs=None,low_inputs=None,up0=2,up1=1,n_i=256): | ||
|
||
g0 = unpool(slim.conv2d(high_inputs, n_i, 3), scale=up0) | ||
|
||
if low_inputs is None: | ||
return g0 | ||
|
||
g1=unpool(slim.conv2d(low_inputs,n_i,3),scale=up1) | ||
return tf.add(g0,g1) | ||
|
||
def ChainedResidualPooling(inputs,n_i=256): | ||
net_relu=tf.nn.relu(inputs) | ||
net=slim.max_pool2d(net_relu, [5, 5],stride=1,padding='SAME') | ||
net=slim.conv2d(net,n_i,3) | ||
return tf.add(net,net_relu) | ||
|
||
def RefineBlock(high_inputs=None,low_inputs=None): | ||
if low_inputs is not None: | ||
print(high_inputs.shape) | ||
rcu_high=ResidualConvUnit(high_inputs,features=256) | ||
rcu_low=ResidualConvUnit(low_inputs,features=256) | ||
fuse=MultiResolutionFusion(rcu_high,rcu_low,up0=2,up1=1,n_i=256) | ||
fuse_pooling=ChainedResidualPooling(fuse,n_i=256) | ||
output=ResidualConvUnit(fuse_pooling,features=256) | ||
return output | ||
else: | ||
rcu_high = ResidualConvUnit(high_inputs, features=256) | ||
fuse = MultiResolutionFusion(rcu_high, low_inputs=None, up0=1, n_i=256) | ||
fuse_pooling = ChainedResidualPooling(fuse, n_i=256) | ||
output = ResidualConvUnit(fuse_pooling, features=256) | ||
return output | ||
|
||
|
||
def model(images, weight_decay=1e-5, is_training=True): | ||
images = mean_image_subtraction(images) | ||
|
||
with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)): | ||
logits, end_points = resnet_v1.resnet_v1_101(images, is_training=is_training, scope='resnet_v1_101') | ||
|
||
with tf.variable_scope('feature_fusion', values=[end_points.values]): | ||
batch_norm_params = { | ||
'decay': 0.997, | ||
'epsilon': 1e-5, | ||
'scale': True, | ||
'is_training': is_training | ||
} | ||
with slim.arg_scope([slim.conv2d], | ||
activation_fn=tf.nn.relu, | ||
normalizer_fn=slim.batch_norm, | ||
normalizer_params=batch_norm_params, | ||
weights_regularizer=slim.l2_regularizer(weight_decay)): | ||
f = [end_points['pool5'], end_points['pool4'], | ||
end_points['pool3'], end_points['pool2']] | ||
for i in range(4): | ||
print('Shape of f_{} {}'.format(i, f[i].shape)) | ||
|
||
g = [None, None, None, None] | ||
h = [None, None, None, None] | ||
|
||
for i in range(4): | ||
h[i]=slim.conv2d(f[i], 256, 1) | ||
for i in range(4): | ||
print('Shape of h_{} {}'.format(i, h[i].shape)) | ||
|
||
g[0]=RefineBlock(h[0]) | ||
g[1]=RefineBlock(g[0],h[1]) | ||
g[2]=RefineBlock(g[1],h[2]) | ||
g[3]=RefineBlock(g[2],h[3]) | ||
g[3]=unpool(g[3],scale=4) | ||
F_score = slim.conv2d(g[3], 21, 1, activation_fn=tf.nn.relu, normalizer_fn=None) | ||
|
||
return F_score | ||
|
||
|
||
def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]): | ||
num_channels = images.get_shape().as_list()[-1] | ||
if len(means) != num_channels: | ||
raise ValueError('len(means) must match the number of channels') | ||
channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images) | ||
for i in range(num_channels): | ||
channels[i] -= means[i] | ||
return tf.concat(axis=3, values=channels) | ||
|
||
def loss(annotation_batch,upsampled_logits_batch,class_labels): | ||
valid_labels_batch_tensor, valid_logits_batch_tensor = get_valid_logits_and_labels( | ||
annotation_batch_tensor=annotation_batch, | ||
logits_batch_tensor=upsampled_logits_batch, | ||
class_labels=class_labels) | ||
|
||
cross_entropies = tf.nn.softmax_cross_entropy_with_logits(logits=valid_logits_batch_tensor, | ||
labels=valid_labels_batch_tensor) | ||
|
||
cross_entropy_sum = tf.reduce_mean(cross_entropies) | ||
tf.summary.scalar('cross_entropy_loss', cross_entropy_sum) | ||
|
||
return cross_entropy_sum | ||
|
Oops, something went wrong.