[ Relay ][ Frontend ][ Tensorflow ]add op add_n to relay/frontend/tensorflow.py ...
authorKim <kimyangbaochen@vip.qq.com>
Fri, 1 Nov 2019 15:54:33 +0000 (23:54 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 1 Nov 2019 15:54:33 +0000 (08:54 -0700)
docs/frontend/tensorflow.rst
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index 8782888..436a888 100644 (file)
@@ -115,6 +115,7 @@ Supported Ops
 
 - Abs
 - Add
+- AddN
 - All
 - Any
 - ArgMax
index 648d7f4..38c51b9 100644 (file)
@@ -1318,6 +1318,18 @@ def _size():
         return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr)
     return _impl
 
+def _add_n():
+    def _impl(inputs, attr, params):
+        if not isinstance(inputs, tuple):
+            inputs = list(inputs)
+        assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given."
+        _res = inputs[0]
+        for each in inputs[1:]:
+            _res = _op.add(_res, each)
+        return  _res
+    return _impl
+
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -1329,6 +1341,7 @@ _identity_list = []
 _convert_map = {
     'Abs'                               : AttrCvt('abs'),
     'Add'                               : _elemwise('add'),
+    'AddN'                              : _add_n(),
     'All'                               : _reduce('all'),
     'Any'                               : _reduce('any'),
     'ArgMax'                            : _argx(_op.argmax, 'argmax'),
index 88787ef..b17ec12 100644 (file)
@@ -41,11 +41,14 @@ import tvm.relay.testing.tf as tf_testing
 #######################################################################
 # Generic run functions for TVM & tensorflow
 # ------------------------------------------
+
+
 def convert_to_list(x):
     if not isinstance(x, list):
         x = [x]
     return x
 
+
 def vmobj_to_list(o):
     if isinstance(o, tvm.relay.backend.vmobj.Tensor):
         return [o.asnumpy().tolist()]
@@ -72,12 +75,14 @@ def vmobj_to_list(o):
         elif 'tensor' in o.constructor.name_hint:
             return [o.fields[0].asnumpy()]
         else:
-            raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint)
+            raise RuntimeError("Unknown object type: %s" %
+                               o.constructor.name_hint)
     elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
         return [o.data.asnumpy()]
     else:
         raise RuntimeError("Unknown object type: %s" % type(o))
 
+
 def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
                   target='llvm', out_names=None, opt_level=3, mode='graph_runtime'):
     """ Generic function to compile on relay and execute on tvm """
@@ -116,16 +121,19 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
         # get outputs
         assert out_names is None or num_output == len(out_names), (
             "out_names: {} num_output: {}".format(out_names, num_output))
-        tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)]
+        tvm_output_list = [m.get_output(i).asnumpy()
+                           for i in range(num_output)]
         return tvm_output_list
 
+
 def run_tf_graph(sess, input_data, input_node, output_node):
     """ Generic function to execute tensorflow """
     input_data = convert_to_list(input_data)
     input_node = convert_to_list(input_node)
     output_node = convert_to_list(output_node)
 
-    tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node]
+    tensor = [sess.graph.get_tensor_by_name(
+        output_name) for output_name in output_node]
 
     input_dict = {e: input_data[i] for i, e in enumerate(input_node)}
 
@@ -152,7 +160,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
             sess,
             sess.graph.as_graph_def(add_shapes=True),
             out_node,
-            )
+        )
         tf_output = run_tf_graph(sess, in_data, in_name, out_name)
 
         for device in ["llvm", "cuda"]:
@@ -169,10 +177,12 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
             # since the names from tensorflow and relay runs are not exactly same,
             # first len(tf_output) will be compared
             for i in range(len(tf_output)):
-                tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+                tvm.testing.assert_allclose(
+                    tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
 
         sess.close()
 
+
 def is_gpu_available():
     from tensorflow.python.client import device_lib
     local_device_protos = device_lib.list_local_devices()
@@ -186,6 +196,8 @@ def is_gpu_available():
 #######################################################################
 # Pooling
 # -------
+
+
 def _test_pooling_iteration(input_shape, **kwargs):
     """ One iteration of pool operation with given shapes and attributes """
 
@@ -203,6 +215,7 @@ def _test_pooling_iteration(input_shape, **kwargs):
 
         compare_tf_with_tvm(x, 'Placeholder:0', out_name)
 
+
 def _test_pooling(input_shape, **kwargs):
     _test_pooling_iteration(input_shape, **kwargs)
 
@@ -211,6 +224,7 @@ def _test_pooling(input_shape, **kwargs):
         kwargs['data_format'] = 'NCHW'
         _test_pooling_iteration(input_shape, **kwargs)
 
+
 def test_forward_pooling():
     """ Pooling """
 
@@ -260,6 +274,7 @@ def test_forward_pooling():
 # Convolution
 # -----------
 
+
 def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
                       dilations, strides, padding, data_format):
     """ One iteration of convolution with given shapes and attributes """
@@ -273,7 +288,8 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
-        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
+        in_filter = constant_op.constant(
+            filter_array, shape=filter_in_sizes, dtype='float32')
         if data_format == 'NHWC':
             strides = [1] + strides + [1]
             dilations = [1] + dilations + [1]
@@ -293,15 +309,16 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
                                 'Placeholder:0', 'Conv2D:0')
         else:
             nn_ops.depthwise_conv2d_native(in_data,
-                          in_filter,
-                          strides=strides,
-                          dilations=dilations,
-                          padding=padding,
-                          data_format=data_format)
+                                           in_filter,
+                                           strides=strides,
+                                           dilations=dilations,
+                                           padding=padding,
+                                           data_format=data_format)
 
             compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
                                 'Placeholder:0', 'DepthwiseConv2dNative:0')
 
+
 def test_forward_convolution():
     if is_gpu_available():
         _test_convolution('conv', [4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW')
@@ -327,13 +344,16 @@ def test_forward_convolution():
 #######################################################################
 # BiasAdd
 # -----------
+
+
 def _test_biasadd(tensor_in_sizes, data_format):
     """ One iteration of biasadd with given shapes and attributes """
 
     total_size_1 = 1
     for s in tensor_in_sizes:
         total_size_1 *= s
-    tensor_bias_sizes = [tensor_in_sizes[1]] if data_format == 'NCHW' else [tensor_in_sizes[3]]
+    tensor_bias_sizes = [tensor_in_sizes[1]
+                         ] if data_format == 'NCHW' else [tensor_in_sizes[3]]
     total_size_2 = tensor_bias_sizes[0]
     # Initializes the input tensor with array containing incrementing
     # numbers from 1.
@@ -342,7 +362,8 @@ def _test_biasadd(tensor_in_sizes, data_format):
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
-        in_bias = constant_op.constant(bias_array, shape=tensor_bias_sizes, dtype='float32')
+        in_bias = constant_op.constant(
+            bias_array, shape=tensor_bias_sizes, dtype='float32')
         nn_ops.bias_add(in_data,
                         in_bias,
                         data_format=data_format)
@@ -350,6 +371,7 @@ def _test_biasadd(tensor_in_sizes, data_format):
         compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
                             'Placeholder:0', 'BiasAdd:0')
 
+
 def test_forward_biasadd():
     if is_gpu_available():
         _test_biasadd([4, 176, 8, 8], 'NCHW')
@@ -362,15 +384,17 @@ def test_forward_biasadd():
     _test_biasadd([4, 17, 17, 19], 'NHWC')
     _test_biasadd([4, 3, 3, 124], 'NHWC')
 
+
 def _test_forward_where(input_shape):
     with tf.Graph().as_default():
-        dtype =  tf.float32
+        dtype = tf.float32
         t = tf.constant(np.random.choice([0, 1, -2, 3, -1, 0.1, -0.2],
                                          size=input_shape).astype(dtype.name))
         out = tf.where(t)
         compare_tf_with_tvm([], [], out.name, mode='debug')
         compare_tf_with_tvm([], [], out.name, mode='vm')
 
+
 def test_forward_argwhere():
     _test_forward_where((5,))
     _test_forward_where((5, 5))
@@ -381,6 +405,8 @@ def test_forward_argwhere():
 #######################################################################
 # SpaceToBatchND
 # --------------
+
+
 def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
     data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
 
@@ -390,6 +416,7 @@ def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
 
         compare_tf_with_tvm(data, in_data.name, out.name)
 
+
 def test_forward_space_to_batch_nd():
     # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d
     _test_space_to_batch_nd(
@@ -436,6 +463,8 @@ def test_forward_space_to_batch_nd():
 #######################################################################
 # BatchToSpaceND
 # --------------
+
+
 def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
     data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
 
@@ -445,6 +474,7 @@ def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
 
         compare_tf_with_tvm(data, in_data.name, out.name)
 
+
 def test_forward_batch_to_space_nd():
     # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d
     _test_batch_to_space_nd(
@@ -492,6 +522,7 @@ def test_forward_batch_to_space_nd():
 # Reshape
 # -------
 
+
 def _test_reshape(data, out_shape):
     """ One iteration of reshape operation with given data and out shape """
 
@@ -501,6 +532,7 @@ def _test_reshape(data, out_shape):
 
         compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0')
 
+
 def test_forward_reshape():
     _test_reshape(np.arange(6.0), [2, 3])
     _test_reshape(np.arange(6), [-1, 2])
@@ -511,6 +543,7 @@ def test_forward_reshape():
 # DepthToSpace
 # ------------
 
+
 def _test_depthtospace(data, block_size):
     """ One iteration of depth_to_space operation with given data and block size """
 
@@ -520,6 +553,7 @@ def _test_depthtospace(data, block_size):
 
         compare_tf_with_tvm(data, 'Placeholder:0', 'DepthToSpace:0')
 
+
 def test_forward_depthtospace():
     _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]), 2)
     _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]), 4)
@@ -528,6 +562,7 @@ def test_forward_depthtospace():
 # SpaceToDepth
 # ------------
 
+
 def _test_spacetodepth(data, block_size):
     """ One iteration of space_to_depth operation with given data and block size """
 
@@ -537,6 +572,7 @@ def _test_spacetodepth(data, block_size):
 
         compare_tf_with_tvm(data, 'Placeholder:0', 'SpaceToDepth:0')
 
+
 def test_forward_spacetodepth():
     _test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]), 2)
     _test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]), 4)
@@ -545,6 +581,7 @@ def test_forward_spacetodepth():
 # Squeeze
 # -------
 
+
 def _test_squeeze(data, squeeze_dims=None):
     """ One iteration of squeeze """
 
@@ -561,6 +598,7 @@ def _test_squeeze(data, squeeze_dims=None):
 
         compare_tf_with_tvm(data, 'Placeholder:0', 'Squeeze:0')
 
+
 def test_forward_squeeze():
     """ Squeeze """
 
@@ -584,16 +622,20 @@ def test_forward_squeeze():
     _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5])
     _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1])
 
+
 def test_tensor_array_constructor():
     def run(dtype_str):
         with tf.Graph().as_default():
-            dtype =  {
+            dtype = {
                 'float32': tf.float32,
-                'int32'  : tf.int32
+                'int32': tf.int32
             }[dtype_str]
-            t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
-            t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
-            ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
+            t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(
+                dtype_str), dtype=dtype)
+            t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(
+                dtype_str), dtype=dtype)
+            ta1 = tf.TensorArray(dtype=dtype, size=2,
+                                 infer_shape=False, dynamic_size=False)
             ta2 = ta1.write(0, t)
             ta3 = ta2.write(1, t2)
             out = ta3.read(0)
@@ -602,24 +644,29 @@ def test_tensor_array_constructor():
     run('float32')
     run('int32')
 
+
 def test_tensor_array_scatter():
     def run(dtype_str):
         with tf.Graph().as_default():
-            dtype =  {
+            dtype = {
                 'float32': tf.float32,
-                'int32'  : tf.int32
+                'int32': tf.int32
             }[dtype_str]
-            t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
+            t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(
+                dtype_str), dtype=dtype)
             indices = tf.constant([2, 1, 0])
-            ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=False, dynamic_size=False)
+            ta1 = tf.TensorArray(dtype=dtype, size=3,
+                                 infer_shape=False, dynamic_size=False)
             ta2 = ta1.scatter(indices, t)
             out0 = ta2.read(0)
             out1 = ta2.read(1)
             out2 = ta2.read(2)
             g = tf.get_default_graph()
             compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
+            compare_tf_with_tvm(
+                [], [], ['TensorArrayReadV3_1:0'], mode='debug')
+            compare_tf_with_tvm(
+                [], [], ['TensorArrayReadV3_2:0'], mode='debug')
     run('float32')
     run('int32')
 
@@ -636,16 +683,19 @@ def test_tensor_array_scatter():
 #         g = tf.get_default_graph()
 #         compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug')
 
+
 def test_tensor_array_split():
     def run(dtype_str):
         with tf.Graph().as_default():
-            dtype =  {
+            dtype = {
                 'float32': tf.float32,
-                'int32'  : tf.int32
+                'int32': tf.int32
             }[dtype_str]
-            t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
+            t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [
+                            6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
             split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
-            ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False)
+            ta1 = tf.TensorArray(dtype=dtype, size=4,
+                                 infer_shape=False, dynamic_size=False)
             ta2 = ta1.split(t, split_length)
             out0 = ta2.read(0)
             out1 = ta2.read(1)
@@ -653,36 +703,45 @@ def test_tensor_array_split():
             out3 = ta2.read(3)
             g = tf.get_default_graph()
             compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
-            compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
+            compare_tf_with_tvm(
+                [], [], ['TensorArrayReadV3_1:0'], mode='debug')
+            compare_tf_with_tvm(
+                [], [], ['TensorArrayReadV3_2:0'], mode='debug')
+            compare_tf_with_tvm(
+                [], [], ['TensorArrayReadV3_3:0'], mode='debug')
     run('float32')
     run('int32')
 
+
 def test_tensor_array_concat():
     def run(dtype_str):
         with tf.Graph().as_default():
             dtype = {
                 'float32': tf.float32,
-                'int32'  : tf.int32
+                'int32': tf.int32
             }[dtype_str]
-            t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
+            t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [
+                            6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
             split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
-            ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False)
+            ta1 = tf.TensorArray(dtype=dtype, size=4,
+                                 infer_shape=False, dynamic_size=False)
             ta2 = ta1.split(t, split_length)
             t = ta2.concat()
-            compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug')
+            compare_tf_with_tvm(
+                [], [], ['TensorArrayConcatV3:0'], mode='debug')
     run('float32')
     run('int32')
 
+
 def test_tensor_array_size():
     def run(dtype_str):
         with tf.Graph().as_default():
-            dtype =  {
+            dtype = {
                 'float32': tf.float32,
-                'int32'  : tf.int32
+                'int32': tf.int32
             }[dtype_str]
-            ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
+            ta1 = tf.TensorArray(dtype=dtype, size=2,
+                                 infer_shape=False, dynamic_size=False)
             out = ta1.size()
             g = tf.get_default_graph()
             compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
@@ -693,6 +752,7 @@ def test_tensor_array_size():
 # ConcatV2
 # --------
 
+
 def _test_concat_v2(shape1, shape2, dim):
     """ One iteration of ConcatV2 """
 
@@ -705,7 +765,9 @@ def _test_concat_v2(shape1, shape2, dim):
         np_data1 = np.random.uniform(size=shape1).astype(dtype)
         np_data2 = np.random.uniform(size=shape2).astype(dtype)
 
-        compare_tf_with_tvm([np_data1, np_data2], ['in1:0', 'in2:0'], 'ConcatV2:0')
+        compare_tf_with_tvm([np_data1, np_data2], [
+                            'in1:0', 'in2:0'], 'ConcatV2:0')
+
 
 def test_forward_concat_v2():
     if tf.__version__ < LooseVersion('1.4.1'):
@@ -721,6 +783,7 @@ def test_forward_concat_v2():
 # Sigmoid
 # -------
 
+
 def _test_sigmoid(data):
     """ One iteration of sigmoid """
 
@@ -730,6 +793,7 @@ def _test_sigmoid(data):
 
         compare_tf_with_tvm(data, 'Placeholder:0', 'Sigmoid:0')
 
+
 def test_forward_sigmoid():
     """ Sigmoid """
 
@@ -739,14 +803,17 @@ def test_forward_sigmoid():
 # Argmin/Argmax
 # -------------
 
+
 def _test_argx(func, data, **kwargs):
 
     with tf.Graph().as_default():
-        inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
+        inp = array_ops.placeholder(
+            shape=data.shape, dtype=data.dtype, name="c0")
         func(inp, name="argx0", output_type=tf.int32, **kwargs)
 
         compare_tf_with_tvm(data, 'c0:0', 'argx0:0')
 
+
 def test_forward_argminmax():
     for axis in [None, 0, 1, 2]:
         data = np.random.uniform(size=(8, 4, 9)).astype('float32')
@@ -757,15 +824,18 @@ def test_forward_argminmax():
 # Reduce
 # ------
 
+
 def _test_reduce(func, data, **kwargs):
     """ One iteration of a reduce operation"""
 
     with tf.Graph().as_default():
-        inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
+        inp = array_ops.placeholder(
+            shape=data.shape, dtype=data.dtype, name="c0")
         func(inp, name="reducex0", **kwargs)
 
         compare_tf_with_tvm(data, 'c0:0', 'reducex0:0')
 
+
 def test_forward_reduce():
     data = np.random.uniform(size=(8, 4, 9)).astype('float32')
     _test_reduce(tf.reduce_sum, data=data)
@@ -790,7 +860,9 @@ def _test_variable(data):
             "w", shape=[size, size], dtype=input_tensor.dtype)
     math_ops.matmul(input_tensor, w)
 
-    compare_tf_with_tvm(data, 'Placeholder:0', 'MatMul:0', init_global_variables=True)
+    compare_tf_with_tvm(data, 'Placeholder:0', 'MatMul:0',
+                        init_global_variables=True)
+
 
 def test_forward_variable():
     """Variable type op test"""
@@ -810,23 +882,29 @@ def _test_matmul(i, j, k, dtype, outer=None):
     for transpose_a in [False, True]:
         for transpose_b in [False, True]:
             outer = outer or []
-            A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init)
-            B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init)
+            A_shape = outer + \
+                (A_shape_init[::-1] if transpose_a else A_shape_init)
+            B_shape = outer + \
+                (B_shape_init[::-1] if transpose_b else B_shape_init)
 
             with tf.Graph().as_default():
                 A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
                 B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
-                result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b)
+                result = tf.matmul(
+                    A, B, transpose_a=transpose_a, transpose_b=transpose_b)
 
                 A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
                 B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
-                compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
+                compare_tf_with_tvm(
+                    [A_np, B_np], [A.name, B.name], result.name)
+
 
 def test_forward_matmul():
     """ MatMul op test"""
     _test_matmul(1, 3, 6, 'int32')
     _test_matmul(5, 3, 1, 'float64')
 
+
 def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False):
 
     with tf.Graph().as_default():
@@ -839,6 +917,7 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False
         B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
         compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
 
+
 def test_forward_batch_matmul():
     """ TF op BatchMatMul, BatchMatMulV2 test"""
     _test_batch_matmul((3, 5, 4), (3, 4, 5), 'int32')
@@ -846,9 +925,11 @@ def test_forward_batch_matmul():
     _test_batch_matmul((3, 5, 4), (3, 5, 4), 'int32', True, False)
     _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True)
     _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'int32')
-    _test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), 'float32', True, True)
+    _test_batch_matmul((1, 2, 3, 4, 5, 6),
+                       (1, 2, 3, 4, 6, 5), 'float32', True, True)
     _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), 'int32', True, False)
-    _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), 'float32', False, True)
+    _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6),
+                       (2, 3, 4, 2, 3, 4, 5, 6), 'float32', False, True)
 
 
 #######################################################################
@@ -870,16 +951,23 @@ def _test_stridedslice(ip_shape, begin, end, stride, dtype,
 
     compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0')
 
+
 def test_forward_stridedslice():
     '''test StridedSlice'''
 
     _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
-    _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
-    _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8)
-    _test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
-    _test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
-    _test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], 'float32', ellipsis_mask=2)
-    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5)
+    _test_stridedslice((3, 4, 3), [1, -1, 0],
+                       [4, -5, 3], [2, -1, 1], 'float32')
+    _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [
+                       2, 1], 'float32', ellipsis_mask=8)
+    _test_stridedslice((3, 4, 3), [1, 0], [4, 2], [
+                       2, 1], 'float32', ellipsis_mask=2)
+    _test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [
+                       2, 1], 'float32', ellipsis_mask=2)
+    _test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [
+                       2, 1, 1], 'float32', ellipsis_mask=2)
+    _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [
+                       2, 1, 1], 'float32', new_axis_mask=5)
     _test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2,
                        new_axis_mask=4)
     _test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2,
@@ -892,7 +980,8 @@ def test_forward_stridedslice():
                        new_axis_mask=3)
     _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2,
                        new_axis_mask=2)
-    _test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2)
+    _test_stridedslice((3, 4), [1, 0], [4, 4], [
+                       1, 1], 'float32', shrink_axis_mask=2)
     _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=2,
                        new_axis_mask=2)
     _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32', shrink_axis_mask=1,
@@ -918,6 +1007,7 @@ def test_forward_stridedslice():
 # FloorDiv, RealDiv
 # -----------------
 
+
 def _test_forward_divide(ip_shape, dtype):
     np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
     np_denomin = np.random.uniform(1, 100, size=ip_shape).astype(dtype)
@@ -925,7 +1015,9 @@ def _test_forward_divide(ip_shape, dtype):
     numerator = tf.placeholder(dtype, ip_shape, name="numer")
     denominator = tf.placeholder(dtype, ip_shape, name="denomin")
     tf.math.divide(numerator, denominator, name='RealDiv')
-    compare_tf_with_tvm([np_numer, np_denomin], ['numer:0', 'denomin:0'], 'RealDiv:0')
+    compare_tf_with_tvm([np_numer, np_denomin], [
+                        'numer:0', 'denomin:0'], 'RealDiv:0')
+
 
 def _test_forward_floordiv(ip_shape, dtype):
     np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
@@ -934,6 +1026,7 @@ def _test_forward_floordiv(ip_shape, dtype):
     tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name='FloorDiv')
     compare_tf_with_tvm([np_numer], ['numer:0'], 'FloorDiv:0')
 
+
 def test_forward_divide():
     '''test FloorDiv, RealDiv'''
     _test_forward_divide((4,), 'int32')
@@ -951,7 +1044,9 @@ def _test_forward_truncatemod(ip_shape, dtype):
     in_data_1 = tf.placeholder(dtype, ip_shape, name="in_data_1")
     in_data_2 = tf.placeholder(dtype, ip_shape, name="in_data_2")
     tf.truncatemod(in_data_1, in_data_2, name='truncatemod')
-    compare_tf_with_tvm([np_data_1, np_data_2], ['in_data_1:0', 'in_data_2:0'], 'truncatemod:0')
+    compare_tf_with_tvm([np_data_1, np_data_2], [
+                        'in_data_1:0', 'in_data_2:0'], 'truncatemod:0')
+
 
 def test_forward_truncatemod():
     '''test TruncateMod'''
@@ -980,7 +1075,9 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
         return indices
     np_indices = _fill_indices(indice_value)
 
-    compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], out.name)
+    compare_tf_with_tvm([np_data, np_indices], [
+                        'in_data:0', 'indices:0'], out.name)
+
 
 def test_forward_gather():
     '''test Gather/GatherV2 layer'''
@@ -995,6 +1092,7 @@ def test_forward_gather():
     _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32')
     _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32')
 
+
 def test_forward_gather_nd():
     """test operator GatherNd"""
     np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32)
@@ -1016,7 +1114,8 @@ def test_forward_bias_add():
         lft_data = tf.placeholder(dtype, name="lft_data")
         rgt_data = tf.placeholder(dtype, name="rgt_data")
         tf.nn.bias_add(lft_data, rgt_data, name="BiasAdd")
-        compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'BiasAdd:0')
+        compare_tf_with_tvm([lh_data, rh_data], [
+                            'lft_data:0', 'rgt_data:0'], 'BiasAdd:0')
 
     check_bias_add((10, 8, 16, 32), (32,), dtype="int32")
     check_bias_add((10, 20), (20,), dtype="float32")
@@ -1033,7 +1132,7 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype):
     tf.reset_default_graph()
     in_data = tf.placeholder(dtype, in_shape, name="in_data")
     num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list)\
-                else num_or_size_splits
+        else num_or_size_splits
     split = tf.split(in_data, num_or_size_splits, axis=axis)
     relu = [tf.nn.relu(i) for i in split]
 
@@ -1047,6 +1146,7 @@ def _test_split(in_shape, axis, num_or_size_splits, dtype):
 
     compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0')
 
+
 def test_forward_split():
     '''test split layer'''
     # rank 1
@@ -1086,6 +1186,7 @@ def _test_forward_top_k_v2(in_shape, k):
     tf.math.top_k(in_data, k, name='TopK')
     compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')
 
+
 def test_forward_top_k_v2():
     _test_forward_top_k_v2((3,), 1)
     _test_forward_top_k_v2((3,), 3)
@@ -1112,6 +1213,7 @@ def _test_unstack(ip_shape, axis, dtype):
 
     compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')
 
+
 def test_forward_unstack():
     '''test unstack layer'''
     _test_unstack((6,), 0, 'int32')
@@ -1132,6 +1234,7 @@ def _test_tile(in_shape, multiples, dtype):
     tf.tile(in_data, multiples=multiples, name="tile")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'tile:0')
 
+
 def test_forward_tile():
     '''test Tile'''
     _test_tile((2, ), (3, ), "int32")
@@ -1146,10 +1249,12 @@ def test_forward_tile():
 def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype):
     tf.reset_default_graph()
     in_data = tf.placeholder(dtype, ip_shape, name="in_data")
-    tf.clip_by_value(in_data, clip_value_min, clip_value_max, name="ClipByValue")
+    tf.clip_by_value(in_data, clip_value_min,
+                     clip_value_max, name="ClipByValue")
     np_data = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
     compare_tf_with_tvm([np_data], ['in_data:0'], 'ClipByValue:0')
 
+
 def test_forward_clip_by_value():
     '''test ClipByValue op'''
     if tf.__version__ < LooseVersion('1.9'):
@@ -1160,6 +1265,7 @@ def test_forward_clip_by_value():
 # Multi Input to graph
 # --------------------
 
+
 def test_forward_multi_input():
     with tf.Graph().as_default():
         in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
@@ -1179,6 +1285,7 @@ def test_forward_multi_input():
 # Multi Output to Graph
 # ---------------------
 
+
 def test_forward_multi_output():
     with tf.Graph().as_default():
         in1 = tf.placeholder(tf.int32, shape=[3, 3], name='in1')
@@ -1202,12 +1309,14 @@ def test_forward_multi_output():
             tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target='llvm',
                                        out_names=out_node, num_output=2)
             for i in range(len(tf_output)):
-                tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+                tvm.testing.assert_allclose(
+                    tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
 
 #######################################################################
 # Resize Bilinear, Nearest_Neighbor
 # ---------------------------------
 
+
 def _test_resize_bilinear(in_shape, to_shape, align_corners):
     """ One iteration of resize bilinear """
 
@@ -1218,10 +1327,12 @@ def _test_resize_bilinear(in_shape, to_shape, align_corners):
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         shape_data = constant_op.constant(
             shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
-        tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners)
+        tf.image.resize_bilinear(
+            in_data, shape_data, align_corners=align_corners)
 
         compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
 
+
 def _test_resize_bilinear_from_tensor(in_shape, align_corners):
     """ One iteration of resize bilinear with non-constant output shape, requires
         value inference to get proper output shape."""
@@ -1232,7 +1343,8 @@ def _test_resize_bilinear_from_tensor(in_shape, align_corners):
         in_data = array_ops.placeholder(
             shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype)
         to_shape = tf.shape(in_data)[2:]
-        tf.image.resize_bilinear(in_data, to_shape, align_corners=align_corners)
+        tf.image.resize_bilinear(
+            in_data, to_shape, align_corners=align_corners)
 
         compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
 
@@ -1247,7 +1359,8 @@ def _test_resize_nearest_neighbor(in_shape, to_shape):
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         shape_data = constant_op.constant(
             shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
-        tf.image.resize_nearest_neighbor(in_data, shape_data, name='resize_nearest_neighbor')
+        tf.image.resize_nearest_neighbor(
+            in_data, shape_data, name='resize_nearest_neighbor')
 
         compare_tf_with_tvm(data, 'Placeholder:0', 'resize_nearest_neighbor:0')
 
@@ -1278,7 +1391,8 @@ def _test_broadcast_to(in_shape, to_shape):
             shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
         tf.broadcast_to(in_data, shape_data)
 
-        compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0', opt_level=0)
+        compare_tf_with_tvm(data, 'Placeholder:0',
+                            'BroadcastTo:0', opt_level=0)
 
 
 def _test_broadcast_to_from_tensor(in_shape):
@@ -1315,6 +1429,7 @@ def _test_fill(in_shape):
         tf.ones(shape=in_shape, dtype='float32')
         compare_tf_with_tvm(in_shape, [], 'ones:0', opt_level=1)
 
+
 def _test_fill_from_tensor(in_shape):
     """ Use the fill op to create a tensor of ones with non-constant shape.
         Some extra ops need to be added here to prevent the graph from
@@ -1330,6 +1445,7 @@ def _test_fill_from_tensor(in_shape):
         y = tf.math.add(in_data, tf.reduce_mean(x), name='out1')
         compare_tf_with_tvm(data, 'Placeholder:0', 'out1:0')
 
+
 def test_forward_fill():
     """ Resize Bilinear """
 
@@ -1341,13 +1457,16 @@ def test_forward_fill():
 # Crop to bounding box
 # --------------------
 
+
 def _test_crop(in_shape, off_h, off_w, tar_h, tar_w):
     """ Crop to bounding box """
     data = np.random.uniform(size=in_shape).astype('float32')
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w)
-        compare_tf_with_tvm(data, 'Placeholder:0', 'crop_to_bounding_box/Slice:0')
+        compare_tf_with_tvm(data, 'Placeholder:0',
+                            'crop_to_bounding_box/Slice:0')
+
 
 def test_forward_crop():
     """ Crop to bounding box """
@@ -1366,19 +1485,25 @@ def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, method='
                              method=method, name="crop_and_resize")
     compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0')
 
+
 def test_forward_crop_and_resize():
     """ CropAndResize """
     _test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, 1, 1]], [0], [5, 5])
-    _test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5])
-    _test_forward_crop_and_resize([1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5])
-    _test_forward_crop_and_resize([1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4])
-    _test_forward_crop_and_resize([1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3])
+    _test_forward_crop_and_resize(
+        [1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5])
+    _test_forward_crop_and_resize(
+        [1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5])
+    _test_forward_crop_and_resize(
+        [1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4])
+    _test_forward_crop_and_resize(
+        [1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3])
     _test_forward_crop_and_resize([10, 11, 11, 3],
                                   [[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]],
                                   [0, 1],
                                   [5, 5])
     _test_forward_crop_and_resize([3, 11, 11, 3],
-                                  [[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8],[0, 0, 1, 1]],
+                                  [[0, 0, 0.9, 0.9], [
+                                      0.2, 0.2, 0.8, 0.8], [0, 0, 1, 1]],
                                   [0, 1, 2],
                                   [3, 3])
     _test_forward_crop_and_resize([3, 11, 11, 3],
@@ -1397,8 +1522,10 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
     tf.reset_default_graph()
     input_size = num_hidden
     input_data = np.full((batch_size, input_size), 1., dtype=dtype)
-    in_state_c = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
-    in_state_h = np.full((num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
+    in_state_c = np.full(
+        (num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
+    in_state_h = np.full(
+        (num_layers, batch_size, num_hidden), 0.1, dtype=dtype)
 
     def _get_tensorflow_output():
         with tf.Session() as sess:
@@ -1408,8 +1535,8 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
                 m1 = array_ops.zeros([batch_size, num_hidden])
                 x = tf.placeholder(shape=(batch_size, input_size), dtype=dtype)
                 g, ((out_m0, out_m1)) = \
-                     tf.contrib.rnn.LSTMBlockCell(num_hidden,
-                                                  forget_bias=forget_bias)(x, ((m0, m1)))
+                    tf.contrib.rnn.LSTMBlockCell(num_hidden,
+                                                 forget_bias=forget_bias)(x, ((m0, m1)))
                 sess.run([variables.global_variables_initializer()])
                 res = sess.run([g, out_m0, out_m1], {
                     x.name: np.array([[1., 1.]]),
@@ -1437,12 +1564,12 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype):
     tvm_out = [out, out_state_c, out_state_h]
     tvm.testing.assert_allclose(tf_out[0], tvm_out[0], rtol=1e-3, atol=1e-3)
 
+
 def test_forward_lstm():
     '''test LSTM block cell'''
     _test_lstm_cell(1, 2, 1, 0.5, 'float32')
 
 
-
 #######################################################################
 # Pack
 # ---
@@ -1459,6 +1586,7 @@ def _test_pack(axis, shape, **kwargs):
 
         compare_tf_with_tvm([a, b], ['pl_a:0', 'pl_b:0'], 'stack:0')
 
+
 def test_forward_pack():
     for axis in range(-3, 3):
         _test_pack(axis, [3, 2, 1])
@@ -1478,6 +1606,7 @@ def _test_forward_unpack(in_shape, axis, dtype):
     tf.unstack(in_data, axis=axis, name="Unpack")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'Unpack:0')
 
+
 def test_forward_unpack():
     _test_forward_unpack((3,), 0, 'int32')
     _test_forward_unpack((3,), -1, 'int16')
@@ -1486,6 +1615,8 @@ def test_forward_unpack():
 #######################################################################
 # Range
 # -----
+
+
 def test_forward_range():
     """test operator Range"""
     tf.reset_default_graph()
@@ -1495,6 +1626,8 @@ def test_forward_range():
 #######################################################################
 # Pad
 # ---
+
+
 def _test_pad(input_shape, paddings, mode, **kwargs):
     """ One iteration of pad operation with given shape"""
 
@@ -1515,6 +1648,7 @@ def _test_pad(input_shape, paddings, mode, **kwargs):
 
         compare_tf_with_tvm(x, 'Placeholder:0', out_name)
 
+
 def test_forward_pad():
     """ Pad """
     _test_pad((2, 3), [[1, 1], [2, 2]], mode="CONSTANT")
@@ -1525,40 +1659,53 @@ def test_forward_pad():
 #######################################################################
 # Logical operators
 # --------------------
+
+
 def test_logical_and():
     with tf.Graph().as_default():
         in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
         in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
         out = tf.logical_and(in1, in2, name='out')
-        in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
-        in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
+        in_data1 = np.random.choice(
+            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
+        in_data2 = np.random.choice(
+            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
         compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
 
+
 def test_logical_or():
     with tf.Graph().as_default():
         in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
         in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
         out = tf.logical_or(in1, in2, name='out')
-        in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
-        in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
+        in_data1 = np.random.choice(
+            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
+        in_data2 = np.random.choice(
+            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
         compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
 
+
 def test_logical_xor():
     with tf.Graph().as_default():
         in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
         in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
         out = tf.logical_xor(in1, in2, name='out')
-        in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
-        in_data2 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
+        in_data1 = np.random.choice(
+            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
+        in_data2 = np.random.choice(
+            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
         compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')
 
+
 def test_logical_not():
     with tf.Graph().as_default():
         in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
         out = tf.logical_not(in1, name='out')
-        in_data1 = np.random.choice(a=[False, True], size=(1, 4, 4, 3)).astype('bool')
+        in_data1 = np.random.choice(
+            a=[False, True], size=(1, 4, 4, 3)).astype('bool')
         compare_tf_with_tvm(in_data1, 'in1:0', 'out:0')
 
+
 def test_forward_logical():
     test_logical_and()
     test_logical_or()
@@ -1573,13 +1720,18 @@ def test_forward_where():
     ''' Where: return elements depending on conditions'''
     with tf.Graph().as_default():
         with tf.Session() as sess:
-            input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input1')
-            input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input2')
+            input1 = tf.placeholder(
+                tf.int32, shape=[1, 4, 4, 3], name='input1')
+            input2 = tf.placeholder(
+                tf.int32, shape=[1, 4, 4, 3], name='input2')
             mask = input1 > input2
             tf.where(mask, input1 + 1, input2 * 2)
-            in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
-            in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
-            compare_tf_with_tvm([in_data1, in_data2], ['input1:0', 'input2:0'], 'Select:0')
+            in_data1 = np.random.uniform(
+                0, 10, size=(1, 4, 4, 3)).astype("uint32")
+            in_data2 = np.random.uniform(
+                0, 10, size=(1, 4, 4, 3)).astype("uint32")
+            compare_tf_with_tvm([in_data1, in_data2], [
+                                'input1:0', 'input2:0'], 'Select:0')
 
 
 #######################################################################
@@ -1596,17 +1748,22 @@ def test_forward_inception_v3():
         data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
 
         with tf.Session() as sess:
-            tf_output = run_tf_graph(sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
+            tf_output = run_tf_graph(
+                sess, data, 'input:0', 'InceptionV3/Predictions/Reshape_1:0')
             tvm_output = run_tvm_graph(graph_def, data, 'input')
-            tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(
+                tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
 
 #######################################################################
 # Inception V1
 # ------------
+
+
 def test_forward_inception_v1():
     '''test inception V1 model'''
     with tf.Graph().as_default():
-        graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
+        graph_def = tf_testing.get_workload(
+            "InceptionV1/classify_image_graph_def-with_shapes.pb")
         # Call the utility to import the graph definition into default graph.
         graph_def = tf_testing.ProcessGraphDefParam(graph_def)
 
@@ -1615,7 +1772,8 @@ def test_forward_inception_v1():
         from tvm.contrib import util
 
         img_array = np.random.uniform(size=(1, 600, 600, 3)).astype("uint8")
-        img = Image.frombuffer('RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1)
+        img = Image.frombuffer(
+            'RGB', (600, 600), img_array.tostring(), 'raw', 'RGB', 0, 1)
         temp = util.tempdir()
         img_path = temp.relpath("tf-test.jpg")
         img.save(img_path)
@@ -1629,16 +1787,22 @@ def test_forward_inception_v1():
 
         # Extract tensorflow decoded image frame for tvm input
         with tf.Session() as sess:
-            tvm_data = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0')
+            tvm_data = run_tf_graph(
+                sess, data, 'DecodeJpeg/contents:0', 'DecodeJpeg:0')
 
         with tf.Session() as sess:
-            tf_output = run_tf_graph(sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
-            tvm_output = run_tvm_graph(graph_def, tvm_data, 'DecodeJpeg/contents')
-            tvm.testing.assert_allclose(tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
+            tf_output = run_tf_graph(
+                sess, data, 'DecodeJpeg/contents:0', 'softmax:0')
+            tvm_output = run_tvm_graph(
+                graph_def, tvm_data, 'DecodeJpeg/contents')
+            tvm.testing.assert_allclose(
+                tf_output[0], tvm_output[0], rtol=1e-5, atol=1e-5)
 
 #######################################################################
 # Mobilenet
 # ---------
+
+
 def test_forward_mobilenet():
     '''test mobilenet model'''
     # MobilenetV2
@@ -1663,6 +1827,8 @@ def test_forward_mobilenet():
 #######################################################################
 # ResnetV2
 # --------
+
+
 def test_forward_resnetv2():
     '''test resnet model'''
     if is_gpu_available():
@@ -1676,7 +1842,8 @@ def test_forward_resnetv2():
             out_node = 'ArgMax'
 
             with tf.Session() as sess:
-                tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
+                tf_output = run_tf_graph(
+                    sess, data, 'input_tensor:0', out_node + ':0')
                 for device in ["llvm", "cuda"]:
                     ctx = tvm.context(device, 0)
                     if not ctx.exist:
@@ -1690,6 +1857,8 @@ def test_forward_resnetv2():
 #######################################################################
 # Placeholder
 # -----------
+
+
 def test_forward_placeholder():
     '''test a simple pb with Placeholder node in the end of GraphDef'''
     with tf.Graph().as_default():
@@ -1703,15 +1872,19 @@ def test_forward_placeholder():
         with tf.Session() as sess:
             # Add shapes to the graph.
             graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
-            tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0')
+            tf_output = run_tf_graph(
+                sess, data, 'Placeholder:0', out_node + ':0')
             tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
             tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]),
                                         rtol=1e-5, atol=1e-5)
 
+
 #######################################################################
 # PTB
 # ---
 dir(tf.contrib)
+
+
 def test_forward_ptb():
     '''test ptb model'''
     config = tf_testing.get_config()
@@ -1722,7 +1895,7 @@ def test_forward_ptb():
     vocab_size = config.vocab_size
     out_sample_shape = (batch_size, vocab_size)
     out_state_shape = (num_layers, 2, batch_size, num_hidden)
-    #Sample input
+    # Sample input
     inpt = "we have no useful information on"
     cnt_sample = 20
 
@@ -1733,18 +1906,19 @@ def test_forward_ptb():
             return ''.join([id2word[x] for x in items]).replace('_', ' ')
 
     def _get_tvm_graph_module(graph_def):
-        #Cell inputs 'c and 'h' consist of all layers values
+        # Cell inputs 'c and 'h' consist of all layers values
         shape_dict = {'Model/Placeholder': (batch_size, num_steps),
                       'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':
                       (num_layers, batch_size, num_hidden),
                       'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':
                       (num_layers, batch_size, num_hidden)}
 
-        mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
+        mod, params = relay.frontend.from_tensorflow(
+            graph_def, shape=shape_dict)
 
         dtype_dict = {'Model/Placeholder': 'int32',
-                      'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32',
-                      'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'}
+                      'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c': 'float32',
+                      'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h': 'float32'}
         target = 'llvm'
         with relay.build_config(opt_level=0):
             graph, lib, params = relay.build(mod,
@@ -1759,13 +1933,17 @@ def test_forward_ptb():
         samples = []
         state = in_states
         sample = None
+
         def _get_sample(data, state):
             input_data = np.full((batch_size, num_steps), data, dtype="int32")
             in_state_tup = np.split(state, indices_or_sections=2, axis=1)
-            in_state_c = np.reshape(in_state_tup[0], (num_layers, batch_size, num_hidden))
-            in_state_h = np.reshape(in_state_tup[1], (num_layers, batch_size, num_hidden))
+            in_state_c = np.reshape(
+                in_state_tup[0], (num_layers, batch_size, num_hidden))
+            in_state_h = np.reshape(
+                in_state_tup[1], (num_layers, batch_size, num_hidden))
 
-            model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32")))
+            model.set_input('Model/Placeholder',
+                            tvm.nd.array(input_data.astype("int32")))
             model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c',
                             tvm.nd.array(in_state_c.astype("float32")))
             model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h',
@@ -1802,16 +1980,17 @@ def test_forward_ptb():
         graph_def = tf_testing.ProcessGraphDefParam(graph_def)
         sess = tf.Session()
 
-    #TVM graph module creation
+    # TVM graph module creation
     params, m = _get_tvm_graph_module(graph_def)
 
     # Create 10 predicted statments of 20 words
     cnt_stm = 0
     while cnt_stm < 10:
         cnt_stm += 1
-        in_state = np.full((num_layers, 2, batch_size, num_hidden), 0, dtype="float32")
+        in_state = np.full(
+            (num_layers, 2, batch_size, num_hidden), 0, dtype="float32")
         seed_for_sample = inpt.split()
-        tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word] \
+        tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word]
                                                     for word in seed_for_sample],
                                                 in_state, params, cnt_sample)
         tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
@@ -1821,13 +2000,15 @@ def test_forward_ptb():
             in_state, cnt_sample)
         tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
         inpt = tvm_sample_str
-        tvm.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
+        tvm.testing.assert_allclose(
+            tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
         assert tvm_sample_str == tf_sample_str
 
 #######################################################################
 # LRN (Local Response Normalization)
 # ----------------------------------
 
+
 def _test_lrn(ishape, size, axis, bias, alpha, beta):
     """ testing local response normalization """
     lrn_depth_radius = size / 2
@@ -1835,7 +2016,8 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
     inp_array = np.random.uniform(size=ishape).astype(np.float32)
 
     with tf.Graph().as_default():
-        in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data")
+        in1 = tf.placeholder(shape=inp_array.shape,
+                             dtype=inp_array.dtype, name="lrn0_data")
         nn_ops.local_response_normalization(in1,
                                             name="lrn",
                                             depth_radius=lrn_depth_radius,
@@ -1845,6 +2027,7 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
 
         compare_tf_with_tvm(inp_array, 'lrn0_data:0', 'lrn:0')
 
+
 def test_forward_lrn():
     _test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
 
@@ -1852,6 +2035,7 @@ def test_forward_lrn():
 # l2_normalize
 # ------------
 
+
 def _test_l2_normalize(ishape, eps, axis):
     """ testing l2 normalize (uses max, sum, square, sqrt frontend operators)"""
 
@@ -1867,17 +2051,21 @@ def _test_l2_normalize(ishape, eps, axis):
 
         compare_tf_with_tvm(inp_array, 'Placeholder:0', 'l2_normalize:0')
 
+
 def test_forward_l2_normalize():
     _test_l2_normalize((1, 3, 20, 20), 0.001, (0,))
 
 #######################################################################
 # transpose
 # ---------
+
+
 def _test_forward_transpose(ishape, axes=None):
     data = np.random.uniform(size=ishape).astype(np.float32)
 
     with tf.Graph().as_default():
-        in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="transpose_data")
+        in1 = tf.placeholder(
+            shape=data.shape, dtype=data.dtype, name="transpose_data")
 
         if axes is None:
             tf.transpose(in1)
@@ -1886,6 +2074,7 @@ def _test_forward_transpose(ishape, axes=None):
 
         compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0')
 
+
 def test_forward_transpose():
     _test_forward_transpose((2, 3, 4), (1, 2, 0))
     _test_forward_transpose((2, 3, 4))
@@ -1903,6 +2092,7 @@ def test_forward_ceil():
         tf.ceil(in1)
         compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Ceil:0')
 
+
 def test_forward_floor():
     ishape = (1, 3, 10, 10)
     inp_array = np.random.uniform(size=ishape).astype(np.float32)
@@ -1911,6 +2101,7 @@ def test_forward_floor():
         tf.floor(in1)
         compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Floor:0')
 
+
 def test_forward_relu():
     ishape = (1, 3, 10, 10)
     inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
@@ -1919,6 +2110,7 @@ def test_forward_relu():
         tf.nn.relu(in1)
         compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Relu:0')
 
+
 def test_forward_leaky_relu():
     ishape = (1, 3, 10, 10)
     inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
@@ -1927,6 +2119,7 @@ def test_forward_leaky_relu():
         tf.nn.leaky_relu(in1, alpha=0.4)
         compare_tf_with_tvm(inp_array, 'Placeholder:0', 'LeakyRelu:0')
 
+
 def test_forward_elu():
     ishape = (1, 3, 10, 10)
     inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
@@ -1935,6 +2128,7 @@ def test_forward_elu():
         tf.nn.elu(in1)
         compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Elu:0')
 
+
 def test_forward_selu():
     ishape = (1, 3, 10, 10)
     inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
@@ -1943,6 +2137,7 @@ def test_forward_selu():
         tf.nn.selu(in1)
         compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Selu:0')
 
+
 def test_forward_tanh():
     ishape = (1, 3, 10, 10)
     inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
@@ -1979,6 +2174,7 @@ def test_forward_round():
     tf.round(in_data, name="round")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0')
 
+
 def test_forward_abs():
     """test operator Abs"""
     np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32)
@@ -1987,6 +2183,7 @@ def test_forward_abs():
     tf.math.abs(in_data, name="abs")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'abs:0')
 
+
 def _test_forward_zeros_like(in_shape, dtype):
     np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
     tf.reset_default_graph()
@@ -1994,6 +2191,7 @@ def _test_forward_zeros_like(in_shape, dtype):
     tf.zeros_like(in_data, name="zeros_like")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'zeros_like:0')
 
+
 def test_forward_zeros_like():
     if tf.__version__ < LooseVersion('1.2'):
         _test_forward_zeros_like((2, 3), "int32")
@@ -2002,6 +2200,7 @@ def test_forward_zeros_like():
         _test_forward_zeros_like((2, 3, 11), "float32")
         _test_forward_zeros_like((2, 3, 11), "float64")
 
+
 def test_forward_erf():
     ishape = (1, 3, 10, 10)
     inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
@@ -2010,15 +2209,20 @@ def test_forward_erf():
         tf.math.erf(in1)
         compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0')
 
+
 def test_forward_squared_difference():
     ishape = (1, 3, 10, 14)
     inp_array_a = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
     inp_array_b = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
     with tf.Graph().as_default():
-        in1 = tf.placeholder(shape=inp_array_a.shape, dtype=inp_array_a.dtype, name="in1")
-        in2 = tf.placeholder(shape=inp_array_b.shape, dtype=inp_array_b.dtype, name="in2")
+        in1 = tf.placeholder(shape=inp_array_a.shape,
+                             dtype=inp_array_a.dtype, name="in1")
+        in2 = tf.placeholder(shape=inp_array_b.shape,
+                             dtype=inp_array_b.dtype, name="in2")
         out = tf.math.squared_difference(in1, in2)
-        compare_tf_with_tvm([inp_array_a, inp_array_b], [in1.name, in2.name], out.name)
+        compare_tf_with_tvm([inp_array_a, inp_array_b], [
+                            in1.name, in2.name], out.name)
+
 
 def _test_forward_reverse_v2(in_shape, axis, dtype):
     np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
@@ -2027,6 +2231,7 @@ def _test_forward_reverse_v2(in_shape, axis, dtype):
     tf.reverse(in_data, axis=[axis], name="reverse")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'reverse:0')
 
+
 def test_forward_reverse_v2():
     """test ReverseV2"""
     _test_forward_reverse_v2((2, 3), 0, "int32")
@@ -2035,6 +2240,7 @@ def test_forward_reverse_v2():
     _test_forward_reverse_v2((2, 3, 5), -1, "float64")
     _test_forward_reverse_v2((2, 3, 5), -3, "float64")
 
+
 def test_forward_sign():
     """test Sign"""
     np_data = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32)
@@ -2043,6 +2249,7 @@ def test_forward_sign():
     tf.sign(in_data, name="sign")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0')
 
+
 def test_forward_square():
     """test operator Square """
     np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
@@ -2051,6 +2258,7 @@ def test_forward_square():
     tf.square(in_data, name="square")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'square:0')
 
+
 def test_forward_pow_exp():
     """test Pow and Exp """
     np_in1 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32)
@@ -2063,6 +2271,7 @@ def test_forward_pow_exp():
     compare_tf_with_tvm([np_in1, np_in2], ['in1:0', 'in2:0'], 'pow:0')
     compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0')
 
+
 def test_forward_log():
     """test operator Log """
     np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
@@ -2071,6 +2280,7 @@ def test_forward_log():
     tf.log(in_data, name="log")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
 
+
 def test_forward_log1p():
     """test operator Log1p """
     np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
@@ -2079,6 +2289,7 @@ def test_forward_log1p():
     tf.log1p(in_data, name="log1p")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'log1p:0')
 
+
 def test_forward_cos():
     """test operator cos """
     np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
@@ -2087,6 +2298,7 @@ def test_forward_cos():
     tf.cos(in_data, name="cos")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0')
 
+
 def test_forward_sin():
     """test operator sin """
     np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
@@ -2095,14 +2307,17 @@ def test_forward_sin():
     tf.sin(in_data, name="sin")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'sin:0')
 
+
 def test_forward_negative():
     """test tf operator Neg """
-    np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32)
+    np_data = np.random.uniform(-100, 255,
+                                size=(224, 224, 3)).astype(np.float32)
     tf.reset_default_graph()
     in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data")
     tf.negative(in_data, name="negative")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0')
 
+
 def test_forward_log_softmax():
     """test operator LogSoftmax"""
     np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32)
@@ -2111,6 +2326,7 @@ def test_forward_log_softmax():
     tf.math.log_softmax(in_data, name="LogSoftmax")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'LogSoftmax:0')
 
+
 def test_forward_softplus():
     """test operator Softplus"""
     np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
@@ -2119,6 +2335,7 @@ def test_forward_softplus():
     tf.nn.softplus(in_data, name="softplus")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0')
 
+
 def test_forward_rsqrt():
     """test Rsqrt """
     np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
@@ -2127,6 +2344,7 @@ def test_forward_rsqrt():
     tf.rsqrt(in_data, name="rsqrt")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0')
 
+
 def test_forward_sqrt():
     """test Sqrt """
     np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
@@ -2135,6 +2353,7 @@ def test_forward_sqrt():
     tf.sqrt(in_data, name="sqrt")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0')
 
+
 def _test_forward_right_shift(in_shape, dtype):
     """test operator RightShift"""
     lh_data = np.random.randint(1, 3, size=in_shape).astype(dtype)
@@ -2143,12 +2362,15 @@ def _test_forward_right_shift(in_shape, dtype):
     lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
     rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
     tf.bitwise.right_shift(lft_data, rgt_data, name="RightShift")
-    compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'RightShift:0')
+    compare_tf_with_tvm([lh_data, rh_data], [
+                        'lft_data:0', 'rgt_data:0'], 'RightShift:0')
+
 
 def test_forward_right_shift():
     _test_forward_right_shift((7,), 'int32')
     _test_forward_right_shift((3, 11), 'int16')
 
+
 def _test_forward_left_shift(in_shape, dtype):
     """test operator LeftShift"""
     lh_data = np.random.randint(100, 1000000, size=in_shape).astype(dtype)
@@ -2157,7 +2379,9 @@ def _test_forward_left_shift(in_shape, dtype):
     lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
     rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
     tf.bitwise.left_shift(lft_data, rgt_data, name="LeftShift")
-    compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'LeftShift:0')
+    compare_tf_with_tvm([lh_data, rh_data], [
+                        'lft_data:0', 'rgt_data:0'], 'LeftShift:0')
+
 
 def test_forward_left_shift():
     _test_forward_left_shift((10,), 'int32')
@@ -2166,13 +2390,16 @@ def test_forward_left_shift():
 #######################################################################
 # Mean
 # ----
+
+
 def test_forward_mean():
     def check_mean(ishape, **kwargs):
         inp_array = np.random.uniform(size=ishape).astype(np.float32)
         with tf.Graph().as_default():
             in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
             tf.keras.backend.mean(in1, **kwargs)
-            compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Mean:0', no_gpu=True)
+            compare_tf_with_tvm(inp_array, 'Placeholder:0',
+                                'Mean:0', no_gpu=True)
 
     check_mean((10, 8, 16, 32))
     check_mean((10, 8, 16, 32), axis=(2, 3))
@@ -2181,6 +2408,8 @@ def test_forward_mean():
 #######################################################################
 # Size
 # ----
+
+
 def test_forward_size():
     def check_size(ishape):
         np_input = np.random.uniform(size=ishape).astype(np.float32)
@@ -2190,7 +2419,8 @@ def test_forward_size():
         tf_input_shape[0] = None
 
         with tf.Graph().as_default():
-            input = tf.placeholder(shape=tf_input_shape, dtype=np_input.dtype, name='input')
+            input = tf.placeholder(shape=tf_input_shape,
+                                   dtype=np_input.dtype, name='input')
             tf.size(input, name='size')
             compare_tf_with_tvm([np_input], ['input:0'], 'size:0')
 
@@ -2200,6 +2430,8 @@ def test_forward_size():
 #######################################################################
 # All, Any, Max, Min
 # -------------
+
+
 def test_forward_reduce_all():
     """Test the All operator."""
     np_data = np.random.choice([True, False], size=(5, 7, 11))
@@ -2208,32 +2440,28 @@ def test_forward_reduce_all():
     tf.reduce_all(in_data, name="all")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
 
-def test_forward_reduce_any():
-    """Test the Any operator."""
-    np_data = np.random.choice([True, False], size=(5, 7, 11))
-    tf.reset_default_graph()
-    in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
-    tf.reduce_any(in_data, name="any")
-    compare_tf_with_tvm([np_data], ['in_data:0'], 'any:0')
 
 def test_forward_reduce_max():
     def check_max(ishape, axis, keepdims, dtype):
         tf.reset_default_graph()
         np_data = np.random.uniform(size=ishape).astype(dtype)
         in_data = tf.placeholder(dtype, name="in_data")
-        tf.math.reduce_max(in_data, axis=axis, keepdims=keepdims, name="reduce_max")
+        tf.math.reduce_max(in_data, axis=axis,
+                           keepdims=keepdims, name="reduce_max")
         compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
 
     check_max((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32")
     check_max((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32")
     check_max((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32')
 
+
 def test_forward_reduce_min():
     def check_min(ishape, axis, keepdims, dtype):
         tf.reset_default_graph()
         np_data = np.random.uniform(size=ishape).astype(dtype)
         in_data = tf.placeholder(dtype, name="in_data")
-        tf.math.reduce_min(in_data, axis=axis, keepdims=keepdims, name="reduce_max")
+        tf.math.reduce_min(in_data, axis=axis,
+                           keepdims=keepdims, name="reduce_max")
         compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
 
     check_min((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32")
@@ -2243,14 +2471,19 @@ def test_forward_reduce_min():
 #######################################################################
 # Relational operators
 # --------------------
+
+
 def _test_forward_rel_op(data, func):
     with tf.Graph().as_default():
-        in1 = tf.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in1')
-        in2 = tf.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in2')
+        in1 = tf.placeholder(
+            shape=data[0].shape, dtype=data[0].dtype, name='in1')
+        in2 = tf.placeholder(
+            shape=data[1].shape, dtype=data[1].dtype, name='in2')
         op = func(in1, in2, name='op')
         out = tf.cast(op, tf.int32, name='out1')
         compare_tf_with_tvm([data[0], data[1]], ['in1:0', 'in2:0'], 'out1:0')
 
+
 def test_forward_rel_ops():
     t1 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
     t2 = np.array([[9, 8, 7], [6, 5, 4], [3, 2, 1]])
@@ -2264,11 +2497,14 @@ def test_forward_rel_ops():
 #######################################################################
 # ExpandDims
 # ----------
+
+
 def _test_forward_expand_dims(data, axis):
     in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name='in1')
     out = tf.expand_dims(in1, axis)
     compare_tf_with_tvm([data], [in1.name], out.name)
 
+
 def test_forward_expand_dims():
     _test_forward_expand_dims(np.int32(1), 0)
     _test_forward_expand_dims(np.array([1]), 0)
@@ -2288,6 +2524,7 @@ def _test_forward_reduce_prod(shape, axis, keepdims):
         out = tf.math.reduce_prod(in1, axis, keepdims)
         compare_tf_with_tvm(inp_array1, in1.name, out.name)
 
+
 def test_forward_reduce_prod():
     _test_forward_reduce_prod((5,), 0, False)
     _test_forward_reduce_prod((5, 5), 0, False)
@@ -2309,11 +2546,13 @@ def test_forward_maximum():
         lft_data = tf.placeholder(dtype, name="lft_data")
         rgt_data = tf.placeholder(dtype, name="rgt_data")
         tf.math.maximum(lft_data, rgt_data, name="maximum")
-        compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'maximum:0')
+        compare_tf_with_tvm([lh_data, rh_data], [
+                            'lft_data:0', 'rgt_data:0'], 'maximum:0')
 
     check_maximum((10, 8, 16, 32), (1,), dtype="int32")
     check_maximum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32")
 
+
 def test_forward_minimum():
     """test Op Minimum"""
     def check_minimum(lh_shape, rh_shape, dtype):
@@ -2323,7 +2562,8 @@ def test_forward_minimum():
         lft_data = tf.placeholder(dtype, name="lft_data")
         rgt_data = tf.placeholder(dtype, name="rgt_data")
         tf.math.minimum(lft_data, rgt_data, name="minimum")
-        compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'minimum:0')
+        compare_tf_with_tvm([lh_data, rh_data], [
+                            'lft_data:0', 'rgt_data:0'], 'minimum:0')
 
     check_minimum((10, 8, 16, 32), (1,), dtype="int32")
     check_minimum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32")
@@ -2339,7 +2579,8 @@ def test_placeholder():
         var2 = array_ops.placeholder_with_default(var1, None, name='place1')
 
         in_data2 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32)
-        place1 = array_ops.placeholder(shape=in_data1.shape, dtype=in_data1.dtype, name='in2')
+        place1 = array_ops.placeholder(
+            shape=in_data1.shape, dtype=in_data1.dtype, name='in2')
 
         out1 = tf.math.add(var1, var2, name='out1')
         out2 = tf.math.add(out1, place1, name='out2')
@@ -2350,13 +2591,17 @@ def test_placeholder():
 #######################################################################
 # OneHot
 # ----------------------
+
+
 def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype):
     inp_array1 = np.random.randint(0, 5, size=indices_shape)
     with tf.Graph().as_default():
         in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
-        out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype)
+        out = tf.one_hot(in1, depth, on_value, off_value,
+                         axis, dtype=out_dtype)
         compare_tf_with_tvm(inp_array1, in1.name, out.name)
 
+
 def test_forward_one_hot():
     _test_forward_one_hot((3,), 3, 1, 0, -1, "int32")
     _test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
@@ -2365,6 +2610,40 @@ def test_forward_one_hot():
     _test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
     _test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
 
+#######################################################################
+# AddN
+# ----------------------
+
+
+def _test_forward_add_n(inputs):
+    tf.reset_default_graph()
+    with tf.Graph().as_default():
+        temp = []
+        for each in inputs:
+            temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
+        output = tf.add_n(temp)
+        compare_tf_with_tvm([each for each in inputs], [
+                            each.name for each in temp], output.name)
+
+
+def test_forward_add_n():
+    x = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
+    y = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
+    z = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
+    m, n, o = x.astype(np.float32), y.astype(np.float32), z.astype(np.float32)
+    in0 = x
+    in1 = [x, y]
+    in2 = (x, y, z)
+    in3 = m
+    in4 = [m, n]
+    in5 = (m, n, o)
+    _test_forward_add_n(in0)
+    _test_forward_add_n(in1)
+    _test_forward_add_n(in2)
+    _test_forward_add_n(in3)
+    _test_forward_add_n(in4)
+    _test_forward_add_n(in5)
+
 
 #######################################################################
 # Main
@@ -2433,6 +2712,7 @@ if __name__ == '__main__':
     test_forward_zeros_like()
     test_forward_erf()
     test_forward_squared_difference()
+    test_forward_add_n()
 
     # Reductions
     test_forward_argminmax()