[Relay][Frontend][ONNX] New Operators and Opsets to Support BERT (#4197)
authorJosh Fromm <jwfromm@uw.edu>
Wed, 30 Oct 2019 18:24:47 +0000 (11:24 -0700)
committerJared Roesch <roeschinc@gmail.com>
Wed, 30 Oct 2019 18:24:47 +0000 (11:24 -0700)
* Added slice v10

* Added constantofshape operation and small refactor.

* Finished one_hot implementation.

* Reshape working across all bert layers.

* Fixed constantofshape and removed code duplication.

* onnx model fully ingested.

* Working on improving onnx tests.

* Changed onnx testing to use onnxruntime instead of caffe2, also formatted.

* Add arbitrary output nodes to onnx frontend.

* Added v6 tiling for bert squad 8 support.

* Small syntax fixes

* Reduced code duplication in split opset versions.

* Added batch matmul test

* Added unstack split testing.

* Adde onehot test, needs a little cleanup probably.

* Replaced deprecated constant fill with constantofshape and updated tests accordingly.

* Added tests for new opset version of slice and tile.

* lint clean up

* Lint fixes

* Changed onnx dependency

* Went back to caffe2 runtime for CI integration.

* Rebase and small typo/syntax changes.

* Added hard casting of onehot attributes to int.

python/tvm/relay/frontend/common.py
python/tvm/relay/frontend/onnx.py
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/onnx/test_forward.py

index d4b9162..25ba0ef 100644 (file)
@@ -19,11 +19,13 @@ from __future__ import absolute_import as _abs
 import logging
 
 import tvm
+import numpy as np
 from topi.util import get_const_tuple
 from .. import expr as _expr
 from .. import module as _module
 from .. import transform as _transform
 from .. import op as _op
+from .. import analysis
 
 
 class RequiredAttr(object):
@@ -474,6 +476,50 @@ def infer_channels(inputs, transpose=False):
     return channels
 
 
+def infer_value(input_val, params):
+    """A hack for getting the value of an expression by evaluating a
+    portion of the relay graph. This is often needed for functions that
+    whose output shape depends on the value of a tensor.
+    """
+    from tvm.contrib import graph_runtime
+    # Check that all free variables have associated parameters.
+    assert all(var.name_hint in params.keys() for var in analysis.free_vars(
+        input_val)), "All inputs to infer must be available in params."
+    func = _expr.Function(analysis.free_vars(input_val), input_val)
+    with tvm.relay.build_config(opt_level=0):
+        graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
+    ctx = tvm.cpu(0)
+    m = graph_runtime.create(graph, lib, ctx)
+    m.set_input(**params)
+    m.run()
+    return m.get_output(0)
+
+
+def infer_value_simulated(input_val, params):
+    """Extention to infer_value that can be used when some input
+    values are missing. This function creates dummy inputs with the same
+    shape and random values then calls infer_value. This is helpful when
+    implementing certain onnx operators where we need to evaluate the graph
+    to determine a static shape.
+    """
+    fake_params = []
+    # Add a fake copy of all missing params.
+    for free_param in analysis.free_vars(input_val):
+        if free_param.name_hint not in params:
+            fp_dtype = free_param.type_annotation.dtype
+            fp_shape = [s.value for s in free_param.type_annotation.shape]
+            fake_params.append(free_param)
+            params[free_param.name_hint] = tvm.nd.array(
+                np.random.rand(*fp_shape).astype(fp_dtype)
+            )
+    # Now infer the value.
+    output_value = infer_value(input_val, params)
+    # Clean fake params out of param dictionary.
+    for fake_p in fake_params:
+        params.pop(fake_p.name_hint, None)
+    return output_value
+
+
 def new_var(name_hint,
             type_annotation=None,
             shape=None,
index 1d74a01..41fafbc 100644 (file)
 """ONNX: Open Neural Network Exchange frontend for Relay."""
 from __future__ import absolute_import as _abs
 
-import logging
 import numpy as np
 import tvm
 from ... import nd as _nd
 from .. import analysis
-from .. import transform as _transform
 from .. import expr as _expr
 from .. import module as _module
 from .. import op as _op
 from .common import AttrCvt, Renamer
-from .common import get_relay_op, new_var, infer_shape, infer_channels, get_name
+from .common import get_relay_op, new_var, infer_shape, infer_channels
+from .common import infer_type, infer_value, infer_value_simulated, get_name
 
 __all__ = ['from_onnx']
 
+
+def get_numpy(tensor_proto):
+    """Grab data in TensorProto and convert to numpy array."""
+    try:
+        from onnx.numpy_helper import to_array
+    except ImportError as e:
+        raise ImportError(
+            "Unable to import onnx which is required {}".format(e))
+    return to_array(tensor_proto)
+
+
 def dimension_picker(prefix, surfix=''):
     def _impl(attr):
         kernel = attr['kernel_shape']
@@ -43,6 +53,7 @@ def dimension_picker(prefix, surfix=''):
 
     return _impl
 
+
 def revert_caffe2_pad(pads):
     """Caffe2 requires two times the normal padding."""
     if len(pads) == 4:
@@ -279,6 +290,21 @@ class MatMul(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
+        # Need to check input shape as batch matmul must be supported.
+        a_shape = infer_shape(inputs[0])
+        # When performing a batch matmul, we need to properly handle N-dim shapes.
+        if len(a_shape) > 2:
+            b_shape = infer_shape(inputs[1])
+            # Convert a and b into 3 dimensional tensors.
+            a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
+            b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
+            # Transpose matrix dimensions of b.
+            b = _op.transpose(b, [0, 2, 1])
+            # Perform a batch matmul.
+            output = _op.nn.batch_matmul(a, b)
+            # Reshape output to original dimensions.
+            return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
+        # Otherwise a simple dense op will get the job done.
         input_1_t = _op.transpose(inputs[1], axes=(1, 0))
         return _op.nn.dense(inputs[0], input_1_t)
 
@@ -426,35 +452,18 @@ class Reshape(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if 'shape' in attr:
-            return _op.reshape(inputs[0], attr['shape'])
+        return _op.reshape(inputs[0], attr['shape'])
 
+    @classmethod
+    def _impl_v5(cls, inputs, attr, params):
         if get_name(inputs[1]) in params:
             shape = tuple(params[inputs[1].name_hint].asnumpy())
             out = _op.reshape(inputs[0], shape)
         else:
             data, shape = inputs
-            logging.warning("Constant evaluating Reshape's shape argument, may reduce performance")
-            shape_params = analysis.free_vars(shape)
-            func = _expr.Function(shape_params, shape)
-            mod = _module.Module.from_expr(func)
-            seq = _transform.Sequential([_transform.InferType(),
-                                         _transform.FoldConstant(),
-                                         _transform.FuseOps(0),
-                                         _transform.InferType()])
-            with tvm.relay.PassContext(opt_level=2):
-                mod = seq(mod)
-            with tvm.relay.build_config(opt_level=0):
-                ex = tvm.relay.create_executor("debug", mod=mod)
-                inputs = []
-                for sp in shape_params:
-                    if not sp.name_hint in params:
-                        sh = [int(i) for i in sp.type_annotation.shape]
-                        inputs.append(
-                            tvm.nd.array(np.random.rand(*sh).astype('float32')))
-                static_shape = ex.evaluate()(*inputs, **params)
-            out = _op.reshape(data, newshape=tuple(static_shape.asnumpy()))
-
+            static_shape = infer_value_simulated(shape, params)
+            out = _op.reshape(data, newshape=tuple(
+                static_shape.asnumpy().astype('int32')))
         return out
 
 class Concat(OnnxOpConverter):
@@ -640,11 +649,17 @@ class Split(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        attr['indices_or_sections'] = []
-        index = 0
-        for i in attr['split'][:-1]:
-            index += i
-            attr['indices_or_sections'].append(index)
+        splits = attr.get('split', False)
+        if splits:
+            attr['indices_or_sections'] = []
+            index = 0
+            for i in splits[:-1]:
+                index += i
+                attr['indices_or_sections'].append(index)
+        # When splits isnt specified divide evenly over axis.
+        else:
+            in_shape = infer_shape(inputs[0])
+            attr['indices_or_sections'] = in_shape[attr['axis']]
         return AttrCvt(
             'split',
             ignores=['split'])(inputs, attr, params)
@@ -653,6 +668,25 @@ class Split(OnnxOpConverter):
 class Slice(OnnxOpConverter):
     """ Operator converter for Slice.
     """
+
+    @classmethod
+    def _common(cls, starts, ends, axes):
+        new_axes = []
+        new_starts = []
+        new_ends = []
+        pop_index = 0
+        for i in range(max(axes) + 1):
+            if i in axes:
+                new_axes.append(i)
+                new_starts.append(starts[pop_index])
+                new_ends.append(ends[pop_index])
+                pop_index += 1
+            else:
+                new_axes.append(i)
+                new_starts.append(0)
+                new_ends.append(np.iinfo(np.int32).max)
+        return new_starts, new_ends, new_axes
+
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         if isinstance(attr['starts'], int):
@@ -663,22 +697,9 @@ class Slice(OnnxOpConverter):
             # Update the starts and ends according to axes if required.
             if isinstance(attr['axes'], int):
                 attr['axes'] = (attr['axes'],)
-
             if (max(attr['axes']) + 1) != len(attr['axes']):
-                new_axes = []
-                new_starts = []
-                new_ends = []
-                pop_index = 0
-                for i in range(max(attr['axes']) + 1):
-                    if i in attr['axes']:
-                        new_axes.append(i)
-                        new_starts.append(attr['starts'][pop_index])
-                        new_ends.append(attr['ends'][pop_index])
-                        pop_index += 1
-                    else:
-                        new_axes.append(i)
-                        new_starts.append(0)
-                        new_ends.append(np.iinfo(np.int32).max)
+                new_starts, new_ends, new_axes = cls._common(
+                    attr['starts'], attr['ends'], attr['axes'])
                 attr['axes'] = new_axes
                 attr['starts'] = new_starts
                 attr['ends'] = new_ends
@@ -690,6 +711,23 @@ class Slice(OnnxOpConverter):
                                    'ends': 'end'},
                        ignores=['axes'])(inputs, attr)
 
+    @classmethod
+    def _impl_v10(cls, inputs, attr, params):
+        starts = params[get_name(inputs[1])].asnumpy()
+        ends = params[get_name(inputs[2])].asnumpy()
+
+        # Update the starts and ends according to axes if required.
+        if len(inputs) >= 4:
+            axes = params[get_name(inputs[3])].asnumpy()
+
+            if max(axes + 1) != len(axes):
+                new_starts, new_ends, _ = cls._common(
+                    starts, ends, axes)
+                starts = new_starts
+                ends = new_ends
+        return _op.strided_slice(inputs[0], begin=starts, end=ends)
+
+
 class Gather(OnnxOpConverter):
     """ Operator converter for Gather.
     """
@@ -698,7 +736,6 @@ class Gather(OnnxOpConverter):
         axis = attr.get('axis', 0)
         return AttrCvt('take',
                        extras={'axis':axis})(inputs, {})
-        #return _op.take(inputs[0], inputs[1], axis)
 
 
 class Greater(OnnxOpConverter):
@@ -848,33 +885,49 @@ class Softmax(OnnxOpConverter):
             attr['axis'] = 1
         return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params)
 
-class ConstantFill(OnnxOpConverter):
-    """ Operator converter for ConstantFill.
+
+class OneHot(OnnxOpConverter):
+    """ Operator converter for OneHot.
     """
     @classmethod
-    def _impl_v1(cls, inputs, attr, params):
-        num_inputs = len(inputs)
-        if 'shape' in attr:
-            if num_inputs > 1:
-                raise ImportError(
-                    "Can't set shape and input tensor at a time")
-            shape = attr.pop('shape')
+    def _impl_v9(cls, inputs, attr, params):
+        # Extract relay one_hot inputs.
+        indices, depth, values = inputs
+        # Split onnx on off values into two separate expressions.
+        off_value, on_value = _op.take(
+            values, _op.const(0)), _op.take(values, _op.const(1))
+        # Extract the datatype of the output from on_value.
+        dtype = infer_type(on_value).checked_type.dtype
+        # Convert depth into an integer.
+        depth = int(infer_value(depth, params).asnumpy()[0])
+        # set default value when axis is not set in the model
+        if 'axis' not in attr:
+            attr['axis'] = -1
+        return _op.one_hot(indices,
+                           on_value,
+                           off_value,
+                           depth,
+                           int(attr['axis']),
+                           dtype=dtype)
+
+
+class ConstantOfShape(OnnxOpConverter):
+    """ Operator converter for ConstantOfShape.
+    """
+    @classmethod
+    def _impl_v9(cls, inputs, attr, params):
+        if 'value' in attr:
+            np_value = get_numpy(attr.pop('value'))[0]
+            value = _expr.const(np_value)
+            dtype = np_value.dtype.name
         else:
-            if num_inputs == 1:
-                raise ImportError(
-                    "Either shape attribute or input should be set")
-            if 'input_as_shape' in attr and attr['input_as_shape']:
-                shape = params[get_name(inputs[0])].asnumpy()
-            else:
-                if 'extra_shape' in attr:
-                    raise tvm.error.OpAttributeInvalid('Attribute "extra_shape" not '
-                                                       'supported with "fill_like" for '
-                                                       'operator ConstantFill.')
-                return _op.full_like(inputs[0], inputs[1])
+            value = _expr.const(0)
+            dtype = 'float32'
+        static_shape = infer_value_simulated(inputs[0], params)
+        output = _op.full(
+            value, shape=tuple(static_shape.asnumpy().astype('int32')), dtype=dtype)
+        return output
 
-        if 'extra_shape' in attr:
-            shape = shape + attr.pop('extra_shape')
-        return _op.full(inputs[0], shape)
 
 class Sign(OnnxOpConverter):
     """ Operator converter for Sign.
@@ -916,6 +969,12 @@ class Tile(Elemwise):
         reps = attr.pop('repeats')  # The number of times repeating the tensor data.
         return _op.tile(inputs[0], reps)
 
+    @classmethod
+    def _impl_v6(cls, inputs, attr, params):
+        reps = tuple(infer_value_simulated(
+            inputs[1], params).asnumpy().astype('int32'))
+        return _op.tile(inputs[0], reps)
+
 class Erf(OnnxOpConverter):
     """Operator converter for Erf
     """
@@ -948,7 +1007,7 @@ def _get_convert_map(opset):
         'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
         'ScaledTanh': ScaledTanh.get_converter(opset),
         'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
-        'ConstantFill': ConstantFill.get_converter(opset),
+        'ConstantOfShape': ConstantOfShape.get_converter(opset),
         # 'GivenTensorFill'
         'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
         'Scale': Scale.get_converter(opset),
@@ -958,7 +1017,7 @@ def _get_convert_map(opset):
         # 'MeanVarianceNormalization'
         # 'Crop'
         # 'Embedding'
-        'Upsample' : Upsample.get_converter(opset),
+        'Upsample': Upsample.get_converter(opset),
         'SpatialBN': BatchNorm.get_converter(opset),
 
         # defs/generator
@@ -1002,6 +1061,7 @@ def _get_convert_map(opset):
         # softmax default axis is different in onnx
         'Softmax': Softmax.get_converter(opset),
         'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
+        'OneHot': OneHot.get_converter(opset),
         # 'Hardmax'
         'Softsign': Softsign.get_converter(opset),
         'SoftPlus': SoftPlus.get_converter(opset),
@@ -1164,14 +1224,6 @@ class GraphProto(object):
                     shape=list(t_proto.dims),
                     dtype=array.dtype)
             else:
-                if op_name == "ConstantFill":
-                    fill_value = attr.get('value', 0.0)
-                    dtype = attr.get('dtype', b'int32').decode("utf-8")
-                    i_name = node.output[0]
-                    self._params[i_name] = fill_value
-                    self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype)
-                    inputs.append(self._nodes[i_name])
-
                 i_name = self._parse_value_proto(node)
                 attr['tvm_custom'] = {}
                 attr['tvm_custom']['name'] = i_name
@@ -1214,13 +1266,7 @@ class GraphProto(object):
             return dtype
 
     def _parse_array(self, tensor_proto):
-        """Grab data in TensorProto and convert to numpy array."""
-        try:
-            from onnx.numpy_helper import to_array
-        except ImportError as e:
-            raise ImportError(
-                "Unable to import onnx which is required {}".format(e))
-        np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
+        np_array = get_numpy(tensor_proto).reshape(tuple(tensor_proto.dims))
         return _nd.array(np_array)
 
     def _parse_attr(self, attr_proto):
@@ -1301,7 +1347,8 @@ class GraphProto(object):
 
 def from_onnx(model,
               shape=None,
-              dtype="float32"):
+              dtype="float32",
+              opset=None):
     """Convert a ONNX model into an equivalent Relay Function.
 
     ONNX graphs are represented as Python Protobuf objects.
@@ -1322,6 +1369,10 @@ def from_onnx(model,
     dtype : str or dict of str to str
         The input types to the graph
 
+    opset : int, optional
+        Override to autodetected opset.
+        This can be helpful for some testing.
+
     Returns
     -------
     mod : tvm.relay.Module
@@ -1344,9 +1395,10 @@ def from_onnx(model,
         pass
     g = GraphProto(shape, dtype)
     graph = model.graph
-    try:
-        opset = model.opset_import[0].version if model.opset_import else 1
-    except AttributeError:
-        opset = 1
+    if opset is None:
+        try:
+            opset = model.opset_import[0].version if model.opset_import else 1
+        except AttributeError:
+            opset = 1
     mod, params = g.from_onnx(graph, opset)
     return mod, params
index bfa3431..2ef8d15 100644 (file)
@@ -39,22 +39,10 @@ from .common import AttrCvt, get_relay_op
 from .common import infer_type as _infer_type
 from .common import infer_shape as _infer_shape
 from .common import infer_channels as _infer_channels
+from .common import infer_value as _infer_value
 
 __all__ = ['from_tensorflow']
 
-def _infer_value(input_val, params):
-    from tvm.contrib import graph_runtime
-    # Check that all free variables have associated parameters.
-    assert all(var.name_hint in params.keys() for var in analysis.free_vars(
-        input_val)), "All inputs to infer must be available in params."
-    func = _expr.Function(analysis.free_vars(input_val), input_val)
-    with tvm.relay.build_config(opt_level=0):
-        graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
-    ctx = tvm.context("llvm", 0)
-    m = graph_runtime.create(graph, lib, ctx)
-    m.set_input(**params)
-    m.run()
-    return m.get_output(0)
 
 def _get_pad_pair(input1d, kernel1d, stride1d):
     if input1d % stride1d == 0:
index 3d1262f..2d2265b 100644 (file)
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import attr
 import numpy as np
 import math
 import torch
@@ -26,11 +25,11 @@ from tvm import relay
 from tvm.contrib import graph_runtime
 from nnvm.testing.config import ctx_list
 import onnx
-from onnx import helper, TensorProto
-import unittest
+from onnx import helper, TensorProto, mapping
 import scipy
 
-def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32'):
+
+def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
     """ Generic function to execute and get tvm output"""
     target = 'llvm'
     if isinstance(input_data, list):
@@ -46,21 +45,22 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
         shape_dict = {input_names: input_data.shape}
         dtype_dict = {input_names: input_data.dtype}
 
-    mod, params = relay.frontend.from_onnx(graph_def, shape_dict)
+    mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
     with relay.build_config(opt_level=1):
         graph, lib, params = relay.build(mod,
                                          target,
                                          params=params)
 
     ctx = tvm.cpu(0)
-    from tvm.contrib import graph_runtime
     m = graph_runtime.create(graph, lib, ctx)
     # set inputs
     if isinstance(input_data, list):
         for i, e in enumerate(input_names):
-            m.set_input(input_names[i], tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
+            m.set_input(input_names[i], tvm.nd.array(
+                input_data[i].astype(input_data[i].dtype)))
     else:
-        m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
+        m.set_input(input_names, tvm.nd.array(
+            input_data.astype(input_data.dtype)))
 
     m.set_input(**params)
     # execute
@@ -76,6 +76,7 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
         tvm_output = m.get_output(0)
         return tvm_output.asnumpy()
 
+
 def get_caffe2_output(model, x, dtype='float32'):
     import caffe2.python.onnx.backend
     prepared_backend = caffe2.python.onnx.backend.prepare(model)
@@ -93,15 +94,20 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
         tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
         tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def verify_super_resolution_example():
-    verify_onnx_forward_impl(super_resolution, (1, 1, 224, 224), (1, 1, 672, 672))
+    verify_onnx_forward_impl(
+        super_resolution, (1, 1, 224, 224), (1, 1, 672, 672))
+
 
 def verify_squeezenet1_1():
     verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000))
 
+
 def verify_lenet():
     verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10))
 
+
 def verify_resnet18():
     verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000))
 
@@ -112,20 +118,20 @@ def test_reshape():
 
     ref_array = np.array(ref_shape)
     ref_node = onnx.helper.make_node('Constant',
-                                 inputs=[],
-                                 outputs=['ref_in'],
-                                 value=onnx.helper.make_tensor(name = 'const_tensor',
-                                                               data_type = onnx.TensorProto.INT32,
-                                                               dims = ref_array.shape,
-                                                               vals = ref_array.flatten().astype(int)))
+                                     inputs=[],
+                                     outputs=['ref_in'],
+                                     value=onnx.helper.make_tensor(name='const_tensor',
+                                                                   data_type=onnx.TensorProto.INT32,
+                                                                   dims=ref_array.shape,
+                                                                   vals=ref_array.flatten().astype(int)))
     reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
 
     graph = helper.make_graph([ref_node, reshape_node],
                               "reshape_test",
-                              inputs = [helper.make_tensor_value_info("in",
-                                            TensorProto.FLOAT, list(in_shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(ref_shape))])
+                              inputs=[helper.make_tensor_value_info("in",
+                                                                    TensorProto.FLOAT, list(in_shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(ref_shape))])
 
     model = helper.make_model(graph, producer_name='reshape_test')
 
@@ -135,28 +141,29 @@ def test_reshape():
 
     tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
 
+
 def test_shape():
     in_shape = (4, 3, 3, 4)
     ref_shape = (6, 2, 4, 3)
 
     ref_array = np.array(ref_shape)
     ref_node = onnx.helper.make_node('Constant',
-                                 inputs=[],
-                                 outputs=['ref_in'],
-                                 value=onnx.helper.make_tensor(name = 'const_tensor',
-                                                               data_type = onnx.TensorProto.INT32,
-                                                               dims = ref_array.shape,
-                                                               vals = ref_array.flatten().astype(int)))
+                                     inputs=[],
+                                     outputs=['ref_in'],
+                                     value=onnx.helper.make_tensor(name='const_tensor',
+                                                                   data_type=onnx.TensorProto.INT32,
+                                                                   dims=ref_array.shape,
+                                                                   vals=ref_array.flatten().astype(int)))
     reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
 
     shape_node = helper.make_node("Shape", ['out'], ['final_out'])
 
     graph = helper.make_graph([ref_node, reshape_node, shape_node],
                               "shape_test",
-                              inputs = [helper.make_tensor_value_info("in",
-                                            TensorProto.FLOAT, list(in_shape))],
-                              outputs = [helper.make_tensor_value_info("final_out",
-                                            TensorProto.FLOAT, list(ref_shape))])
+                              inputs=[helper.make_tensor_value_info("in",
+                                                                    TensorProto.FLOAT, list(in_shape))],
+                              outputs=[helper.make_tensor_value_info("final_out",
+                                                                     TensorProto.FLOAT, list(ref_shape))])
 
     model = helper.make_model(graph, producer_name='shape_test')
 
@@ -166,6 +173,7 @@ def test_shape():
 
     tvm.testing.assert_allclose(ref_shape, tvm_out)
 
+
 def _test_power_iteration(x_shape, y_shape):
     if isinstance(y_shape, int):
         y_shape = [y_shape]
@@ -179,12 +187,12 @@ def _test_power_iteration(x_shape, y_shape):
 
     graph = helper.make_graph([res],
                               'power_test',
-                              inputs = [helper.make_tensor_value_info("x",
-                                            TensorProto.FLOAT, list(x_shape)),
-                                        helper.make_tensor_value_info("y",
-                                            TensorProto.FLOAT, list(y_shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(np_res.shape))])
+                              inputs=[helper.make_tensor_value_info("x",
+                                                                    TensorProto.FLOAT, list(x_shape)),
+                                      helper.make_tensor_value_info("y",
+                                                                    TensorProto.FLOAT, list(y_shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(np_res.shape))])
 
     model = helper.make_model(graph, producer_name='power_test')
 
@@ -192,11 +200,13 @@ def _test_power_iteration(x_shape, y_shape):
         tvm_out = get_tvm_output(model, [x, y], target, ctx, np_res.shape)
         tvm.testing.assert_allclose(np_res, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def test_power():
     _test_power_iteration((1, 3), (1))
     _test_power_iteration((2, 3), (2, 3))
     _test_power_iteration((2, 3), (1, 3))
 
+
 def test_squeeze():
     in_shape = (1, 3, 1, 3, 1, 1)
     out_shape = (3, 3)
@@ -204,10 +214,10 @@ def test_squeeze():
 
     graph = helper.make_graph([y],
                               'squeeze_test',
-                              inputs = [helper.make_tensor_value_info("in",
-                                            TensorProto.FLOAT, list(in_shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(out_shape))])
+                              inputs=[helper.make_tensor_value_info("in",
+                                                                    TensorProto.FLOAT, list(in_shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(out_shape))])
 
     model = helper.make_model(graph, producer_name='squeeze_test')
 
@@ -217,20 +227,21 @@ def test_squeeze():
 
     tvm.testing.assert_allclose(out_shape, tvm_out.shape)
 
+
 def test_flatten():
 
     in_shape = (1, 3, 4, 4)
     axis = 1
     ref_shape = (1, 48)
 
-    flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis = axis)
+    flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis=axis)
 
     graph = helper.make_graph([flatten_node],
                               "flatten_test",
-                              inputs = [helper.make_tensor_value_info("in",
-                                            TensorProto.FLOAT, list(in_shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(ref_shape))])
+                              inputs=[helper.make_tensor_value_info("in",
+                                                                    TensorProto.FLOAT, list(in_shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(ref_shape))])
 
     model = helper.make_model(graph, producer_name='flatten_test')
 
@@ -240,6 +251,7 @@ def test_flatten():
 
     tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
 
+
 def test_unsqueeze():
     in_shape = (3, 3)
     axis = (0, 3, 4)
@@ -248,10 +260,10 @@ def test_unsqueeze():
 
     graph = helper.make_graph([y],
                               'squeeze_test',
-                              inputs = [helper.make_tensor_value_info("in",
-                                            TensorProto.FLOAT, list(in_shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(out_shape))])
+                              inputs=[helper.make_tensor_value_info("in",
+                                                                    TensorProto.FLOAT, list(in_shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(out_shape))])
 
     model = helper.make_model(graph, producer_name='squeeze_test')
 
@@ -261,6 +273,7 @@ def test_unsqueeze():
 
     tvm.testing.assert_allclose(out_shape, tvm_out.shape)
 
+
 def verify_gather(in_shape, indices, axis, dtype):
     x = np.random.uniform(size=in_shape).astype(dtype)
     indices = np.array(indices, dtype="int32")
@@ -270,52 +283,123 @@ def verify_gather(in_shape, indices, axis, dtype):
 
     graph = helper.make_graph([y],
                               'gather_test',
-                              inputs = [helper.make_tensor_value_info("in",
-                                            TensorProto.FLOAT, list(in_shape)),
-                                        helper.make_tensor_value_info("indices",
-                                            TensorProto.INT32, list(indices.shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(out_np.shape))])
+                              inputs=[helper.make_tensor_value_info("in",
+                                                                    TensorProto.FLOAT, list(in_shape)),
+                                      helper.make_tensor_value_info("indices",
+                                                                    TensorProto.INT32, list(indices.shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(out_np.shape))])
     model = helper.make_model(graph, producer_name='gather_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape)
+        tvm_out = get_tvm_output(
+            model, [x, indices], target, ctx, out_np.shape)
         tvm.testing.assert_allclose(out_np, tvm_out)
 
+
 def test_gather():
     verify_gather((4,), [1], 0, 'int32')
-    verify_gather((1,4), [0], 0, 'int32')
-    verify_gather((4,), [[[1,0],[0,1]]], 0, 'float32')
-    verify_gather((2,2), [[[1,0],[0,1]]], 1, 'int32')
-    verify_gather((3,3,3), [[[1,0]]], -1, 'int32')
-    verify_gather((4,3,5,6), [[2,1,0,0]], 0, 'float32')
+    verify_gather((1, 4), [0], 0, 'int32')
+    verify_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
+    verify_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
+    verify_gather((3, 3, 3), [[[1, 0]]], -1, 'int32')
+    verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
+
 
-def _test_slice_iteration(indata, outdata, starts, ends, axes=None):
+def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None):
     if axes:
-        y = helper.make_node("Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends)
+        y = helper.make_node(
+            "Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends)
     else:
-        y = helper.make_node("Slice", ['in'], ['out'], starts=starts, ends=ends)
+        y = helper.make_node(
+            "Slice", ['in'], ['out'], starts=starts, ends=ends)
 
     graph = helper.make_graph([y],
                               'slice_test',
-                              inputs = [helper.make_tensor_value_info("in",
-                                            TensorProto.FLOAT, list(indata.shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(outdata.shape))])
+                              inputs=[helper.make_tensor_value_info("in",
+                                                                    TensorProto.FLOAT, list(indata.shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(outdata.shape))])
 
     model = helper.make_model(graph, producer_name='slice_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
+        tvm_out = get_tvm_output(
+            model, indata, target, ctx, outdata.shape, 'float32', opset=1)
 
     tvm.testing.assert_allclose(outdata, tvm_out)
 
+
+def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None):
+    if isinstance(starts, int):
+        starts = (starts, )
+    if isinstance(ends, int):
+        ends = (ends, )
+    if isinstance(axes, int):
+        axes = (axes, )
+    starts = np.asarray(starts)
+    ends = np.asarray(ends)
+    inputs = [
+        helper.make_tensor_value_info("data", TensorProto.FLOAT,
+                                      list(indata.shape)),
+        helper.make_tensor_value_info("starts", TensorProto.INT32,
+                                      list(starts.shape)),
+        helper.make_tensor_value_info("ends", TensorProto.INT32,
+                                      list(ends.shape))
+    ]
+    initializer = [
+        helper.make_tensor("starts", TensorProto.INT32, list(starts.shape),
+                           starts),
+        helper.make_tensor("ends", TensorProto.INT32, list(ends.shape), ends)
+    ]
+
+    if axes:
+        axes = np.asarray(axes)
+        y = helper.make_node("Slice", ["data", "starts", "ends", "axes"],
+                             ["out"])
+        inputs.append(
+            helper.make_tensor_value_info("axes", TensorProto.INT32,
+                                          list(axes.shape)))
+        initializer.append(
+            helper.make_tensor("axes", TensorProto.INT32, list(axes.shape),
+                               axes))
+    else:
+        y = helper.make_node("Slice", ["data", "starts", "ends"], ["out"])
+
+    graph = helper.make_graph([y],
+                              'slice_test',
+                              inputs=inputs,
+                              outputs=[
+                                  helper.make_tensor_value_info(
+                                      "out", TensorProto.FLOAT,
+                                      list(outdata.shape))
+                              ],
+                              initializer=initializer)
+    model = helper.make_model(graph, producer_name='slice_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model,
+                                 indata,
+                                 target,
+                                 ctx,
+                                 outdata.shape,
+                                 'float32',
+                                 opset=10)
+
+    tvm.testing.assert_allclose(outdata, tvm_out)
+
+
 def test_slice():
     x = np.random.randn(20, 10, 5).astype(np.float32)
-    _test_slice_iteration(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
-    _test_slice_iteration(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
-    _test_slice_iteration(x, x[:, 1:1000], (1), (1000), (1))
-    _test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1))
+    _test_slice_iteration_v1(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
+    _test_slice_iteration_v1(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
+    _test_slice_iteration_v1(x, x[:, 1:1000], (1), (1000), (1))
+    _test_slice_iteration_v1(x, x[:, 0:-1], (0), (-1), (1))
+    _test_slice_iteration_v10(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
+    _test_slice_iteration_v10(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
+    _test_slice_iteration_v10(x, x[:, 1:1000], (1), (1000), (1))
+    _test_slice_iteration_v10(x, x[:, 0:-1], (0), (-1), (1))
+
 
 def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
     indata = np.random.uniform(-1, 1, size=inshape).astype(dtype)
@@ -325,24 +409,29 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
 
     graph = helper.make_graph([y],
                               opname+'_test',
-                              inputs = [helper.make_tensor_value_info("in",
-                                            TensorProto.FLOAT, list(indata.shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(outdata.shape))])
+                              inputs=[helper.make_tensor_value_info("in",
+                                                                    TensorProto.FLOAT, list(indata.shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(outdata.shape))])
 
     model = helper.make_model(graph, producer_name=opname+'_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)
+        tvm_out = get_tvm_output(
+            model, indata, target, ctx, outdata.shape, dtype)
 
     tvm.testing.assert_allclose(outdata, tvm_out)
 
+
 def test_floor():
-    _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, 'float32', 'Floor', {})
+    _test_onnx_op_elementwise((2, 4, 5, 6), np.floor,
+                              {}, 'float32', 'Floor', {})
+
 
 def test_ceil():
     _test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, 'float32', 'Ceil', {})
 
+
 def test_clip():
     _test_onnx_op_elementwise((2, 4, 5, 6),
                               np.clip,
@@ -351,6 +440,38 @@ def test_clip():
                               'Clip',
                               {'min': -1.0, 'max': 1.0})
 
+
+def test_onehot():
+    indices_shape = [10]
+    indices_array = np.random.randint(
+        low=0, high=9, size=indices_shape, dtype='int32')
+    depth = 10
+    values = np.asarray([0, 1])
+    out_np = np.eye(depth)[indices_array.reshape(-1)]
+
+    onehot_node = helper.make_node(
+        "OneHot", ["indices", "depth", "values"], ["out"])
+
+    graph = helper.make_graph([onehot_node],
+                              "onehot_test",
+                              inputs=[helper.make_tensor_value_info("indices",
+                                                                    TensorProto.INT32, indices_shape),
+                                      helper.make_tensor_value_info("depth",
+                                                                    TensorProto.INT32, [1]),
+                                      helper.make_tensor_value_info("values",
+                                                                    TensorProto.INT32, values.shape)],
+                              initializer=[helper.make_tensor("depth", TensorProto.INT32, [1], [depth]),
+                                           helper.make_tensor("values", TensorProto.INT32, values.shape, values)],
+                              outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, out_np.shape)])
+
+    model = helper.make_model(graph, producer_name="onehot_test")
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(
+            model, [indices_array], target, ctx, out_np.shape)
+        tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+
 def test_matmul():
     a_shape = (4, 3)
     b_shape = (3, 4)
@@ -363,52 +484,84 @@ def test_matmul():
 
     graph = helper.make_graph([mul_node],
                               "matmul_test",
-                              inputs = [helper.make_tensor_value_info("a",
-                                            TensorProto.FLOAT, list(a_shape)),
-                                        helper.make_tensor_value_info("b",
-                                            TensorProto.FLOAT, list(b_shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(out_np.shape))])
+                              inputs=[helper.make_tensor_value_info("a",
+                                                                    TensorProto.FLOAT, list(a_shape)),
+                                      helper.make_tensor_value_info("b",
+                                                                    TensorProto.FLOAT, list(b_shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(out_np.shape))])
+
+    model = helper.make_model(graph, producer_name='matmul_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(
+            model, [a_array, b_array], target, ctx, out_np.shape)
+        tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+
+def test_batch_matmul():
+    a_shape = (2, 3, 4, 3)
+    b_shape = (2, 3, 3, 4)
+
+    a_array = np.random.uniform(size=a_shape).astype('float32')
+    b_array = np.random.uniform(size=b_shape).astype('float32')
+    out_np = np.matmul(a_array, b_array)
+
+    mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
+
+    graph = helper.make_graph([mul_node],
+                              "matmul_test",
+                              inputs=[helper.make_tensor_value_info("a",
+                                                                    TensorProto.FLOAT, list(a_shape)),
+                                      helper.make_tensor_value_info("b",
+                                                                    TensorProto.FLOAT, list(b_shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(out_np.shape))])
 
     model = helper.make_model(graph, producer_name='matmul_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape)
+        tvm_out = get_tvm_output(
+            model, [a_array, b_array], target, ctx, out_np.shape)
         tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
     in_array = np.random.uniform(size=shape).astype(dtype)
 
-    if alpha == None and beta == None and bias==None:
+    if alpha == None and beta == None and bias == None:
         alpha = 0.0001
         beta = 0.75
         bias = 1.0
-        node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], size=nsize)
+        node = onnx.helper.make_node(
+            'LRN', inputs=['in'], outputs=['out'], size=nsize)
     else:
         node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha,
                                      beta=beta, bias=bias, size=nsize)
 
     graph = helper.make_graph([node],
                               "lrn_test",
-                              inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))],
-                              outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))])
+                              inputs=[helper.make_tensor_value_info(
+                                  "in", TensorProto.FLOAT, list(shape))],
+                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))])
     model = helper.make_model(graph, producer_name='lrn_test')
 
     def _get_python_lrn():
         square_sum = np.zeros(shape).astype(dtype)
         for n, c, h, w in np.ndindex(in_array.shape):
             square_sum[n, c, h, w] = sum(in_array[n,
-                                         max(0, c - int(math.floor((nsize - 1) / 2))): \
-                                             min(5, c + int(math.ceil((nsize - 1) / 2)) + 1),
-                                         h,
-                                         w] ** 2)
+                                                  max(0, c - int(math.floor((nsize - 1) / 2))):
+                                                  min(5, c + int(math.ceil((nsize - 1) / 2)) + 1),
+                                                  h,
+                                                  w] ** 2)
         py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta)
         return py_out
 
     for target, ctx in ctx_list():
         input_name = model.graph.input[0].name
         py_out = _get_python_lrn()
-        tvm_out = get_tvm_output(model, in_array, target, ctx, py_out.shape, 'float32')
+        tvm_out = get_tvm_output(
+            model, in_array, target, ctx, py_out.shape, 'float32')
         tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -436,20 +589,22 @@ def verify_instance_norm(shape, axis=1):
     y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32)
 
     node = onnx.helper.make_node(
-            'InstanceNormalization',
-            inputs=['x', 'gamma', 'beta'],
-            outputs=['y'],
-            epsilon=epsilon,
-        )
+        'InstanceNormalization',
+        inputs=['x', 'gamma', 'beta'],
+        outputs=['y'],
+        epsilon=epsilon,
+    )
     graph = helper.make_graph([node],
                               "instance_norm_test",
                               inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)),
-                                      helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)),
+                                      helper.make_tensor_value_info(
+                                          "gamma", TensorProto.FLOAT, (shape[1],)),
                                       helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))],
                               outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))])
     model = helper.make_model(graph, producer_name='instance_norm_test')
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, 'float32')
+        tvm_out = get_tvm_output(
+            model, [x, gamma, beta], target, ctx, shape, 'float32')
         tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -464,103 +619,122 @@ def _test_upsample_nearest():
     scale = 2
     in_shape = (1, 1, 3, 3)
     out_shape = (1, 1, 3*scale, 3*scale)
-    y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0])
+    y = helper.make_node("Upsample", ['in'], [
+                         'out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0])
 
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = topi.testing.upsampling_python(in_array, (scale, scale), "NCHW")
+    out_array = topi.testing.upsampling_python(
+        in_array, (scale, scale), "NCHW")
 
     graph = helper.make_graph([y],
                               'upsample_nearest_test',
-                              inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
-                              outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
+                              inputs=[helper.make_tensor_value_info(
+                                  "in", TensorProto.FLOAT, list(in_shape))],
+                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
 
     model = helper.make_model(graph, producer_name='upsample_nearest_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32')
+        tvm_out = get_tvm_output(
+            model, in_array, target, ctx, out_shape, 'float32')
         tvm.testing.assert_allclose(out_array, tvm_out)
 
+
 def _test_upsample_bilinear():
     scale = 2
     in_shape = (1, 1, 3, 3)
     out_shape = (1, 1, 3*scale, 3*scale)
-    y = helper.make_node("Upsample", ['in'], ['out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0])
+    y = helper.make_node("Upsample", ['in'], [
+                         'out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0])
 
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW")
+    out_array = topi.testing.bilinear_resize_python(
+        in_array, (3*scale, 3*scale), "NCHW")
 
     graph = helper.make_graph([y],
                               'upsample_bilinear_test',
-                              inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
-                              outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
+                              inputs=[helper.make_tensor_value_info(
+                                  "in", TensorProto.FLOAT, list(in_shape))],
+                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
 
     model = helper.make_model(graph, producer_name='upsample_bilinear_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32')
+        tvm_out = get_tvm_output(
+            model, in_array, target, ctx, out_shape, 'float32')
         tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def _test_upsample_bilinear_opset9():
     scale = 2
     in_shape = (1, 1, 3, 3)
     out_shape = (1, 1, 3*scale, 3*scale)
-    y = helper.make_node("Upsample", ['in','scales'], ['out'], mode='linear')
-    scales=[1.0, 1.0, 2.0, 2.0]
+    y = helper.make_node("Upsample", ['in', 'scales'], ['out'], mode='linear')
+    scales = [1.0, 1.0, 2.0, 2.0]
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW")
+    out_array = topi.testing.bilinear_resize_python(
+        in_array, (3*scale, 3*scale), "NCHW")
 
     ref_array = np.array(scales)
     ref_node = helper.make_node('Constant',
-                                 inputs=[],
-                                 outputs=['scales'],
-                                 value=onnx.helper.make_tensor(name = 'const_tensor',
-                                                               data_type = TensorProto.FLOAT,
-                                                               dims = ref_array.shape,
-                                                               vals = ref_array.flatten().astype(float)))
+                                inputs=[],
+                                outputs=['scales'],
+                                value=onnx.helper.make_tensor(name='const_tensor',
+                                                              data_type=TensorProto.FLOAT,
+                                                              dims=ref_array.shape,
+                                                              vals=ref_array.flatten().astype(float)))
 
     graph = helper.make_graph([ref_node, y],
                               'upsample_bilinear_opset9_test',
-                              inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
-                              outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
+                              inputs=[helper.make_tensor_value_info(
+                                  "in", TensorProto.FLOAT, list(in_shape))],
+                              outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
 
-    model = helper.make_model(graph, producer_name='upsample_bilinear_opset9_test')
+    model = helper.make_model(
+        graph, producer_name='upsample_bilinear_opset9_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32')
+        tvm_out = get_tvm_output(
+            model, in_array, target, ctx, out_shape, 'float32')
         tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def test_upsample():
     _test_upsample_nearest()
     _test_upsample_bilinear()
     _test_upsample_bilinear_opset9()
 
+
 def _test_softmax(inshape, axis):
     opname = 'Softmax'
     indata = np.random.uniform(size=inshape).astype(np.float32)
     outshape = inshape
     outdata = topi.testing.softmax_python(indata)
     if isinstance(axis, int):
-        y = helper.make_node(opname, ['in'], ['out'], axis = axis)
+        y = helper.make_node(opname, ['in'], ['out'], axis=axis)
     elif axis is None:
         y = helper.make_node(opname, ['in'], ['out'])
 
     graph = helper.make_graph([y],
                               opname+'_test',
-                              inputs = [helper.make_tensor_value_info("in",
-                                            TensorProto.FLOAT, list(indata.shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(outdata.shape))])
+                              inputs=[helper.make_tensor_value_info("in",
+                                                                    TensorProto.FLOAT, list(indata.shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(outdata.shape))])
 
     model = helper.make_model(graph, producer_name=opname+'_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outshape, 'float32')
+        tvm_out = get_tvm_output(
+            model, indata, target, ctx, outshape, 'float32')
         tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def test_softmax():
     _test_softmax((1, 10), None)
     _test_softmax((1, 10), 1)
 
+
 def verify_min(input_dim):
     dtype = 'float32'
 
@@ -574,25 +748,28 @@ def verify_min(input_dim):
 
     graph = helper.make_graph([min_node],
                               "Min_test",
-                              inputs = [helper.make_tensor_value_info("a_np1",
-                                            TensorProto.FLOAT, list(input_dim)),
-                                        helper.make_tensor_value_info("a_np2",
-                                            TensorProto.FLOAT, list(input_dim)),
-                                        helper.make_tensor_value_info("a_np3",
-                                            TensorProto.FLOAT, list(input_dim))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(b_np.shape))])
+                              inputs=[helper.make_tensor_value_info("a_np1",
+                                                                    TensorProto.FLOAT, list(input_dim)),
+                                      helper.make_tensor_value_info("a_np2",
+                                                                    TensorProto.FLOAT, list(input_dim)),
+                                      helper.make_tensor_value_info("a_np3",
+                                                                    TensorProto.FLOAT, list(input_dim))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(b_np.shape))])
 
     model = helper.make_model(graph, producer_name='Min_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
+        tvm_out = get_tvm_output(
+            model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def test_forward_min():
     verify_min((1, 3, 20, 20))
     verify_min((20, 20))
 
+
 def verify_max(input_dim):
     dtype = 'float32'
 
@@ -606,25 +783,28 @@ def verify_max(input_dim):
 
     graph = helper.make_graph([max_node],
                               "Max_test",
-                              inputs = [helper.make_tensor_value_info("a_np1",
-                                            TensorProto.FLOAT, list(input_dim)),
-                                        helper.make_tensor_value_info("a_np2",
-                                            TensorProto.FLOAT, list(input_dim)),
-                                        helper.make_tensor_value_info("a_np3",
-                                            TensorProto.FLOAT, list(input_dim))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(b_np.shape))])
+                              inputs=[helper.make_tensor_value_info("a_np1",
+                                                                    TensorProto.FLOAT, list(input_dim)),
+                                      helper.make_tensor_value_info("a_np2",
+                                                                    TensorProto.FLOAT, list(input_dim)),
+                                      helper.make_tensor_value_info("a_np3",
+                                                                    TensorProto.FLOAT, list(input_dim))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(b_np.shape))])
 
     model = helper.make_model(graph, producer_name='Max_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
+        tvm_out = get_tvm_output(
+            model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def test_forward_max():
     verify_max((1, 3, 20, 20))
     verify_max((20, 20))
 
+
 def verify_mean(input_dim):
     dtype = 'float32'
 
@@ -638,25 +818,28 @@ def verify_mean(input_dim):
 
     graph = helper.make_graph([mean_node],
                               "Mean_test",
-                              inputs = [helper.make_tensor_value_info("a_np1",
-                                            TensorProto.FLOAT, list(input_dim)),
-                                        helper.make_tensor_value_info("a_np2",
-                                            TensorProto.FLOAT, list(input_dim)),
-                                        helper.make_tensor_value_info("a_np3",
-                                            TensorProto.FLOAT, list(input_dim))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(b_np.shape))])
+                              inputs=[helper.make_tensor_value_info("a_np1",
+                                                                    TensorProto.FLOAT, list(input_dim)),
+                                      helper.make_tensor_value_info("a_np2",
+                                                                    TensorProto.FLOAT, list(input_dim)),
+                                      helper.make_tensor_value_info("a_np3",
+                                                                    TensorProto.FLOAT, list(input_dim))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(b_np.shape))])
 
     model = helper.make_model(graph, producer_name='Mean_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
+        tvm_out = get_tvm_output(
+            model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def test_forward_mean():
     verify_mean((1, 3, 20, 20))
     verify_mean((20, 20))
 
+
 def verify_hardsigmoid(input_dim, alpha, beta):
     dtype = 'float32'
 
@@ -664,14 +847,15 @@ def verify_hardsigmoid(input_dim, alpha, beta):
 
     b_np = np.clip(a_np1 * alpha + beta, 0, 1)
 
-    hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta)
+    hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], [
+                                        "out"], alpha=alpha, beta=beta)
 
     graph = helper.make_graph([hardsigmoid_node],
                               "HardSigmoid_test",
-                              inputs = [helper.make_tensor_value_info("a_np1",
-                                            TensorProto.FLOAT, list(input_dim))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(b_np.shape))])
+                              inputs=[helper.make_tensor_value_info("a_np1",
+                                                                    TensorProto.FLOAT, list(input_dim))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.FLOAT, list(b_np.shape))])
 
     model = helper.make_model(graph, producer_name='HardSigmoid_test')
 
@@ -679,10 +863,12 @@ def verify_hardsigmoid(input_dim, alpha, beta):
         tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def test_forward_hardsigmoid():
     verify_hardsigmoid((1, 3, 20, 20), 0.5, 0.6)
     verify_hardsigmoid((20, 20), 0.3, 0.4)
 
+
 def verify_argmin(input_dim, axis=None, keepdims=None):
     def _argmin_numpy(data, axis=0, keepdims=True):
         result = np.argmin(data, axis=axis)
@@ -717,17 +903,19 @@ def verify_argmin(input_dim, axis=None, keepdims=None):
                                      keepdims=keepdims)
     graph = helper.make_graph([node],
                               "argmin_test",
-                              inputs = [helper.make_tensor_value_info("a_np1",
-                                            TensorProto.INT32, list(a_np1.shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.INT32, list(b_np.shape))])
+                              inputs=[helper.make_tensor_value_info("a_np1",
+                                                                    TensorProto.INT32, list(a_np1.shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.INT32, list(b_np.shape))])
 
     model = helper.make_model(graph, producer_name='argmin_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
+        tvm_out = get_tvm_output(
+            model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def verify_argmax(input_dim, axis=None, keepdims=None):
     def _argmax_numpy(data, axis=0, keepdims=True):
         result = np.argmax(data, axis=axis)
@@ -763,66 +951,72 @@ def verify_argmax(input_dim, axis=None, keepdims=None):
 
     graph = helper.make_graph([node],
                               "argmax_test",
-                              inputs = [helper.make_tensor_value_info("a_np1",
-                                            TensorProto.INT32, list(a_np1.shape))],
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.INT32, list(b_np.shape))])
+                              inputs=[helper.make_tensor_value_info("a_np1",
+                                                                    TensorProto.INT32, list(a_np1.shape))],
+                              outputs=[helper.make_tensor_value_info("out",
+                                                                     TensorProto.INT32, list(b_np.shape))])
 
     model = helper.make_model(graph, producer_name='argmax_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
+        tvm_out = get_tvm_output(
+            model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
         tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def test_forward_arg_min_max():
     '''Verify argmin and argmax'''
-    verify_argmin([3,4,4])
-    verify_argmax([3,4,4])
-    verify_argmin([3,4,4], axis=1)
-    verify_argmax([3,4,4], axis=0)
-    verify_argmin([3,4,4], keepdims=0)
-    verify_argmax([3,4,4], keepdims=1)
-    for axis in [None, 0,1,2]:
-        for keepdims in [None, True,False]:
-            verify_argmin([3,4,4], axis, keepdims)
-            verify_argmax([3,4,4], axis, keepdims)
-
-def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs):
-    input_a = np.random.uniform(size=input_dim).astype(dtype)
-    out = np.empty(shape=out_dim, dtype=dtype)
+    verify_argmin([3, 4, 4])
+    verify_argmax([3, 4, 4])
+    verify_argmin([3, 4, 4], axis=1)
+    verify_argmax([3, 4, 4], axis=0)
+    verify_argmin([3, 4, 4], keepdims=0)
+    verify_argmax([3, 4, 4], keepdims=1)
+    for axis in [None, 0, 1, 2]:
+        for keepdims in [None, True, False]:
+            verify_argmin([3, 4, 4], axis, keepdims)
+            verify_argmax([3, 4, 4], axis, keepdims)
+
+
+def verify_constantofshape(input_dim, value, dtype):
+    out = np.empty(shape=input_dim, dtype=dtype)
     out.fill(value)
 
-    if is_shape == True:
-        fill_node = helper.make_node("ConstantFill", [], ["out"], shape=input_dim, value=value, **kwargs)
-    else:
-        fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs)
-
-    if is_shape == True:
-        inputs = []
-    else:
-        inputs = [helper.make_tensor_value_info("input_a",
-                  TensorProto.FLOAT, list(input_dim))]
-
-    graph = helper.make_graph([fill_node],
-                              "fill_test",
-                              inputs,
-                              outputs = [helper.make_tensor_value_info("out",
-                                            TensorProto.FLOAT, list(out.shape))])
+    fill_node = helper.make_node("ConstantOfShape", ["input"], ["output"],
+                                 value=helper.make_tensor(
+                                     'value',
+                                     mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)],
+                                     (1, ), (value, )))
+
+    inputs = [
+        helper.make_tensor_value_info("input", TensorProto.FLOAT, input_dim)
+    ]
+
+    graph = helper.make_graph(
+        [fill_node],
+        "fill_test",
+        inputs,
+        outputs=[
+            helper.make_tensor_value_info("output", TensorProto.FLOAT,
+                                          list(out.shape))
+        ],
+        initializer=[
+            helper.make_tensor("input", TensorProto.INT32, (len(input_dim), ),
+                               input_dim)
+        ])
 
     model = helper.make_model(graph, producer_name='fill_test')
 
     for target, ctx in ctx_list():
-        if is_shape == True:
-            tvm_out = get_tvm_output(model, [], target, ctx, out.shape)
-        else:
-            tvm_out = get_tvm_output(model, [input_a], target, ctx, out.shape)
+        tvm_out = get_tvm_output(model, [], target, ctx, out.shape)
 
         tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5)
 
-def test_constantfill():
-    verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
-    verify_constantfill(False, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
-    verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6))
+
+def test_constantofshape():
+    verify_constantofshape((2, 3, 4, 5), 10, 'float32')
+    verify_constantofshape((3, 3), 0, 'int32')
+    verify_constantofshape((1, 2, 3), -1, 'float32')
 
 
 def verify_pad(indata, pads, mode='constant', value=0.0):
@@ -841,7 +1035,8 @@ def verify_pad(indata, pads, mode='constant', value=0.0):
             pads=pads,
         )
     else:
-        outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value)
+        outdata = np.pad(indata, pad_width=np_pads,
+                         mode='constant', constant_values=value)
         node = helper.make_node(
             'Pad',
             inputs=['input'],
@@ -852,22 +1047,30 @@ def verify_pad(indata, pads, mode='constant', value=0.0):
         )
     graph = helper.make_graph([node],
                               'pad_test',
-                              inputs = [helper.make_tensor_value_info("input",
-                                            TensorProto.FLOAT, list(indata.shape))],
-                              outputs = [helper.make_tensor_value_info("output",
-                                            TensorProto.FLOAT, list(outdata.shape))])
+                              inputs=[helper.make_tensor_value_info("input",
+                                                                    TensorProto.FLOAT, list(indata.shape))],
+                              outputs=[helper.make_tensor_value_info("output",
+                                                                     TensorProto.FLOAT, list(outdata.shape))])
     model = helper.make_model(graph, producer_name='pad_test')
     #  tvm result
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
+        tvm_out = get_tvm_output(
+            model, indata, target, ctx, outdata.shape, 'float32')
     tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def test_pad():
-    verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 'constant', 0.0)
-    verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 'constant', 0.0)
-    verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 'constant', 5.0)
-    verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge')
-    verify_pad(np.random.randn(1, 3, 4, 5).astype(np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect')
+    verify_pad(np.random.randn(2, 2).astype(
+        np.float32), [0, 1, 0, 0], 'constant', 0.0)
+    verify_pad(np.random.randn(2, 3).astype(
+        np.float32), [1, 0, 0, 1], 'constant', 0.0)
+    verify_pad(np.random.randn(3, 2).astype(
+        np.float32), [0, 0, 1, 0], 'constant', 5.0)
+    verify_pad(np.random.randn(1, 3, 4, 5).astype(
+        np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge')
+    verify_pad(np.random.randn(1, 3, 4, 5).astype(
+        np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect')
+
 
 def verify_reduce_x(name, indata, axis, keepdims):
     indata = np.array(indata).astype(np.float32)
@@ -893,16 +1096,18 @@ def verify_reduce_x(name, indata, axis, keepdims):
                                 axes=axis, keepdims=keepdims)
     graph = helper.make_graph([node],
                               '{}_test'.format(name),
-                              inputs = [helper.make_tensor_value_info("input",
-                                            TensorProto.FLOAT, list(indata.shape))],
-                              outputs = [helper.make_tensor_value_info("output",
-                                            TensorProto.FLOAT, list(outdata.shape))])
+                              inputs=[helper.make_tensor_value_info("input",
+                                                                    TensorProto.FLOAT, list(indata.shape))],
+                              outputs=[helper.make_tensor_value_info("output",
+                                                                     TensorProto.FLOAT, list(outdata.shape))])
     model = helper.make_model(graph, producer_name='{}_test'.format(name))
     #  tvm result
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
+        tvm_out = get_tvm_output(
+            model, indata, target, ctx, outdata.shape, 'float32')
     tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
 
+
 def test_reduce_max():
     verify_reduce_x("ReduceMax",
                     np.random.randn(3, 2, 2).astype(np.float32),
@@ -914,6 +1119,7 @@ def test_reduce_max():
                     np.random.randn(3, 3, 3).astype(np.float32),
                     axis=(1,), keepdims=1)
 
+
 def test_reduce_min():
     verify_reduce_x("ReduceMin",
                     np.random.randn(3, 2, 2).astype(np.float32),
@@ -925,6 +1131,7 @@ def test_reduce_min():
                     np.random.randn(3, 3, 3).astype(np.float32),
                     axis=(1,), keepdims=1)
 
+
 def test_reduce_sum():
     verify_reduce_x("ReduceSum",
                     np.random.randn(3, 2, 2).astype(np.float32),
@@ -936,6 +1143,7 @@ def test_reduce_sum():
                     np.random.randn(3, 3, 3).astype(np.float32),
                     axis=(1,), keepdims=1)
 
+
 def test_reduce_mean():
     verify_reduce_x("ReduceMean",
                     np.random.randn(3, 2, 2).astype(np.float32),
@@ -947,40 +1155,52 @@ def test_reduce_mean():
                     np.random.randn(3, 3, 3).astype(np.float32),
                     axis=(1,), keepdims=1)
 
+
 def verify_split(indata, outdatas, split, axis=0):
     indata = np.array(indata).astype(np.float32)
     outdatas = [np.array(o).astype(np.float32) for o in outdatas]
+    if split:
+        split_index = range(len(split))
+    else:
+        split_index = range(len(outdatas))
     node = helper.make_node(
         'Split',
         inputs=['input'],
-        outputs=['output_{}'.format(i) for i in range(len(split))],
+        outputs=['output_{}'.format(i) for i in range(len(split_index))],
         axis=axis,
         split=split
     )
     graph = helper.make_graph([node],
                               'split_test',
-                              inputs = [helper.make_tensor_value_info("input",
-                                            TensorProto.FLOAT, list(indata.shape))],
-                              outputs = [helper.make_tensor_value_info("output_{}".format(i),
-                                            TensorProto.FLOAT, list(outdatas[i].shape))
-                                            for i in range(len(split))
-                                         ])
+                              inputs=[helper.make_tensor_value_info("input",
+                                                                    TensorProto.FLOAT, list(indata.shape))],
+                              outputs=[helper.make_tensor_value_info("output_{}".format(i),
+                                                                     TensorProto.FLOAT, list(outdatas[i].shape))
+                                       for i in range(len(split_index))
+                                       ])
     model = helper.make_model(graph, producer_name='split_test')
 
     for target, ctx in ctx_list():
         output_shape = [o.shape for o in outdatas]
         output_type = ['float32', 'float32', 'float32']
-        tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type)
+        tvm_out = get_tvm_output(
+            model, indata, target, ctx, output_shape, output_type)
     for o, t in zip(outdatas, tvm_out):
         tvm.testing.assert_allclose(o, t)
 
+
 def test_split():
     # 1D
-    verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0)
-    verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0)
+    verify_split([1., 2., 3., 4., 5., 6.], [
+                 [1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0)
+    verify_split([1., 2., 3., 4., 5., 6.], [
+                 [1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0)
     # 2D
     verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]],
                  [[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1)
+    # Split evenly (unstack)
+    verify_split([1, 2, 3], [[1], [2], [3]], False)
+
 
 def test_binary_ops():
     in_shape = (1, 2, 3, 3)
@@ -993,13 +1213,13 @@ def test_binary_ops():
         else:
             z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1)
         graph = helper.make_graph([z],
-                                   '_test',
-                                  inputs = [helper.make_tensor_value_info("in1",
-                                                TensorProto.FLOAT, list(in_shape)),
-                                            helper.make_tensor_value_info("in2",
-                                                TensorProto.FLOAT, list(in_shape))],
-                                  outputs = [helper.make_tensor_value_info("out",
-                                                TensorProto.FLOAT, list(out_shape))])
+                                  '_test',
+                                  inputs=[helper.make_tensor_value_info("in1",
+                                                                        TensorProto.FLOAT, list(in_shape)),
+                                          helper.make_tensor_value_info("in2",
+                                                                        TensorProto.FLOAT, list(in_shape))],
+                                  outputs=[helper.make_tensor_value_info("out",
+                                                                         TensorProto.FLOAT, list(out_shape))])
         model = helper.make_model(graph, producer_name='_test')
         for target, ctx in ctx_list():
             tvm_out = get_tvm_output(model, [x, y], target, ctx)
@@ -1008,11 +1228,11 @@ def test_binary_ops():
     x = np.random.uniform(size=in_shape).astype(dtype)
     y = np.random.uniform(size=in_shape).astype(dtype)
     z = np.random.uniform(size=(3,)).astype(dtype)
-    verify_binary_ops("Add",x, y, x + y, broadcast=None)
+    verify_binary_ops("Add", x, y, x + y, broadcast=None)
     verify_binary_ops("Add", x, z,  x + z, broadcast=True)
     verify_binary_ops("Sub", x, y, x - y, broadcast=None)
     verify_binary_ops("Sub", x, z, x - z, broadcast=True)
-    verify_binary_ops("Mul",x, y, x * y, broadcast=None)
+    verify_binary_ops("Mul", x, y, x * y, broadcast=None)
     verify_binary_ops("Mul", x, z,  x * z, broadcast=True)
     verify_binary_ops("Div", x, y, x / y, broadcast=None)
     verify_binary_ops("Div", x, z, x / z, broadcast=True)
@@ -1021,6 +1241,7 @@ def test_binary_ops():
     verify_binary_ops("Less", x, y, x < y, broadcast=True)
     verify_binary_ops("Equal", x, y, x == y, broadcast=True)
 
+
 def test_single_ops():
     in_shape = (1, 2, 3, 3)
     dtype = "float32"
@@ -1029,29 +1250,30 @@ def test_single_ops():
     def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5):
         z = helper.make_node(op, ['in1'], ['out'])
         graph = helper.make_graph([z],
-                                   '_test',
-                                  inputs = [helper.make_tensor_value_info("in1",
-                                                TensorProto.FLOAT, list(in_shape)),],
-                                  outputs = [helper.make_tensor_value_info("out",
-                                                TensorProto.FLOAT, list(out_shape))])
+                                  '_test',
+                                  inputs=[helper.make_tensor_value_info("in1",
+                                                                        TensorProto.FLOAT, list(in_shape)), ],
+                                  outputs=[helper.make_tensor_value_info("out",
+                                                                         TensorProto.FLOAT, list(out_shape))])
         model = helper.make_model(graph, producer_name='_test')
         for target, ctx in ctx_list():
             tvm_out = get_tvm_output(model, [x], target, ctx)
             tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
 
     x = np.random.uniform(size=in_shape).astype(dtype)
-    verify_single_ops("Neg",x, -x)
-    verify_single_ops("Abs",x, np.abs(x))
-    verify_single_ops("Reciprocal",x, 1/x)
-    verify_single_ops("Sqrt",x, np.sqrt(x))
-    verify_single_ops("Relu",x, np.maximum(x, 0))
-    verify_single_ops("Exp",x, np.exp(x))
-    verify_single_ops("Log",x, np.log(x))
-    verify_single_ops("Log",x, np.log(x))
-    verify_single_ops("Tanh",x, np.tanh(x))
-    verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)))
-    verify_single_ops("Softsign",x, x / (1 + np.abs(x)))
-    verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)))
+    verify_single_ops("Neg", x, -x)
+    verify_single_ops("Abs", x, np.abs(x))
+    verify_single_ops("Reciprocal", x, 1/x)
+    verify_single_ops("Sqrt", x, np.sqrt(x))
+    verify_single_ops("Relu", x, np.maximum(x, 0))
+    verify_single_ops("Exp", x, np.exp(x))
+    verify_single_ops("Log", x, np.log(x))
+    verify_single_ops("Log", x, np.log(x))
+    verify_single_ops("Tanh", x, np.tanh(x))
+    verify_single_ops("Sigmoid", x, 1 / (1 + np.exp(-x)))
+    verify_single_ops("Softsign", x, x / (1 + np.abs(x)))
+    verify_single_ops("SoftPlus", x, np.log(1 + np.exp(x)))
+
 
 def test_leaky_relu():
     def leaky_relu_x(x, alpha):
@@ -1063,6 +1285,7 @@ def test_leaky_relu():
                               'LeakyRelu',
                               {'alpha': 0.25})
 
+
 def test_elu():
     def elu_x(x, alpha):
         return np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
@@ -1073,6 +1296,7 @@ def test_elu():
                               'Elu',
                               {'alpha': 0.25})
 
+
 def test_selu():
     def selu_x(x, alpha, gamma):
         return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
@@ -1083,6 +1307,7 @@ def test_selu():
                               'Selu',
                               {'alpha': 0.25, 'gamma': 0.3})
 
+
 def test_ThresholdedRelu():
     def ThresholdedRelu_x(x, alpha):
         out_np = np.clip(x, alpha, np.inf)
@@ -1095,6 +1320,7 @@ def test_ThresholdedRelu():
                               'ThresholdedRelu',
                               {'alpha': 0.25})
 
+
 def test_ScaledTanh():
     def ScaledTanh_x(x, alpha, beta):
         return alpha * np.tanh(beta * x)
@@ -1105,6 +1331,7 @@ def test_ScaledTanh():
                               'ScaledTanh',
                               {'alpha': 0.25, 'beta': 0.3})
 
+
 def test_ParametricSoftplus():
     def ParametricSoftplus_x(x, alpha, beta):
         return alpha * np.log(np.exp(beta * x) + 1)
@@ -1115,6 +1342,7 @@ def test_ParametricSoftplus():
                               'ParametricSoftplus',
                               {'alpha': 0.25, 'beta': 0.3})
 
+
 def test_Scale():
     def Scale_x(x, scale):
         return scale * x
@@ -1125,6 +1353,7 @@ def test_Scale():
                               'Scale',
                               {'scale': 0.25})
 
+
 def test_LogSoftmax():
     _test_onnx_op_elementwise((1, 4),
                               topi.testing.log_softmax_python,
@@ -1138,7 +1367,8 @@ def check_torch_conversion(model, input_size):
     dummy_input = torch.randn(*input_size)
     file_name = '{}.onnx'.format(model.__name__)
     # Set verbose=True for more output
-    torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)
+    torch.onnx.export(model(), dummy_input, file_name,
+                      export_params=True, verbose=False)
     onnx_model = onnx.load(file_name)
     for target, ctx in ctx_list():
         input_data = np.random.uniform(size=input_size).astype('int32')
@@ -1146,13 +1376,14 @@ def check_torch_conversion(model, input_size):
         tvm_out = get_tvm_output(onnx_model, input_data, target, ctx)
         tvm.testing.assert_allclose(c2_out, tvm_out)
 
+
 def test_resnet():
-    check_torch_conversion(torchvision.models.resnet18, (1,3,224,224))
+    check_torch_conversion(torchvision.models.resnet18, (1, 3, 224, 224))
     # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))
 
 # def test_alexnet():
-    # Torch's ONNX export does not support the adaptive pooling used by AlexNet?
-    # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))
+# Torch's ONNX export does not support the adaptive pooling used by AlexNet?
+# check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))
 
 # Torch's ONNX export does not support the adaptive pooling used by vgg16?
 # def test_vgg16():
@@ -1163,11 +1394,13 @@ def test_resnet():
 #     # Torch's ONNX export does not support the max pooling used by Squezenet
 #     check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224))
 
+
 def test_densenet():
-    check_torch_conversion(torchvision.models.densenet161, (1,3,224,224))
+    check_torch_conversion(torchvision.models.densenet161, (1, 3, 224, 224))
+
 
 def test_inception():
-    check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224))
+    check_torch_conversion(torchvision.models.inception_v3, (1, 3, 224, 224))
 
 # TODO(@jroesch): Update Torch + ONNX to support this import.
 # def test_googlenet():
@@ -1177,6 +1410,7 @@ def test_inception():
 # def test_shufflenetv2():
 #     check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))
 
+
 def test_sign():
     def Sign_x(x):
         return np.sign(x)
@@ -1196,7 +1430,8 @@ def verify_not(indata, dtype):
 
     graph = helper.make_graph([node],
                               'not_test',
-                              inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, list(x.shape))],
+                              inputs=[helper.make_tensor_value_info(
+                                  "in", TensorProto.BOOL, list(x.shape))],
                               outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])
 
     model = helper.make_model(graph, producer_name='not_test')
@@ -1262,31 +1497,70 @@ def test_and():
     verify_and(indata=[x, y], dtype=bool)
 
 
-def verify_tile(indata, outdata, **kwargs):
+def verify_tile_v1(indata, outdata, **kwargs):
     node = helper.make_node('Tile', inputs=['in'], outputs=['out'], **kwargs)
     graph = helper.make_graph([node],
                               'tile_test',
-                              inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
+                              inputs=[helper.make_tensor_value_info(
+                                  "in", TensorProto.FLOAT, list(indata.shape))],
                               outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))])
 
     model = helper.make_model(graph, producer_name='tile_test')
 
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape)
+        tvm_out = get_tvm_output(
+            model, [indata], target, ctx, outdata.shape, opset=1)
+        tvm.testing.assert_allclose(outdata, tvm_out)
+
+
+def verify_tile_v6(indata, repeats, outdata):
+    node = helper.make_node('Tile',
+                            inputs=['input', 'repeats'],
+                            outputs=['out'])
+    graph = helper.make_graph(
+        [node],
+        'tile_test',
+        inputs=[
+            helper.make_tensor_value_info("input", TensorProto.FLOAT,
+                                          list(indata.shape)),
+            helper.make_tensor_value_info("repeats", TensorProto.INT64,
+                                          list(repeats.shape))
+        ],
+        outputs=[
+            helper.make_tensor_value_info("out", TensorProto.FLOAT,
+                                          list(outdata.shape))
+        ],
+        initializer=[
+            helper.make_tensor("repeats", TensorProto.INT64,
+                               list(repeats.shape), repeats)
+        ])
+
+    model = helper.make_model(graph, producer_name='tile_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [indata],
+                                 target,
+                                 ctx,
+                                 outdata.shape,
+                                 opset=6)
         tvm.testing.assert_allclose(outdata, tvm_out)
 
 
 def test_tile():
     x = np.random.rand(2, 3, 4, 5).astype(np.float32)
-    repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
+    repeats = np.random.randint(
+        low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
     z = np.tile(x, repeats)
-    verify_tile(x, z, repeats=repeats)
+    verify_tile_v1(x, z, repeats=repeats)
+    verify_tile_v6(x, repeats, z)
+
 
 def verify_erf(indata, outdata):
     node = helper.make_node('Erf', inputs=['in'], outputs=['out'])
     graph = helper.make_graph([node],
                               'erf_test',
-                              inputs=[helper.make_tensor_value_info('in', TensorProto.FLOAT, list(indata.shape))],
+                              inputs=[helper.make_tensor_value_info(
+                                  'in', TensorProto.FLOAT, list(indata.shape))],
                               outputs=[helper.make_tensor_value_info('out', TensorProto.FLOAT, list(outdata.shape))])
     model = helper.make_model(graph, producer_name='erf_test')
 
@@ -1294,6 +1568,7 @@ def verify_erf(indata, outdata):
         tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape)
         tvm.testing.assert_allclose(outdata, tvm_out)
 
+
 def test_erf():
     x = np.random.rand(2, 3, 4, 6).astype(np.float32)
     z = scipy.special.erf(x)
@@ -1337,7 +1612,9 @@ if __name__ == '__main__':
     test_floor()
     test_ceil()
     test_clip()
+    test_onehot()
     test_matmul()
+    test_batch_matmul()
     test_gather()
     test_lrn()
     test_instance_norm()
@@ -1348,7 +1625,7 @@ if __name__ == '__main__':
     test_forward_hardsigmoid()
     test_forward_arg_min_max()
     test_softmax()
-    test_constantfill()
+    test_constantofshape()
     test_reduce_max()
     test_reduce_min()
     test_reduce_sum()