diff --git a/neural_gpu/BUILD b/neural_gpu/BUILD new file mode 100644 index 00000000000..28079d8b34f --- /dev/null +++ b/neural_gpu/BUILD @@ -0,0 +1,41 @@ +py_library( + name = "data_utils", + srcs = [ + "data_utils.py", + ], + deps = [ + "//file/colossus/public:cns", + "//third_party/py/numpy", + "//third_party/py/tensorflow", + ], +) + +py_library( + name = "neural_gpu", + srcs = [ + "neural_gpu.py", + ], + deps = [ + ":data_utils", + "//third_party/py/numpy", + "//third_party/py/tensorflow", + ], +) + +py_binary( + name = "neural_gpu_trainer", + srcs = [ + "neural_gpu_trainer.py", + ], + launcher = "//devtools/python/launcher", + malloc = "//tcmalloc:tcmalloc_or_debug", + deps = [ + ":neural_gpu", + "//file/colossus/public:cns", + "//net/proto2/python/public:use_fast_cpp_protos", + "//third_party/py/Tkinter", + "//third_party/py/matplotlib", + "//third_party/py/numpy", + "//third_party/py/tensorflow", + ], +) diff --git a/neural_gpu/README.md b/neural_gpu/README.md new file mode 100644 index 00000000000..769548b4f09 --- /dev/null +++ b/neural_gpu/README.md @@ -0,0 +1,4 @@ +# NeuralGPU +Code for the Neural GPU model as described +in [[http://arxiv.org/abs/1511.08228]]. + diff --git a/neural_gpu/data_utils.py b/neural_gpu/data_utils.py new file mode 100644 index 00000000000..117a1a1a2b2 --- /dev/null +++ b/neural_gpu/data_utils.py @@ -0,0 +1,244 @@ +"""Convolutional Gated Recurrent Networks for Algorithm Learning.""" + +import math +import random +import sys +import time + +import google3 + +import numpy as np +import tensorflow as tf + +from google3.third_party.tensorflow.python.platform import gfile + +FLAGS = tf.app.flags.FLAGS + +bins = [8, 16, 32, 64, 128] +all_tasks = ["sort", "id", "rev", "incr", "left", "right", "left-shift", "add", + "right-shift", "bmul", "dup", "badd", "qadd"] +forward_max = 128 +log_filename = "" + + +def pad(l): + for b in bins: + if b >= l: return b + return forward_max + + +train_set = {} +test_set = {} +for some_task in all_tasks: + train_set[some_task] = [] + test_set[some_task] = [] + for all_max_len in xrange(10000): + train_set[some_task].append([]) + test_set[some_task].append([]) + + +def add(n1, n2, base=10): + """Add two numbers represented as lower-endian digit lists.""" + k = max(len(n1), len(n2)) + 1 + d1 = n1 + [0 for _ in xrange(k - len(n1))] + d2 = n2 + [0 for _ in xrange(k - len(n2))] + res = [] + carry = 0 + for i in xrange(k): + if d1[i] + d2[i] + carry < base: + res.append(d1[i] + d2[i] + carry) + carry = 0 + else: + res.append(d1[i] + d2[i] + carry - base) + carry = 1 + while res and res[-1] == 0: + res = res[:-1] + if res: return res + return [0] + + +def init_data(task, length, nbr_cases, nclass): + """Data initialization.""" + def rand_pair(l, task): + """Random data pair for a task. Total length should be <= l.""" + k = (l-1)/2 + base = 10 + if task[0] == "b": base = 2 + if task[0] == "q": base = 4 + d1 = [np.random.randint(base) for _ in xrange(k)] + d2 = [np.random.randint(base) for _ in xrange(k)] + if task in ["add", "badd", "qadd"]: + res = add(d1, d2, base) + elif task in ["bmul"]: + d1n = sum([d * (base ** i) for i, d in enumerate(d1)]) + d2n = sum([d * (base ** i) for i, d in enumerate(d2)]) + res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]] + else: + sys.exit() + sep = [12] + if task in ["add", "badd", "qadd"]: sep = [11] + inp = [d + 1 for d in d1] + sep + [d + 1 for d in d2] + return inp, [r + 1 for r in res] + + def rand_dup_pair(l): + """Random data pair for duplication task. Total length should be <= l.""" + k = l/2 + x = [np.random.randint(nclass - 1) + 1 for _ in xrange(k)] + inp = x + [0 for _ in xrange(l - k)] + res = x + x + [0 for _ in xrange(l - 2*k)] + return inp, res + + def spec(inp): + """Return the target given the input for some tasks.""" + if task == "sort": + return sorted(inp) + elif task == "id": + return inp + elif task == "rev": + return [i for i in reversed(inp)] + elif task == "incr": + carry = 1 + res = [] + for i in xrange(len(inp)): + if inp[i] + carry < nclass: + res.append(inp[i] + carry) + carry = 0 + else: + res.append(1) + carry = 1 + return res + elif task == "left": + return [inp[0]] + elif task == "right": + return [inp[-1]] + elif task == "left-shift": + return [inp[l-1] for l in xrange(len(inp))] + elif task == "right-shift": + return [inp[l+1] for l in xrange(len(inp))] + else: + print_out("Unknown spec for task " + str(task)) + sys.exit() + + l = length + cur_time = time.time() + total_time = 0.0 + for case in xrange(nbr_cases): + total_time += time.time() - cur_time + cur_time = time.time() + if l > 10000 and case % 100 == 1: + print_out(" avg gen time %.4f s" % (total_time / float(case))) + if task in ["add", "badd", "qadd", "bmul"]: + i, t = rand_pair(l, task) + train_set[task][len(i)].append([i, t]) + i, t = rand_pair(l, task) + test_set[task][len(i)].append([i, t]) + elif task == "dup": + i, t = rand_dup_pair(l) + train_set[task][len(i)].append([i, t]) + i, t = rand_dup_pair(l) + test_set[task][len(i)].append([i, t]) + else: + inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)] + target = spec(inp) + train_set[task][l].append([inp, target]) + inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)] + target = spec(inp) + test_set[task][l].append([inp, target]) + + +def get_batch(max_length, batch_size, do_train, task, offset=None, preset=None): + """Get a batch of data, training or testing.""" + inputs = [] + targets = [] + length = max_length + if preset is None: + cur_set = test_set[task] + if do_train: cur_set = train_set[task] + while not cur_set[length]: + length -= 1 + pad_length = pad(length) + for b in xrange(batch_size): + if preset is None: + elem = random.choice(cur_set[length]) + if offset is not None and offset + b < len(cur_set[length]): + elem = cur_set[length][offset + b] + else: + elem = preset + inp, target = elem[0], elem[1] + assert len(inp) == length + inputs.append(inp + [0 for l in xrange(pad_length - len(inp))]) + targets.append(target + [0 for l in xrange(pad_length - len(target))]) + res_input = [] + res_target = [] + for l in xrange(pad_length): + new_input = np.array([inputs[b][l] for b in xrange(batch_size)], + dtype=np.int32) + new_target = np.array([targets[b][l] for b in xrange(batch_size)], + dtype=np.int32) + res_input.append(new_input) + res_target.append(new_target) + return res_input, res_target + + +def print_out(s, newline=True): + """Print a message out and log it to file.""" + if log_filename: + try: + with gfile.GFile(log_filename, mode="a") as f: + f.write(s + ("\n" if newline else "")) + # pylint: disable=bare-except + except: + sys.stdout.write("Error appending to %s\n" % log_filename) + sys.stdout.write(s + ("\n" if newline else "")) + sys.stdout.flush() + + +def decode(output): + return [np.argmax(o, axis=1) for o in output] + + +def accuracy(inpt, output, target, batch_size, nprint): + """Calculate output accuracy given target.""" + assert nprint < batch_size + 1 + def task_print(inp, output, target): + stop_bound = 0 + print_len = 0 + while print_len < len(target) and target[print_len] > stop_bound: + print_len += 1 + print_out(" i: " + " ".join([str(i - 1) for i in inp if i > 0])) + print_out(" o: " + + " ".join([str(output[l] - 1) for l in xrange(print_len)])) + print_out(" t: " + + " ".join([str(target[l] - 1) for l in xrange(print_len)])) + decoded_target = target + decoded_output = decode(output) + total = 0 + errors = 0 + seq = [0 for b in xrange(batch_size)] + for l in xrange(len(decoded_output)): + for b in xrange(batch_size): + if decoded_target[l][b] > 0: + total += 1 + if decoded_output[l][b] != decoded_target[l][b]: + seq[b] = 1 + errors += 1 + e = 0 # Previous error index + for _ in xrange(min(nprint, sum(seq))): + while seq[e] == 0: + e += 1 + task_print([inpt[l][e] for l in xrange(len(inpt))], + [decoded_output[l][e] for l in xrange(len(decoded_target))], + [decoded_target[l][e] for l in xrange(len(decoded_target))]) + e += 1 + for b in xrange(nprint - errors): + task_print([inpt[l][b] for l in xrange(len(inpt))], + [decoded_output[l][b] for l in xrange(len(decoded_target))], + [decoded_target[l][b] for l in xrange(len(decoded_target))]) + return errors, total, sum(seq) + + +def safe_exp(x): + perp = 10000 + if x < 100: perp = math.exp(x) + if perp > 10000: return 10000 + return perp diff --git a/neural_gpu/neural_gpu.py b/neural_gpu/neural_gpu.py new file mode 100644 index 00000000000..d83eaaf6b99 --- /dev/null +++ b/neural_gpu/neural_gpu.py @@ -0,0 +1,271 @@ +"""The Neural GPU Model.""" + +import time + +import google3 + +import tensorflow as tf + +from google3.experimental.users.lukaszkaiser.neural_gpu import data_utils + + +def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix): + """Convolutional linear map.""" + assert args + if not isinstance(args, (list, tuple)): + args = [args] + with tf.variable_scope(prefix): + k = tf.get_variable("CvK", [kw, kh, nin, nout]) + if len(args) == 1: + res = tf.nn.conv2d(args[0], k, [1, 1, 1, 1], "SAME") + else: + res = tf.nn.conv2d(tf.concat(3, args), k, [1, 1, 1, 1], "SAME") + if not do_bias: return res + bias_term = tf.get_variable("CvB", [nout], + initializer=tf.constant_initializer(0.0)) + return res + bias_term + bias_start + + +def sigmoid_cutoff(x, cutoff): + """Sigmoid with cutoff, e.g., 1.2sigmoid(x) - 0.1.""" + y = tf.sigmoid(x) + if cutoff < 1.01: return y + d = (cutoff - 1.0) / 2.0 + return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d)) + + +def conv_gru(inpts, mem, kw, kh, nmaps, cutoff, prefix): + """Convolutional GRU.""" + def conv_lin(args, suffix, bias_start): + return conv_linear(args, kw, kh, len(args) * nmaps, nmaps, True, bias_start, + prefix + "/" + suffix) + reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff) + candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0)) + gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff) + return gate * mem + (1 - gate) * candidate + + +def relaxed_average(var_name_suffix, rx_step): + """Calculate the average of relaxed variables having var_name_suffix.""" + relaxed_vars = [] + for l in xrange(rx_step): + with tf.variable_scope("RX%d" % l, reuse=True): + try: + relaxed_vars.append(tf.get_variable(var_name_suffix)) + except ValueError: + pass + dsum = tf.add_n(relaxed_vars) + avg = dsum / len(relaxed_vars) + diff = [v - avg for v in relaxed_vars] + davg = tf.add_n([d*d for d in diff]) + return avg, tf.reduce_sum(davg) + + +def relaxed_distance(rx_step): + """Distance between relaxed variables and their average.""" + res, ops, rx_done = [], [], {} + for v in tf.trainable_variables(): + if v.name[0:2] == "RX": + rx_name = v.op.name[v.name.find("/") + 1:] + if rx_name not in rx_done: + avg, dist_loss = relaxed_average(rx_name, rx_step) + res.append(dist_loss) + rx_done[rx_name] = avg + ops.append(v.assign(rx_done[rx_name])) + return tf.add_n(res), tf.group(*ops) + + +def make_dense(targets, noclass): + """Move a batch of targets to a dense 1-hot representation.""" + with tf.device("/cpu:0"): + shape = tf.shape(targets) + batch_size = shape[0] + indices = targets + noclass * tf.range(0, batch_size) + length = batch_size * noclass + dense = tf.sparse_to_dense(indices, length, 1.0, 0.0) + return tf.reshape(dense, [-1, noclass]) + + +def check_for_zero(sparse): + """In a sparse batch of ints, make 1.0 if it's 0 and 0.0 else.""" + with tf.device("/cpu:0"): + shape = tf.shape(sparse) + batch_size = shape[0] + sparse = tf.minimum(sparse, 1) + indices = sparse + 2 * tf.range(0, batch_size) + dense = tf.sparse_to_dense(indices, 2 * batch_size, 1.0, 0.0) + reshaped = tf.reshape(dense, [-1, 2]) + return tf.reshape(tf.slice(reshaped, [0, 0], [-1, 1]), [-1]) + + +class NeuralGPU(object): + """Neural GPU Model.""" + + def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step, + max_grad_norm, cutoff, nconvs, kw, kh, height, mode, + learning_rate, pull, pull_incr, min_length): + # Feeds for parameters and ops to update them. + self.global_step = tf.Variable(0, trainable=False) + self.cur_length = tf.Variable(min_length, trainable=False) + self.cur_length_incr_op = self.cur_length.assign_add(1) + self.lr = tf.Variable(float(learning_rate), trainable=False) + self.lr_decay_op = self.lr.assign(self.lr * 0.98) + self.pull = tf.Variable(float(pull), trainable=False) + self.pull_incr_op = self.pull.assign(self.pull * pull_incr) + self.do_training = tf.placeholder(tf.float32, name="do_training") + self.noise_param = tf.placeholder(tf.float32, name="noise_param") + + # Feeds for inputs, targets, outputs, losses, etc. + self.input = [] + self.target = [] + for l in xrange(data_utils.forward_max + 1): + self.input.append(tf.placeholder(tf.int32, name="inp{0}".format(l))) + self.target.append(tf.placeholder(tf.int32, name="tgt{0}".format(l))) + self.outputs = [] + self.losses = [] + self.grad_norms = [] + self.updates = [] + + # Computation. + inp0_shape = tf.shape(self.input[0]) + batch_size = inp0_shape[0] + with tf.device("/cpu:0"): + emb_weights = tf.get_variable( + "embedding", [niclass, vec_size], + initializer=tf.random_uniform_initializer(-1.7, 1.7)) + e0 = tf.scatter_update(emb_weights, + tf.constant(0, dtype=tf.int32, shape=[1]), + tf.zeros([1, vec_size])) + + adam = tf.train.AdamOptimizer(0.01*self.lr, epsilon=1e-5) + + # Main graph creation loop, for every bin in data_utils. + self.steps = [] + for length in sorted(list(set(data_utils.bins + [data_utils.forward_max]))): + data_utils.print_out("Creating model for bin of length %d." % length) + start_time = time.time() + if length > data_utils.bins[0]: + tf.get_variable_scope().reuse_variables() + + # Embed inputs and calculate mask. + with tf.device("/cpu:0"): + with tf.control_dependencies([e0]): + embedded = [tf.nn.embedding_lookup(emb_weights, self.input[l]) + for l in xrange(length)] + # Mask to 0-out padding space in each step. + imask = [check_for_zero(self.input[l]) for l in xrange(length)] + omask = [check_for_zero(self.target[l]) for l in xrange(length)] + mask = [1.0 - (imask[i] * omask[i]) for i in xrange(length)] + mask = [tf.reshape(m, [-1, 1]) for m in mask] + # Use a shifted mask for step scaling and concatenated for weights. + shifted_mask = mask + [tf.zeros_like(mask[0])] + scales = [shifted_mask[i] * (1.0 - shifted_mask[i+1]) + for i in xrange(length)] + scales = [tf.reshape(s, [-1, 1, 1, 1]) for s in scales] + mask = tf.concat(1, mask[0:length]) # batch x length + weights = mask + # Add a height dimension to mask to use later for masking. + mask = tf.reshape(mask, [-1, length, 1, 1]) + mask = tf.concat(2, [mask for _ in xrange(height)]) + tf.zeros( + tf.pack([batch_size, length, height, nmaps]), dtype=tf.float32) + + # Start is a length-list of batch-by-nmaps tensors, reshape and concat. + start = [tf.tanh(embedded[l]) for l in xrange(length)] + start = [tf.reshape(start[l], [-1, 1, nmaps]) for l in xrange(length)] + start = tf.reshape(tf.concat(1, start), [-1, length, 1, nmaps]) + + # First image comes from start by applying one convolution and adding 0s. + first = conv_linear(start, 1, 1, vec_size, nmaps, True, 0.0, "input") + first = [first] + [tf.zeros(tf.pack([batch_size, length, 1, nmaps]), + dtype=tf.float32) for _ in xrange(height - 1)] + first = tf.concat(2, first) + + # Computation steps. + step = [tf.nn.dropout(first, 1.0 - self.do_training * dropout) * mask] + outputs = [] + for it in xrange(length): + with tf.variable_scope("RX%d" % (it % rx_step)) as vs: + if it >= rx_step: + vs.reuse_variables() + cur = step[it] + # Do nconvs-many CGRU steps. + for layer in xrange(nconvs): + cur = conv_gru([], cur, kw, kh, nmaps, cutoff, "cgru_%d" % layer) + cur = tf.nn.dropout(cur, 1.0 - self.do_training * dropout) + step.append(cur * mask) + outputs.append(tf.slice(step[-1], [0, 0, 0, 0], [-1, -1, 1, -1])) + + self.steps.append([tf.reshape(s, [-1, length, height * nmaps]) + for s in step]) + # Output is the n-th step output; n = current length, as in scales. + output = tf.add_n([outputs[i] * scales[i] for i in xrange(length)]) + # Final convolution to get logits, list outputs. + output = conv_linear(output, 1, 1, nmaps, noclass, True, 0.0, "output") + output = tf.reshape(output, [-1, length, noclass]) + self.outputs.append([tf.reshape(o, [-1, noclass]) + for o in list(tf.split(1, length, output))]) + + # Calculate cross-entropy loss and normalize it. + targets = tf.concat(1, [make_dense(self.target[l], noclass) + for l in xrange(length)]) + targets = tf.reshape(targets, [-1, noclass]) + xent = tf.reshape(tf.nn.softmax_cross_entropy_with_logits( + tf.reshape(output, [-1, noclass]), targets), [-1, length]) + perp_loss = tf.reduce_sum(xent * weights) + perp_loss /= tf.cast(batch_size, dtype=tf.float32) + perp_loss /= length + + # Final loss: cross-entropy + shared parameter relaxation part. + relax_dist, self.avg_op = relaxed_distance(rx_step) + total_loss = perp_loss + relax_dist * self.pull + self.losses.append(perp_loss) + + # Gradients and Adam update operation. + if length == data_utils.bins[0] or (mode == 0 and + length < data_utils.bins[-1] + 1): + data_utils.print_out("Creating backward for bin of length %d." % length) + params = tf.trainable_variables() + grads = tf.gradients(total_loss, params) + grads, norm = tf.clip_by_global_norm(grads, max_grad_norm) + self.grad_norms.append(norm) + for grad in grads: + if isinstance(grad, tf.Tensor): + grad += tf.truncated_normal(tf.shape(grad)) * self.noise_param + update = adam.apply_gradients(zip(grads, params), + global_step=self.global_step) + self.updates.append(update) + data_utils.print_out("Created model for bin of length %d in" + " %.2f s." % (length, time.time() - start_time)) + self.saver = tf.train.Saver(tf.all_variables()) + + def step(self, sess, inp, target, do_backward, noise_param=None): + """Run a step of the network.""" + assert len(inp) == len(target) + length = len(target) + feed_in = {} + feed_in[self.noise_param.name] = noise_param if noise_param else 0.0 + feed_in[self.do_training.name] = 1.0 if do_backward else 0.0 + feed_out = [] + index = len(data_utils.bins) + if length < data_utils.bins[-1] + 1: + index = data_utils.bins.index(length) + if do_backward: + feed_out.append(self.updates[index]) + feed_out.append(self.grad_norms[index]) + feed_out.append(self.losses[index]) + for l in xrange(length): + feed_in[self.input[l].name] = inp[l] + for l in xrange(length): + feed_in[self.target[l].name] = target[l] + feed_out.append(self.outputs[index][l]) + for l in xrange(length+1): + feed_out.append(self.steps[index][l]) + res = sess.run(feed_out, feed_in) + offset = 0 + norm = None + if do_backward: + offset = 2 + norm = res[1] + outputs = res[offset + 1:offset + 1 + length] + steps = res[offset + 1 + length:] + return res[offset], outputs, norm, steps diff --git a/neural_gpu/neural_gpu_trainer.py b/neural_gpu/neural_gpu_trainer.py new file mode 100644 index 00000000000..c233fabe6b6 --- /dev/null +++ b/neural_gpu/neural_gpu_trainer.py @@ -0,0 +1,376 @@ +"""Neural GPU for Learning Algorithms.""" + +import math +import os +import random +import sys +import time + +import google3 + +import matplotlib.animation as anim +import matplotlib.pyplot as plt +import numpy as np +import tensorflow as tf + +from google3.third_party.tensorflow.python.platform import gfile +import google3.experimental.users.lukaszkaiser.neural_gpu.data_utils as data +import google3.experimental.users.lukaszkaiser.neural_gpu.neural_gpu as ngpu + +tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.") +tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.") +tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.") +tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.") +tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.") +tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.") +tf.app.flags.DEFINE_float("dropout", 0.2, "Dropout that much.") +tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.") +tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.") +tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.") +tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100, "Steps per epoch.") +tf.app.flags.DEFINE_integer("nmaps", 24, "Number of floats in each cell.") +tf.app.flags.DEFINE_integer("niclass", 14, "Number of classes (0 is padding).") +tf.app.flags.DEFINE_integer("noclass", 14, "Number of classes (0 is padding).") +tf.app.flags.DEFINE_integer("train_data_size", 5000, "Training examples/len.") +tf.app.flags.DEFINE_integer("max_length", 41, "Maximum length.") +tf.app.flags.DEFINE_integer("rx_step", 6, "Relax that many recursive steps.") +tf.app.flags.DEFINE_integer("random_seed", 125459, "Random seed.") +tf.app.flags.DEFINE_integer("nconvs", 2, "How many convolutions / 1 step.") +tf.app.flags.DEFINE_integer("kw", 3, "Kernel width.") +tf.app.flags.DEFINE_integer("kh", 3, "Kernel height.") +tf.app.flags.DEFINE_integer("height", 4, "Height.") +tf.app.flags.DEFINE_integer("forward_max", 401, "Maximum forward length.") +tf.app.flags.DEFINE_integer("jobid", -1, "Task id when running on borg.") +tf.app.flags.DEFINE_integer("nprint", 0, "How many test examples to print out.") +tf.app.flags.DEFINE_integer("mode", 0, "Mode: 0-train other-decode.") +tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?") +tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.") + +FLAGS = tf.app.flags.FLAGS + + +def initialize(sess): + """Initialize data and model.""" + if FLAGS.jobid >= 0: + data.log_filename = os.path.join(FLAGS.train_dir, "log%d" % FLAGS.jobid) + data.print_out("NN ", newline=False) + + # Set random seed. + seed = FLAGS.random_seed + max(0, FLAGS.jobid) + tf.set_random_seed(seed) + random.seed(seed) + np.random.seed(seed) + + # Check data sizes. + data.forward_max = max(FLAGS.forward_max, data.bins[-1]) + assert data.bins + min_length = 3 + max_length = min(FLAGS.max_length, data.bins[-1]) + assert max_length + 1 > min_length + while len(data.bins) > 1 and data.bins[-2] > max_length + 12: + data.bins = data.bins[:-1] + assert data.bins[0] > FLAGS.rx_step + nclass = min(FLAGS.niclass, FLAGS.noclass) + data_size = FLAGS.train_data_size if FLAGS.mode == 0 else 1000 + + # Initialize data for each task. + tasks = FLAGS.task.split("-") + for t in tasks: + for l in xrange(max_length + 11): + data.init_data(t, l, data_size, nclass) + data.init_data(t, data.bins[-2], data_size, nclass) + data.init_data(t, data.bins[-1], data_size, nclass) + end_size = 4 * 1024 if FLAGS.mode > 0 else 1024 + data.init_data(t, data.forward_max, end_size, nclass) + + # Print out parameters. + curriculum = 0.12 + fin = ("cv %d kw %d h %d kh %d rxr %d bs %d ns %.2f t %s" + % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step, + FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task)) + fin = "data %d %s" % (FLAGS.train_data_size, fin) + tag = ("df %.2f p %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" % + (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight, + curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, fin)) + data.print_out(tag) + + # Create checkpoint directory if it does not exist. + checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s" + % ("" if FLAGS.jobid < 0 else str(FLAGS.jobid))) + if not gfile.IsDirectory(checkpoint_dir): + data.print_out("Creating checkpoint directory %s." % checkpoint_dir) + gfile.MkDir(checkpoint_dir) + + # Create model and initialize it. + tf.get_variable_scope().set_initializer( + tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight)) + model = ngpu.NeuralGPU( + FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout, + FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs, + FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr, + FLAGS.pull, FLAGS.pull_incr, min_length + 3) + data.print_out("Created model.") + sess.run(tf.initialize_all_variables()) + data.print_out("Initialized variables.") + + # Load model from parameters if a checkpoint exists. + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) + if ckpt and gfile.Exists(ckpt.model_checkpoint_path): + data.print_out("Reading model parameters from %s" + % ckpt.model_checkpoint_path) + model.saver.restore(sess, ckpt.model_checkpoint_path) + + # Return the model and needed variables. + return (model, min_length, max_length, checkpoint_dir, curriculum) + + +def single_test(l, model, sess, task, nprint, batch_size, print_out=True, + offset=None): + """Test model on test data of length l using the given session.""" + inpt, target = data.get_batch(l, batch_size, False, task, offset) + _, res, _, steps = model.step(sess, inpt, target, False) + errors, total, seq = data.accuracy(inpt, res, target, batch_size, nprint) + seq = float(seq) / batch_size + if total > 0: + errors = float(errors) / total + if print_out: + data.print_out(" %s len %d errors %.2f sequence-errors %.2f" + % (task, l, 100*errors, 100*seq)) + return errors, seq, (steps, inpt, [np.argmax(o, axis=1) for o in res]) + + +def multi_test(l, model, sess, task, nprint, batch_size, offset=None): + """Run multiple tests at lower batch size to save memory.""" + errors = 0.0 + seq = 0.0 + to_print = nprint + low_batch = FLAGS.low_batch_size + low_batch = min(low_batch, batch_size) + for mstep in xrange(batch_size / low_batch): + cur_offset = None if offset is None else offset + mstep * low_batch + err, sq, _ = single_test(l, model, sess, task, to_print, low_batch, False, + cur_offset) + to_print = max(0, to_print - low_batch) + errors += err + seq += sq + if FLAGS.mode > 0: + cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch) + cur_seq = float(low_batch * seq) / ((mstep+1) * low_batch) + data.print_out(" %s multitest current errors %.2f sequence-errors %.2f" + % (task, 100*cur_errors, 100*cur_seq)) + errors = float(low_batch) * float(errors) / batch_size + seq = float(low_batch) * float(seq) / batch_size + data.print_out(" %s len %d errors %.2f sequence-errors %.2f" + % (task, l, 100*errors, 100*seq)) + return errors, seq + + +def train(): + """Main training function.""" + batch_size = FLAGS.batch_size + tasks = FLAGS.task.split("-") + with tf.Session() as sess: + model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess) + max_cur_length = min(min_length + 3, max_length) + prev_acc_perp = [1000000 for _ in xrange(3)] + prev_sq = 1.0 + + while True: + global_step, pull, max_cur_length, learning_rate = sess.run( + [model.global_step, model.pull, model.cur_length, model.lr]) + ep = global_step / FLAGS.steps_per_checkpoint + acc_loss, acc_total, acc_errors, acc_seq = 0.0, 0, 0, 0 + acc_grad_norm, step_count, step_time = 0.0, 0, 0.0 + for _ in xrange(FLAGS.steps_per_checkpoint): + global_step += 1 + task = random.choice(tasks) + l1 = np.random.randint(max_cur_length - min_length + 1) + min_length + l = l1 + if np.random.randint(10) > 3: # Prefer longer stuff 60% of time. + l = np.random.randint(max_cur_length - min_length+1) + min_length + l = max(l, l1) + if np.random.randint(4) < 1: # Mixed learning: once in a while big. + l = np.random.randint(max_length - min_length + 1) + min_length + l = max(l, l1) + start_time = time.time() + inp, target = data.get_batch(l, batch_size, True, task) + stepp = math.pow(global_step, -0.55) + noise_param = math.sqrt(stepp * 20 * prev_sq) * FLAGS.grad_noise_scale + loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param) + step_time += time.time() - start_time + acc_grad_norm += float(gnorm) + if l < max_cur_length + 1: + step_count += 1 + acc_loss += loss + errors, total, seq = data.accuracy(inp, res, target, + batch_size, 0) + acc_total += total + acc_errors += errors + acc_seq += seq + acc_loss /= step_count + step_time /= FLAGS.steps_per_checkpoint + acc_seq = float(acc_seq) / (step_count * batch_size) + prev_sq = acc_seq + acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0 + msg1 = "ep %d st %.2f lr %.8f" % (ep, step_time, learning_rate) + msg2 = "pl %.3f cme %.3f" % (pull, curriculum) + msg = ("%s %s gn %.8f" + % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint)) + data.print_out("%s len %d ppx %.8f errs %.2f sq %.2f" % + (msg, max_cur_length, data.safe_exp(acc_loss), + 100*acc_errors, 100*acc_seq)) + if curriculum > acc_seq: + prev_acc_perp.append(1000000) + do_incr = True + while do_incr and max_cur_length < max_length: + sess.run(model.cur_length_incr_op) + for t in tasks: + if data.train_set[t]: do_incr = False + if pull < 1: + sess.run(model.pull_incr_op) + else: + data.print_out(" Averaging parameters.") + sess.run([model.avg_op, model.lr_decay_op]) + else: + acc_perp = data.safe_exp(acc_loss) + if acc_perp > max(prev_acc_perp[-3:]): + sess.run(model.lr_decay_op) + prev_acc_perp.append(acc_perp) + checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt") + model.saver.save(sess, checkpoint_path, + global_step=model.global_step) + # Run evaluation. + should_exit = True + bound = data.bins[-1] + 1 + for t in tasks: + l = min_length + while l < max_length + 12 and l < bound: + _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size) + l += 1 + while l < bound + 1 and not data.test_set[t][l]: + l += 1 + if sq < 0.5: + _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint, + batch_size * 4) + if sq > 0.001: should_exit = False + if should_exit: + if data.forward_max > 4000 and len(tasks) == 1: + multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint, + batch_size * 16, 0) + + +def animate(l, test_data, anim_size): + """Create animation for the given data (hacky matplotlib use).""" + xf = 12 + fps = 2 + fig = plt.figure(figsize=(16, 9), facecolor="white") + ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2) + ax.set_xticks([i * 24-0.5 for i in xrange(4)]) + ax.set_xticklabels([]) + ax.set_yticks([i - 0.5 for i in xrange(l+1)]) + ax.grid(which="major", axis="both", linestyle="-", color="black") + text_fields = [] + text_size = 24*32/l + for y in xrange(l): + text_fields.append(ax.text( + 11.25, y + 0.15, "", color="g", ha="center", va="center", + bbox={"facecolor": "b", "alpha": 0.01, "pad": 24 * text_size}, + size=text_size - (4 * 32 / l), animated=True)) + im = ax.imshow(np.zeros_like(test_data[0][0][0]), vmin=-1.0, + vmax=1.0, cmap="gray", aspect="auto", origin="upper", + interpolation="none", animated=True) + im.set_zorder(1) + def to_symbol(i): + if i == 0: return "" + if i == 11: return "+" + if i == 12: return "*" + return str(i-1) + def animation_update(frame_no, test_data, xf, im, text_fields): + """Update an animation frame.""" + steps, inpt, out_raw = test_data + length = len(steps) + batch = frame_no / (fps * (l+4*xf)) + index = int((frame_no % (fps * (l+4*xf))) / fps) + # Cut output after first padding. + out = [out_raw[i][batch] for i in xrange(len(text_fields))] + if 0 in out: + i = out.index(0) + out = out[0:i] + [0 for _ in xrange(len(out) - i)] + # Show the state after the first frames. + if index >= 2*xf: + im.set_array(steps[min(length - 1, index - 2*xf)][batch]) + for i, t in enumerate(text_fields): + if index - 2*xf < length: + t.set_text("") + else: + t.set_text(to_symbol(out[i])) + else: + for i, t in enumerate(text_fields): + t.set_text(to_symbol(inpt[i][batch]) if index < xf else "") + if index < xf: + im.set_array(np.zeros_like(steps[0][0])) + else: + im.set_array(steps[0][batch]) + return im, + animation = anim.FuncAnimation( + fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps, + interval=500/fps, fargs=(test_data, xf, im, text_fields)) + animation.save("/tmp/neural_gpu.mp4", writer="mencoder", fps=4*fps, dpi=3*80) + + +def evaluate(): + """Evaluate an existing model.""" + batch_size = FLAGS.batch_size + tasks = FLAGS.task.split("-") + with tf.Session() as sess: + model, min_length, max_length, _, _ = initialize(sess) + bound = data.bins[-1] + 1 + for t in tasks: + l = min_length + while l < max_length + 12 and l < bound: + _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size) + l += 1 + while l < bound + 1 and not data.test_set[t][l]: + l += 1 + # Animate. + anim_size = 2 + _, _, test_data = single_test(l, model, sess, t, 0, anim_size) + animate(l, test_data, anim_size) + # More tests. + _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint, + batch_size * 4) + if sq < 0.01: # More tests. + if data.forward_max > 4000 and len(tasks) == 1: + multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint, + batch_size * 64, 0) + + +def interactive(): + """Interactively probe an existing model.""" + with tf.Session() as sess: + model, _, _, _, _ = initialize(sess) + sys.stdout.write("> ") + sys.stdout.flush() + inpt = sys.stdin.readline() + while inpt: + ids = [int(c) for c in inpt.strip()] + inpt, target = data.get_batch(len(ids), 1, False, "", + preset=(ids, [0 for _ in ids])) + _, res, _, _ = model.step(sess, inpt, target, False) + res = [np.argmax(o, axis=1) for o in res] + print " ".join([str(output[0]) for output in res]) + sys.stdout.write("> ") + sys.stdout.flush() + inpt = sys.stdin.readline() + + +def main(_): + if FLAGS.mode == 0: + train() + elif FLAGS.mode == 1: + evaluate() + else: + interactive() + +if __name__ == "__main__": + tf.app.run()