From 32f023fe8c0a0327f8f14b1c041536a7c6b1f4ec Mon Sep 17 00:00:00 2001 From: Long Jin Date: Tue, 8 Aug 2017 10:40:27 -0700 Subject: [PATCH] Add conv layer and layer tests Reviewed By: xianjiec Differential Revision: D5569206 fbshipit-source-id: ed836315f3ee4d7983da94f2633a3085fe99194d --- caffe2/python/layers/conv.py | 135 +++++++++++++++++++++++++++++++++++ caffe2/python/layers_test.py | 73 +++++++++++++++++++ 2 files changed, 208 insertions(+) create mode 100644 caffe2/python/layers/conv.py diff --git a/caffe2/python/layers/conv.py b/caffe2/python/layers/conv.py new file mode 100644 index 00000000000..bb22acf0caf --- /dev/null +++ b/caffe2/python/layers/conv.py @@ -0,0 +1,135 @@ +## @package conv +# Module caffe2.python.layers.conv +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import schema +from caffe2.python.layers.layers import ( + ModelLayer, +) +import numpy as np + + +class Conv(ModelLayer): + """ + Convolutional layer + Input: + - input_record: at least has the shape info of C (num_channels) + - output_dim: number of convolutional filters + - kernel_h, kernel_w: kernel size for h and w + - stride_h, stride_w: stride for h and w + - pad_b, pad_l, pad_r, pad_t: padding sizes, if stride == 1, + 'None' value will do auto padding + - order: either 'NHWC' or 'NCHW' + """ + + def __init__(self, model, input_record, output_dim, kernel_h, kernel_w, + stride_h, stride_w, pad_b=None, pad_l=None, pad_r=None, + pad_t=None, order='NHWC', kernel_init=None, bias_init=None, + kernel_optim=None, bias_optim=None, + name='conv', **kwargs): + + super(Conv, self).__init__(model, name, input_record, **kwargs) + assert isinstance(input_record, schema.Scalar), "Incorrect input type" + # input num_channels (C) is needed + input_dims = input_record.field_type().shape + + assert (kernel_h > 0 and isinstance(kernel_h, int)), ( + "kernel_h should be positive integer") + assert (kernel_w > 0 and isinstance(kernel_w, int)), ( + "kernel_w should be positive integer") + self.kernel_h = kernel_h + self.kernel_w = kernel_w + + assert (stride_h > 0 and isinstance(stride_h, int)), ( + "stride_h should be positive integer") + assert (stride_w > 0 and isinstance(stride_w, int)), ( + "stride_w should be positive integer") + self.stride_h = stride_h + self.stride_w = stride_w + + # output_dim calculation (http://cs231n.github.io/convolutional-networks/) + # output_dim_w = (input_dim_w - kernel_w + pad_r + pad_l) / stride_w + 1 + # so, do auto_padding requires + # pad_r, pad_l = [(input_dim_w - 1) * stride_w - input_dim_w + kernel_w] / 2 + # similair for pad_t and pad_b to auto pad kernel_h + # here we only do auto padding for stride = 1 case + if stride_h == 1: + pad_t = int((kernel_h - 1) / 2) if pad_t is None else pad_t + pad_b = int((kernel_h - 1) / 2) if pad_b is None else pad_b + else: + pad_t = 0 if pad_t is None else pad_t + pad_b = 0 if pad_b is None else pad_b + + if stride_w == 1: + pad_r = int((kernel_w - 1) / 2) if pad_r is None else pad_r + pad_l = int((kernel_w - 1) / 2) if pad_l is None else pad_l + else: + pad_r = 0 if pad_r is None else pad_r + pad_l = 0 if pad_l is None else pad_l + + assert (pad_t >= 0 and isinstance(pad_t, int)), "pad_t should be int >= 0" + assert (pad_b >= 0 and isinstance(pad_b, int)), "pad_b should be int >= 0" + assert (pad_r >= 0 and isinstance(pad_r, int)), "pad_r should be int >= 0" + assert (pad_l >= 0 and isinstance(pad_l, int)), "pad_l should be int >= 0" + self.pad_t = pad_t + self.pad_b = pad_b + self.pad_r = pad_r + self.pad_l = pad_l + + assert order in ['NHWC', 'NCHW'], "order should either 'NHWC' or 'NCHW'" + self.order = order + + if order == 'NHWC': + input_c = input_dims[-1] + kernel_shape = [output_dim, kernel_h, kernel_w, input_c] + elif order == 'NCHW': + input_c = input_dims[0] + kernel_shape = [output_dim, input_c, kernel_h, kernel_w] + assert input_c > 0, ( + "Number of input channels in conv parameters should be positive") + + kernel_init = kernel_init if kernel_init else ( + 'XavierFill', {} + ) + bias_init = bias_init if bias_init else ( + 'ConstantFill', {'value': 0.0} + ) + + self.kernel = self.create_param( + param_name='conv_kernel', + shape=kernel_shape, + initializer=kernel_init, + optimizer=kernel_optim, + ) + + self.bias = self.create_param( + param_name='conv_bias', + shape=[output_dim], + initializer=bias_init, + optimizer=bias_optim, + ) + + # the output_schema only has the num of output channels + # output_h and output_w would be inferred internally + self.output_schema = schema.Scalar( + (np.float32, (output_dim,)), + self.get_next_blob_reference('output') + ) + + def add_ops(self, net): + net.Conv( + self.input_record.field_blobs() + [self.kernel, self.bias], + self.output_schema.field_blobs(), + kernel_h=self.kernel_h, + kernel_w=self.kernel_w, + stride_h=self.stride_h, + stride_w=self.stride_w, + pad_t=self.pad_t, + pad_l=self.pad_l, + pad_b=self.pad_b, + pad_r=self.pad_r, + order=self.order + ) diff --git a/caffe2/python/layers_test.py b/caffe2/python/layers_test.py index 78854054e71..7ac1853cac2 100644 --- a/caffe2/python/layers_test.py +++ b/caffe2/python/layers_test.py @@ -1178,3 +1178,76 @@ def _semi_random_hypothesis_test(srf_output, X_full, X_random, rand_w, self._test_net(predict_net, ops_list) _semi_random_hypothesis_test(srf_output.full(), X_full, X_random, rand_w, rand_b, s) + + def testConv(self): + batch_size = 50 + H = 1 + W = 10 + C = 50 + output_dims = 32 + kernel_h = 1 + kernel_w = 3 + stride_h = 1 + stride_w = 1 + pad_t = 0 + pad_b = 0 + pad_r = None + pad_l = None + + input_record = self.new_record(schema.Scalar((np.float32, (H, W, C)))) + X = np.random.random((batch_size, H, W, C)).astype(np.float32) + schema.FeedRecord(input_record, [X]) + conv = self.model.Conv( + input_record, + output_dims, + kernel_h=kernel_h, + kernel_w=kernel_w, + stride_h=stride_h, + stride_w=stride_w, + pad_t=pad_t, + pad_b=pad_b, + pad_r=pad_r, + pad_l=pad_l, + order='NHWC' + ) + + self.assertEqual( + schema.Scalar((np.float32, (output_dims,))), + conv + ) + + self.run_train_net_forward_only() + output_record = schema.FetchRecord(conv) + # check the number of output channels is the same as input in this example + assert output_record.field_types()[0].shape == (H, W, output_dims) + assert output_record().shape == (batch_size, H, W, output_dims) + + train_init_net, train_net = self.get_training_nets() + # Init net assertions + init_ops = self.assertNetContainOps( + train_init_net, + [ + OpSpec("XavierFill", None, None), + OpSpec("ConstantFill", None, None), + ] + ) + conv_spec = OpSpec( + "Conv", + [ + input_record.field_blobs()[0], + init_ops[0].output[0], + init_ops[1].output[0], + ], + conv.field_blobs() + ) + + # Train net assertions + self.assertNetContainOps(train_net, [conv_spec]) + + # Predict net assertions + predict_net = self.get_predict_net() + self.assertNetContainOps(predict_net, [conv_spec]) + + # Eval net assertions + eval_net = self.get_eval_net() + self.assertNetContainOps(eval_net, [conv_spec])