'CAST': self.convert_cast,
'TILE': self.convert_tile,
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
- 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd
+ 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
+ 'PRELU': self.convert_prelu,
}
def check_unsupported_ops(self):
return reshaped_permuted_reshaped_padded
+ def convert_prelu(self, op):
+ """Convert TFLite PReLU"""
+ try:
+ from tflite.Operator import Operator
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ assert isinstance(op, Operator)
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+
+ input_tensor = input_tensors[0]
+ alpha_tensor = input_tensors[1]
+ alpha_tensor_type = alpha_tensor.tensor.Type()
+ alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
+ alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor),
+ dtype=alpha_tensor_type_str)
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+ out = _op.nn.prelu(in_expr, alpha_expr, axis=3)
+
+ return out
+
+
def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
""" ReLU """
_test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
+def _test_prelu(data):
+ """ One iteration of PReLU """
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ alpha = np.full((data.shape[-1],), 0.2, dtype=data.dtype)
+ # This specific pattern will be replaced into PRelu by tflite
+ out = nn_ops.relu(in_data) + (-alpha * nn_ops.relu(-in_data))
+ compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+def test_forward_prelu():
+ """ PReLU """
+ _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"))
+
#######################################################################
# Fully Connected
# -------
test_forward_softmax()
test_forward_tanh()
test_forward_relu()
+ test_forward_prelu()
test_forward_fully_connected()
# Elemwise