[FRONTEND][TFLITE]Logical not op support (#5475)
authorSamuel <siju.samuel@huawei.com>
Thu, 30 Apr 2020 08:07:49 +0000 (13:37 +0530)
committerGitHub <noreply@github.com>
Thu, 30 Apr 2020 08:07:49 +0000 (16:07 +0800)
python/tvm/relay/frontend/tflite.py
tests/python/frontend/tflite/test_forward.py

index 66d0ff3..5c8bbfb 100644 (file)
@@ -94,6 +94,7 @@ class OperatorConverter(object):
             'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn,
             'LOG': self.convert_log,
             'LOGICAL_AND': self.convert_logical_and,
+            'LOGICAL_NOT': self.convert_logical_not,
             'LOGICAL_OR': self.convert_logical_or,
             'LOGISTIC': self.convert_logistic,
             'MAX_POOL_2D': self.convert_max_pool2d,
@@ -992,6 +993,16 @@ class OperatorConverter(object):
         """Convert tflite LOGICAL_OR"""
         return self._convert_logical_binary(_op.logical_or, op)
 
+    def convert_logical_not(self, op):
+        """Convert tflite LOGICAL_NOT"""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+
+        data = self.get_expr(input_tensors[0].tensor_idx)
+        out = _op.logical_not(data)
+
+        return out
+
     def convert_gather(self, op):
         """Method to Convert TFLite GATHER operator"""
         try:
index bc3f32a..26bb86d 100644 (file)
@@ -1183,7 +1183,12 @@ def _test_logical_binary(logical_bin_op, data):
     with tf.Graph().as_default():
         in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'),
                    array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')]
-        out = logical_bin_op(in_data[0], in_data[1], name='out')
+        if logical_bin_op == math_ops.logical_not:
+            out = math_ops.logical_or(in_data[0], in_data[1], name='out1')
+            out = logical_bin_op(out, name='out')
+        else:
+            out = logical_bin_op(in_data[0], in_data[1], name='out')
+
         compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
 
 def _test_forward_logical_and(data):
@@ -1194,6 +1199,10 @@ def _test_forward_logical_or(data):
     """ One iteration of logical or """
     return _test_logical_binary(math_ops.logical_or, data)
 
+def _test_forward_logical_not(data):
+    """ One iteration of logical not """
+    return _test_logical_binary(math_ops.logical_not, data)
+
 def test_all_logical():
     data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'),
             np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')]
@@ -1201,6 +1210,7 @@ def test_all_logical():
     if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
         _test_forward_logical_and(data)
         _test_forward_logical_or(data)
+        _test_forward_logical_not(data)
 
 #######################################################################
 # Zeros like