From 90b08f5e4ef5a40e1efb9a9ba2882df87b0a9391 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Wed, 29 Apr 2020 21:07:18 -0700 Subject: [PATCH] [intrin] a few more math functions (#5468) --- include/tvm/tir/op.h | 6 +++ python/tvm/tir/__init__.py | 4 +- python/tvm/tir/op.py | 80 ++++++++++++++++++++++++++++++++ src/target/intrin_rule.cc | 21 +++++++-- tests/python/unittest/test_tir_intrin.py | 8 +++- 5 files changed, 114 insertions(+), 5 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index b54aa9a..3fbdca5 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -570,7 +570,13 @@ TVM_DECLARE_INTRIN_UNARY(cos); TVM_DECLARE_INTRIN_UNARY(cosh); TVM_DECLARE_INTRIN_UNARY(sin); TVM_DECLARE_INTRIN_UNARY(sinh); +TVM_DECLARE_INTRIN_UNARY(asin); +TVM_DECLARE_INTRIN_UNARY(acos); TVM_DECLARE_INTRIN_UNARY(atan); +TVM_DECLARE_INTRIN_UNARY(acosh); +TVM_DECLARE_INTRIN_UNARY(asinh); +TVM_DECLARE_INTRIN_UNARY(atanh); + namespace tir { /*! diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 7d06eea..07e0c9c 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -37,7 +37,9 @@ from .function import PrimFunc from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, all, any, min_value, max_value, trace from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp -from .op import cos, sin, cosh, sinh, tan, tanh, atan, atan2 +from .op import sin, sinh, asin, asinh +from .op import cos, cosh, acos, acosh +from .op import tan, tanh, atan, atan2, atanh from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else from .op import isnan, isfinite, isinf, copysign diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index e783fe7..b87db19 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -522,6 +522,38 @@ def cosh(x): return call_pure_intrin(x.dtype, "cosh", x) +def acos(x): + """Take acos of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "acos", x) + + +def acosh(x): + """Take acos of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "acosh", x) + + def sin(x): """Take sin of input x. @@ -554,6 +586,38 @@ def sinh(x): return call_pure_intrin(x.dtype, "sinh", x) +def asin(x): + """Take asin of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "asin", x) + + +def asinh(x): + """Take asinh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "asinh", x) + + def atan(x): """Take atan of input x. @@ -570,6 +634,22 @@ def atan(x): return call_pure_intrin(x.dtype, "atan", x) +def atanh(x): + """Take atanh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "atanh", x) + + def atan2(x1, x2): """Take arctan2(x1, x2). diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 3a226e1..b95974f 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -52,22 +52,37 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin") .set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh") +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin") .set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan") +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh") .set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2") +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh") .set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot") diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 61a522c..26bf80f 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -63,6 +63,12 @@ def test_unary_intrin(): (tvm.tir.sinh, lambda x : np.sinh(x)), (tvm.tir.cosh, lambda x : np.cosh(x)), (tvm.tir.log1p, lambda x : np.log1p(x)), + (tvm.tir.asin, lambda x : np.arcsin(x)), + (tvm.tir.acos, lambda x : np.arccos(x)), + (tvm.tir.atan, lambda x : np.arctan(x)), + (tvm.tir.asinh, lambda x : np.arcsinh(x)), + (tvm.tir.acosh, lambda x : np.arccosh(x)), + (tvm.tir.atanh, lambda x : np.arctanh(x)), ] def run_test(tvm_intrin, np_func): m = te.var("m",) @@ -72,7 +78,7 @@ def test_unary_intrin(): f = tvm.build(s, [A, B], "llvm") ctx = tvm.cpu(0) n = 10 - a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), ctx) b = tvm.nd.array( \ np.random.uniform(size=n).astype(A.dtype), ctx) f(a, b) -- 2.7.4