From cd0d52daa6942bdafa9363ff6cfa3d25fcd5b8d6 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 12 Apr 2020 09:32:23 -0700 Subject: [PATCH] [Intrinsic] Add log1p, ldexp, atan2, hypot, nextafter, copysign (#5312) * [Intrinsic] Add log1p, ldexp, atan2, hypot, nextafter, copysign * Lint --- python/tvm/tir/__init__.py | 10 +-- python/tvm/tir/op.py | 113 +++++++++++++++++++++++++++++++ src/target/intrin_rule.cc | 18 +++++ tests/python/unittest/test_tir_intrin.py | 50 +++++++++++++- 4 files changed, 185 insertions(+), 6 deletions(-) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index b5d9fb1..a50c10d 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -35,11 +35,11 @@ from .function import PrimFunc from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, all, any, min_value, max_value, trace -from .op import exp, exp2, exp10, log, log2, log10 -from .op import cos, sin, cosh, sinh, tan, tanh, atan -from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil -from .op import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else -from .op import isnan, isfinite, isinf +from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp +from .op import cos, sin, cosh, sinh, tan, tanh, atan, atan2 +from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot +from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else +from .op import isnan, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import comm_reducer, min, max, sum diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 4b703f3..ce3edee 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -457,6 +457,23 @@ def log10(x): """ return call_pure_intrin(x.dtype, "log10", x) + +def log1p(x): + """Take log(x + 1) with respect to input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "log1p", x) + + def tan(x): """Take tan of input x. @@ -552,6 +569,26 @@ def atan(x): """ return call_pure_intrin(x.dtype, "atan", x) + +def atan2(x1, x2): + """Take arctan2(x1, x2). + + Parameters + ---------- + x1 : PrimExpr + Input argument. + + x2 : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x1.dtype, "atan2", x1, x2) + + def sqrt(x): """Take square root of input x. @@ -690,6 +727,82 @@ def nearbyint(x): return _ffi_api.nearbyint(x) +def nextafter(x1, x2): + """Return the next floating-point value after x1 towards x2. + + Parameters + ---------- + x1 : PrimExpr + Input argument. + + x2 : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x1.dtype, "nextafter", x1, x2) + + +def hypot(x1, x2): + """Equivalent to sqrt(x1**2 + x2**2), element-wise. + + Parameters + ---------- + x1 : PrimExpr + Input argument. + + x2 : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x1.dtype, "hypot", x1, x2) + + +def copysign(x1, x2): + """Change the sign of x1 to that of x2, element-wise. + + Parameters + ---------- + x1 : PrimExpr + Input argument. + + x2 : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x1.dtype, "copysign", x1, x2) + + +def ldexp(x1, x2): + """Returns x1 * (2 ** x2). + + Parameters + ---------- + x1 : PrimExpr + Input argument. + + x2 : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x1.dtype, "ldexp", x1, x2) + + def isnan(x): """Check if input value is Nan. diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 626498b..5d393ab 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -37,6 +37,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") .set_body(DispatchExtern); @@ -52,6 +55,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") .set_body(DispatchExtern); diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 52ae440..61a522c 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -62,6 +62,7 @@ def test_unary_intrin(): (tvm.tir.log10, lambda x : np.log10(x)), (tvm.tir.sinh, lambda x : np.sinh(x)), (tvm.tir.cosh, lambda x : np.cosh(x)), + (tvm.tir.log1p, lambda x : np.log1p(x)), ] def run_test(tvm_intrin, np_func): m = te.var("m",) @@ -79,10 +80,57 @@ def test_unary_intrin(): b.asnumpy(), np_func(a.asnumpy()), atol=1e-5, rtol=1e-5) for func in test_funcs: - run_test(*func); + run_test(*func) + + +def test_binary_intrin(): + test_funcs = [ + (tvm.tir.atan2, lambda x1, x2 : np.arctan2(x1, x2)), + (tvm.tir.nextafter, lambda x1, x2 : np.nextafter(x1, x2)), + (tvm.tir.copysign, lambda x1, x2 : np.copysign(x1, x2)), + (tvm.tir.hypot, lambda x1, x2 : np.hypot(x1, x2)), + ] + def run_test(tvm_intrin, np_func): + m = te.var("m",) + A = te.placeholder((m,), name='A') + B = te.placeholder((m,), name='B') + C = te.compute((m,), lambda *i: tvm_intrin(A(*i), B(*i)), name='C') + s = te.create_schedule(C.op) + f = tvm.build(s, [A, B, C], "llvm") + ctx = tvm.cpu(0) + n = 10 + a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(B.dtype), ctx) + c = tvm.nd.array( \ + np.random.uniform(size=n).astype(A.dtype), ctx) + f(a, b, c) + tvm.testing.assert_allclose( + c.asnumpy(), np_func(a.asnumpy(), b.asnumpy()), atol=1e-5, rtol=1e-5) + + for func in test_funcs: + run_test(*func) + + +def test_ldexp(): + m = te.var("m",) + A = te.placeholder((m,), name='A') + B = te.placeholder((m,), name='B', dtype="int32") + C = te.compute((m,), lambda *i: tvm.tir.ldexp(A(*i), B(*i)), name='C') + s = te.create_schedule(C.op) + f = tvm.build(s, [A, B, C], "llvm") + ctx = tvm.cpu(0) + n = 10 + a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.randint(0, 5, size=n).astype(B.dtype), ctx) + c = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) + f(a, b, c) + tvm.testing.assert_allclose( + c.asnumpy(), np.ldexp(a.asnumpy(), b.asnumpy()), atol=1e-5, rtol=1e-5) if __name__ == "__main__": test_nearbyint() test_unary_intrin() test_round_intrinsics_on_int() + test_binary_intrin() + test_ldexp() -- 2.7.4