[TFLite] Implemented ONE_HOT Operator for TFLite (#6223)
authorRishabh Jain <56974688+jainris@users.noreply.github.com>
Mon, 10 Aug 2020 12:21:20 +0000 (17:51 +0530)
committerGitHub <noreply@github.com>
Mon, 10 Aug 2020 12:21:20 +0000 (20:21 +0800)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index fe28741..f168f1b 100644 (file)
@@ -114,6 +114,7 @@ class OperatorConverter(object):
             'MUL': self.convert_mul,
             'NEG': self.convert_neg,
             'NOT_EQUAL': self.convert_not_equal,
+            'ONE_HOT': self.convert_one_hot,
             'PACK': self.convert_pack,
             'PAD': self.convert_pad,
             'PADV2': self.convert_pad,
@@ -2903,6 +2904,56 @@ class OperatorConverter(object):
         ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
         return ret
 
+    def convert_one_hot(self, op):
+        """Convert TFLite ONE_HOT"""
+        try:
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.OneHotOptions import OneHotOptions
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 4, "Input tensor's length should be 4"
+
+        # Ensuring input isn't quantized
+        assert all(not i.qnn_params for i in input_tensors), \
+            "Quantized input is not expected."
+
+        # TFlite ONE_HOT requires both on_value
+        # and off_value, making dtype redundant.
+        indices = input_tensors[0]
+        depth = input_tensors[1]
+        on_value = input_tensors[2]
+        off_value = input_tensors[3]
+
+        assert on_value.tensor.Type() == off_value.tensor.Type(), \
+            "on_value and off_value should be the same type"
+
+        # Getting relay expr
+        indices_expr = self.get_expr(indices.tensor_idx)
+        on_value_expr = self.get_expr(on_value.tensor_idx)
+        off_value_expr = self.get_expr(off_value.tensor_idx)
+
+        # Getting depth value
+        depth = self.get_tensor_value(depth)
+        if isinstance(depth, np.ndarray):
+            depth = int(depth)
+
+        # Getting Axis from Option (Attributes)
+        assert op.BuiltinOptionsType() == BuiltinOptions.OneHotOptions
+        op_options = op.BuiltinOptions()
+        one_hot_options = OneHotOptions()
+        one_hot_options.Init(op_options.Bytes, op_options.Pos)
+        axis = one_hot_options.Axis()
+
+        # Setting dtype
+        dtype = self.get_tensor_type_str(on_value.tensor.Type())
+
+        out = _op.one_hot(indices_expr, on_value_expr, off_value_expr, depth, axis, dtype)
+
+        return out
+
+
     def get_expr(self, input_tensor_idx):
         return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
 
index 30a6631..2e57175 100644 (file)
@@ -2031,6 +2031,35 @@ def test_forward_padv2():
 
 
 #######################################################################
+# ONE_HOT
+# -------
+
+def _test_one_hot(indices, depth, on_value, off_value, axis = None):
+    """ One iteration of One_Hot """
+    with tf.Graph().as_default():
+        in_indices = tf.placeholder(dtype=indices.dtype, shape=indices.shape, name="indices")
+        in_depth = ops.convert_to_tensor(depth, dtype=depth.dtype)
+        in_on_value = tf.placeholder(dtype=on_value.dtype, shape=on_value.shape, name="on_value")
+        in_off_value = tf.placeholder(dtype=off_value.dtype, shape=off_value.shape, name="off_value")
+        if axis is not None:
+            out = array_ops.one_hot(in_indices, in_depth, in_on_value, in_off_value, axis=axis)
+        else:
+            out = array_ops.one_hot(in_indices, in_depth, in_on_value, in_off_value)
+        compare_tflite_with_tvm(
+            [indices, on_value, off_value],
+            ["indices", "on_value", "off_value"],
+            [in_indices, in_on_value, in_off_value],
+            [out])
+
+def test_forward_one_hot():
+    """ One_Hot """
+    _test_one_hot(np.int32(2), np.int32(8), np.int32(1), np.int32(0))
+    _test_one_hot(np.int32(4), np.int32(8), np.float32(1), np.float32(0))
+    _test_one_hot(np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1))
+    _test_one_hot(np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1), axis=0)
+
+
+#######################################################################
 # Pack
 # ----