From 674feba0dcef8e4d2d4118436f3387f19d8f62da Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Thu, 15 Aug 2019 11:41:54 -0700 Subject: [PATCH] [Relay][Frontend][ONNX] Add Sign and Equal operators to ONNX frontend (#3760) * [Relay][Frontend][ONNX] Add Sign and Equal operators to ONNX frontend * Dummy change to retrigger integration test --- python/tvm/relay/frontend/onnx.py | 14 ++++++++++++++ tests/python/frontend/onnx/test_forward.py | 11 +++++++++++ 2 files changed, 25 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1b904bb..b7d668b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -850,6 +850,18 @@ class ConstantFill(OnnxOpConverter): shape = shape + attr.pop('extra_shape') return _op.full(inputs[0], shape) +class Sign(OnnxOpConverter): + """ Operator converter for Sign. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + return _op.sign(inputs[0]) + +class Equal(Elemwise): + """ Operator converter for Equal. + """ + name = 'equal' + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -964,6 +976,8 @@ def _get_convert_map(opset): 'Unsqueeze': Unsqueeze.get_converter(opset), 'Pad': Pad.get_converter(opset), 'Shape': Shape.get_converter(opset), + 'Sign': Sign.get_converter(opset), + 'Equal': Equal.get_converter(opset) } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d409960..87d38e0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -962,6 +962,7 @@ def test_binary_ops(): verify_binary_ops("Sum", x, y, x + y, broadcast=None) verify_binary_ops("Greater", x, y, x > y, broadcast=True) verify_binary_ops("Less", x, y, x < y, broadcast=True) + verify_binary_ops("Equal", x, y, x == y, broadcast=True) def test_single_ops(): in_shape = (1, 2, 3, 3) @@ -1116,6 +1117,15 @@ def test_inception(): # def test_shufflenetv2(): # check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) +def test_sign(): + def Sign_x(x): + return np.sign(x) + _test_onnx_op_elementwise((3, 4, 5, 6), + Sign_x, + {}, + 'float32', + 'Sign', + {}) if __name__ == '__main__': test_flatten() @@ -1159,3 +1169,4 @@ if __name__ == '__main__': test_resnet() test_inception() test_densenet() + test_sign() -- 2.7.4