[MXNET]Softmin, trunc op support added (#5715)
authorSamuel <siju.samuel@huawei.com>
Wed, 3 Jun 2020 08:29:38 +0000 (13:59 +0530)
committerGitHub <noreply@github.com>
Wed, 3 Jun 2020 08:29:38 +0000 (17:29 +0900)
python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index c75612d..7f9950b 100644 (file)
@@ -846,6 +846,11 @@ def _mx_softsign(inputs, attrs):
     return inputs[0] / (_expr.const(1.0) + _op.abs(inputs[0]))
 
 
+def _mx_softmin(inputs, attrs):
+    axis = attrs.get_int("axis", -1)
+    return _op.nn.softmax(_op.negative(inputs[0]), axis)
+
+
 def _mx_hard_sigmoid(inputs, attrs):
     x = (_expr.const(0.2) * inputs[0]) + _expr.const(0.5)
     return _op.clip(x, a_min=0.0, a_max=1.0)
@@ -1829,6 +1834,7 @@ _identity_list = [
     "floor",
     "ceil",
     "round",
+    "trunc",
     "sign",
     "sigmoid",
     "negative",
@@ -1938,6 +1944,7 @@ _convert_map = {
     "log_softmax"   : _softmax_op(_op.nn.log_softmax),
     "Softmax"       : _softmax_op(_op.nn.softmax),
     "softsign"      : _mx_softsign,
+    "softmin"       : _mx_softmin,
     "hard_sigmoid"  : _mx_hard_sigmoid,
     "reciprocal"    : _mx_reciprocal,
     # per op specialization
index 5ed2fb8..463b50f 100644 (file)
@@ -372,8 +372,17 @@ def test_forward_elemwise_ops():
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
 
 
+def test_forward_softmin():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.softmin(data)
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100))
+
+    mx_sym = mx.sym.softmin(data, axis=2)
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100))
+
+
 def test_forward_unary_ops():
-    for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal",
+    for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", "trunc",
                "softsign", "hard_sigmoid",
                "cos", "sin", "tan",
                "cosh", "sinh", "tanh",
@@ -1191,6 +1200,7 @@ if __name__ == '__main__':
     test_forward_rrelu()
     test_forward_prelu()
     test_forward_softrelu()
+    test_forward_softmin()
     test_forward_fc_flatten()
     test_forward_clip()
     test_forward_split()