From 134a2f25d6305036cd2bd7d63caa555c7dbcbabf Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Mon, 13 May 2019 14:17:11 -0700 Subject: [PATCH] add onnx elemwise greater/less (#3186) --- python/tvm/relay/frontend/onnx.py | 19 +++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 2 ++ 2 files changed, 21 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index eba02e70c..08a64c37d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -622,6 +622,23 @@ class Gather(OnnxOpConverter): extras={'axis':axis})(inputs, {}) #return _op.take(inputs[0], inputs[1], axis) + +class Greater(OnnxOpConverter): + """ Operator logical greater. + """ + @classmethod + def _impl_v7(cls, inputs, attr, params): + return _op.greater(inputs[0], inputs[1]) + + +class Less(OnnxOpConverter): + """ Operator logical less than. + """ + @classmethod + def _impl_v7(cls, inputs, attr, params): + return _op.less(inputs[0], inputs[1]) + + class LRN(OnnxOpConverter): """ Operator converter for Local Response Normalization. """ @@ -836,6 +853,8 @@ def _get_convert_map(opset): 'Selu': Selu.get_converter(opset), 'Elu': Elu.get_converter(opset), 'Exp': Renamer('exp'), + 'Greater': Greater.get_converter(opset), + 'Less': Less.get_converter(opset), 'Log': Renamer('log'), 'Tanh': Renamer('tanh'), 'Pow': Renamer('power'), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f867e73e8..77f045aa0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -955,6 +955,8 @@ def test_binary_ops(): verify_binary_ops("Div", x, y, x / y, broadcast=None) verify_binary_ops("Div", x, z, x / z, broadcast=True) verify_binary_ops("Sum", x, y, x + y, broadcast=None) + verify_binary_ops("Greater", x, y, x > y, broadcast=True) + verify_binary_ops("Less", x, y, x < y, broadcast=True) def test_single_ops(): in_shape = (1, 2, 3, 3) -- 2.34.1