[FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess (#4543)
authormbarrett97 <55580676+mbarrett97@users.noreply.github.com>
Thu, 13 Feb 2020 02:07:58 +0000 (02:07 +0000)
committerGitHub <noreply@github.com>
Thu, 13 Feb 2020 02:07:58 +0000 (10:07 +0800)
* [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

This adds support for the custom operator
TFLite_Detection_PostProcess which is commonly used in
object detection networks such as SSD Mobilenet. It
only adds support for when use_regular_nms = False.

Change-Id: I819b253c0eb6f0fa55da65d2634e09359b888828

* Added a test for the tflite custom op

Change-Id: Ie5baa092deae9a8bcffd2ebd9f6d346b90e58afd

* Removed trailing comma

Change-Id: Ib08f02b5f1a59a883048bfb36e4321152cd2e7f2

* Added spaces between divide

Change-Id: If1171fc03d211a809cedeb800804394972af4060

* Formatted comment

Change-Id: I3ce7e69b8d2c73aec57369c1c64ea1eec07f087b

* Reduced line length in test

Change-Id: I49eaafc3369070f8f3e85fbb965ad20972096c68

* Set random seed for test

Change-Id: I542a787d11422ea83c52147b2cb1144fcef0dd77

* Fixes to style

Change-Id: I2971b8ecebe08c882b2481a99f67cfbe515e0b1f

* Assert for incorrect number of inputs

Change-Id: I393f3b3b62be73e427498d98456fb1d5a214e0af

* Change comparison to pass linting

The linter was updated, so I needed to fix
a small style issue as a result.

Change-Id: Ia3c954565a00de92e7fb1912eae9ed9875d60c7c

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index a0b0c0f..d889631 100644 (file)
@@ -121,6 +121,7 @@ class OperatorConverter(object):
             'SQUARED_DIFFERENCE': self.convert_squared_difference,
             'LOGICAL_AND': self.convert_logical_and,
             'LOGICAL_OR': self.convert_logical_or,
+            'DETECTION_POSTPROCESS': self.convert_detection_postprocess
         }
 
     def check_unsupported_ops(self):
@@ -168,6 +169,10 @@ class OperatorConverter(object):
         op_code_str = self.builtin_op_code[op_code_id]
         if op_code_id == BuiltinOperator.CUSTOM:
             # Custom operator
+            custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode()
+            if custom_op_code_str == b'TFLite_Detection_PostProcess':
+                return "DETECTION_POSTPROCESS"
+
             raise NotImplementedError("Custom operators are currently not supported")
         return op_code_str
 
@@ -1814,6 +1819,113 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_detection_postprocess(self, op):
+        """Convert TFLite_Detection_PostProcess"""
+        _option_names = [
+            "w_scale",
+            "max_detections",
+            "_output_quantized",
+            "detections_per_class",
+            "x_scale",
+            "nms_score_threshold",
+            "num_classes",
+            "max_classes_per_detection",
+            "use_regular_nms",
+            "y_scale",
+            "h_scale",
+            "_support_output_type_float_in_quantized_op",
+            "nms_iou_threshold"
+        ]
+
+        custom_options = get_custom_options(op, _option_names)
+        if custom_options["use_regular_nms"]:
+            raise tvm.error.OpAttributeUnImplemented(
+                "use_regular_nms=True is not yet supported for operator {}."
+                .format("TFLite_Detection_PostProcess")
+            )
+
+        inputs = self.get_input_tensors(op)
+        assert len(inputs) == 3, "inputs length should be 3"
+        cls_pred = self.get_expr(inputs[1].tensor_idx)
+        loc_prob = self.get_expr(inputs[0].tensor_idx)
+        anchor_values = self.get_tensor_value(inputs[2])
+        anchor_boxes = len(anchor_values)
+        anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
+        anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)
+
+        if inputs[0].qnn_params:
+            loc_prob = _qnn.op.dequantize(data=loc_prob,
+                                          input_scale=inputs[0].qnn_params['scale'],
+                                          input_zero_point=inputs[0].qnn_params['zero_point'])
+        if inputs[1].qnn_params:
+            cls_pred = _qnn.op.dequantize(data=cls_pred,
+                                          input_scale=inputs[1].qnn_params['scale'],
+                                          input_zero_point=inputs[1].qnn_params['zero_point'])
+        if inputs[2].qnn_params:
+            anchor_expr = _qnn.op.dequantize(data=anchor_expr,
+                                             input_scale=inputs[2].qnn_params['scale'],
+                                             input_zero_point=inputs[2].qnn_params['zero_point'])
+
+        # reshape the cls_pred and loc_prob tensors so
+        # they can be consumed by multibox_transform_loc
+        cls_pred = _op.transpose(cls_pred, [0, 2, 1])
+        # loc_prob coords are in yxhw format
+        # need to convert to xywh
+        loc_coords = _op.split(loc_prob, 4, axis=2)
+        loc_prob = _op.concatenate(
+            [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
+        )
+        loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])
+
+        # anchor coords are in yxhw format
+        # need to convert to ltrb
+        anchor_coords = _op.split(anchor_expr, 4, axis=1)
+        anchor_y = anchor_coords[0]
+        anchor_x = anchor_coords[1]
+        anchor_h = anchor_coords[2]
+        anchor_w = anchor_coords[3]
+        plus_half = _expr.const(0.5, dtype='float32')
+        minus_half = _expr.const(-0.5, dtype='float32')
+        anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
+        anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
+        anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
+        anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half))
+        anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, anchor_b], axis=1)
+        anchor_expr = _op.expand_dims(anchor_expr, 0)
+
+        # attributes for multibox_transform_loc
+        multibox_transform_loc_attrs = {}
+        multibox_transform_loc_attrs["clip"] = False
+        multibox_transform_loc_attrs["threshold"] = custom_options["nms_score_threshold"]
+        multibox_transform_loc_attrs["variances"] = (
+            1 / custom_options["x_scale"],
+            1 / custom_options["y_scale"],
+            1 / custom_options["w_scale"],
+            1 / custom_options["h_scale"],
+        )
+
+        # attributes for non_max_suppression
+        non_max_suppression_attrs = {}
+        non_max_suppression_attrs["return_indices"] = False
+        non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"]
+        non_max_suppression_attrs["force_suppress"] = False
+        non_max_suppression_attrs["top_k"] = anchor_boxes
+        non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"]
+        non_max_suppression_attrs["invalid_to_bottom"] = False
+
+        ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob,
+                                                anchor_expr, **multibox_transform_loc_attrs)
+        ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs)
+        ret = _op.vision.get_valid_counts(ret, 0)
+        valid_count = ret[0]
+        # the output needs some reshaping to match tflite
+        ret = _op.split(ret[1], 6, axis=2)
+        cls_ids = ret[0]
+        scores = ret[1]
+        boxes = _op.concatenate([ret[3], ret[2], ret[5], ret[4]], axis=2)
+        ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
+        return ret
+
     def get_expr(self, input_tensor_idx):
         return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
 
@@ -1885,6 +1997,91 @@ def get_tensor_name(subgraph, tensor_idx):
     return subgraph.Tensors(tensor_idx).Name().decode("utf-8")
 
 
+def get_custom_options(op, option_names):
+    """Get the options of a custom operator.
+
+    This implements partial flexbuffer deserialization to be able
+    to read custom options. It is not intended to be a general
+    purpose flexbuffer deserializer and as such only supports a
+    limited number of types and assumes the data is a flat map.
+
+    Parameters
+    ----------
+    op:
+        A custom TFlite operator.
+    option_names: list
+        A complete list of the custom option names.
+
+    Returns
+    -------
+    options: dict
+        A dictionary of the custom options.
+
+    """
+    import struct
+    from enum import IntEnum
+
+    class _FlexBufferType(IntEnum):
+        """Flexbuffer type schema from flexbuffers.h"""
+        FBT_NULL = 0
+        FBT_INT = 1
+        FBT_UINT = 2
+        FBT_FLOAT = 3
+        # Types above stored inline, types below store an offset.
+        FBT_KEY = 4
+        FBT_STRING = 5
+        FBT_INDIRECT_INT = 6
+        FBT_INDIRECT_UINT = 7
+        FBT_INDIRECT_FLOAT = 8
+        FBT_MAP = 9
+        FBT_VECTOR = 10 # Untyped.
+        FBT_VECTOR_INT = 11 # Typed any size (stores no type table).
+        FBT_VECTOR_UINT = 12
+        FBT_VECTOR_FLOAT = 13
+        FBT_VECTOR_KEY = 14
+        FBT_VECTOR_STRING = 15
+        FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
+        FBT_VECTOR_UINT2 = 17
+        FBT_VECTOR_FLOAT2 = 18
+        FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
+        FBT_VECTOR_UINT3 = 20
+        FBT_VECTOR_FLOAT3 = 21
+        FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
+        FBT_VECTOR_UINT4 = 23
+        FBT_VECTOR_FLOAT4 = 24
+        FBT_BLOB = 25
+        FBT_BOOL = 26
+        FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type
+
+    buffer = op.CustomOptionsAsNumpy().tobytes()
+    value_vector_offset = buffer[-3]
+    buffer = buffer[:-3]
+    num_bytes = 4 # Assume all values are stored in 32 bit width
+    value_vector_size = struct.unpack(
+        "<i", buffer[-value_vector_offset - num_bytes:-value_vector_offset]
+    )[0]
+    type_offset = value_vector_size
+    types = buffer[-type_offset:]
+    values = []
+    for i, t in enumerate(types):
+        flex_type = _FlexBufferType(t >> 2)
+        value_offset = -value_vector_offset + i*num_bytes
+        value_bytes = buffer[value_offset:value_offset+num_bytes]
+        if flex_type == _FlexBufferType.FBT_BOOL:
+            value = bool(value_bytes[0])
+        if flex_type == _FlexBufferType.FBT_INT:
+            value = struct.unpack("<i", value_bytes)[0]
+        if flex_type == _FlexBufferType.FBT_UINT:
+            value = struct.unpack("<I", value_bytes)[0]
+        if flex_type == _FlexBufferType.FBT_FLOAT:
+            value = struct.unpack("<f", value_bytes)[0]
+
+        values.append(value)
+
+    custom_options = dict(zip(sorted(option_names), values))
+    return custom_options
+
+
 def from_tflite(model, shape_dict, dtype_dict):
     """Convert from tflite model into compatible relay Function.
 
index ccb8b87..e88226c 100644 (file)
@@ -1385,6 +1385,51 @@ def test_forward_fully_connected():
 
 
 #######################################################################
+# Custom Operators
+# ----------------
+
+def test_detection_postprocess():
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/"
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb"
+    )
+    converter = tf.lite.TFLiteConverter.from_frozen_graph(
+        tf_model_file,
+        input_arrays=["raw_outputs/box_encodings", "raw_outputs/class_predictions"],
+        output_arrays=[
+            "TFLite_Detection_PostProcess",
+            "TFLite_Detection_PostProcess:1",
+            "TFLite_Detection_PostProcess:2",
+            "TFLite_Detection_PostProcess:3"
+        ],
+        input_shapes={
+            "raw_outputs/box_encodings": (1, 1917, 4),
+            "raw_outputs/class_predictions": (1, 1917, 91),
+        },
+    )
+    converter.allow_custom_ops = True
+    converter.inference_type = tf.lite.constants.FLOAT
+    tflite_model = converter.convert()
+    np.random.seed(0)
+    box_encodings = np.random.uniform(size=(1, 1917, 4)).astype('float32')
+    class_predictions = np.random.uniform(size=(1, 1917, 91)).astype('float32')
+    tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions])
+    tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions],
+                               ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4)
+    # check valid count is the same
+    assert tvm_output[3] == tflite_output[3]
+    valid_count = tvm_output[3][0]
+    tvm_boxes = tvm_output[0][0][:valid_count]
+    tvm_classes = tvm_output[1][0][:valid_count]
+    tvm_scores = tvm_output[2][0][:valid_count]
+    # check the output data is correct
+    tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5)
+    tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5)
+    tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5)
+
+
+#######################################################################
 # Mobilenet
 # ---------
 
@@ -1611,6 +1656,9 @@ if __name__ == '__main__':
     # Logical
     test_all_logical()
 
+    # Detection_PostProcess
+    test_detection_postprocess()
+
     # End to End
     test_forward_mobilenet_v1()
     test_forward_mobilenet_v2()