[Relay][Frontend][Tensorflow]Add conv2d_transpose (#4300)
authoroptima2005 <56945758+optima2005@users.noreply.github.com>
Mon, 18 Nov 2019 01:24:44 +0000 (09:24 +0800)
committerYizhi Liu <liuyizhi@apache.org>
Mon, 18 Nov 2019 01:24:44 +0000 (17:24 -0800)
* [Relay][Frontend][Tensorflow]Add conv2d_transpose

* add transformation from NHWC to NCHW to compatible with TVM conv2d_transpose implementation

* remove 'dilations' paramater to compitable with TF1.3

python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index e44653f..6085062 100644 (file)
@@ -195,10 +195,24 @@ def _conv(opname):
         attr['data_format'] = attr['data_format'].decode("utf-8")
         flip_layout = False
 
+        if opname == 'conv_transpose' and attr['data_format'] == 'NHWC':
+            # transform to NCHW for TVM backend compatible and set 'flip_layout'
+            # to have output flip back to NHWC
+            tmp_shape = attr['_input_shapes'][inputs[2]]
+            tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
+            inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2))
+            attr['_input_shapes'][inputs[2]] = tmp_shape
+            attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
+                attr['strides'][3], attr['strides'][1], attr['strides'][2]
+            attr['data_format'] = 'NCHW'
+            flip_layout = True
+
+        inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
+
         # NCHW Layout require weights transpose
         if attr['data_format'] == 'NCHW':
             tmp_shape = attr['_input_shapes'][inputs[1]]
-            if opname == 'conv':
+            if opname in ['conv', 'conv_transpose']:
                 tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
                 inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
             else:
@@ -206,13 +220,13 @@ def _conv(opname):
                 inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1))
             attr['_input_shapes'][inputs[1]] = tmp_shape
 
-        input_shape = attr['_input_shapes'][inputs[0]]
+        input_shape = attr['_input_shapes'][inputs_data]
         weights_shape = attr['_input_shapes'][inputs[1]]
 
         if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
             input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
-            inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
-            if opname == 'conv':
+            inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2))
+            if opname in ['conv', 'conv_transpose']:
                 weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
                 inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
             else:
@@ -228,6 +242,8 @@ def _conv(opname):
             attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
             if opname == 'conv':
                 attr['channels'] = weights_shape[3]
+            elif opname == 'conv_transpose':
+                attr['channels'] = weights_shape[2]
             else:
                 attr['channels'] = input_shape[3] * depth_mult
 
@@ -239,6 +255,8 @@ def _conv(opname):
             attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
             if opname == 'conv':
                 attr['channels'] = weights_shape[0]
+            elif opname == 'conv_transpose':
+                attr['channels'] = weights_shape[1]
             else:
                 attr['channels'] = input_shape[1] * depth_mult
                 if attr['channels'] < 0:
@@ -279,17 +297,17 @@ def _conv(opname):
 
 
             if attr['data_format'] == 'NHWC':
-                inputs[0] = _op.nn.pad(data=inputs[0],
-                                       pad_width=((0, 0),
-                                                  (pad_v[0], pad_v[1]),
-                                                  (pad_h[0], pad_h[1]),
-                                                  (0, 0)))
+                inputs_data = _op.nn.pad(data=inputs_data,
+                                         pad_width=((0, 0),
+                                                    (pad_v[0], pad_v[1]),
+                                                    (pad_h[0], pad_h[1]),
+                                                    (0, 0)))
             else:
-                inputs[0] = _op.nn.pad(data=inputs[0],
-                                       pad_width=((0, 0),
-                                                  (0, 0),
-                                                  (pad_v[0], pad_v[1]),
-                                                  (pad_h[0], pad_h[1])))
+                inputs_data = _op.nn.pad(data=inputs_data,
+                                         pad_width=((0, 0),
+                                                    (0, 0),
+                                                    (pad_v[0], pad_v[1]),
+                                                    (pad_h[0], pad_h[1])))
 
             attr['padding'] = [0, 0]
 
@@ -299,27 +317,30 @@ def _conv(opname):
             raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
 
         if 'kernel_layout' not in attr:
-            if opname == 'conv':
+            if opname in ['conv', 'conv_transpose']:
                 attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
             else:
                 attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
 
-        use_bias = len(inputs) == 3
+        use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4)
         channel_axis = 1 if attr['data_format'] == "NCHW" else 3
 
         # Ignore the new attributes from TF2.0, for now.
         out = AttrCvt(
-            op_name=_dimension_picker('conv'),
+            op_name=_dimension_picker('conv', \
+                surfix="_transpose" if opname == 'conv_transpose' else ""),
             ignores=['explicit_paddings'],
             transforms={
                 'kernel_shape': 'kernel_size',
                 'data_format': 'data_layout',
                 'dilations': ('dilation', (0, 0)),
                 'group': ('groups', 1)},
-            custom_check=_dimension_constraint())([inputs[0], inputs[1]], attr)
+            custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr)
 
         if use_bias:
-            out = _op.nn.bias_add(out, inputs[2], axis=channel_axis)
+            out = _op.nn.bias_add(out,
+                                  inputs[2] if opname != 'conv_transpose' else inputs[3],
+                                  axis=channel_axis)
 
         if flip_layout:
             out = _op.transpose(out, axes=(0, 2, 3, 1))
@@ -1403,6 +1424,7 @@ _convert_map = {
     'Concat'                            : _concat(),
     'ConcatV2'                          : _concatV2(),
     'Conv2D'                            : _conv('conv'),
+    'Conv2DBackpropInput'               : _conv('conv_transpose'),
     'CropAndResize'                     : _crop_and_resize(),
     'DecodeJpeg'                        : _decode_image(),
     'DepthwiseConv2dNative'             : _conv('depthwise'),
index e02532f..b9d84dc 100644 (file)
@@ -295,7 +295,8 @@ def test_forward_pooling():
 
 
 def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
-                      dilations, strides, padding, data_format):
+                      dilations, strides, padding, data_format,
+                      deconv_output_shape=[]):
     """ One iteration of convolution with given shapes and attributes """
 
     total_size_1 = np.prod(tensor_in_sizes)
@@ -326,6 +327,16 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
 
             compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
                                 'Placeholder:0', 'Conv2D:0')
+        elif opname == 'conv_transpose':
+            nn_ops.conv2d_transpose(in_data,
+                                    in_filter,
+                                    output_shape=deconv_output_shape,
+                                    strides=strides,
+                                    padding=padding,
+                                    data_format=data_format)
+
+            compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
+                                'Placeholder:0', 'conv2d_transpose:0')
         else:
             nn_ops.depthwise_conv2d_native(in_data,
                                            in_filter,
@@ -349,6 +360,14 @@ def test_forward_convolution():
         _test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW')
         _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW')
         _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')
+        _test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
+                          'NCHW', [4, 176, 8, 8])
+        _test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
+                          'NCHW', [4, 19, 17, 17])
+        _test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
+                          'NCHW', [4, 124, 17, 17])
+        _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
+                          'NCHW', [4, 12, 17, 17])
 
     _test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
     _test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
@@ -359,6 +378,15 @@ def test_forward_convolution():
     _test_convolution('depthwise', [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC')
     _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC')
     _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')
+    _test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
+                      'NHWC', [4, 8, 8, 176])
+    _test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
+                      'NHWC', [4, 17, 17, 19])
+    _test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
+                      'NHWC', [4, 17, 17, 124])
+    _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
+                      'NHWC', [4, 17, 17, 12])
+
 
 #######################################################################
 # BiasAdd