'PACK': self.convert_pack,
'LOGISTIC': self.convert_logistic,
'TANH':self.convert_tanh,
+ 'RELU':self.convert_relu,
'SPLIT': self.convert_split,
'TRANSPOSE': self.convert_transpose,
'TILE': self.convert_tile,
return out
+ def convert_relu(self, op):
+ """Convert TFLite ReLU"""
+ try:
+ from tflite.Operator import Operator
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ assert isinstance(op, Operator)
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+
+ input_tensor = input_tensors[0]
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+ out = _op.nn.relu(in_expr)
+
+ return out
+
def convert_concatenation(self, op):
"""Convert TFLite concatenation"""
try:
_test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
#######################################################################
+# ReLu
+# --------
+
+def _test_relu(data):
+ """ One iteration of ReLU """
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ out = nn_ops.relu(in_data)
+ compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+def test_forward_relu():
+ """ ReLU """
+ _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
+
+#######################################################################
# Fully Connected
# -------
test_forward_pooling()
test_forward_softmax()
test_forward_tanh()
+ test_forward_relu()
test_forward_fully_connected()
# Elemwise