[FRONTEND][TENSORFLOW]Add Split and realdiv op support (#2123)
authorZhebin Jin <zhebin.jzb@alibaba-inc.com>
Thu, 6 Dec 2018 17:51:05 +0000 (01:51 +0800)
committerYizhi Liu <liuyizhi@apache.org>
Thu, 6 Dec 2018 17:51:05 +0000 (09:51 -0800)
* Add Split and realdiv op support

* Fix the pad calculation in the case of dilated convolution

nnvm/python/nnvm/frontend/tensorflow.py
nnvm/tests/python/frontend/tensorflow/test_forward.py

index 26e59dc7e83010ed3ae07f4cf856a8c28064284d..c8db662152e92df98f05eb77b9095792c9314734 100644 (file)
@@ -215,7 +215,7 @@ def _conv(opname):
                 attr['channels'] = input_shape[3] * depth_mult
 
             if 'dilations' in attr:
-                attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
+                attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
             attr['strides'] = (attr['strides'][1], attr['strides'][2])
         elif attr['data_format'] == 'NCHW':
             depth_mult, _, kernel_h, kernel_w = weights_shape
@@ -252,8 +252,12 @@ def _conv(opname):
                 in_h = input_shape[2]
                 in_w = input_shape[3]
 
-            pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
-            pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
+            dilation_h = attr['dilations'][0]
+            dilation_w = attr['dilations'][1]
+            dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+            dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+            pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
+            pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
 
             if attr['data_format'] == 'NHWC':
                 inputs[0] = _sym.pad(data=inputs[0],
@@ -783,6 +787,15 @@ def _broadcast(name):
         )(inputs, attr)
     return _impl
 
+def _split():
+    def _impl(inputs, attr, params):
+        axis = params.pop(inputs[0].list_output_names()[0])
+        return AttrCvt(
+            op_name="split", ignores=['T'],
+            transforms={'num_split': 'indices_or_sections'},
+            extras={'axis': axis.asnumpy()[0]})(inputs[1], attr)
+    return _impl
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -813,6 +826,7 @@ _convert_map = {
     'Add'                               : _elemwise('add'),
     'Sub'                               : _elemwise('sub'),
     'Mul'                               : _elemwise('mul'),
+    'RealDiv'                           : _elemwise('div'),
     'Maximum'                           : _elemwise('max'),
     'Minimum'                           : _elemwise('min'),
     'Sum'                               : _sum(),
@@ -849,6 +863,7 @@ _convert_map = {
     'GreaterEqual'                      : _broadcast('greater_equal'),
     'Equal'                             : _broadcast('equal'),
     'NotEqual'                          : _broadcast('not_equal'),
+    'Split'                             : _split(),
 }
 
 # _convert_map_rnn defines maps of rnn operator name to
@@ -1144,21 +1159,26 @@ class GraphProto(object):
                 # Pass the target layout
                 attr["_target_layout"] = layout
 
-                #ToDo: Some of the tensorflow operators internaly maintain
-                #execution layers and its output name will the layer number along with
-                #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
-                #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
-                #the digit has to be ignored.
-                if ":" in node.input[0]:
-                    in_name, _ = node.input[0].split(':')
-                    node.input[0] = in_name
-
                 # Fill shapes for all inputs in a list
                 inputs = []
                 for i in node.input:
-                    if i in self._nodes:
-                        inputs.append(self._nodes[i])
-                        input_shapes[self._nodes[i]] = self._output_shapes[i]
+                    #ToDo: Some of the tensorflow operators internaly maintain
+                    #execution layers and its output name will the layer number along with
+                    #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
+                    #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
+                    #the digit has to be ignored.
+                    tensor_name = i.split(':')
+                    node_name = tensor_name[0]
+                    if node_name in self._nodes:
+                        in_sym = self._nodes[node_name]
+                        if len(in_sym.list_output_names()) > 1:
+                            tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0
+                            in_sym = in_sym[tensor_slot]
+                            input_shape = (self._output_shapes[node_name])[tensor_slot]
+                        else:
+                            input_shape = self._output_shapes[node_name][0]
+                        inputs.append(in_sym)
+                        input_shapes[in_sym] = [input_shape]
                 attr['_input_shapes'] = input_shapes
 
                 inputs = self._fix_extranodes(node.op, attr, inputs)
index c98748c0fc03303258ab09559f93760040fd8812..219ceb5bd379f80585e7f33d661638c8c765e213 100644 (file)
@@ -502,6 +502,83 @@ def test_forward_gather():
     _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
 
 
+#######################################################################
+# Split
+# -----
+
+def _test_split(in_shape, axis, num_split, dtype):
+    """ One iteration of a Split """
+
+    with tf.Graph().as_default():
+        in_data = tf.placeholder(dtype, in_shape, name="in_data")
+        tf.split(in_data, num_split, axis)
+        np_data = np.random.uniform(size=in_shape).astype(dtype)
+        compare_tf_with_tvm(np_data, 'in_data:0', 'split:0')
+
+def test_forward_split():
+    '''test split layer'''
+    # rank 1
+    _test_split((3,), 0, 1, 'float32')
+    _test_split((3,), 0, 3, 'float32')
+    _test_split((6,), 0, 3, 'float32')
+    # rank 2
+    _test_split((6, 2), 0, 3, 'float32')
+    _test_split((2, 6), 1, 3, 'float32')
+    # rank 3
+    _test_split((6, 2, 4), 0, 3, 'float32')
+    _test_split((2, 6, 4), 1, 3, 'float32')
+    _test_split((2, 4, 6), 2, 3, 'float32')
+    # rank 4
+    _test_split((6, 1, 3, 5), 0, 3, 'float32')
+    _test_split((1, 6, 3, 5), 1, 3, 'float32')
+    _test_split((1, 3, 6, 5), 2, 3, 'float32')
+    _test_split((1, 3, 5, 6), 3, 3, 'float32')
+    # split along negative axis
+    _test_split((6, 1, 3, 5), -4, 3, 'float32')
+    _test_split((1, 6, 3, 5), -3, 3, 'float32')
+    _test_split((1, 3, 6, 5), -2, 3, 'float32')
+    _test_split((1, 3, 5, 6), -1, 3, 'float32')
+
+
+#######################################################################
+# Split followed by concat
+# ------------------------
+
+def _test_split_concat(in_shape, axis, num_split, dtype):
+    """ One iteration of a split_concat pair"""
+
+    with tf.Graph().as_default():
+        in_data = tf.placeholder(dtype, in_shape, name="in_data")
+        splitted = tf.split(in_data, num_split, axis)
+        tf.concat(splitted, axis)
+        np_data = np.random.uniform(size=in_shape).astype(dtype)
+        compare_tf_with_tvm(np_data, 'in_data:0', 'concat:0')
+
+def test_forward_split_concat():
+    '''test split followed by concat layers'''
+    # rank 1
+    _test_split_concat((3,), 0, 1, 'float32')
+    _test_split_concat((3,), 0, 3, 'float32')
+    _test_split_concat((6,), 0, 3, 'float32')
+    # rank 2
+    _test_split_concat((6, 2), 0, 3, 'float32')
+    _test_split_concat((2, 6), 1, 3, 'float32')
+    # rank 3
+    _test_split_concat((6, 2, 4), 0, 3, 'float32')
+    _test_split_concat((2, 6, 4), 1, 3, 'float32')
+    _test_split_concat((2, 4, 6), 2, 3, 'float32')
+    # rank 4
+    _test_split((6, 1, 3, 5), 0, 3, 'float32')
+    _test_split((1, 6, 3, 5), 1, 3, 'float32')
+    _test_split((1, 3, 6, 5), 2, 3, 'float32')
+    _test_split((1, 3, 5, 6), 3, 3, 'float32')
+    # split along negative axis
+    _test_split((6, 1, 3, 5), -4, 3, 'float32')
+    _test_split((1, 6, 3, 5), -3, 3, 'float32')
+    _test_split((1, 3, 6, 5), -2, 3, 'float32')
+    _test_split((1, 3, 5, 6), -1, 3, 'float32')
+
+
 #######################################################################
 # Multi Input to graph
 # --------------------
@@ -1061,6 +1138,8 @@ if __name__ == '__main__':
     test_forward_pad()
     test_forward_gather()
     test_forward_stridedslice()
+    test_forward_split()
+    test_forward_split_concat()
 
     # Activations
     test_forward_sigmoid()