Solve custom model of prelu (#4326)
authorZhao Wu <wuzhaozju@gmail.com>
Fri, 15 Nov 2019 03:43:38 +0000 (11:43 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 15 Nov 2019 03:43:38 +0000 (19:43 -0800)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 30209ff..415f04e 100644 (file)
@@ -1340,14 +1340,13 @@ class OperatorConverter(object):
         alpha_tensor = input_tensors[1]
         alpha_tensor_type = alpha_tensor.tensor.Type()
         alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
-        alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor),
+        alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor).flatten(),
                                             dtype=alpha_tensor_type_str)
         in_expr = self.get_expr(input_tensor.tensor_idx)
         out = _op.nn.prelu(in_expr, alpha_expr, axis=3)
 
         return out
 
-
     def get_expr(self, input_tensor_idx):
         return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
 
index 83a0730..8d19026 100644 (file)
@@ -934,18 +934,18 @@ def test_forward_relu():
     """ ReLU """
     _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
 
-def _test_prelu(data):
+def _test_prelu(data, alpha):
     """ One iteration of PReLU """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
-        alpha = np.full((data.shape[-1],), 0.2, dtype=data.dtype)
         # This specific pattern will be replaced into PRelu by tflite
         out = nn_ops.relu(in_data) + (-alpha * nn_ops.relu(-in_data))
         compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
 
 def test_forward_prelu():
     """ PReLU """
-    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"))
+    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((3,), 0.2, dtype="float32"))
+    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((1, 1, 3), 0.2, dtype="float32"))
 
 #######################################################################
 # Fully Connected