Skip to content

Commit

Permalink
new branch dev
Browse files Browse the repository at this point in the history
  • Loading branch information
eragonruan committed Nov 2, 2017
1 parent ad38eca commit ef77407
Show file tree
Hide file tree
Showing 15 changed files with 216 additions and 349 deletions.
6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

131 changes: 34 additions & 97 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# refinenet(in work condition)
# refinenet
a tensorflow implement of refinenet. RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation


Expand All @@ -19,7 +19,7 @@ this is a tensorflow implement of refinenet discribed in [arxiv:1611.06612](http
- put images in demo/ and run python RefineNet/demo.py

## roadmap
- [ ] python2/3 compatibility
- [x] python2/3 compatibility
- [ ] Complete realization of refinenet model
- [ ] test on pascal voc, give the IoU result
- [ ] training on other datasets
Expand Down
4 changes: 3 additions & 1 deletion RefineNet/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import cv2
import time
import os,shutil
import sys
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim

import model as model
sys.path.append(os.getcwd())
from nets import model as model
from matplotlib import pyplot as plt
from utils.pascal_voc import pascal_segmentation_lut
from utils.visualization import visualize_segmentation_adaptive
Expand Down
33 changes: 12 additions & 21 deletions RefineNet/multi_gpu_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,35 @@
import shutil
import datetime
import os
import pickle
import cv2
import pickle
import numpy as np
from tensorflow.contrib import slim
import sys
sys.path.append(os.getcwd())
import model as model
from utils.input import get_batch,generator
from nets import model as model
from utils.tf_records import read_tfrecord_and_decode_into_image_annotation_pair_tensors
from utils.pascal_voc import pascal_segmentation_lut
from utils.augmentation import (distort_randomly_image_color,flip_randomly_left_right_image_with_annotation,
scale_randomly_image_with_annotation_with_fixed_size_output)


tf.app.flags.DEFINE_integer('batch_size', 2, '')
tf.app.flags.DEFINE_integer('train_size', 384, '')
tf.app.flags.DEFINE_integer('batch_size', 3, '')
tf.app.flags.DEFINE_integer('train_size', 512, '')
tf.app.flags.DEFINE_float('learning_rate', 0.0001, '')
tf.app.flags.DEFINE_integer('max_steps', 100000, '')
tf.app.flags.DEFINE_integer('max_steps', 60000, '')
tf.app.flags.DEFINE_float('moving_average_decay', 0.997, '')
tf.app.flags.DEFINE_integer('num_classes', 21, '')
tf.app.flags.DEFINE_string('gpu_list', '0', '')
tf.app.flags.DEFINE_string('gpu_list', '0,1', '')
tf.app.flags.DEFINE_string('checkpoint_path', 'checkpoints/', '')
tf.app.flags.DEFINE_string('logs_path', 'logs/', '')
tf.app.flags.DEFINE_boolean('restore', False, 'whether to resotre from checkpoint')
tf.app.flags.DEFINE_boolean('restore', True, 'whether to resotre from checkpoint')
tf.app.flags.DEFINE_integer('save_checkpoint_steps', 2000, '')
tf.app.flags.DEFINE_integer('save_summary_steps', 10, '')
tf.app.flags.DEFINE_integer('save_image_steps', 10, '')
tf.app.flags.DEFINE_integer('save_image_steps', 100, '')
tf.app.flags.DEFINE_string('training_data_path', 'data/pascal_augmented_train.tfrecords', '')
tf.app.flags.DEFINE_string('pretrained_model_path', 'data/resnet_v1_101.ckpt', '')
tf.app.flags.DEFINE_integer('decay_steps',40000,'')
tf.app.flags.DEFINE_integer('decay_steps',20000,'')
tf.app.flags.DEFINE_integer('decay_rate',0.1,'')
FLAGS = tf.app.flags.FLAGS

Expand Down Expand Up @@ -94,10 +93,6 @@ def main(argv=None):
shutil.rmtree(FLAGS.checkpoint_path)
os.makedirs(FLAGS.checkpoint_path)


#input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images')
#input_segs = tf.placeholder(tf.float32, shape=[None, None,None, 1], name='input_segs')

filename_queue = tf.train.string_input_producer([FLAGS.training_data_path], num_epochs=1000)
image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(filename_queue)

Expand Down Expand Up @@ -177,17 +172,12 @@ def main(argv=None):
if FLAGS.pretrained_model_path is not None:
variable_restore_op(sess)

#data_generator = get_batch(num_workers=8, batch_size=FLAGS.batch_size * len(gpus))
#data_generator = generator(batch_size=FLAGS.batch_size * len(gpus))
start = time.time()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
for step in range(restore_step,FLAGS.max_steps):

#data = next(data_generator)

if step != 0 and step % FLAGS.decay_steps == 0:
sess.run(tf.assign(learning_rate, learning_rate.eval() * FLAGS.decay_rate))

Expand Down Expand Up @@ -219,6 +209,8 @@ def main(argv=None):
seg_split=np.squeeze(seg_split)[0]
pred=np.squeeze(pred)[0]

#img_split=cv2.resize(img_split,(128,128))

color_seg = np.zeros((seg_split.shape[0], seg_split.shape[1], 3))
for i in range(seg_split.shape[0]):
for j in range(seg_split.shape[1]):
Expand All @@ -229,11 +221,10 @@ def main(argv=None):
for j in range(pred.shape[1]):
color_pred[i, j, :] = color_map[str(pred[i][j])]

write_img=np.hstack((img_split,color_seg,color_pred))
write_img=np.hstack((color_seg,color_pred))
log_image_summary_op = sess.run(log_image,feed_dict={log_image_name: log_image_name_str, \
log_image_data: write_img})
summary_writer.add_summary(log_image_summary_op, global_step=step)

except tf.errors.OutOfRangeError:
print('finish')
finally:
Expand Down
Loading

0 comments on commit ef77407

Please sign in to comment.