add onnx elemwise greater/less (#3186)
authorJoshua Z. Zhang <cheungchih@gmail.com>
Mon, 13 May 2019 21:17:11 +0000 (14:17 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 13 May 2019 21:17:11 +0000 (14:17 -0700)
python/tvm/relay/frontend/onnx.py
tests/python/frontend/onnx/test_forward.py

index eba02e70c86541d5fe6075bfce1b8452be7b78ac..08a64c37d8dfd2e2f123b0feec5a851258e446d1 100644 (file)
@@ -622,6 +622,23 @@ class Gather(OnnxOpConverter):
                        extras={'axis':axis})(inputs, {})
         #return _op.take(inputs[0], inputs[1], axis)
 
+
+class Greater(OnnxOpConverter):
+    """ Operator logical greater.
+    """
+    @classmethod
+    def _impl_v7(cls, inputs, attr, params):
+        return _op.greater(inputs[0], inputs[1])
+
+
+class Less(OnnxOpConverter):
+    """ Operator logical less than.
+    """
+    @classmethod
+    def _impl_v7(cls, inputs, attr, params):
+        return _op.less(inputs[0], inputs[1])
+
+
 class LRN(OnnxOpConverter):
     """ Operator converter for Local Response Normalization.
     """
@@ -836,6 +853,8 @@ def _get_convert_map(opset):
         'Selu': Selu.get_converter(opset),
         'Elu': Elu.get_converter(opset),
         'Exp': Renamer('exp'),
+        'Greater': Greater.get_converter(opset),
+        'Less': Less.get_converter(opset),
         'Log': Renamer('log'),
         'Tanh': Renamer('tanh'),
         'Pow': Renamer('power'),
index f867e73e8c0846dfd64b6035d69d78aeaf0f705d..77f045aa06cc724b97502d1451536aecd86b72d3 100644 (file)
@@ -955,6 +955,8 @@ def test_binary_ops():
     verify_binary_ops("Div", x, y, x / y, broadcast=None)
     verify_binary_ops("Div", x, z, x / z, broadcast=True)
     verify_binary_ops("Sum", x, y, x + y, broadcast=None)
+    verify_binary_ops("Greater", x, y, x > y, broadcast=True)
+    verify_binary_ops("Less", x, y, x < y, broadcast=True)
 
 def test_single_ops():
     in_shape = (1, 2, 3, 3)