[TFLite] Implemented EXPAND_DIMS Operator for TFLite. (#6243)
authorRishabh Jain <56974688+jainris@users.noreply.github.com>
Tue, 11 Aug 2020 08:05:55 +0000 (13:35 +0530)
committerGitHub <noreply@github.com>
Tue, 11 Aug 2020 08:05:55 +0000 (16:05 +0800)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index f168f1b..11d6576 100644 (file)
@@ -84,6 +84,7 @@ class OperatorConverter(object):
             'ELU': self.convert_elu,
             'EQUAL': self.convert_equal,
             'EXP': self.convert_exp,
+            'EXPAND_DIMS': self.convert_expand_dims,
             'FILL': self.convert_fill,
             'FLOOR_DIV': self.convert_floor_div,
             'FLOOR_MOD': self.convert_floor_mod,
@@ -2904,6 +2905,31 @@ class OperatorConverter(object):
         ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
         return ret
 
+    def convert_expand_dims(self, op):
+        """Convert TFLite EXPAND_DIMS"""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+
+        if input_tensors[0].qnn_params:
+            # Check that input and output tensor have same qnn params.
+            output_tensors = self.get_output_tensors(op)
+            assert self.has_same_qnn_params(input_tensors[0], output_tensors[0]), \
+                "TFLite EXPAND_DIMS requires input and output tensors' \
+                    scale and zero points to be equal"
+
+        input_expr = self.get_tensor_expr(input_tensors[0])
+        axis = self.get_tensor_value(input_tensors[1])
+        if isinstance(axis, np.ndarray):
+            assert len(axis) == 1, "only one value is expected."
+            axis = int(axis)
+
+        ndims = len(input_tensors[0].tensor.ShapeAsNumpy())
+        assert (-1-ndims <= axis <= ndims), "axis out of range"
+
+        out = _op.expand_dims(input_expr, axis, 1)
+
+        return out
+
     def convert_one_hot(self, op):
         """Convert TFLite ONE_HOT"""
         try:
index 2e57175..33ac6d4 100644 (file)
@@ -2031,6 +2031,61 @@ def test_forward_padv2():
 
 
 #######################################################################
+# EXPAND_DIMS
+# -----------
+
+def _test_expand_dims(input_shape, input_type, axis, quantized=False):
+    """ One iteration of EXPAND_DIMS """
+    with tf.Graph().as_default():
+        axis= ops.convert_to_tensor(axis, dtype=axis.dtype)
+
+        if quantized:
+            # ignoring input_type as quantized requires uint8
+            input = np.random.uniform(0, 256, input_shape).astype('uint8')
+            in_input = tf.placeholder(dtype='float32', shape=input.shape, name="input")
+
+            input_range = {'q_input': (-100, 100)}
+            inq_input = tf.quantization.fake_quant_with_min_max_args(
+                in_input,
+                min=-100,
+                max=100,
+                name="q_input")
+
+            out = array_ops.expand_dims(inq_input, axis=axis)
+            out = tf.quantization.fake_quant_with_min_max_args(
+                out,
+                min=-100,
+                max=100,
+                name="out")
+
+            compare_tflite_with_tvm(
+                [input],
+                ["q_input"],
+                [inq_input],
+                [out],
+                quantized=True,
+                input_range=input_range)
+        else:
+            input = np.random.uniform(-100, 100, input_shape).astype(input_type)
+            in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")
+
+            out = array_ops.expand_dims(in_input, axis=axis)
+
+            compare_tflite_with_tvm(
+                [input],
+                ["input"],
+                [in_input],
+                [out])
+
+def test_forward_expand_dims():
+    """ EXPAND_DIMS """
+    for quantized in [False, True]:
+        _test_expand_dims((6, 2, 7, 5), 'float32', np.int32(0), quantized=quantized)
+        _test_expand_dims((1, 2, 3), 'int32', np.int32(-2), quantized=quantized)
+        _test_expand_dims((2, 4, 5), 'float32', np.array([1], dtype=np.int32), quantized=quantized)
+
+
+#######################################################################
 # ONE_HOT
 # -------
 
@@ -3021,6 +3076,7 @@ if __name__ == '__main__':
     test_forward_select()
     test_forward_quantize_dequantize()
     test_forward_arg_min_max()
+    test_forward_expand_dims()
 
     # NN
     test_forward_convolution()