[Relay][Frontend][TFLite] Add constant input support for elemwise ops (#4666)
authorWang Yucheng <wyc91543@163.com>
Wed, 15 Jan 2020 16:48:08 +0000 (00:48 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 15 Jan 2020 16:48:08 +0000 (08:48 -0800)
* [Relay][Frontend][TFLite] Add constant input support for elemwise ops

* modify in tflite.py

python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index cb6dbea..02b8ed9 100644 (file)
@@ -611,7 +611,16 @@ class OperatorConverter(object):
         assert len(input_tensors) == 2, "input tensors length should be 2"
 
         lhs_tensor = input_tensors[0]
-        lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
+        if self.has_expr(lhs_tensor.tensor_idx):
+            # In most cases, we can assume that TOCO fuses elemwise operators
+            # with constants - it means both will be tensors.
+            lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
+        else:
+            # However, in some corner cases, the elemwise operator is not fused,
+            # we can receive as constant.
+            lhs_type_str = self.get_tensor_type_str(lhs_tensor.tensor.Type())
+            lhs_expr = self.exp_tab.new_const(self.get_tensor_value(lhs_tensor),
+                                              dtype=lhs_type_str)
 
         rhs_tensor = input_tensors[1]
         if self.has_expr(rhs_tensor.tensor_idx):
index 1478b25..837f0f6 100644 (file)
@@ -787,6 +787,24 @@ def _test_elemwise(math_op, data, fused_activation_function=None, quantized=Fals
             out = with_fused_activation_function(out, fused_activation_function)
             compare_tflite_with_tvm(data[0], ['in_0:0'], in_data, [out])
 
+    # Test with constant and tensor
+    with tf.Graph().as_default():
+        in_data = [array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')]
+
+        if quantized:
+            inq_const = tf.quantization.fake_quant_with_min_max_args(data[0], min=-100, max=100, name="const_tensor")
+            inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_1")]
+            # the 1st tensor is treated as constant and directly added as part of the operation
+            out = math_op(ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'), inq_data)
+            out = with_fused_activation_function(out, fused_activation_function)
+            out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
+            out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
+            compare_tflite_with_tvm(data[1], ['inq_1:0'], inq_data, [out], quantized=True)
+        else:
+            out = math_op(ops.convert_to_tensor(data[0], dtype=data[0].dtype), in_data[0])
+            out = with_fused_activation_function(out, fused_activation_function)
+            compare_tflite_with_tvm(data[1], ['in_1:0'], in_data, [out])
+
 #######################################################################
 # Add
 # ---