[TFLite] Add transpose_conv to TFLite parser (#4440)
authorAlexander Pivovarov <pivovaa@amazon.com>
Sun, 1 Dec 2019 15:41:50 +0000 (07:41 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sun, 1 Dec 2019 15:41:50 +0000 (07:41 -0800)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index e2dc0e7..7abbd1e 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, unused-argument
+# pylint: disable=invalid-name, unused-argument, too-many-lines
 """Tensorflow lite frontend."""
 from __future__ import absolute_import as _abs
 import math
@@ -96,6 +96,7 @@ class OperatorConverter(object):
             'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
             'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
             'PRELU': self.convert_prelu,
+            'TRANSPOSE_CONV': self.convert_transpose_conv,
         }
 
     def check_unsupported_ops(self):
@@ -1370,6 +1371,84 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_transpose_conv(self, op):
+        """Convert TFLite TRANSPOSE_CONV"""
+        try:
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.TensorType import TensorType
+            from tflite.Operator import Operator
+            from tflite.TransposeConvOptions import TransposeConvOptions
+            from tflite.Padding import Padding
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 3, "input tensors length should be 3"
+
+        # Input (data) Tensor. NHWC layout
+        input_tensor = input_tensors[2]
+        _, _, _, input_c = input_tensor.tensor.ShapeAsNumpy()
+        # Weights tensor. TFLite uses OHWI layout
+        weights_tensor = input_tensors[1]
+        out_channels, kernel_h, kernel_w, in_channels = weights_tensor.tensor.ShapeAsNumpy()
+        assert input_c == in_channels, \
+            "Input channel in the filter should match to channel in the input"
+        # output_shape Tensor. NHWC layout
+        output_shape_tensor = input_tensors[0]
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+        output_tensor_type = output_tensor.tensor.Type()
+        output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)
+
+        assert op.BuiltinOptionsType() == BuiltinOptions.TransposeConvOptions
+        op_options = op.BuiltinOptions()
+        deconv_options = TransposeConvOptions()
+        deconv_options.Init(op_options.Bytes, op_options.Pos)
+
+        padding = deconv_options.Padding()
+        stride_h = deconv_options.StrideH()
+        stride_w = deconv_options.StrideW()
+        assert padding in (Padding.VALID, Padding.SAME), \
+            'Padding format {} is not supported for operator TRANSPOSE_CONV'.format(padding)
+
+        # Data
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+
+        # Weights
+        weights_tensor_type = weights_tensor.tensor.Type()
+        # weights tensor type should be UINT8 (quantization) or FLOAT32
+        assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
+        weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type)
+        weight_value_ohwi = self.get_tensor_value(weights_tensor)
+        # Relay kernel_layout should be OIHW
+        # Relay weights layout should be different from kernel_layout - it should be IOHW
+        weight_value_iohw = np.transpose(weight_value_ohwi, (3, 0, 1, 2))
+        weight_expr_iohw = self.exp_tab.new_const(weight_value_iohw, dtype=weight_tensor_type_str)
+
+        # Output shape value
+        output_shape_value = self.get_tensor_value(output_shape_tensor)
+        # Relay expects filter output channel to match to output tensor channel.
+        assert out_channels == output_shape_value[3], \
+            "Output channel in the filter should match to channel in the output_shape"
+
+        # TF frontend supports 'SAME' padding for kernel 1x1 only. Lets do the same here
+        if padding == Padding.SAME:
+            assert (kernel_h, kernel_w) == (1, 1), \
+                "SAME padding is supported for kernel (1,1) only"
+
+        out = _op.nn.conv2d_transpose(in_expr, weight_expr_iohw,
+                                      strides=(stride_h, stride_w),
+                                      channels=int(out_channels),
+                                      kernel_size=(int(kernel_h), int(kernel_w)),
+                                      data_layout="NHWC",
+                                      kernel_layout="OIHW",
+                                      out_dtype=output_tensor_type_str)
+
+        return out
+
     def get_expr(self, input_tensor_idx):
         return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
 
index ad7989f..e1926f8 100644 (file)
@@ -478,6 +478,60 @@ def test_forward_convolution():
 
 
 #######################################################################
+# Transpose Convolution
+# ---------------------
+
+def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides, padding):
+    """ One iteration of transpose convolution with given shapes and attributes """
+
+    total_size_1 = 1
+    total_size_2 = 1
+    for s in tensor_in_sizes:
+        total_size_1 *= s
+    for s in filter_in_sizes:
+        total_size_2 *= s
+    # Initializes the input tensor with array containing incrementing
+    # numbers from 1.
+    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
+    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
+        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
+        strides = [1] + strides + [1]
+        # in_filter layout is HWOI
+        out = nn_ops.conv2d_transpose(in_data,
+                                      in_filter,
+                                      output_shape=output_shape,
+                                      strides=strides,
+                                      padding=padding)
+        data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
+        compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])
+
+
+def test_forward_transpose_conv():
+    # kernel 3x3, padding VALID
+    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], 'VALID')
+    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], 'VALID')
+    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], 'VALID')
+
+    # kernel 2x2, padding VALID
+    _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], 'VALID')
+    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], 'VALID')
+    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], 'VALID')
+
+    # kernel 1x1, padding VALID
+    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], 'VALID')
+    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], 'VALID')
+    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], 'VALID')
+
+    # kernel 1x1, padding SAME
+    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], 'SAME')
+    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], 'SAME')
+    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], 'SAME')
+
+
+#######################################################################
 # Reshape
 # -------
 
@@ -1232,6 +1286,7 @@ if __name__ == '__main__':
 
     # NN
     test_forward_convolution()
+    test_forward_transpose_conv()
     test_forward_logistic()
     test_forward_pooling()
     test_forward_softmax()