[Caffe Frontend] introduce caffe frontend for tvm (#6206)
authorFernChen <fernchen@qq.com>
Thu, 27 Aug 2020 03:26:58 +0000 (11:26 +0800)
committerGitHub <noreply@github.com>
Thu, 27 Aug 2020 03:26:58 +0000 (11:26 +0800)
* [Caffe Frontend] introduce caffe frontend for tvm.

* [Caffe Frontend] fix bugs for generating caption in tutorial.

* [Caffe Frontend] delete statement for python2 and modify the function name.

* [Caffe Frontend] change the directory which will hold the tmp files
when testing the caffe frondend.

* [Caffe Frontend] delete tutorial about caffe frontend.

* [Caffe Frontend] delete some print statements

Co-authored-by: fernchen <zifeng.cf@alibaba-inc.com>
python/tvm/relay/frontend/__init__.py
python/tvm/relay/frontend/caffe.py [new file with mode: 0644]
tests/python/frontend/caffe/test_forward.py [new file with mode: 0644]
tests/scripts/task_python_frontend_cpu.sh

index aba9eea..7154f5a 100644 (file)
@@ -33,3 +33,4 @@ from .caffe2 import from_caffe2
 from .tensorflow import from_tensorflow
 from .darknet import from_darknet
 from .pytorch import from_pytorch
+from .caffe import from_caffe
diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py
new file mode 100644 (file)
index 0000000..b7bcbde
--- /dev/null
@@ -0,0 +1,848 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
+# pylint: disable=no-else-return, no-else-continue
+"""Caffe frontend."""
+import numpy as np
+import tvm
+from tvm.ir import IRModule
+from .. import analysis
+from .. import expr as _expr
+from .. import function as _function
+from .. import op as _op
+from ... import nd as _nd
+from .common import ExprTable
+from .common import infer_shape as _infer_shape
+
+__all__ = ['from_caffe']
+
+
+class OperatorConverter(object):
+    """ Operator Converted for converting Caffe ops to Relay ops """
+    def __init__(self, init_layer_dict, predict_layer, exp_tab):
+        self.init_layer_dict = init_layer_dict
+        self.predict_layer = predict_layer
+        self.exp_tab = exp_tab
+        self.new_bn = {}
+        self.changed_layers = None
+
+        self.convert_map = {
+            'BatchNorm': self.convert_batch_norm,
+            'Concat': self.convert_concat,
+            'Convolution': self.convert_conv,
+            'Crop': self.convert_crop,
+            'Deconvolution': self.convert_deconv,
+            'Dropout': self.convert_dropout,
+            'Eltwise': self.convert_eltwise,
+            'Flatten': self.convert_flatten,
+            'InnerProduct': self.convert_innerproduct,
+            'Input': None,
+            'LRN': self.convert_lrn,
+            'Pooling': self.convert_pooling,
+            'PReLU': self.convert_prelu,
+            'ReLU': self.convert_relu,
+            'Reshape': self.convert_reshape,
+            'Scale': self.convert_scale,
+            'Sigmoid': self.convert_sigmoid,
+            'Slice': self.convert_slice,
+            'Softmax': self.convert_softmax,
+            'TanH': self.convert_tanh,
+        }
+
+    def convert_flatten(self, op):
+        """ Convert Flatten layer """
+        inputs = op.bottom
+        in_expr = self.exp_tab.get_expr(inputs[0])
+
+        flatten_params = op.flatten_param.axis
+        assert flatten_params == 1, "flatten axis should be 1"
+        out = _op.nn.batch_flatten(in_expr)
+
+        return out
+
+    def convert_eltwise(self, op):
+        """ Convert Eltwise layer """
+        inputs = op.bottom
+        assert len(inputs) == 2, "input tensors length should be 2"
+
+        lhs_expr = self.exp_tab.get_expr(inputs[0])
+        rhs_expr = self.exp_tab.get_expr(inputs[1])
+
+        lhs_shape = _infer_shape(lhs_expr)
+        rhs_shape = _infer_shape(rhs_expr)
+
+        assert lhs_shape == rhs_shape, "input tensors shape should be equal"
+
+        eltwise_params = op.eltwise_param
+        eltwise_type_dict = ['PROD', 'SUM', 'MAX']
+        eltwise_type = eltwise_params.operation
+        coeff = list(eltwise_params.coeff)
+
+        if eltwise_type_dict[eltwise_type] == 'PROD':
+            out = _op.multiply(lhs_expr, rhs_expr)
+        elif eltwise_type_dict[eltwise_type] == 'SUM':
+            if coeff:
+                left_coeff_expr = self.exp_tab.new_const(
+                    np.asarray(coeff[0], np.float32))
+                right_coeff_expr = self.exp_tab.new_const(
+                    np.asarray(coeff[1], np.float32))
+                lhs_expr_scale = _op.multiply(lhs_expr, left_coeff_expr)
+                rhs_expr_scale = _op.multiply(rhs_expr, right_coeff_expr)
+                out = _op.add(lhs_expr_scale, rhs_expr_scale)
+            else:
+                out = _op.add(lhs_expr, rhs_expr)
+        elif eltwise_type_dict[eltwise_type] == 'MAX':
+            out = _op.maximum(lhs_expr, rhs_expr)
+        else:
+            raise tvm.error.OpNotImplemented(
+                "eltwise_type {} is not supported for frontend Caffe.".format(
+                    eltwise_type))
+
+        return out
+
+    def _parse_conv_params(self, op):
+        """ Parse the parameters of Convolution and Deconvolution layer """
+        nonzone = lambda val, pos, dflt: val[pos] if pos < len(val) else dflt
+
+        conv_params = op.convolution_param
+
+        params = dict()
+        # parse kernel size
+        if conv_params.kernel_h > 0 or conv_params.kernel_w > 0:
+            params['kernel_size'] = (conv_params.kernel_h,
+                                     conv_params.kernel_w)
+        else:
+            ksize_h = nonzone(conv_params.kernel_size, 0, 1)
+            ksize_w = nonzone(conv_params.kernel_size, 1, ksize_h)
+            params['kernel_size'] = (ksize_h, ksize_w)
+
+        # parse padding size
+        if conv_params.pad_h > 0 or conv_params.pad_w > 0:
+            params['padding'] = (conv_params.pad_h, conv_params.pad_w)
+        else:
+            pad_h = nonzone(conv_params.pad, 0, 0)
+            pad_w = nonzone(conv_params.pad, 1, pad_h)
+            params['padding'] = (pad_h, pad_w)
+
+        # parse stride size
+        if conv_params.stride_h > 0 or conv_params.stride_w > 0:
+            params['strides'] = (conv_params.stride_h, conv_params.stride_w)
+        else:
+            stride_h = nonzone(conv_params.stride, 0, 1)
+            stride_w = nonzone(conv_params.stride, 1, stride_h)
+            params['strides'] = (stride_h, stride_w)
+
+        # parse dilation size
+        if hasattr(conv_params, 'dilation') and len(conv_params.dilation) > 0:
+            dilation = ' '.join(str(d) for d in conv_params.dilation)
+            dilation = tuple(map(int, dilation.split(' ')))
+            params['dilation'] = dilation
+            if len(dilation) == 1:
+                params['dilation'] = (dilation[0], dilation[0])
+
+        params['kernel_layout'] = 'OIHW'
+        params['data_layout'] = 'NCHW'
+        params['groups'] = conv_params.group
+        params['channels'] = conv_params.num_output
+        return params
+
+    def convert_batch_norm(self, op):
+        """ Convert BatchNorm layer """
+        inputs = op.bottom
+        in_expr = self.exp_tab.get_expr(inputs[0])
+        n, c, h, w = _infer_shape(in_expr)
+
+        if op.name in self.new_bn:
+            mean, var, eps, gamma, beta = self.new_bn[op.name]
+            mean_expr = self.exp_tab.new_const(mean, dtype='float32')
+            var_expr = self.exp_tab.new_const(var, dtype='float32')
+            gamma_expr = self.exp_tab.new_const(gamma, dtype='float32')
+            beta_expr = self.exp_tab.new_const(beta, dtype='float32')
+            out = _op.nn.batch_norm(in_expr,
+                                    gamma_expr,
+                                    beta_expr,
+                                    mean_expr,
+                                    var_expr,
+                                    epsilon=eps,
+                                    scale=True)
+
+        else:
+            weight_bias_blobs = self.init_layer_dict[op.name].blobs
+            mean = np.asarray(weight_bias_blobs[0].data, np.float32)
+            var = np.asarray(weight_bias_blobs[1].data, np.float32)
+            if len(weight_bias_blobs) == 2:
+                mean = np.repeat(mean, h * w).reshape((c, h, w))
+                mean = np.expand_dims(mean, 0).repeat(n, axis=0)
+                mean_expr = self.exp_tab.new_const(mean, dtype='float32')
+
+                var = np.repeat(var, h * w).reshape((c, h, w))
+                var = np.expand_dims(var, 0).repeat(n, axis=0)
+                var_expr = self.exp_tab.new_const(var, dtype='float32')
+
+                tmp_out = _op.multiply(in_expr, mean_expr)
+                out = _op.add(tmp_out, var_expr)
+
+                return out
+            else:
+                scale = np.asarray(weight_bias_blobs[2].data, np.float32)
+                if scale:
+                    scale = 1 / scale
+            mean_expr = self.exp_tab.new_const(mean * scale, dtype='float32')
+            var_expr = self.exp_tab.new_const(var * scale, dtype='float32')
+
+            #caffe bn layer not support scale
+            gamma_expr = self.exp_tab.new_const(np.ones(mean.shape,
+                                                        dtype=np.float32),
+                                                dtype='float32')
+            beta_expr = self.exp_tab.new_const(np.zeros(mean.shape,
+                                                        dtype=np.float32),
+                                               dtype='float32')
+
+            bn_params = op.batch_norm_param.eps
+            out = _op.nn.batch_norm(in_expr,
+                                    gamma_expr,
+                                    beta_expr,
+                                    mean_expr,
+                                    var_expr,
+                                    epsilon=bn_params,
+                                    scale=False)
+
+        return out[0]
+
+    def convert_scale(self, op):
+        """ Convert Scale layer """
+        inputs = op.bottom
+        in_expr = self.exp_tab.get_expr(inputs[0])
+        weight_bias_blobs = self.init_layer_dict[op.name].blobs
+
+        params = dict()
+        params['bias'] = op.scale_param.bias_term
+        params['axis'] = op.scale_param.axis
+
+        gamma = np.asarray(weight_bias_blobs[0].data, np.float32)
+        gamma_expr = self.exp_tab.new_const(gamma, dtype='float32')
+        if params['bias']:
+            beta = np.asarray(weight_bias_blobs[1].data, np.float32)
+            beta_expr = self.exp_tab.new_const(beta, dtype='float32')
+        else:
+            beta_expr = self.exp_tab.new_const(np.zeros(gamma.shape,
+                                                        dtype=np.float32),
+                                               dtype='float32')
+
+        _, c, _, _ = _infer_shape(in_expr)
+        gamma_expr = _op.reshape(gamma_expr, newshape=(1, c, 1, 1))
+        beta_expr = _op.reshape(beta_expr, newshape=(1, c, 1, 1))
+        out = _op.multiply(in_expr, gamma_expr)
+        out = _op.add(out, beta_expr)
+
+        return out
+
+    def convert_concat(self, op):
+        """ Convert Concat layer """
+        inputs = op.bottom
+        in_expr = (self.exp_tab.get_expr(inputs[i])
+                   for i in range(len(inputs)))
+
+        c_params = dict()
+        c_params['axis'] = op.concat_param.axis
+        out = _op.concatenate(in_expr, axis=c_params['axis'])
+
+        return out
+
+    def convert_reshape(self, op):
+        """ Convert Reshape layer """
+        inputs = op.bottom
+        input_name = inputs[0]
+
+        reshape_param = op.reshape_param
+        dims = list(reshape_param.shape.dim)
+
+        in_expr = self.exp_tab.get_expr(input_name)
+        input_shape = list(_infer_shape(in_expr))
+
+        start_axis = int(reshape_param.axis)
+        if start_axis < 0:
+            start_axis = len(input_shape) + start_axis + 1
+        num_axes = int(reshape_param.num_axes)
+        end_axis = len(input_shape)
+        if num_axes != -1:
+            end_axis = start_axis + num_axes
+
+        left_shape = input_shape[:start_axis]
+        if end_axis == len(input_shape):
+            center_shape = input_shape[start_axis:]
+            right_shape = []
+        else:
+            center_shape = input_shape[start_axis:end_axis]
+            right_shape = input_shape[end_axis:]
+
+        for idx, dim in enumerate(dims):
+            if dim == 0:
+                dims[idx] = center_shape[idx]
+
+        tmp = np.random.rand(*center_shape)
+        tmp = np.reshape(tmp, dims)
+        center_shape = list(tmp.shape)
+
+        newshape = left_shape + center_shape + right_shape
+
+        out = _op.reshape(in_expr, newshape=newshape)
+        return out
+
+    def convert_softmax(self, op):
+        """ Convert Softmax layer """
+        inputs = op.bottom
+        assert len(inputs) == 1, "input tensors length should be 1"
+
+        input_name = inputs[0]
+        in_expr = self.exp_tab.get_expr(input_name)
+
+        softmax_param = op.softmax_param
+        parmas = {'axis': softmax_param.axis}
+
+        out = _op.nn.softmax(in_expr, **parmas)
+
+        return out
+
+    def convert_conv(self, op):
+        """ Convert Convolution layer """
+        params = self._parse_conv_params(op)
+        weight_bias_blobs = self.init_layer_dict[op.name].blobs
+        conv_params = op.convolution_param
+        inputs = op.bottom
+        # process weight and bias blobs
+        weight, bias = None, None
+        if len(weight_bias_blobs) > 1:
+            weight = weight_bias_blobs[0]
+            bias = weight_bias_blobs[1]
+        else:
+            weight = weight_bias_blobs[0]
+        if weight:
+            kh, kw = params['kernel_size']
+            weight_shape = [conv_params.num_output, -1, kh, kw]
+            weight_value = np.asarray(weight.data, np.float32)
+            weight_value = np.reshape(weight_value, weight_shape)
+        else:
+            raise Exception('No weight value of layer {} in caffemodel'.format(
+                op.name))
+
+        weight_expr = self.exp_tab.new_const(weight_value, dtype='float32')
+        in_expr = self.exp_tab.get_expr(inputs[0])
+        out = _op.nn.conv2d(data=in_expr, weight=weight_expr, **params)
+        if bias:
+            bias_value = np.asarray(bias.data, np.float32)
+            bias_expr = self.exp_tab.new_const(bias_value, dtype='float32')
+            out = _op.nn.bias_add(out, bias_expr)
+        return out
+
+    def convert_pooling(self, op):
+        """ Convert Pooling layer """
+        inputs = op.bottom
+        input_name = inputs[0]
+
+        pool_params = op.pooling_param
+        pool_type_dict = ['MAX', 'AVE', 'STOCHASTIC']
+
+        params = dict()
+        # parse pool type: 0: MAX, 1: AVE, 2: STOCHASTIC
+        pool_type = pool_params.pool
+        # parse kernel size
+        if pool_params.kernel_h > 0 or pool_params.kernel_w > 0:
+            params['pool_size'] = (pool_params.kernel_h, pool_params.kernel_w)
+        else:
+            params['pool_size'] = (pool_params.kernel_size,
+                                   pool_params.kernel_size)
+
+        # parse padding size
+        if pool_params.pad_h > 0 or pool_params.pad_w > 0:
+            params['padding'] = (pool_params.pad_h, pool_params.pad_w)
+        else:
+            params['padding'] = (pool_params.pad, pool_params.pad)
+
+        # parse stride size
+        if pool_params.stride_h > 0 or pool_params.stride_w > 0:
+            params['strides'] = (pool_params.stride_h, pool_params.stride_w)
+        else:
+            params['strides'] = (pool_params.stride, pool_params.stride)
+
+        params['ceil_mode'] = True
+        if hasattr(pool_params, 'ceil_mode'):
+            params['ceil_mode'] = pool_params.ceil_mode
+
+        in_expr = self.exp_tab.get_expr(input_name)
+
+        if pool_type_dict[pool_type] == 'MAX':
+            if pool_params.global_pooling:
+                out = _op.nn.global_max_pool2d(in_expr)
+            else:
+                if len(op.top) == 1:
+                    out = _op.nn.max_pool2d(in_expr, **params)
+                elif len(op.top) == 2:
+                    out1 = _op.nn.max_pool2d_with_argmax(in_expr, **params)
+                    out2 = _op.vision.max_pool2d_location(in_expr, **params)
+                    return _expr.Tuple((out1, out2))
+
+        elif pool_type_dict[pool_type] == 'AVE':  # AVE
+            if pool_params.global_pooling:
+                out = _op.nn.global_avg_pool2d(in_expr)
+            else:
+                params['count_include_pad'] = True
+                out = _op.nn.avg_pool2d(in_expr, **params)
+        else:
+            raise tvm.error.OpNotImplemented(
+                "Operator {} is not supported for frontend Caffe.".format(
+                    pool_type_dict[pool_type] + ' pool'))
+
+        return out
+
+    def convert_lrn(self, op):
+        """ Convert LRN layer """
+        inputs = op.bottom
+        input_name = inputs[0]
+
+        params = dict()
+        lrn_params = op.lrn_param
+        params['size'] = lrn_params.local_size
+        params['bias'] = lrn_params.k
+        params['alpha'] = lrn_params.alpha
+        params['beta'] = lrn_params.beta
+
+        in_expr = self.exp_tab.get_expr(input_name)
+        out = _op.nn.lrn(in_expr, **params)
+        return out
+
+    def convert_innerproduct(self, op):
+        """ Convert InnerProduct layer """
+        inputs = op.bottom
+        weight_bias_blobs = self.init_layer_dict[op.name].blobs
+        dense_params = op.inner_product_param
+
+        params = dict()
+        params["num_output"] = dense_params.num_output
+        params["bias"] = dense_params.bias_term
+        params["axis"] = dense_params.axis
+        if params["axis"] != 1:
+            raise Exception("Only support 2D InnerProduct")
+
+        # process weight and bias blobs
+        weight, bias = None, None
+        if params["bias"]:
+            weight = weight_bias_blobs[0]
+            bias = weight_bias_blobs[1]
+        else:
+            weight = weight_bias_blobs[0]
+
+        if weight:
+            weight_value = np.asarray(weight.data, np.float32)
+            weight_value = np.reshape(weight_value, (params["num_output"], -1))
+            weight_shape = weight_value.shape
+        else:
+            raise Exception('No weight value of layer {} in caffemodel'.format(
+                op.name))
+
+        weight_expr = self.exp_tab.new_const(weight_value, dtype='float32')
+
+        in_expr = self.exp_tab.get_expr(inputs[0])
+        in_reshape = _op.reshape(data=in_expr, newshape=(-1, weight_shape[-1]))
+
+        out = _op.nn.dense(data=in_reshape, weight=weight_expr)
+
+        if bias:
+            bias_value = np.asarray(bias.data, np.float32)
+            bias_expr = self.exp_tab.new_const(bias_value, dtype='float32')
+            out = _op.nn.bias_add(out, bias_expr, axis=params["axis"])
+        return out
+
+    def convert_dropout(self, op):
+        """ Convert Dropout layer """
+        inputs = op.bottom
+        input_name = inputs[0]
+
+        params = dict()
+        dropout_params = op.dropout_param
+
+        params['rate'] = dropout_params.dropout_ratio
+
+        in_expr = self.exp_tab.get_expr(input_name)
+        out = _op.nn.dropout(in_expr, **params)
+        return out
+
+    def convert_relu(self, op):
+        """ Convert ReLU layer """
+        inputs = op.bottom
+        in_expr = self.exp_tab.get_expr(inputs[0])
+        negative_slope = op.relu_param.negative_slope
+        if negative_slope:
+            out = _op.nn.leaky_relu(in_expr, negative_slope)
+            return out
+
+        out = _op.nn.relu(in_expr)
+        return out
+
+    def convert_prelu(self, op):
+        """ Convert PReLU layer """
+        inputs = op.bottom
+        in_expr = self.exp_tab.get_expr(inputs[0])
+
+        alpha = self.init_layer_dict[op.name].blobs[0].data
+        alpha = np.asarray(alpha, np.float32)
+        alpha = self.exp_tab.new_const(alpha, dtype='float32')
+        axis = 1
+        out = _op.nn.prelu(in_expr, alpha, axis=axis)
+        return out
+
+    def convert_deconv(self, op):
+        """ Convert Deconvolution layer """
+        params = self._parse_conv_params(op)
+        weight_bias_blobs = self.init_layer_dict[op.name].blobs
+        conv_params = op.convolution_param
+        inputs = op.bottom
+
+        # process weight and bias blobs
+        weight, bias = None, None
+        if len(weight_bias_blobs) > 1:
+            weight = weight_bias_blobs[0]
+            bias = weight_bias_blobs[1]
+        else:
+            weight = weight_bias_blobs[0]
+        if weight:
+            kh, kw = params['kernel_size']
+            weight_shape = [-1, conv_params.num_output, kh, kw]
+            weight_value = np.asarray(weight.data, np.float32)
+            weight_value = np.reshape(weight_value, weight_shape)
+        else:
+            raise Exception('No weight value of layer {} in caffemodel'.format(
+                op.name))
+
+        weight_expr = self.exp_tab.new_const(weight_value, dtype='float32')
+        in_expr = self.exp_tab.get_expr(inputs[0])
+        out = _op.nn.conv2d_transpose(data=in_expr,
+                                      weight=weight_expr,
+                                      **params)
+        if bias:
+
+            bias_value = np.asarray(bias.data, np.float32)
+            bias_expr = self.exp_tab.new_const(bias_value, dtype='float32')
+            out = _op.nn.bias_add(out, bias_expr)
+        return out
+
+    def convert_slice(self, op):
+        """ Convert Slice layer """
+        inputs = op.bottom
+        in_expr = self.exp_tab.get_expr(inputs[0])
+
+        output_num = len(op.top)
+
+        slice_params = op.slice_param
+        axis = int(slice_params.axis)
+        indices_or_sections = list([int(s) for s in slice_params.slice_point])
+        if len(indices_or_sections) == 0:
+            indices_or_sections = output_num
+        else:
+            indices_or_sections = sorted(indices_or_sections)
+
+        out = _op.split(in_expr,
+                        indices_or_sections=indices_or_sections,
+                        axis=axis)
+        return out
+
+    def convert_sigmoid(self, op):
+        """ Convert Sigmoid layer """
+        inputs = op.bottom
+        in_expr = self.exp_tab.get_expr(inputs[0])
+        out = _op.sigmoid(in_expr)
+        return out
+
+    def convert_tanh(self, op):
+        """ Convert TanH layer """
+        inputs = op.bottom
+        in_expr = self.exp_tab.get_expr(inputs[0])
+        out = _op.tanh(in_expr)
+        return out
+
+    def convert_crop(self, op):
+        """ Convert Crop layer """
+        inputs = op.bottom
+        assert len(inputs) == 2, "Need two inputs of Crop layer"
+        in_expr_a = self.exp_tab.get_expr(inputs[0])
+        in_expr_b = self.exp_tab.get_expr(inputs[1])
+
+        # parse crop params
+        crop_params = op.crop_param
+        axis = int(getattr(crop_params, 'axis', 2))
+        offset = list(getattr(crop_params, 'offset', 0))
+
+        # expand offset to (offset1, offset2, ...)
+        in_a_shape = _infer_shape(in_expr_a)
+        num_to_crop = len(in_a_shape) - axis
+        if not offset:
+            offset = [0] * num_to_crop
+        if len(offset) == 1:
+            offset = offset * num_to_crop
+        elif len(offset) != num_to_crop:
+            raise Exception("No matching the number between axis and offset!")
+
+        slice_end = in_a_shape
+        slice_start = [0] * len(in_a_shape)
+        for i in range(num_to_crop):
+            slice_start[i + axis] = offset[i]
+
+        to_crop_axis = list(range(len(in_a_shape)))
+        to_crop_axis = to_crop_axis[axis:]
+
+        # secondly, crop in_expr_a by in_expr_b
+        in_expr_a_stride = _op.strided_slice(in_expr_a, slice_start, slice_end)
+        out = _op.slice_like(in_expr_a_stride, in_expr_b, axes=to_crop_axis)
+        return out
+
+
+    def check_unsupported_ops(self):
+        """Check unsupported Caffe ops in our converter."""
+        unsupported_ops_set = set()
+
+        include_layer = dict()
+        for pl in self.predict_layer:
+            if pl.type not in include_layer:
+                include_layer[pl.type] = 1
+            else:
+                include_layer[pl.type] = include_layer[pl.type] + 1
+
+        for pl in self.predict_layer:
+            op_name = pl.type
+            if op_name not in self.convert_map:
+                unsupported_ops_set.add(op_name)
+
+        if unsupported_ops_set:
+            msg = 'The following operators are not supported in frontend ' \
+                'Caffe: {}'
+            ops = str(list(unsupported_ops_set)).strip('[,]')
+            raise tvm.error.OpNotImplemented(msg.format(ops))
+
+    def fuse_op(self, layers):
+        """ Fusing the BatchNorm and Scale layer """
+        bn, scale = layers["bn"], layers["scale"]
+
+        # bn params
+        bn_weight_bias_blobs = self.init_layer_dict[bn.name].blobs
+        bn_scale = np.asarray(bn_weight_bias_blobs[2].data, np.float32)
+        if bn_scale:
+            bn_scale = 1 / bn_scale
+        bn_mean = np.asarray(bn_weight_bias_blobs[0].data,
+                             np.float32) * bn_scale
+        bn_var = np.asarray(bn_weight_bias_blobs[1].data,
+                            np.float32) * bn_scale
+        bn_eps = bn.batch_norm_param.eps
+
+        # scale params
+        scale_weight_bias_blobs = self.init_layer_dict[scale.name].blobs
+        scale_gamma = np.asarray(scale_weight_bias_blobs[0].data, np.float32)
+        scale_bias = scale.scale_param.bias_term
+        if scale_bias:
+            scale_beta = np.asarray(scale_weight_bias_blobs[1].data,
+                                    np.float32)
+        else:
+            scale_beta = np.zeros(scale_gamma.shape, dtype=np.float32)
+
+        # new params
+        self.new_bn[bn.name] = [
+            bn_mean, bn_var, bn_eps, scale_gamma, scale_beta
+        ]
+        return bn
+
+    def op_fuse(self):
+        """fuse bn and scale """
+        new_layers = []
+        temp_layers = {}
+        changed_layers = {}
+
+        for index, pl in enumerate(self.predict_layer):
+            op_type = pl.type
+            if op_type == "Input":
+                new_layers.append(pl)
+                continue
+            elif op_type == "BatchNorm":
+                if (index != len(self.predict_layer) - 1) and (
+                        self.predict_layer[index + 1].type == "Scale"):
+                    temp_layers["bn"] = pl
+                    continue
+                else:
+                    new_layers.append(pl)
+                    temp_layers.clear()
+            elif op_type == "Scale":
+                if self.predict_layer[index - 1].type == "BatchNorm":
+                    temp_layers["scale"] = pl
+                else:
+                    new_layers.append(pl)
+                    temp_layers.clear()
+            else:
+                temp_layers.clear()
+
+            if len(temp_layers) == 2:
+                layer = self.fuse_op(temp_layers)
+                new_layers.append(layer)
+                changed_layers[
+                    temp_layers["scale"].name] = temp_layers['bn'].name
+
+            for idx, plt in enumerate(pl.bottom):
+                if plt in changed_layers:
+                    pl.bottom[idx] = changed_layers[plt]
+
+            if op_type not in ['BatchNorm', 'Scale']:
+                new_layers.append(pl)
+
+        self.predict_layer = new_layers
+        self.changed_layers = changed_layers
+
+    def convert_op_to_relay(self):
+        """Convert Caffe ops to relay ops"""
+        for pl in self.predict_layer:
+            op_type = pl.type
+            if op_type == "Input":
+                continue
+            output_tensors = pl.top
+
+            ret = self.convert_map[op_type](pl)
+
+            if len(output_tensors) == 1:
+                self.exp_tab.set_expr(output_tensors[0], ret)
+            else:
+                for idx, output_tensor in enumerate(output_tensors):
+                    self.exp_tab.set_expr(output_tensor, ret[idx])
+
+
+def _rebuild_layers(predict_layer):
+    """Rebuild caffe layer. If the the caffe net include in-place layers, repalce its top
+    with its name and update the bottom of other layer that is related to it.
+    """
+    # dict of input name that will be changed to new name
+    changed_top_dict = dict()
+
+    for pl in predict_layer:
+        if pl.type == "Input":
+            continue
+        # if current layer has single input and output and input equals to output
+        # it means that the layer does "in-place"
+        if (len(pl.top) == 1 and len(pl.bottom) == 1):
+            if pl.top[0] == pl.bottom[0]:
+                # change current layer's input firstly
+                if pl.bottom[0] in changed_top_dict:
+                    pl.bottom[0] = changed_top_dict[pl.bottom[0]]
+                # update "change" dict
+                changed_top_dict[pl.top[0]] = pl.name
+                # change current layer's output to its name
+                pl.top[0] = pl.name
+            else:
+                if pl.bottom[0] in changed_top_dict:
+                    pl.bottom[0] = changed_top_dict[pl.bottom[0]]
+        # if the layer does not
+        else:
+            for index, plt in enumerate(pl.bottom):
+                if plt in changed_top_dict:
+                    pl.bottom[index] = changed_top_dict[plt]
+
+
+def _get_inputs_outputs(predict_layer):
+    """Obtain Caffe model's inputs and outpus"""
+    # model inputs / outputs
+    model_inputs = list()
+    model_outputs = list()
+
+    # The bottoms of every layer can not be as outputs
+    not_outputs = set()
+    for pl in predict_layer:
+        if pl.type == "Input":
+            assert len(
+                pl.top
+            ) == 1, "The number of Input layer's output is more than 1."
+            model_inputs.append(pl.top[0])
+        for i in pl.bottom:
+            not_outputs.add(i)
+
+    for pl in predict_layer:
+        if len(pl.bottom) > 0:
+            for t in pl.top:
+                if t not in not_outputs:
+                    model_outputs.append(t)
+    return model_inputs, model_outputs
+
+
+def from_caffe(init_net, predict_net, shape_dict, dtype_dict):
+    """Convert from caffe model into compatible relay Function.
+
+    Parameters
+    ----------
+    init_net : caffe_pb2.NetParameter
+        caffemodel
+    predict_net : caffe_pb2.NetParameter
+        caffe prototxt
+    shape_dict : dict of str to int list/tuple
+        Input shapes of the model.
+    dtype_dict : dict of str to str
+        Input types of the model.
+
+    Returns
+    -------
+    mod : tvm.relay.Module
+        The relay module for compilation.
+
+    params : dict of str to tvm.NDArray
+        The parameter dict to be used by relay
+    """
+    old_caffe = False
+    if len(predict_net.input) != 0:  # old caffe version
+        old_caffe = True
+        model_inputs = list(predict_net.input)
+
+    predict_layer = predict_net.layer
+
+    # replace layer's top with its name and update other layers'bottoms
+    _rebuild_layers(predict_layer)
+    # obtain inputs and outputs of Net
+    if old_caffe:
+        _, model_outputs = _get_inputs_outputs(predict_layer)
+    else:
+        model_inputs, model_outputs = _get_inputs_outputs(predict_layer)
+
+    exp_tab = ExprTable()
+    for in_name in model_inputs:
+        shape = shape_dict[in_name] if in_name in shape_dict else None
+        dtype = dtype_dict[in_name] if in_name in dtype_dict else "float32"
+        exp_tab.set_expr(in_name, _expr.var(in_name, shape=shape, dtype=dtype))
+    if list(init_net.layer):
+        init_layer = init_net.layer
+    else:
+        init_layer = init_net.layers
+    init_layer_dict = {il.name: il for il in init_layer}
+    # op code in model
+    op_converter = OperatorConverter(init_layer_dict, predict_layer, exp_tab)
+    op_converter.check_unsupported_ops()
+    op_converter.op_fuse()
+    op_converter.convert_op_to_relay()
+
+    # params and outputs
+    params = {k: _nd.array(np.array(v)) for k, v in exp_tab.params.items()}
+    outputs = list()
+    for n in model_outputs:
+        if n in op_converter.changed_layers:
+            n = op_converter.changed_layers[n]
+        outputs.append(exp_tab.get_expr(n))
+    outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
+    func = _function.Function(analysis.free_vars(outputs), outputs)
+    mod = IRModule.from_expr(func)
+
+    return mod, params
diff --git a/tests/python/frontend/caffe/test_forward.py b/tests/python/frontend/caffe/test_forward.py
new file mode 100644 (file)
index 0000000..8567e4b
--- /dev/null
@@ -0,0 +1,968 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-self, invalid-name, unused-argument
+"""
+Caffe testcases
+====================
+This article is a test script to test Caffe operator with Relay.
+"""
+import os
+os.environ['GLOG_minloglevel'] = '2'
+import sys
+import logging
+logging.basicConfig(level=logging.ERROR)
+
+import numpy as np
+from google.protobuf import text_format
+import caffe
+from caffe import layers as L, params as P
+from caffe.proto import caffe_pb2 as pb
+
+import tvm
+from tvm import relay
+from tvm.contrib import util, graph_runtime
+from tvm.contrib.download import download_testdata
+
+CURRENT_DIR = os.path.join(os.path.expanduser('~'), '.tvm_test_data', 'caffe_test')
+
+#######################################################################
+# Generic functions for TVM & Caffe
+# ------------------------------------------
+
+
+def _create_dir(d_path):
+    """ If the directory is not existed, create it"""
+    if not (os.path.exists(d_path) and os.path.isdir(d_path)):
+        os.makedirs(d_path)
+
+
+def _list_to_str(ll):
+    """ Convert list or tuple to str, separated by underline. """
+    if isinstance(ll, (tuple, list)):
+        tmp = [str(i) for i in ll]
+        return '_'.join(tmp)
+
+
+def _gen_filename_str(op_name, data_shape, *args, **kwargs):
+    """ Combining the filename according to the op_name, shape and other args. """
+    file_dir = os.path.join(CURRENT_DIR, op_name)
+    _create_dir(file_dir)
+    res = op_name + "_"
+    shape_str = _list_to_str(list(data_shape))
+    res += shape_str
+    for arg in args:
+        if isinstance(arg, (tuple, list)):
+            res += ("_" + _list_to_str(arg))
+        elif isinstance(arg, (int, float, str)):
+            res += ("_" + str(arg))
+    for _, v in kwargs.items():
+        if isinstance(v, (tuple, list)):
+            res += ("_" + _list_to_str(v))
+        elif isinstance(v, (int, float, str)):
+            res += ("_" + str(v))
+    res = res.replace(".", "_")
+    res = res.replace("-", "_")
+    proto_file = os.path.join(file_dir, res + ".prototxt")
+    blob_file = os.path.join(file_dir, res + ".caffemodel")
+    solver_file = os.path.join(file_dir, res + "_solver.prototxt")
+
+    return (proto_file, blob_file, solver_file)
+
+
+def _save_prototxt(n_netspec, f_path):
+    """ Generate .prototxt file according to caffe.NetSpec"""
+    s = n_netspec.to_proto()
+    with open(f_path, 'w') as f:
+        f.write(str(s))
+
+
+def _save_solver(solver_file, proto_file, blob_file):
+    """ Define a solver proto, you can change the configs."""
+    blob_file_prefix = blob_file.split(".caffemodel")[0]
+    s = pb.SolverParameter()
+    s.train_net = proto_file
+    s.base_lr = 0.01
+    s.momentum = 0.9
+    s.weight_decay = 0.0005
+    s.lr_policy = "inv"
+    s.gamma = 0.0001
+    s.power = 0.75
+    s.display = 1
+    s.max_iter = 100000
+    s.snapshot = 100000
+    s.snapshot_prefix = blob_file_prefix
+
+    with open(solver_file, 'w') as f:
+        f.write(str(s))
+
+
+def _save_caffemodel(solver_file, blob_file):
+    """ Generate .caffemodel file."""
+    solver = caffe.SGDSolver(solver_file)
+    solver.net.save(blob_file)
+
+
+def _gen_model_files(n_netspec, proto_file, blob_file, solver_file):
+    _save_prototxt(n_netspec, proto_file)
+    _save_solver(solver_file, proto_file, blob_file)
+    _save_caffemodel(solver_file, blob_file)
+
+
+def _siso_op(data, func, *args, **kwargs):
+    """ Create single input and single output Caffe op """
+    n = caffe.NetSpec()
+    n.data = L.Input(input_param={'shape': {'dim': list(data.shape)}})
+    n.output = func(n.data, *args, **kwargs)
+    return n
+
+
+def _miso_op(data_list, func, *args, **kwargs):
+    """ Create multi input and single output Caffe op """
+    n = caffe.NetSpec()
+    if not isinstance(data_list, (tuple, list)):
+        raise TypeError("Need tuple or list but get {}".format(
+            type(data_list)))
+    input_list = list()
+    for idx, data in enumerate(data_list):
+        n['data' +
+          str(idx)] = L.Input(input_param={'shape': {
+              'dim': list(data.shape)
+          }})
+        input_list.append(n['data' + str(idx)])
+    n.output = func(*input_list, *args, **kwargs)
+    return n
+
+
+def _simo_op(data, func, *args, **kwargs):
+    """ Create single input and multi output Caffe op """
+    n = caffe.NetSpec()
+    n.data = L.Input(input_param={'shape': {'dim': list(data.shape)}})
+    output_list = func(n.data, *args, **kwargs)
+    for idx, out in enumerate(output_list):
+        n['output' + str(idx)] = out
+    return n
+
+
+def _run_caffe(data, proto_file, blob_file):
+    """ Run caffe model by Caffe according to .caffemodel and .prototxt"""
+    net = caffe.Net(proto_file, blob_file, caffe.TEST)
+    if isinstance(data, (list, tuple)):
+        for idx, d in enumerate(data):
+            net.blobs['data' + str(idx)].data[...] = d
+    else:
+        net.blobs['data'].data[...] = data
+    out = net.forward()
+
+    caffe_output = list()
+    for i in range(len(out.keys())):
+        if 'output'+str(i) not in out.keys():
+            caffe_output.clear()
+            return list(out.values())
+        caffe_output.append(out['output'+str(i)])
+    return caffe_output
+
+
+def _run_tvm(data, proto_file, blob_file):
+    """ Run caffe model by TVM according to .caffemodel and .prototxt"""
+    init_net = pb.NetParameter()
+    predict_net = pb.NetParameter()
+
+    # load model
+    with open(proto_file, 'r') as f:
+        text_format.Merge(f.read(), predict_net)
+    # load blob
+    with open(blob_file, 'rb') as f:
+        init_net.ParseFromString(f.read())
+
+    shape_dict = dict()
+    dtype_dict = dict()
+    if isinstance(data, (tuple, list)):
+        for idx, d in enumerate(data):
+            shape_dict['data' + str(idx)] = d.shape
+            dtype_dict['data' + str(idx)] = 'float32'
+    else:
+        shape_dict = {'data': data.shape}
+        dtype_dict = {'data': 'float32'}
+
+    mod, params = relay.frontend.from_caffe(
+        init_net, predict_net, shape_dict, dtype_dict)
+
+    target = 'llvm'
+    target_host = 'llvm'
+
+    ctx = tvm.cpu(0)
+    with tvm.transform.PassContext(opt_level=3):
+        lib = relay.build(mod,
+                          target=target,
+                          target_host=target_host,
+                          params=params)
+    dtype = 'float32'
+    m = graph_runtime.GraphModule(lib['default'](ctx))
+    if isinstance(data, (tuple, list)):
+        for idx, d in enumerate(data):
+            m.set_input('data' + str(idx), tvm.nd.array(d.astype(dtype)))
+    else:
+        m.set_input('data', tvm.nd.array(data.astype(dtype)))
+    # execute
+    m.run()
+    tvm_output = list()
+    # get outputs
+    for i in range(m.get_num_outputs()):
+        tvm_output.append(m.get_output(i).asnumpy())
+    return tvm_output
+
+
+def _compare_caffe_tvm(caffe_out, tvm_out, is_network=False):
+    for i in range(len(caffe_out)):
+        if is_network:
+            caffe_out[i] = caffe_out[i][:1]
+        tvm.testing.assert_allclose(caffe_out[i],
+                                    tvm_out[i],
+                                    rtol=1e-5,
+                                    atol=1e-5)
+
+
+def _test_op(data, func_op, op_name, **kwargs):
+    """ Single op testing pipline. """
+    shape_list = list()
+    if isinstance(data, (list, tuple)):
+        n = _miso_op(data, func_op, **kwargs)
+        for d in data:
+            shape_list.extend(list(d.shape))
+    else:
+        output_num = 1
+        if 'ntop' in kwargs.keys():
+            output_num = kwargs['ntop']
+        if output_num == 1:
+            n = _siso_op(data, func_op, **kwargs)
+        else:
+            n = _simo_op(data, func_op, **kwargs)
+        shape_list = list(data.shape)
+
+    # obtain the .caffemodel file and .prototxt file
+    (proto_file, blob_file,
+     solver_file) = _gen_filename_str(op_name, shape_list, **kwargs)
+    _gen_model_files(n, proto_file, blob_file, solver_file)
+    # run model in Caffe
+    caffe_out = _run_caffe(data, proto_file, blob_file)
+    # run model in TVM
+    tvm_out = _run_tvm(data, proto_file, blob_file)
+    _compare_caffe_tvm(caffe_out, tvm_out)
+
+
+def _test_network(data, proto_file, blob_file):
+    # run model in Caffe
+    caffe_out = _run_caffe(data, proto_file, blob_file)
+    # run model in TVM
+    tvm_out = _run_tvm(data, proto_file, blob_file)
+    _compare_caffe_tvm(caffe_out, tvm_out, is_network=True)
+
+
+#######################################################################
+# BatchNorm
+# -----------
+
+
+def _test_batchnorm(data, moving_average_fraction=0.999, eps=1e-5):
+    """ One iteration of BatchNorm """
+    _test_op(data,
+             L.BatchNorm,
+             "BatchNorm",
+             moving_average_fraction=moving_average_fraction,
+             eps=eps)
+
+
+def test_forward_BatchNorm():
+    """ BatchNorm """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    _test_batchnorm(data)
+    _test_batchnorm(data, moving_average_fraction=0.88, eps=1e-4)
+
+
+#######################################################################
+# Concat
+# -----------
+
+
+def _test_concat(data_list, axis=1):
+    """ One iteration of Concat """
+    _test_op(data_list, L.Concat, "Concat", axis=axis)
+
+
+def test_forward_Concat():
+    """ Concat """
+    _test_concat([np.random.rand(1, 3, 10, 10),
+                  np.random.rand(1, 2, 10, 10)],
+                 axis=1)
+    _test_concat([np.random.rand(3, 10, 10),
+                  np.random.rand(2, 10, 10)],
+                 axis=0)
+    _test_concat([np.random.rand(3, 10), np.random.rand(2, 10)], axis=0)
+
+
+#######################################################################
+# Convolution
+# -----------
+
+
+def _test_convolution(data, **kwargs):
+    """ One iteration of Convolution """
+    _test_op(data, L.Convolution, "Convolution", **kwargs)
+
+
+def test_forward_Convolution():
+    """ Convolution """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    _test_convolution(data,
+                      num_output=20,
+                      bias_term=True,
+                      pad=0,
+                      kernel_size=3,
+                      stride=2,
+                      dilation=1,
+                      weight_filler=dict(type="xavier"),
+                      bias_filler=dict(type="xavier"))
+    _test_convolution(data,
+                      num_output=20,
+                      bias_term=False,
+                      pad=[1, 2],
+                      kernel_size=3,
+                      stride=2,
+                      dilation=1,
+                      weight_filler=dict(type="xavier"),
+                      bias_filler=dict(type="xavier"))
+    _test_convolution(data,
+                      num_output=20,
+                      bias_term=True,
+                      pad=[1, 2],
+                      kernel_size=[3, 5],
+                      stride=[2, 1],
+                      dilation=[1, 2],
+                      weight_filler=dict(type="xavier"),
+                      bias_filler=dict(type="xavier"))
+    _test_convolution(np.random.rand(1, 2, 10, 10).astype(np.float32),
+                      num_output=20,
+                      bias_term=True,
+                      pad=[1, 2],
+                      kernel_size=[3, 5],
+                      stride=[2, 1],
+                      dilation=[1, 2],
+                      weight_filler=dict(type="xavier"),
+                      bias_filler=dict(type="xavier"),
+                      group=2)
+    _test_convolution(data,
+                      num_output=20,
+                      bias_term=True,
+                      pad_h=1,
+                      pad_w=2,
+                      kernel_h=3,
+                      kernel_w=5,
+                      stride_h=2,
+                      stride_w=1,
+                      dilation=[1, 2],
+                      weight_filler=dict(type="xavier"),
+                      bias_filler=dict(type="xavier"))
+
+
+#######################################################################
+# Crop
+# -----------
+
+
+def _test_crop(data, **kwargs):
+    """ One iteration of Crop """
+    _test_op(data, L.Crop, "Crop", **kwargs)
+
+
+def test_forward_Crop():
+    """ Crop """
+    _test_crop(
+        [np.random.rand(10, 10, 120, 120),
+         np.random.rand(10, 5, 50, 60)])
+    _test_crop(
+        [np.random.rand(10, 10, 120, 120),
+         np.random.rand(10, 5, 50, 60)],
+        axis=1)
+    _test_crop(
+        [np.random.rand(10, 10, 120, 120),
+         np.random.rand(10, 5, 50, 60)],
+        axis=1,
+        offset=2)
+    _test_crop(
+        [np.random.rand(10, 10, 120, 120),
+         np.random.rand(10, 5, 50, 60)],
+        axis=1,
+        offset=[1, 2, 4])
+    _test_crop(
+        [np.random.rand(10, 10, 120, 120),
+         np.random.rand(10, 5, 50, 60)],
+        axis=2,
+        offset=[2, 4])
+    _test_crop([np.random.rand(10, 120, 120),
+                np.random.rand(5, 50, 60)],
+               axis=1,
+               offset=[2, 4])
+    _test_crop([np.random.rand(120, 120),
+                np.random.rand(50, 60)],
+               axis=0,
+               offset=[2, 4])
+
+
+#######################################################################
+# Deconvolution
+# -----------
+
+
+def _test_deconvolution(data, **kwargs):
+    """ One iteration of Deconvolution """
+    _test_op(data, L.Deconvolution, "Deconvolution", **kwargs)
+
+
+def test_forward_Deconvolution():
+    """ Deconvolution """
+    data = np.random.rand(1, 16, 32, 32).astype(np.float32)
+    _test_deconvolution(data,
+                        convolution_param=dict(
+                            num_output=20,
+                            bias_term=True,
+                            pad=0,
+                            kernel_size=3,
+                            stride=2,
+                            dilation=1,
+                            weight_filler=dict(type="xavier"),
+                            bias_filler=dict(type="xavier")))
+    _test_deconvolution(data,
+                        convolution_param=dict(
+                            num_output=20,
+                            bias_term=False,
+                            pad=[1, 2],
+                            kernel_size=3,
+                            stride=2,
+                            dilation=1,
+                            weight_filler=dict(type="xavier"),
+                            bias_filler=dict(type="xavier")))
+    _test_deconvolution(data,
+                        convolution_param=dict(
+                            num_output=20,
+                            bias_term=True,
+                            pad_h=1,
+                            pad_w=2,
+                            kernel_h=3,
+                            kernel_w=5,
+                            stride_h=2,
+                            stride_w=1,
+                            dilation=1,
+                            weight_filler=dict(type="xavier"),
+                            bias_filler=dict(type="xavier")))
+
+
+#######################################################################
+# Dropout
+# -----------
+
+
+def _test_dropout(data, **kwargs):
+    """ One iteration of Dropout """
+    _test_op(data, L.Dropout, "Dropout", **kwargs)
+
+
+def test_forward_Dropout():
+    """ Dropout """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    _test_dropout(data)
+    _test_dropout(data, dropout_ratio=0.7)
+
+
+#######################################################################
+# Eltwise
+# -----------
+
+
+def _test_eltwise(data_list, **kwargs):
+    """ One iteration of Eltwise """
+    _test_op(data_list, L.Eltwise, "Eltwise", **kwargs)
+
+
+def test_forward_Eltwise():
+    """ Eltwise """
+    _test_eltwise([
+        np.random.rand(1, 3, 10, 11).astype(np.float32),
+        np.random.rand(1, 3, 10, 11).astype(np.float32)
+    ],
+                  operation=0)
+    _test_eltwise([
+        np.random.rand(1, 3, 10, 11).astype(np.float32),
+        np.random.rand(1, 3, 10, 11).astype(np.float32)
+    ],
+                  operation=1)
+    _test_eltwise([
+        np.random.rand(1, 3, 10, 11).astype(np.float32),
+        np.random.rand(1, 3, 10, 11).astype(np.float32)
+    ],
+                  operation=2)
+    _test_eltwise([
+        np.random.rand(1, 3, 10, 11).astype(np.float32),
+        np.random.rand(1, 3, 10, 11).astype(np.float32)
+    ],
+                  operation=1,
+                  coeff=[0.5, 1])
+
+
+#######################################################################
+# Flatten
+# -----------
+
+
+def _test_flatten(data, axis=1):
+    """ One iteration of Flatten """
+    _test_op(data, L.Flatten, 'Flatten', axis=axis)
+
+
+def test_forward_Flatten():
+    """ Flatten """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    _test_flatten(data)
+    _test_flatten(data, axis=1)
+
+
+#######################################################################
+# Flatten
+# -----------
+
+
+def _test_inner_product(data, **kwargs):
+    """ One iteration of InnerProduct"""
+    _test_op(data, L.InnerProduct, "InnerProduct", **kwargs)
+
+
+def test_forward_InnerProduct():
+    """ InnerProduct """
+    data = np.random.rand(1, 3, 10, 10)
+    _test_inner_product(data,
+                        num_output=20,
+                        bias_term=False,
+                        weight_filler=dict(type='xavier'))
+    _test_inner_product(data,
+                        num_output=20,
+                        bias_term=True,
+                        weight_filler=dict(type='xavier'),
+                        bias_filler=dict(type='xavier'))
+    _test_inner_product(np.random.rand(20, 10).astype(np.float32),
+                        num_output=30,
+                        bias_term=True,
+                        weight_filler=dict(type='xavier'),
+                        bias_filler=dict(type='xavier'))
+
+
+#######################################################################
+# LRN
+# -----------
+
+
+def _test_lrn(data, local_size=5, alpha=1., beta=0.75, k=1.):
+    """ One iteration of LRN """
+    _test_op(data,
+             L.LRN,
+             'LRN',
+             local_size=local_size,
+             alpha=alpha,
+             beta=beta,
+             k=k)
+
+
+def test_forward_LRN():
+    """ LRN """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    _test_lrn(data)
+    _test_lrn(data, local_size=3)
+    _test_lrn(data, local_size=3, alpha=2.)
+    _test_lrn(
+        data,
+        local_size=3,
+        alpha=2.,
+        beta=0.5,
+    )
+    _test_lrn(data, local_size=3, alpha=2., beta=0.5, k=2.)
+
+
+#######################################################################
+# Pooling
+# -----------
+
+
+def _test_pooling(data, **kwargs):
+    """ One iteration of Pooling. """
+    _test_op(data, L.Pooling, "Pooling", **kwargs)
+
+
+def test_forward_Pooling():
+    """ Pooing """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    # MAX Pooling
+    _test_pooling(data, kernel_size=2, stride=2, pad=0, pool=P.Pooling.MAX)
+    _test_pooling(data,
+                  kernel_h=2,
+                  kernel_w=3,
+                  stride_h=2,
+                  stride_w=1,
+                  pad_h=1,
+                  pad_w=2,
+                  pool=P.Pooling.MAX)
+    _test_pooling(data, pool=P.Pooling.MAX, global_pooling=True)
+
+    # AVE Pooing
+    _test_pooling(data, kernel_size=2, stride=2, pad=0, pool=P.Pooling.AVE)
+    _test_pooling(data,
+                  kernel_h=2,
+                  kernel_w=3,
+                  stride_h=2,
+                  stride_w=1,
+                  pad_h=1,
+                  pad_w=2,
+                  pool=P.Pooling.AVE)
+    _test_pooling(data, pool=P.Pooling.AVE, global_pooling=True)
+
+
+#######################################################################
+# PReLU
+# -----------
+
+
+def _test_prelu(data, **kwargs):
+    """ One iteration of PReLU. """
+    _test_op(data, L.PReLU, "PReLU", **kwargs)
+
+
+def test_forward_PReLU():
+    """ PReLU """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    _test_prelu(data, filler=dict(type='constant', value=0.5))
+    _test_prelu(data)
+    _test_prelu(np.random.rand(10, 20).astype(np.float32))
+
+
+#######################################################################
+# ReLU
+# -----------
+
+
+def _test_relu(data, **kwargs):
+    """ One iteration of ReLU. """
+    _test_op(data, L.ReLU, "ReLU", **kwargs)
+
+
+def test_forward_ReLU():
+    """ ReLU """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    _test_relu(data)
+    _test_relu(np.random.rand(10, 20).astype(np.float32))
+
+
+#######################################################################
+# Reshape
+# -----------
+
+
+def _test_reshape(data, **kwargs):
+    """ One iteration of Reshape. """
+    _test_op(data, L.Reshape, "Reshape", **kwargs)
+
+
+def test_forward_Reshape():
+    """ Reshape """
+    data = np.random.rand(1, 8, 6).astype(np.float32)
+    _test_reshape(data, reshape_param={'shape': {'dim': [4, 3, 4]}})
+    _test_reshape(data, reshape_param={'shape': {'dim': [2, 0, 3]}})
+    _test_reshape(data, reshape_param={'shape': {'dim': [2, 0, -1]}})
+    _test_reshape(data, reshape_param={'shape': {'dim': [0, -1]}})
+
+    _test_reshape(data, reshape_param={'shape': {'dim': [2, 3]}, 'axis': 2})
+    _test_reshape(data, reshape_param={'shape': {'dim': [4, 3, 4]}, 'axis': 1})
+    _test_reshape(data,
+                  reshape_param={
+                      'shape': {
+                          'dim': [4, 3, 4]
+                      },
+                      'axis': -3
+                  })
+
+    _test_reshape(data,
+                  reshape_param={
+                      'shape': {
+                          'dim': [2, 4]
+                      },
+                      'axis': 1,
+                      'num_axes': 1
+                  })
+    _test_reshape(data,
+                  reshape_param={
+                      'shape': {
+                          'dim': [3, 16]
+                      },
+                      'axis': 1,
+                      'num_axes': 2
+                  })
+
+
+#######################################################################
+# Scale
+# -----------
+
+
+def _test_scale(data, **kwargs):
+    """ One iteration of Scale. """
+    _test_op(data, L.Scale, "Scale", **kwargs)
+
+
+def test_forward_Scale():
+    """ Scale """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    _test_scale(data, filler=dict(type="xavier"))
+    _test_scale(data,
+                filler=dict(type="xavier"),
+                bias_term=True,
+                bias_filler=dict(type="xavier"))
+
+
+#######################################################################
+# Sigmoid
+# -----------
+
+
+def _test_sigmoid(data, **kwargs):
+    """ One iteration of Sigmoid. """
+    _test_op(data, L.Sigmoid, "Sigmoid", **kwargs)
+
+
+def test_forward_Sigmoid():
+    """ Sigmoid """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    _test_sigmoid(data)
+
+
+#######################################################################
+# Slice
+# -----------
+
+
+def _test_slice(data, **kwargs):
+    """ One iteration of Slice """
+    _test_op(data, L.Slice, "Slice", **kwargs)
+
+
+def test_forward_Slice():
+    """ Slice """
+    data = np.random.rand(1, 3, 10, 10).astype(np.float32)
+    _test_slice(data, ntop=2, slice_param=dict(axis=1, slice_point=[1]))
+    _test_slice(data, ntop=2, slice_param=dict(axis=-1, slice_point=[1]))
+    _test_slice(data, ntop=3, slice_param=dict(axis=2, slice_point=[1, 6]))
+    _test_slice(data, ntop=3)
+
+
+#######################################################################
+# Softmax
+# -----------
+
+
+def _test_softmax(data, **kwargs):
+    """ One iteration of Softmax """
+    _test_op(data, L.Softmax, "Softmax", **kwargs)
+
+
+def test_forward_Softmax():
+    """ Softmax"""
+    _test_softmax(np.random.rand(1, 3, 10, 10).astype(np.float32))
+    _test_softmax(np.random.rand(1, 3, 10, 10).astype(np.float32), axis=2)
+    _test_softmax(np.random.rand(10, 10).astype(np.float32), axis=0)
+    _test_softmax(np.random.rand(2, 10, 10).astype(np.float32), axis=1)
+
+
+#######################################################################
+# TanH
+# -----------
+
+
+def _test_tanh(data, **kwargs):
+    """ One iteration of TanH """
+    _test_op(data, L.TanH, "TanH", **kwargs)
+
+
+def test_forward_TanH():
+    """ TanH """
+    _test_tanh(np.random.rand(1, 3, 10, 10).astype(np.float32))
+    _test_tanh(np.random.rand(3, 10, 10).astype(np.float32))
+    _test_tanh(np.random.rand(10, 10).astype(np.float32))
+    _test_tanh(np.random.rand(10).astype(np.float32))
+
+
+#######################################################################
+# Mobilenetv2
+# -----------
+
+
+def _test_mobilenetv2(data):
+    """ One iteration of Mobilenetv2 """
+    mean_val = np.array([103.939, 116.779, 123.68], dtype=np.float32)
+    mean_val = np.reshape(mean_val, (1, 3, 1, 1))
+    mean_val = np.tile(mean_val, (1, 1, 224, 224))
+    data_process = data - mean_val
+    data_process = data_process / 58.8
+    data_process = data_process.astype(np.float32)
+
+    proto_file_url = ("https://github.com/shicai/MobileNet-Caffe/raw/"
+                        "master/mobilenet_v2_deploy.prototxt")
+    blob_file_url = ("https://github.com/shicai/MobileNet-Caffe/blob/"
+                        "master/mobilenet_v2.caffemodel?raw=true")
+    proto_file = download_testdata(proto_file_url, 'mobilenetv2.prototxt',
+                                     module='model')
+    blob_file = download_testdata(blob_file_url, 'mobilenetv2.caffemodel',
+                                     module='model')
+    _test_network(data_process, proto_file, blob_file)
+
+
+def test_forward_Mobilenetv2():
+    """ Mobilenetv2 """
+    data = np.random.randint(0, 256, size=(1, 3, 224, 224)).astype(np.float32)
+    _test_mobilenetv2(data)
+
+
+#######################################################################
+# Alexnet
+# -----------
+
+
+def _test_alexnet(data):
+    """ One iteration of Alexnet """
+    mean_val = np.array([103.939, 116.779, 123.68], dtype=np.float32)
+    mean_val = np.reshape(mean_val, (1, 3, 1, 1))
+    mean_val = np.tile(mean_val, (1, 1, 227, 227))
+    data_process = data - mean_val
+    data_process = data_process.astype(np.float32)
+
+    proto_file_url = ("https://github.com/BVLC/caffe/raw/master/models/"
+                        "bvlc_alexnet/deploy.prototxt")
+    blob_file_url = 'http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel'
+    proto_file = download_testdata(proto_file_url, 'alexnet.prototxt',
+                                    module="model")
+    blob_file = download_testdata(blob_file_url, 'alexnet.caffemodel',
+                                    module='model')
+    _test_network(data_process, proto_file, blob_file)
+
+
+def test_forward_Alexnet():
+    """ Alexnet """
+    data = np.random.randint(0, 256, size=(1, 3, 227, 227)).astype(np.float32)
+    _test_alexnet(data)
+
+
+#######################################################################
+# Resnet50
+# -----------
+
+
+def _test_resnet50(data):
+    """ One iteration of Resnet50 """
+    mean_val = np.array([103.939, 116.779, 123.68], dtype=np.float32)
+    mean_val = np.reshape(mean_val, (1, 3, 1, 1))
+    mean_val = np.tile(mean_val, (1, 1, 224, 224))
+    data_process = data - mean_val
+    data_process = data_process.astype(np.float32)
+    
+    proto_file_url = ("https://github.com/fernchen/CaffeModels/raw/"
+                        "master/resnet/ResNet-50-deploy.prototxt")
+    blob_file_url = ("https://github.com/fernchen/CaffeModels/raw/"
+                       "master/resnet/ResNet-50-model.caffemodel")
+
+    proto_file = download_testdata(proto_file_url, 'resnet50.prototxt',
+                                    module="model")
+    blob_file = download_testdata(blob_file_url, 'resnet50.caffemodel',
+                                    module='model')
+
+    _test_network(data_process, proto_file, blob_file)
+
+
+def test_forward_Resnet50():
+    """ Resnet50 """
+    data = np.random.randint(0, 256, size=(1, 3, 224, 224)).astype(np.float32)
+    _test_resnet50(data)
+
+
+#######################################################################
+# Inceptionv4
+# -----------
+
+
+def _test_inceptionv1(data):
+    """ One iteration of Inceptionv4 """
+    mean_val = np.array([103.939, 116.779, 123.68], dtype=np.float32)
+    mean_val = np.reshape(mean_val, (1, 3, 1, 1))
+    mean_val = np.tile(mean_val, (1, 1, 224, 224))
+    data_process = data - mean_val
+    data_process = data_process / 58.8
+    data_process = data_process.astype(np.float32)
+
+    proto_file_url = ("https://github.com/BVLC/caffe/raw/master/models"
+                        "/bvlc_googlenet/deploy.prototxt")
+    blob_file_url = 'http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel'
+    proto_file = download_testdata(proto_file_url, 'inceptionv1.prototxt',
+                                    module="model")
+    blob_file = download_testdata(blob_file_url, 'inceptionv1.caffemodel',
+                                    module='model')
+    _test_network(data_process, proto_file, blob_file)
+
+
+def test_forward_Inceptionv1():
+    """ Inceptionv4 """
+    data = np.random.randint(0, 256, size=(1, 3, 224, 224)).astype(np.float32)
+    _test_inceptionv1(data)
+
+
+if __name__ == "__main__":
+    # NN
+    test_forward_Convolution()
+    test_forward_Deconvolution()
+    test_forward_Dropout()
+    test_forward_LRN()
+    test_forward_Pooling()
+    test_forward_Scale()
+    test_forward_InnerProduct()
+    test_forward_BatchNorm()
+
+    # Elemwise
+    test_forward_Eltwise()
+
+    # Activation
+    test_forward_PReLU()
+    test_forward_ReLU()
+    test_forward_Sigmoid()
+    test_forward_Softmax()
+    test_forward_TanH()
+
+    # Reshape
+    test_forward_Reshape()
+    test_forward_Flatten()
+
+    # Math
+    test_forward_Concat()
+    test_forward_Crop()
+    test_forward_Slice()
+
+    # End to End
+    test_forward_Mobilenetv2()
+    test_forward_Alexnet()
+    test_forward_Resnet50()
+    test_forward_Inceptionv1()
index 96c5ce6..10354e5 100755 (executable)
@@ -35,3 +35,6 @@ python3 -m pytest tests/python/frontend/tflite
 
 echo "Running relay Keras frontend test..."
 python3 -m pytest tests/python/frontend/keras
+
+echo "Running relay Caffe frontend test..."
+python3 -m pytest tests/python/frontend/caffe