[Relay][Frontend][TF] Fix slice when begin or size is not Const (#4372)
authorSiyuan Li <siyuanli.s.c@gmail.com>
Thu, 21 Nov 2019 18:53:37 +0000 (02:53 +0800)
committerYao Wang <kevinthesunwy@gmail.com>
Thu, 21 Nov 2019 18:53:37 +0000 (10:53 -0800)
* fix slice bug when input is param

* use _infer_value rather than _infer_value_simulated

python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index 2decc21..7c1d34f 100644 (file)
@@ -626,8 +626,14 @@ def _tile():
 
 def _slice():
     def _impl(inputs, attr, params):
-        begin = _get_list_param(params, inputs[1])
-        size = _get_list_param(params, inputs[2])
+        try:
+            begin = _get_list_param(params, inputs[1])
+        except (IndexError, KeyError, AttributeError):
+            begin = _infer_value(inputs[1], params).asnumpy().tolist()[0]
+        try:
+            size = _get_list_param(params, inputs[2])
+        except (IndexError, KeyError, AttributeError):
+            size = _infer_value(inputs[2], params).asnumpy().tolist()[0]
         data_shape = attr['_input_shapes'][inputs[0]]
         data_dim = len(data_shape)
         end = size
index db19ed4..4ec8abd 100644 (file)
@@ -2188,6 +2188,20 @@ def test_forward_transpose():
     _test_forward_tranapose_axes_input((2, 3, 4, 5), (3, 0, 1, 2))
 
 
+def _test_forward_slice_operation_input(input_value, begin_value, size_value):
+    input_data = np.array(input_value, dtype=np.float32)
+    with tf.Graph().as_default():
+        input_tensor = tf.placeholder(
+            shape=input_data.shape, dtype=input_data.dtype, name="input")
+        begin_tensor = tf.expand_dims(begin_value, axis=0)
+        size_tensor = tf.expand_dims(size_value, axis=0)
+        slice_tensor = tf.slice(input_tensor, begin_tensor, size_tensor, name='slice_output')
+        compare_tf_with_tvm([input_data], ['input:0'], 'slice_output:0')
+
+
+def test_forward_slice():
+    _test_forward_slice_operation_input([1, 1], 0, 2)
+
 def test_forward_ceil():
     ishape = (1, 3, 10, 10)
     inp_array = np.random.uniform(size=ishape).astype(np.float32)
@@ -2760,8 +2774,8 @@ def test_forward_add_n():
 # Main
 # ----
 if __name__ == '__main__':
-
     # Transforms
+    test_forward_slice()
     test_forward_transpose()
     test_forward_reshape()
     test_forward_depthtospace()