[Frontend][TFLite] Add parser support for shape and range (#5329)
authorDhruva Ray <dhruvaray@gmail.com>
Thu, 4 Jun 2020 17:55:38 +0000 (23:25 +0530)
committerGitHub <noreply@github.com>
Thu, 4 Jun 2020 17:55:38 +0000 (10:55 -0700)
* [Relay][Frontend][TFLite] Add parser support for shape and range

Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>
* Incorporated review comments and used new functions

Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>
* Few cosmetic changes

Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>
* Removed an extra line added by rebase...

Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 15d0253..08ea715 100644 (file)
@@ -114,6 +114,7 @@ class OperatorConverter(object):
             'PAD': self.convert_pad,
             'POW': self.convert_pow,
             'PRELU': self.convert_prelu,
+            'RANGE': self.convert_range,
             'QUANTIZE': self.convert_quantize,
             'REDUCE_ANY': self.convert_reduce_any,
             'REDUCE_MAX': self.convert_reduce_max,
@@ -126,6 +127,7 @@ class OperatorConverter(object):
             'ROUND': self.convert_round,
             'RSQRT': self.convert_rsqrt,
             'SELECT': self.convert_select,
+            'SHAPE': self.convert_shape,
             'SIN': self.convert_sin,
             'SLICE': self.convert_slice,
             'SOFTMAX': self.convert_softmax,
@@ -609,6 +611,39 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_range(self, op):
+        """Convert TFLite Range"""
+        try:
+            from tflite.TensorType import TensorType
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 3, "input tensors length should be 3"
+
+        start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]
+
+        expressions = [self.get_tensor_expr(t) for t in [start, limit, delta]]
+
+        # out type inference
+        if delta.tensor.Type() == TensorType.FLOAT32:
+            out_type = self.get_tensor_type_str(delta.tensor.Type())
+        else:
+            out_type = self.get_tensor_type_str(start.tensor.Type())
+
+        out = _op.arange(expressions[0], expressions[1], expressions[2], out_type)
+
+        return out
+
+    def convert_shape(self, op):
+        """Convert TFLite Shape"""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+
+        out = _op.shape_of(self.get_tensor_expr(input_tensors[0]))
+
+        return out
+
     def convert_relu(self, op):
         """Convert TFLite ReLU"""
         input_tensors = self.get_input_tensors(op)
index d5dafd8..2951540 100644 (file)
@@ -83,8 +83,34 @@ def get_real_image_object_detection(im_height, im_width):
     data = np.reshape(x, (1, im_height, im_width, 3))
     return data
 
+def vmobj_to_list(o):
+    if isinstance(o, tvm.nd.NDArray):
+        return [o.asnumpy().tolist()]
+    elif isinstance(o, tvm.runtime.container.ADT):
+        result = []
+        for f in o:
+            result.extend(vmobj_to_list(f))
+        return result
+    elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
+        if o.constructor.name_hint == 'Cons':
+            tl = vmobj_to_list(o.fields[1])
+            hd = vmobj_to_list(o.fields[0])
+            hd.extend(tl)
+            return hd
+        elif o.constructor.name_hint == 'Nil':
+            return []
+        elif 'tensor_nil' in o.constructor.name_hint:
+            return [0]
+        elif 'tensor' in o.constructor.name_hint:
+            return [o.fields[0].asnumpy()]
+        else:
+            raise RuntimeError("Unknown object type: %s" %
+                               o.constructor.name_hint)
+    else:
+        raise RuntimeError("Unknown object type: %s" % type(o))
+
 def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
-                  out_names=None):
+                  out_names=None, mode='graph_runtime'):
     """ Generic function to compile on relay and execute on tvm """
     # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
     try:
@@ -109,27 +135,43 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
                                              shape_dict=shape_dict,
                                              dtype_dict=dtype_dict)
 
-    with tvm.transform.PassContext(opt_level=3):
-        graph, lib, params = relay.build(mod, target, params=params)
-
-    ctx = tvm.context(target, 0)
-    from tvm.contrib import graph_runtime
-    m = graph_runtime.create(graph, lib, ctx)
-    # set inputs
-    for i, e in enumerate(input_node):
-        m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
-
-    m.set_input(**params)
-    # execute
-    m.run()
-    # get outputs
-    assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format(
-        out_names, num_output)
-    tvm_output_list = []
-    for i in range(0, num_output):
-        tvm_output = m.get_output(i)
-        tvm_output_list.append(tvm_output.asnumpy())
-    return tvm_output_list
+    if mode in ['debug', 'vm']:
+        ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm")
+        inputs = []
+        for param in mod['main'].params:
+            found = False
+            for i, n in enumerate(input_node):
+                if n == param.name_hint:
+                    found = True
+                    inputs.append(tvm.nd.array(input_data[i]))
+                    break
+            # Interpreter doesn't bind constants, so still need to find in params
+            if not found:
+                inputs.append(tvm.nd.array(params[param.name_hint]))
+        result = ex.evaluate()(*inputs)
+        return vmobj_to_list(result)
+    else:
+        with tvm.transform.PassContext(opt_level=3):
+            graph, lib, params = relay.build(mod, target, params=params)
+
+        ctx = tvm.context(target, 0)
+        from tvm.contrib import graph_runtime
+        m = graph_runtime.create(graph, lib, ctx)
+        # set inputs
+        for i, e in enumerate(input_node):
+            m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
+
+        m.set_input(**params)
+        # execute
+        m.run()
+        # get outputs
+        assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format(
+            out_names, num_output)
+        tvm_output_list = []
+        for i in range(0, num_output):
+            tvm_output = m.get_output(i)
+            tvm_output_list.append(tvm_output.asnumpy())
+        return tvm_output_list
 
 
 def run_tflite_graph(tflite_model_buf, input_data):
@@ -160,7 +202,7 @@ def run_tflite_graph(tflite_model_buf, input_data):
 
 def compare_tflite_with_tvm(in_data, in_name, input_tensors,
                             output_tensors, init_global_variables=False,
-                            out_names=None, quantized=False, input_range=None):
+                            out_names=None, quantized=False, input_range=None, mode='graph_runtime'):
     """Generic function to generate and compare TFLite and TVM output"""
     in_data = convert_to_list(in_data)
     in_name = convert_to_list(in_name)
@@ -202,7 +244,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
                 continue
 
             tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device,
-                                       num_output=len(out_names), out_names=out_names)
+                                       num_output=len(out_names), out_names=out_names, mode=mode)
 
             # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output
             # range for the specific operator. While adding test ensure that we aren't getting only clipped values
@@ -860,6 +902,80 @@ def test_all_resize():
         _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)
 
 #######################################################################
+# Range
+# -----
+def _test_range(start, limit, delta):
+    # tflite 1.13 convert method does not accept empty shapes
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+        tf.reset_default_graph()
+        with tf.Graph().as_default():
+            start_scalar, limit_scalar, delta_scalar = \
+                tf.placeholder(dtype=start.dtype, shape=(), name="start"), \
+                tf.placeholder(dtype=limit.dtype, shape=(), name="limit"), \
+                tf.placeholder(dtype=delta.dtype, shape=(), name="delta")
+
+            out = tf.range(start_scalar, limit_scalar, delta_scalar, name="range")
+
+            compare_tflite_with_tvm(
+                [start, limit, delta],
+                ["start", "limit", "delta"],
+                [start_scalar, limit_scalar, delta_scalar],
+                [out],
+                mode="vm",
+                quantized=False
+        )
+
+def _test_range_default():
+    # tflite 1.13 convert method does not accept empty shapes
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+        tf.reset_default_graph()
+        with tf.Graph().as_default():
+            inputs = [
+                tf.placeholder(dtype=tf.int32, shape=(), name="p1"),
+                tf.placeholder(dtype=tf.int32, shape=(), name="p2")
+            ]
+            outputs = [
+                tf.range(start = inputs[0], limit = inputs[1]), # use default delta
+                tf.range(start = inputs[1]) # use start as limit with 0 as the first item in the range
+            ]
+
+            compare_tflite_with_tvm(
+                [np.int32(1), np.int32(18)],
+                ["p1", "p2"],
+                inputs,
+                outputs,
+                mode="vm"
+        )
+
+def test_forward_range():
+   _test_range(np.int32(1), np.int32(18), np.int32(3))
+   _test_range(np.int32(1), np.int32(18), np.float32(3.1)) # increment is of type float
+   _test_range(np.float32(1.0), np.int32(18), np.int32(3.1)) # start is of type float
+   _test_range_default()
+
+#######################################################################
+# Shape
+# -----
+def test_forward_shape():
+    # tflite 1.13 convert method does not accept empty shapes
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+        tf.reset_default_graph()
+        with tf.Graph().as_default():
+            data = np.array([1, 18, 3], dtype=np.int32)
+            start = tf.placeholder(dtype=tf.int32, shape=[], name="start")
+            limit = tf.placeholder(dtype=tf.int32, shape=[], name="limit")
+            delta = tf.placeholder(dtype=tf.int32, shape=[], name="delta")
+            r = tf.range(start, limit, delta, tf.int32, name="range")
+            out = tf.shape(r, out_type=tf.dtypes.int32)
+            compare_tflite_with_tvm(
+                [x for x in np.nditer(data)],
+                ["start", "limit", "delta"],
+                [start, limit, delta],
+                [out],
+                mode="vm"
+            )
+
+#######################################################################
 # Concatenation
 # -------------
 
@@ -2363,6 +2479,9 @@ if __name__ == '__main__':
     # Tile
     test_forward_tile()
 
+    # Query
+    test_forward_shape()
+
     # Transforms
     test_forward_concatenation()
     test_forward_pad()
@@ -2370,6 +2489,7 @@ if __name__ == '__main__':
     test_forward_unpack()
     test_forward_reshape()
     test_all_resize()
+    test_forward_range()
     test_forward_squeeze()
     test_forward_slice()
     test_forward_topk()