[FRONTEND]onnx, mxnet, pytorch mathops added (#5561)
authorSamuel <siju.samuel@huawei.com>
Mon, 11 May 2020 18:56:23 +0000 (00:26 +0530)
committerGitHub <noreply@github.com>
Mon, 11 May 2020 18:56:23 +0000 (03:56 +0900)
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/onnx.py
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/mxnet/test_forward.py
tests/python/frontend/onnx/test_forward.py
tests/python/frontend/pytorch/test_forward.py

index 7dbc788..4cb7a2a 100644 (file)
@@ -1749,16 +1749,18 @@ _identity_list = [
     "floor",
     "ceil",
     "sigmoid",
-    "tanh",
     "negative",
     "reshape_like",
     "zeros_like",
     "ones_like",
     "where",
     "gather_nd",
-    "tan",
     "cos",
-    "sin"
+    "cosh",
+    "sin",
+    "sinh",
+    "tan",
+    "tanh",
 ]
 
 _convert_map = {
@@ -1774,7 +1776,12 @@ _convert_map = {
     "broadcast_maximum"      : _rename(_op.maximum),
     "broadcast_minimum"      : _rename(_op.minimum),
     "broadcast_power"        : _rename(_op.power),
+    "arccos"                 : _rename(_op.acos),
+    "arcsin"                 : _rename(_op.asin),
     "arctan"                 : _rename(_op.atan),
+    "arccosh"                : _rename(_op.acosh),
+    "arcsinh"                : _rename(_op.asinh),
+    "arctanh"                : _rename(_op.atanh),
     "broadcast_equal"        : _mx_compare(_op.equal, _rename),
     "broadcast_not_equal"    : _mx_compare(_op.not_equal, _rename),
     "broadcast_greater"      : _mx_compare(_op.greater, _rename),
index 1a4aee0..58ec4ee 100644 (file)
@@ -1627,6 +1627,17 @@ def _get_convert_map(opset):
         'Greater': Greater.get_converter(opset),
         'Less': Less.get_converter(opset),
         'Log': Renamer('log'),
+        'ACos': Renamer('acos'),
+        'ACosh': Renamer('acosh'),
+        'ASin': Renamer('asin'),
+        'ASinh': Renamer('asinh'),
+        'ATan': Renamer('atan'),
+        'ATanh': Renamer('atanh'),
+        'Cos': Renamer('cos'),
+        'Cosh': Renamer('cosh'),
+        'Sin': Renamer('sin'),
+        'Sinh': Renamer('sinh'),
+        'Tan': Renamer('tan'),
         'Tanh': Renamer('tanh'),
         'Pow': Renamer('power'),
         'PRelu': Prelu.get_converter(opset),
index 64f30f3..3af1051 100644 (file)
@@ -1699,6 +1699,8 @@ def _get_convert_map(prelude):
         "aten::sinh"                            : _unary("sinh"),
         "aten::tan"                             : _unary("tan"),
         "aten::tanh"                            : _unary("tanh"),
+        "aten::acos"                            : _unary("acos"),
+        "aten::asin"                            : _unary("asin"),
         "aten::atan"                            : _unary("atan"),
         "aten::log"                             : _unary("log"),
         "aten::log2"                            : _unary("log2"),
index 84c8acf..3fb8e30 100644 (file)
@@ -363,6 +363,26 @@ def test_forward_elemwise_ops():
                 op_res = intrp.evaluate()(a_np, b_np)
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
 
+
+def test_forward_unary_ops():
+    for op in ["cos", "sin", "tan",
+               "cosh", "sinh", "tanh",
+               "arccos", "arcsin", "arctan",
+               "arccosh", "arcsinh", "arctanh"]:
+        shape = (1, 3, 4, 5)
+        dtype = 'float32'
+        a_np = np.random.uniform(size=shape).astype(dtype)
+        mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a')])
+        ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np)])
+        shapes = {'a': shape}
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(a_np)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
+
+
 def test_forward_scalar_ops():
     for op in [operator.add, operator.sub, operator.mul, operator.truediv,
                operator.pow, operator.lt, operator.le, operator.eq,
@@ -1113,6 +1133,7 @@ if __name__ == '__main__':
     test_forward_broadcast_to()
     test_forward_logical_not()
     test_forward_elemwise_ops()
+    test_forward_unary_ops()
     test_forward_scalar_ops()
     test_forward_slice_like()
     test_forward_slice_axis()
index a26c613..6140414 100644 (file)
@@ -1598,6 +1598,17 @@ def test_single_ops():
     verify_single_ops("Exp", x, np.exp(x))
     verify_single_ops("Log", x, np.log(x))
     verify_single_ops("Log", x, np.log(x))
+    verify_single_ops("ACos", x, np.arccos(x))
+    verify_single_ops("ACosh", x, np.arccosh(x))
+    verify_single_ops("ASin", x, np.arcsin(x))
+    verify_single_ops("ASinh", x, np.arcsinh(x))
+    verify_single_ops("ATan", x, np.arctan(x))
+    verify_single_ops("ATanh", x, np.arctanh(x))
+    verify_single_ops("Cos", x, np.cos(x))
+    verify_single_ops("Cosh", x, np.cosh(x))
+    verify_single_ops("Sin", x, np.sin(x))
+    verify_single_ops("Sinh", x, np.sinh(x))
+    verify_single_ops("Tan", x, np.tan(x))
     verify_single_ops("Tanh", x, np.tanh(x))
     verify_single_ops("Sigmoid", x, 1 / (1 + np.exp(-x)))
     verify_single_ops("Softsign", x, x / (1 + np.abs(x)))
index a53f354..e1c276b 100644 (file)
@@ -1895,7 +1895,15 @@ def test_forward_unary():
         def forward(self, *args):
             return torch.tanh(args[0])
 
-    class ATanh1(Module):
+    class Acos1(Module):
+        def forward(self, *args):
+            return torch.acos(args[0])
+
+    class Asin1(Module):
+        def forward(self, *args):
+            return torch.asin(args[0])
+
+    class Atan1(Module):
         def forward(self, *args):
             return torch.atan(args[0])
 
@@ -1956,7 +1964,9 @@ def test_forward_unary():
     verify_model(Sinh1().float().eval(), input_data=input_data)
     verify_model(Tan1().float().eval(), input_data=input_data)
     verify_model(Tanh1().float().eval(), input_data=input_data)
-    verify_model(ATanh1().float().eval(), input_data=input_data)
+    verify_model(Acos1().float().eval(), input_data=input_data)
+    verify_model(Asin1().float().eval(), input_data=input_data)
+    verify_model(Atan1().float().eval(), input_data=input_data)
     verify_model(Log1().float().eval(), input_data=input_data)
     verify_model(Log2_1().float().eval(), input_data=input_data)
     verify_model(Log10_1().float().eval(), input_data=input_data)