[Relay-TFLite] FP32 and Quantized Object Detection Model (#5479)
authorAnimesh Jain <anijain@umich.edu>
Fri, 8 May 2020 03:18:18 +0000 (20:18 -0700)
committerGitHub <noreply@github.com>
Fri, 8 May 2020 03:18:18 +0000 (11:18 +0800)
* TFlite e2e FP32 Object detection model

* Fix test

* [Relay-TFLite] Quantized activations

* Flexbuffer parsing

* Lint

* Relaxing checks.

* Github reviews

* comments

Co-authored-by: Ubuntu <ubuntu@ip-172-31-34-212.us-west-2.compute.internal>
python/tvm/relay/frontend/tflite.py
python/tvm/relay/frontend/tflite_flexbuffer.py [new file with mode: 0644]
python/tvm/relay/testing/tf.py
tests/python/frontend/tflite/test_forward.py

index ab0eabc..5a645c6 100644 (file)
@@ -31,6 +31,7 @@ from .. import qnn as _qnn
 from ... import nd as _nd
 from .common import ExprTable
 from .common import infer_shape as _infer_shape
+from .tflite_flexbuffer import FlexBufferDecoder
 
 __all__ = ['from_tflite']
 
@@ -323,6 +324,45 @@ class OperatorConverter(object):
                                          input_zero_point=tensor.qnn_params['zero_point'])
         return dequantized
 
+
+    def convert_qnn_fused_activation_function(self, expr, fused_activation_fn,
+                                              scale, zero_point, dtype):
+        """Convert TFLite fused activation function. The expr is an input quantized tensor with
+        scale and zero point """
+        try:
+            from tflite.ActivationFunctionType import ActivationFunctionType
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        # Quantize a float value to an quantized integer value
+        quantize = lambda x: float(int(round(x / scale)) + zero_point)
+
+        # Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not
+        # beyond the dtype range.
+        qmin = float(tvm.tir.op.min_value(dtype).value)
+        qmax = float(tvm.tir.op.max_value(dtype).value)
+
+        # The input expr is a quantized tensor with its scale and zero point. We calculate the
+        # suitable clip off points based on these scale and zero point.
+        if fused_activation_fn == ActivationFunctionType.NONE:
+            return expr
+        if fused_activation_fn == ActivationFunctionType.RELU6:
+            return _op.clip(expr,
+                            a_min=max(qmin, quantize(0)),
+                            a_max=min(qmax, quantize(6.0)))
+        if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1:
+            return _op.clip(expr,
+                            a_min=max(qmin, quantize(-1.0)),
+                            a_max=min(qmax, quantize(1.0)))
+        if fused_activation_fn == ActivationFunctionType.RELU:
+            return _op.clip(expr,
+                            a_min=max(qmin, quantize(0.0)),
+                            a_max=qmax)
+
+        fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
+        raise tvm.error.OpNotImplemented(
+            'Quantized activation {} is not supported yet.'.format(fused_activation_fn_str))
+
     def convert_conv2d(self, op):
         """Convert TFLite conv2d"""
         return self.convert_conv(op, "conv2d")
@@ -431,7 +471,6 @@ class OperatorConverter(object):
         try:
             from tflite.BuiltinOptions import BuiltinOptions
             from tflite.L2NormOptions import L2NormOptions
-            from tflite.ActivationFunctionType import ActivationFunctionType
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -456,17 +495,15 @@ class OperatorConverter(object):
         if self.is_quantized(op):
             raise tvm.error.OpNotImplemented(
                 'TFLite quantized L2_NORMALIZATION operator is not supported yet.')
+
         # TFL uses only the default epsilon value
         out = _op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1])
 
         # if we have fused activation fn
-        if fused_activation_fn != ActivationFunctionType.NONE:
-            if not output_tensor.qnn_params:
-                out = self.convert_fused_activation_function(out, fused_activation_fn)
-            else:
-                raise tvm.error.OpNotImplemented(
-                    'TFLite quantized L2_NORMALIZATION operator\
-                    with fused activation function is not supported yet.')
+        if output_tensor.qnn_params:
+            raise tvm.error.OpNotImplemented(
+                'TFLite quantized L2_NORMALIZATION operator is not supported yet.')
+        out = self.convert_fused_activation_function(out, fused_activation_fn)
 
         return out
 
@@ -611,7 +648,6 @@ class OperatorConverter(object):
         try:
             from tflite.ConcatenationOptions import ConcatenationOptions
             from tflite.BuiltinOptions import BuiltinOptions
-            from tflite.ActivationFunctionType import ActivationFunctionType
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -643,14 +679,20 @@ class OperatorConverter(object):
                                       output_zero_point=output_tensor.qnn_params['zero_point'],
                                       axis=concatenation_axis)
 
-        # if we have activation fn
-        if fused_activation_fn != ActivationFunctionType.NONE:
-            if not output_tensor.qnn_params:
-                out = self.convert_fused_activation_function(out, fused_activation_fn)
-            else:
-                raise tvm.error.OpNotImplemented(
-                    'Operator {} with fused activation is not supported yet.'
-                    .format('qnn.op.concatenate'))
+        # Handle fused activations
+        if output_tensor.qnn_params:
+            scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
+            zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
+            output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
+            out = self.convert_qnn_fused_activation_function(\
+                    expr=out,
+                    fused_activation_fn=fused_activation_fn,
+                    scale=scale_val,
+                    zero_point=zero_point_val,
+                    dtype=output_tensor_type_str)
+        else:
+            out = self.convert_fused_activation_function(out, fused_activation_fn)
+
         return out
 
     def _convert_unary_elemwise(self, relay_op, op):
@@ -793,7 +835,6 @@ class OperatorConverter(object):
             from tflite.MulOptions import MulOptions
             from tflite.DivOptions import DivOptions
             from tflite.BuiltinOptions import BuiltinOptions
-            from tflite.ActivationFunctionType import ActivationFunctionType
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -839,13 +880,20 @@ class OperatorConverter(object):
             op_options = op.BuiltinOptions()
             options.Init(op_options.Bytes, op_options.Pos)
             fused_activation_fn = options.FusedActivationFunction()
-            # if we have activation fn
-            if fused_activation_fn != ActivationFunctionType.NONE:
-                if output_tensor.qnn_params:
-                    raise tvm.error.OpNotImplemented(
-                        'Elemwise operators with fused activation are not supported yet.')
-                out = self.convert_fused_activation_function(out, fused_activation_fn)
 
+            # Handle fused activations
+            if output_tensor.qnn_params:
+                scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
+                zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
+                output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
+                out = self.convert_qnn_fused_activation_function(\
+                        expr=out,
+                        fused_activation_fn=fused_activation_fn,
+                        scale=scale_val,
+                        zero_point=zero_point_val,
+                        dtype=output_tensor_type_str)
+            else:
+                out = self.convert_fused_activation_function(out, fused_activation_fn)
         return out
 
     def convert_add(self, op):
@@ -1307,7 +1355,6 @@ class OperatorConverter(object):
             from tflite.FullyConnectedOptions import FullyConnectedOptions
             from tflite.BuiltinOptions import BuiltinOptions
             from tflite.TensorType import TensorType
-            from tflite.ActivationFunctionType import ActivationFunctionType
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -1389,15 +1436,6 @@ class OperatorConverter(object):
                                                dtype=bias_tensor_type_str)
             out = _op.nn.bias_add(out, bias_expr)
 
-        # If we have fused activations
-        if fused_activation_fn != ActivationFunctionType.NONE:
-            if not output_tensor.qnn_params:
-                out = self.convert_fused_activation_function(out, fused_activation_fn)
-            else:
-                raise tvm.error.OpNotImplemented(
-                    'Operator {} with fused activation is not supported yet.'
-                    .format('qnn.op.dense'))
-
         # Finally if the dense is quantized. Add a requantize at the end.
         if output_tensor.qnn_params:
             data_scale = input_tensor.qnn_params['scale']
@@ -1407,6 +1445,8 @@ class OperatorConverter(object):
             new_input_scale_val = data_scale_val * weight_scale_val
             new_input_scale = relay.const(new_input_scale_val, 'float32')
             new_input_zero_point = relay.const(0, 'int32')
+
+            # Requantize
             out = _qnn.op.requantize(out,
                                      input_scale=new_input_scale,
                                      input_zero_point=new_input_zero_point,
@@ -1414,6 +1454,19 @@ class OperatorConverter(object):
                                      output_zero_point=output_tensor.qnn_params['zero_point'],
                                      out_dtype=output_tensor_type_str)
 
+            # Call activation function
+            output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
+            output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
+            out = self.convert_qnn_fused_activation_function(\
+                    expr=out,
+                    fused_activation_fn=fused_activation_fn,
+                    scale=output_scale_val,
+                    zero_point=output_zero_point_val,
+                    dtype=output_tensor_type_str)
+
+        else:
+            out = self.convert_fused_activation_function(out, fused_activation_fn)
+
         return out
 
     def convert_squeeze(self, op):
@@ -1448,7 +1501,9 @@ class OperatorConverter(object):
             from tflite.ActivationFunctionType import ActivationFunctionType
         except ImportError:
             raise ImportError("The tflite package must be installed")
-        assert fused_activation_fn != ActivationFunctionType.NONE
+
+        if fused_activation_fn == ActivationFunctionType.NONE:
+            return in_expr
         if fused_activation_fn == ActivationFunctionType.RELU6:
             return _op.clip(in_expr, a_min=0, a_max=6)
         if fused_activation_fn == ActivationFunctionType.RELU:
@@ -1459,13 +1514,12 @@ class OperatorConverter(object):
             return _op.tanh(in_expr)
         fused_activation_fn_str = self.activation_fn_type[fused_activation_fn]
         raise tvm.error.OpNotImplemented(
-            'Operator {} is not supported for frontend TFLite.'.format(fused_activation_fn_str))
+            'Fused activation {} is not supported yet.'.format(fused_activation_fn_str))
 
     def convert_conv(self, op, conv_type):
         """convolution implementation."""
         try:
             from tflite.BuiltinOptions import BuiltinOptions
-            from tflite.ActivationFunctionType import ActivationFunctionType
             from tflite.TensorType import TensorType
             from tflite.Conv2DOptions import Conv2DOptions
             from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions
@@ -1596,17 +1650,9 @@ class OperatorConverter(object):
             channel_axis = 3
             out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)
 
-        # If we have fused activations
-        if fused_activation_fn != ActivationFunctionType.NONE:
-            if not output_tensor.qnn_params:
-                out = self.convert_fused_activation_function(out, fused_activation_fn)
-            else:
-                raise tvm.error.OpNotImplemented(
-                    'Operator {} with fused activation is not supported yet.'
-                    .format('qnn.op.conv2d'))
-
-        # Finally if the conv is quantized. Add a requantize at the end.
+        # Handle fused activation.
         if output_tensor.qnn_params:
+            # Calculate the intermediate scale and zero point of the int32 output.
             data_scale = input_tensor.qnn_params['scale']
             weight_scale = weight_tensor.qnn_params['scale']
             data_scale_val = get_scalar_from_constant(data_scale)
@@ -1614,6 +1660,8 @@ class OperatorConverter(object):
             new_input_scale_val = data_scale_val * weight_scale_val
             new_input_scale = relay.const(new_input_scale_val, 'float32')
             new_input_zero_point = relay.const(0, 'int32')
+
+            # Finally requantize
             out = _qnn.op.requantize(out,
                                      input_scale=new_input_scale,
                                      input_zero_point=new_input_zero_point,
@@ -1621,6 +1669,18 @@ class OperatorConverter(object):
                                      output_zero_point=output_tensor.qnn_params['zero_point'],
                                      out_dtype=output_tensor_type_str)
 
+            # Call activation function
+            output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
+            output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
+            out = self.convert_qnn_fused_activation_function(\
+                    expr=out,
+                    fused_activation_fn=fused_activation_fn,
+                    scale=output_scale_val,
+                    zero_point=output_zero_point_val,
+                    dtype=output_tensor_type_str)
+        else:
+            out = self.convert_fused_activation_function(out, fused_activation_fn)
+
         return out
 
     def convert_split(self, op):
@@ -1796,7 +1856,6 @@ class OperatorConverter(object):
         """pool2d implementation."""
         try:
             from tflite.BuiltinOptions import BuiltinOptions
-            from tflite.ActivationFunctionType import ActivationFunctionType
             from tflite.Pool2DOptions import Pool2DOptions
             from tflite.Padding import Padding
         except ImportError:
@@ -1871,13 +1930,19 @@ class OperatorConverter(object):
             raise tvm.error.OpNotImplemented(
                 'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool'))
 
-        # If we have fused activations
-        if fused_activation_fn != ActivationFunctionType.NONE:
-            if input_tensor.qnn_params:
-                raise tvm.error.OpNotImplemented(
-                    'Operator {} with fused activation is not supported yet.'
-                    .format('qnn.op.pool2d'))
+        # Handle fused activations
+        if output_tensor.qnn_params:
+            scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
+            zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
+            out = self.convert_qnn_fused_activation_function(\
+                    expr=out,
+                    fused_activation_fn=fused_activation_fn,
+                    scale=scale_val,
+                    zero_point=zero_point_val,
+                    dtype=output_tensor_type_str)
+        else:
             out = self.convert_fused_activation_function(out, fused_activation_fn)
+
         return out
 
     def convert_pad(self, op):
@@ -2266,28 +2331,15 @@ class OperatorConverter(object):
 
     def convert_detection_postprocess(self, op):
         """Convert TFLite_Detection_PostProcess"""
-        _option_names = [
-            "w_scale",
-            "max_detections",
-            "_output_quantized",
-            "detections_per_class",
-            "x_scale",
-            "nms_score_threshold",
-            "num_classes",
-            "max_classes_per_detection",
-            "use_regular_nms",
-            "y_scale",
-            "h_scale",
-            "_support_output_type_float_in_quantized_op",
-            "nms_iou_threshold"
-        ]
-
-        custom_options = get_custom_options(op, _option_names)
-        if custom_options["use_regular_nms"]:
-            raise tvm.error.OpAttributeUnImplemented(
-                "use_regular_nms=True is not yet supported for operator {}."
-                .format("TFLite_Detection_PostProcess")
-            )
+        flexbuffer = op.CustomOptionsAsNumpy().tobytes()
+        custom_options = FlexBufferDecoder(flexbuffer).decode()
+
+        if "use_regular_nms" in custom_options:
+            if custom_options["use_regular_nms"]:
+                raise tvm.error.OpAttributeUnImplemented(
+                    "use_regular_nms=True is not yet supported for operator {}."
+                    .format("TFLite_Detection_PostProcess")
+                )
 
         inputs = self.get_input_tensors(op)
         assert len(inputs) == 3, "inputs length should be 3"
@@ -2472,91 +2524,6 @@ def get_tensor_name(subgraph, tensor_idx):
     return subgraph.Tensors(tensor_idx).Name().decode("utf-8")
 
 
-def get_custom_options(op, option_names):
-    """Get the options of a custom operator.
-
-    This implements partial flexbuffer deserialization to be able
-    to read custom options. It is not intended to be a general
-    purpose flexbuffer deserializer and as such only supports a
-    limited number of types and assumes the data is a flat map.
-
-    Parameters
-    ----------
-    op:
-        A custom TFlite operator.
-    option_names: list
-        A complete list of the custom option names.
-
-    Returns
-    -------
-    options: dict
-        A dictionary of the custom options.
-
-    """
-    import struct
-    from enum import IntEnum
-
-    class _FlexBufferType(IntEnum):
-        """Flexbuffer type schema from flexbuffers.h"""
-        FBT_NULL = 0
-        FBT_INT = 1
-        FBT_UINT = 2
-        FBT_FLOAT = 3
-        # Types above stored inline, types below store an offset.
-        FBT_KEY = 4
-        FBT_STRING = 5
-        FBT_INDIRECT_INT = 6
-        FBT_INDIRECT_UINT = 7
-        FBT_INDIRECT_FLOAT = 8
-        FBT_MAP = 9
-        FBT_VECTOR = 10 # Untyped.
-        FBT_VECTOR_INT = 11 # Typed any size (stores no type table).
-        FBT_VECTOR_UINT = 12
-        FBT_VECTOR_FLOAT = 13
-        FBT_VECTOR_KEY = 14
-        FBT_VECTOR_STRING = 15
-        FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
-        FBT_VECTOR_UINT2 = 17
-        FBT_VECTOR_FLOAT2 = 18
-        FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
-        FBT_VECTOR_UINT3 = 20
-        FBT_VECTOR_FLOAT3 = 21
-        FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
-        FBT_VECTOR_UINT4 = 23
-        FBT_VECTOR_FLOAT4 = 24
-        FBT_BLOB = 25
-        FBT_BOOL = 26
-        FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type
-
-    buffer = op.CustomOptionsAsNumpy().tobytes()
-    value_vector_offset = buffer[-3]
-    buffer = buffer[:-3]
-    num_bytes = 4 # Assume all values are stored in 32 bit width
-    value_vector_size = struct.unpack(
-        "<i", buffer[-value_vector_offset - num_bytes:-value_vector_offset]
-    )[0]
-    type_offset = value_vector_size
-    types = buffer[-type_offset:]
-    values = []
-    for i, t in enumerate(types):
-        flex_type = _FlexBufferType(t >> 2)
-        value_offset = -value_vector_offset + i*num_bytes
-        value_bytes = buffer[value_offset:value_offset+num_bytes]
-        if flex_type == _FlexBufferType.FBT_BOOL:
-            value = bool(value_bytes[0])
-        if flex_type == _FlexBufferType.FBT_INT:
-            value = struct.unpack("<i", value_bytes)[0]
-        if flex_type == _FlexBufferType.FBT_UINT:
-            value = struct.unpack("<I", value_bytes)[0]
-        if flex_type == _FlexBufferType.FBT_FLOAT:
-            value = struct.unpack("<f", value_bytes)[0]
-
-        values.append(value)
-
-    custom_options = dict(zip(sorted(option_names), values))
-    return custom_options
-
-
 def from_tflite(model, shape_dict, dtype_dict):
     """Convert from tflite model into compatible relay Function.
 
diff --git a/python/tvm/relay/frontend/tflite_flexbuffer.py b/python/tvm/relay/frontend/tflite_flexbuffer.py
new file mode 100644 (file)
index 0000000..d08570b
--- /dev/null
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
+"""Tensorflow lite frontend helper to parse custom options in Flexbuffer format."""
+
+import struct
+from enum import IntEnum
+
+class BitWidth(IntEnum):
+    """Flexbuffer bit width schema from flexbuffers.h"""
+    BIT_WIDTH_8 = 0
+    BIT_WIDTH_16 = 1
+    BIT_WIDTH_32 = 2
+    BIT_WIDTH_64 = 3
+
+class FlexBufferType(IntEnum):
+    """Flexbuffer type schema from flexbuffers.h"""
+    FBT_NULL = 0
+    FBT_INT = 1
+    FBT_UINT = 2
+    FBT_FLOAT = 3
+    # Types above stored inline, types below store an offset.
+    FBT_KEY = 4
+    FBT_STRING = 5
+    FBT_INDIRECT_INT = 6
+    FBT_INDIRECT_UINT = 7
+    FBT_INDIRECT_FLOAT = 8
+    FBT_MAP = 9
+    FBT_VECTOR = 10 # Untyped.
+    FBT_VECTOR_INT = 11 # Typed any size (stores no type table).
+    FBT_VECTOR_UINT = 12
+    FBT_VECTOR_FLOAT = 13
+    FBT_VECTOR_KEY = 14
+    FBT_VECTOR_STRING = 15
+    FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
+    FBT_VECTOR_UINT2 = 17
+    FBT_VECTOR_FLOAT2 = 18
+    FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
+    FBT_VECTOR_UINT3 = 20
+    FBT_VECTOR_FLOAT3 = 21
+    FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
+    FBT_VECTOR_UINT4 = 23
+    FBT_VECTOR_FLOAT4 = 24
+    FBT_BLOB = 25
+    FBT_BOOL = 26
+    FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type
+
+
+class FlexBufferDecoder(object):
+    """
+    This implements partial flexbuffer deserialization to be able
+    to read custom options. It is not intended to be a general
+    purpose flexbuffer deserializer and as such only supports a
+    limited number of types and assumes the data is a flat map.
+    """
+
+    def __init__(self, buffer):
+        self.buffer = buffer
+
+    def indirect_jump(self, offset, byte_width):
+        """ Helper function to read the offset value and jump """
+        unpack_str = ""
+        if byte_width == 1:
+            unpack_str = "<B"
+        elif byte_width == 4:
+            unpack_str = "<i"
+        assert unpack_str != ""
+        back_jump = struct.unpack(unpack_str,
+                                  self.buffer[offset: offset + byte_width])[0]
+        return offset - back_jump
+
+    def decode_keys(self, end, size, byte_width):
+        """ Decodes the flexbuffer type vector. Map keys are stored in this form """
+        # Keys are strings here. The format is all strings seperated by null, followed by back
+        # offsets for each of the string. For example, (str1)\0(str1)\0(offset1)(offset2) The end
+        # pointer is pointing at the end of all strings
+        keys = list()
+        for i in range(0, size):
+            offset_pos = end + i * byte_width
+            start_index = self.indirect_jump(offset_pos, byte_width)
+            str_size = self.buffer[start_index:].find(b"\0")
+            assert str_size != -1
+            s = self.buffer[start_index: start_index + str_size].decode("utf-8")
+            keys.append(s)
+        return keys
+
+    def decode_vector(self, end, size, byte_width):
+        """ Decodes the flexbuffer vector """
+        # Each entry in the vector can have different datatype. Each entry is of fixed length. The
+        # format is a sequence of all values followed by a sequence of datatype of all values. For
+        # example - (4)(3.56)(int)(float) The end here points to the start of the values.
+        values = list()
+        for i in range(0, size):
+            value_type_pos = end + size * byte_width + i
+            value_type = FlexBufferType(self.buffer[value_type_pos] >> 2)
+            value_bytes = self.buffer[end + i * byte_width: end + (i + 1) * byte_width]
+            if value_type == FlexBufferType.FBT_BOOL:
+                value = bool(value_bytes[0])
+            elif value_type == FlexBufferType.FBT_INT:
+                value = struct.unpack("<i", value_bytes)[0]
+            elif value_type == FlexBufferType.FBT_UINT:
+                value = struct.unpack("<I", value_bytes)[0]
+            elif value_type == FlexBufferType.FBT_FLOAT:
+                value = struct.unpack("<f", value_bytes)[0]
+            else:
+                raise Exception
+            values.append(value)
+        return values
+
+    def decode_map(self, end, byte_width, parent_byte_width):
+        """ Decodes the flexbuffer map and returns a dict """
+        mid_loc = self.indirect_jump(end, parent_byte_width)
+        map_size = struct.unpack("<i", self.buffer[mid_loc - byte_width:mid_loc])[0]
+
+        # Find keys
+        keys_offset = mid_loc - byte_width * 3
+        keys_end = self.indirect_jump(keys_offset, byte_width)
+        keys = self.decode_keys(keys_end, map_size, 1)
+
+        # Find values
+        values_end = self.indirect_jump(end, parent_byte_width)
+        values = self.decode_vector(values_end, map_size, byte_width)
+        return dict(zip(keys, values))
+
+    def decode(self):
+        """ Decode the buffer. Decoding is partially implemented """
+        root_end = len(self.buffer) - 1
+        root_byte_width = self.buffer[root_end]
+        root_end -= 1
+        root_packed_type = self.buffer[root_end]
+        root_end -= root_byte_width
+
+        root_type = FlexBufferType(root_packed_type >> 2)
+        byte_width = 1 << BitWidth(root_packed_type & 3)
+
+        if root_type == FlexBufferType.FBT_MAP:
+            return self.decode_map(root_end, byte_width, root_byte_width)
+        raise NotImplementedError("Flexbuffer Decoding is partially imlpemented.")
index 1a231eb..dc7937c 100644 (file)
@@ -183,11 +183,16 @@ def get_workload_official(model_url, model_sub_path):
     model_path = download_testdata(model_url, model_tar_name, module=['tf', 'official'])
     dir_path = os.path.dirname(model_path)
 
-    import tarfile
     if model_path.endswith("tgz") or model_path.endswith("gz"):
+        import tarfile
         tar = tarfile.open(model_path)
         tar.extractall(path=dir_path)
         tar.close()
+    elif model_path.endswith("zip"):
+        import zipfile
+        zip_object = zipfile.ZipFile(model_path)
+        zip_object.extractall(path=dir_path)
+        zip_object.close()
     else:
         raise RuntimeError('Could not decompress the file: ' + model_path)
     return os.path.join(dir_path, model_sub_path)
index 957f622..da89a13 100644 (file)
@@ -73,6 +73,16 @@ def get_real_image(im_height, im_width):
     data = np.reshape(x, (1, im_height, im_width, 3))
     return data
 
+def get_real_image_object_detection(im_height, im_width):
+    repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/'
+    img_name = 'street_small.jpg'
+    image_url = os.path.join(repo_base, img_name)
+    img_path = download_testdata(image_url, img_name, module='data')
+    image = Image.open(img_path).resize((im_height, im_width))
+    x = np.array(image).astype('uint8')
+    data = np.reshape(x, (1, im_height, im_width, 3))
+    return data
+
 def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
                   out_names=None):
     """ Generic function to compile on relay and execute on tvm """
@@ -98,6 +108,7 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
     mod, params = relay.frontend.from_tflite(tflite_model,
                                              shape_dict=shape_dict,
                                              dtype_dict=dtype_dict)
+
     with relay.build_config(opt_level=3):
         graph, lib, params = relay.build(mod, target, params=params)
 
@@ -1822,23 +1833,30 @@ def test_detection_postprocess():
     tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions])
     tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions],
                                ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4)
-    # check valid count is the same
+
+    # Check all output shapes are equal
+    assert all([tvm_tensor.shape == tflite_tensor.shape \
+                for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)])
+
+    # Check valid count is the same
     assert tvm_output[3] == tflite_output[3]
-    # check all the output shapes are the same
-    assert tvm_output[0].shape == tflite_output[0].shape
-    assert tvm_output[1].shape == tflite_output[1].shape
-    assert tvm_output[2].shape == tflite_output[2].shape
     valid_count = tvm_output[3][0]
-    # only check the valid detections are the same
-    # tvm has a different convention to tflite for invalid detections, it uses all -1s whereas
-    # tflite appears to put in nonsense data instead
-    tvm_boxes = tvm_output[0][0][:valid_count]
-    tvm_classes = tvm_output[1][0][:valid_count]
-    tvm_scores = tvm_output[2][0][:valid_count]
-    # check the output data is correct
-    tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5)
-    tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5)
-    tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5)
+
+    # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare
+    # tflite and tvm tensors for only valid boxes.
+    for i in range(0, valid_count):
+        # Check bounding box co-ords
+        tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]),
+                                    rtol=1e-5, atol=1e-5)
+
+        # Check the class
+        # Stricter check to ensure class remains same
+        np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]),
+                                np.squeeze(tflite_output[1][0][i]))
+
+        # Check the score
+        tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]),
+                                    rtol=1e-5, atol=1e-5)
 
 
 #######################################################################
@@ -2024,6 +2042,100 @@ def test_forward_qnn_mobilenet_v3_net():
 
 
 #######################################################################
+# Quantized SSD Mobilenet
+# -----------------------
+
+def test_forward_qnn_coco_ssd_mobilenet_v1():
+    """Test the quantized Coco SSD Mobilenet V1 TF Lite model."""
+    pytest.skip("LLVM bug - getExtendedVectorNumElements - "
+                + "https://discuss.tvm.ai/t/segfault-in-llvm/3567. The workaround is to use a "
+                + "specific target, for example, llvm -mpcu=core-avx2")
+
+    tflite_model_file = tf_testing.get_workload_official(
+        "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip",
+        "detect.tflite")
+
+    with open(tflite_model_file, "rb") as f:
+        tflite_model_buf = f.read()
+
+    data = get_real_image_object_detection(300, 300)
+    tflite_output = run_tflite_graph(tflite_model_buf, data)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4)
+
+    # Check all output shapes are equal
+    assert all([tvm_tensor.shape == tflite_tensor.shape \
+                for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)])
+
+    # Check valid count is the same
+    assert tvm_output[3] == tflite_output[3]
+    valid_count = tvm_output[3][0]
+
+    # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare
+    # tflite and tvm tensors for only valid boxes.
+    for i in range(0, valid_count):
+        # We compare the bounding boxes whose prediction score is above 60%. This is typical in end
+        # to end application where a low prediction score is discarded. This is also needed because
+        # multiple low score bounding boxes can have same score and TFlite and TVM can have
+        # different orderings for same score bounding boxes. Another reason for minor differences in
+        # low score bounding boxes is the difference between TVM and TFLite for requantize operator.
+        if tvm_output[2][0][i] > 0.6:
+            # Check bounding box co-ords. The tolerances have to be adjusted, from 1e-5 to 1e-2,
+            # because of differences between for requantiize operator in TFLite and TVM.
+            tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]),
+                                        np.squeeze(tflite_output[0][0][i]),
+                                        rtol=1e-2, atol=1e-2)
+
+            # Check the class
+            # Stricter check to ensure class remains same
+            np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]),
+                                    np.squeeze(tflite_output[1][0][i]))
+
+            # Check the score
+            tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]),
+                                        np.squeeze(tflite_output[2][0][i]),
+                                        rtol=1e-5, atol=1e-5)
+
+
+#######################################################################
+# SSD Mobilenet
+# -------------
+
+def test_forward_coco_ssd_mobilenet_v1():
+    """Test the FP32 Coco SSD Mobilenet V1 TF Lite model."""
+    tflite_model_file = tf_testing.get_workload_official(
+        "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tgz",
+        "ssd_mobilenet_v1_coco_2018_01_28.tflite")
+
+    with open(tflite_model_file, "rb") as f:
+        tflite_model_buf = f.read()
+
+    np.random.seed(0)
+    data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32')
+    tflite_output = run_tflite_graph(tflite_model_buf, data)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4)
+
+    # Check all output shapes are equal
+    assert all([tvm_tensor.shape == tflite_tensor.shape \
+                for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)])
+
+    # Check valid count is the same
+    assert tvm_output[3] == tflite_output[3]
+    valid_count = tvm_output[3][0]
+
+    # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare
+    # tflite and tvm tensors for only valid boxes.
+    for i in range(0, valid_count):
+        # Check bounding box co-ords
+        tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]),
+                                    rtol=1e-5, atol=1e-5)
+        # Check the class
+        np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]))
+
+        # Check the score
+        tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]),
+                                    rtol=1e-5, atol=1e-5)
+
+#######################################################################
 # MediaPipe
 # -------------
 
@@ -2125,6 +2237,7 @@ if __name__ == '__main__':
     test_forward_mobilenet_v3()
     test_forward_inception_v3_net()
     test_forward_inception_v4_net()
+    test_forward_coco_ssd_mobilenet_v1()
     test_forward_mediapipe_hand_landmark()
 
     # End to End quantized
@@ -2134,3 +2247,4 @@ if __name__ == '__main__':
     #This also fails with a segmentation fault in my run
     #with Tflite 1.15.2
     test_forward_qnn_mobilenet_v3_net()
+    test_forward_qnn_coco_ssd_mobilenet_v1()