From: Samuel Date: Wed, 3 Jun 2020 08:29:38 +0000 (+0530) Subject: [MXNET]Softmin, trunc op support added (#5715) X-Git-Tag: upstream/0.7.0~621 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c1f3b2f496dc9d39578631b922a16bbf1e0c4a3f;p=platform%2Fupstream%2Ftvm.git [MXNET]Softmin, trunc op support added (#5715) --- diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index c75612d..7f9950b 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -846,6 +846,11 @@ def _mx_softsign(inputs, attrs): return inputs[0] / (_expr.const(1.0) + _op.abs(inputs[0])) +def _mx_softmin(inputs, attrs): + axis = attrs.get_int("axis", -1) + return _op.nn.softmax(_op.negative(inputs[0]), axis) + + def _mx_hard_sigmoid(inputs, attrs): x = (_expr.const(0.2) * inputs[0]) + _expr.const(0.5) return _op.clip(x, a_min=0.0, a_max=1.0) @@ -1829,6 +1834,7 @@ _identity_list = [ "floor", "ceil", "round", + "trunc", "sign", "sigmoid", "negative", @@ -1938,6 +1944,7 @@ _convert_map = { "log_softmax" : _softmax_op(_op.nn.log_softmax), "Softmax" : _softmax_op(_op.nn.softmax), "softsign" : _mx_softsign, + "softmin" : _mx_softmin, "hard_sigmoid" : _mx_hard_sigmoid, "reciprocal" : _mx_reciprocal, # per op specialization diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 5ed2fb8..463b50f 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -372,8 +372,17 @@ def test_forward_elemwise_ops(): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) +def test_forward_softmin(): + data = mx.sym.var('data') + mx_sym = mx.sym.softmin(data) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100)) + + mx_sym = mx.sym.softmin(data, axis=2) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100)) + + def test_forward_unary_ops(): - for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", + for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", "trunc", "softsign", "hard_sigmoid", "cos", "sin", "tan", "cosh", "sinh", "tanh", @@ -1191,6 +1200,7 @@ if __name__ == '__main__': test_forward_rrelu() test_forward_prelu() test_forward_softrelu() + test_forward_softmin() test_forward_fc_flatten() test_forward_clip() test_forward_split()