From: Haichen Shen Date: Tue, 7 Apr 2020 19:05:33 +0000 (-0700) Subject: [Relay][OP] Add fast_erf implementation (#5241) X-Git-Tag: upstream/0.7.0~967 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f5b02fdb1b5a7b6be79df97035ec1c3b80e3c665;p=platform%2Fupstream%2Ftvm.git [Relay][OP] Add fast_erf implementation (#5241) * add fast erf * doc * lint * fix * fix indent --- diff --git a/include/tvm/target/generic_func.h b/include/tvm/target/generic_func.h index 89a7f57..f2a361b3 100644 --- a/include/tvm/target/generic_func.h +++ b/include/tvm/target/generic_func.h @@ -72,7 +72,7 @@ class GenericFunc : public ObjectRef { * * \code * // Example code on how to call generic function - * void CallGeneirc(GenericFunc f) { + * void CallGeneric(GenericFunc f) { * // call like normal functions by pass in arguments * // return value is automatically converted back * int rvalue = f(1, 2.0); diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index f24da05..a607a47 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -76,6 +76,7 @@ register_injective_schedule("shape_of") register_injective_schedule("ndarray_size") register_broadcast_schedule("fast_exp") register_broadcast_schedule("fast_tanh") +register_broadcast_schedule("fast_erf") # zeros @@ -222,3 +223,4 @@ register_shape_func("exp", False, elemwise_shape_func) register_shape_func("tan", False, elemwise_shape_func) register_shape_func("fast_exp", False, elemwise_shape_func) register_shape_func("fast_tanh", False, elemwise_shape_func) +register_shape_func("fast_erf", False, elemwise_shape_func) diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 3da77e9..4cca8b0 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -128,6 +128,17 @@ RELAY_REGISTER_UNARY_OP("erf") .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf)); +RELAY_REGISTER_UNARY_OP("fast_erf") +.describe(R"code(Returns the error function value for input array, computed element-wise. + +.. math:: + \fast_erf(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf)); + + RELAY_REGISTER_UNARY_OP("sqrt") .describe(R"code(Returns the sqrt input array, computed element-wise. diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index 861566f..cf00a89 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -35,11 +35,14 @@ class FastMathMutator : public ExprRewriter { public: FastMathMutator() : exp_op_(Op::Get("exp")), + erf_op_(Op::Get("erf")), tanh_op_(Op::Get("tanh")) {} Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (pre->op == exp_op_) { return FastExp(post.as()->args[0]); + } else if (pre->op == erf_op_) { + return FastErf(post.as()->args[0]); } else if (pre->op == tanh_op_) { return FastTanh(post.as()->args[0]); } @@ -51,6 +54,7 @@ class FastMathMutator : public ExprRewriter { // operator equivalence checking so that the registry lookup overhead can be // reduced. const Op& exp_op_; + const Op& erf_op_; const Op& tanh_op_; }; diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 350d9e1..cd2af9f 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -322,6 +322,11 @@ inline Expr FastExp(Expr e) { return Call(op, {e}); } +inline Expr FastErf(Expr e) { + static const Op& op = Op::Get("fast_erf"); + return Call(op, {e}); +} + inline Expr FastTanh(Expr e) { static const Op& op = Op::Get("fast_tanh"); return Call(op, {e}); diff --git a/tests/python/relay/test_op_fast_math.py b/tests/python/relay/test_op_fast_math.py index 1d661c3..215b83e 100644 --- a/tests/python/relay/test_op_fast_math.py +++ b/tests/python/relay/test_op_fast_math.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import scipy +from scipy import special import tvm import tvm.relay as relay import topi @@ -52,6 +54,7 @@ def test_fastmath(): rtol=1e-5, atol=1e-5) test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01) + test_apply(relay.erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01) test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01) diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 88d5732..49eb088 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include "broadcast.h" @@ -63,7 +64,7 @@ TOPI_DECLARE_UNARY_OP(tanh); TOPI_DECLARE_UNARY_OP(isfinite); TOPI_DECLARE_UNARY_OP(isinf); -/* +/*! * \brief Fast_tanh_float implementation from Eigen * https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26 */ @@ -461,5 +462,75 @@ inline Tensor fast_exp(const Tensor& x, } } +/*! + * \brief Fast_tanh_float implementation from Eigen + * https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290 + */ +inline Tensor fast_erf_float32(const Tensor& data, + std::string name, + std::string tag) { + auto plus_4 = make_const(DataType::Float(32), 4.f); + auto minus_4 = make_const(DataType::Float(32), -4.f); + + // The monomial coefficients of the numerator polynomial (odd). + auto alpha_1 = make_const(DataType::Float(32), -1.60960333262415e-02f); + auto alpha_3 = make_const(DataType::Float(32), -2.95459980854025e-03f); + auto alpha_5 = make_const(DataType::Float(32), -7.34990630326855e-04f); + auto alpha_7 = make_const(DataType::Float(32), -5.69250639462346e-05f); + auto alpha_9 = make_const(DataType::Float(32), -2.10102402082508e-06f); + auto alpha_11 = make_const(DataType::Float(32), 2.77068142495902e-08f); + auto alpha_13 = make_const(DataType::Float(32), -2.72614225801306e-10f); + + // The monomial coefficients of the denominator polynomial (even). + auto beta_0 = make_const(DataType::Float(32), -1.42647390514189e-02f); + auto beta_2 = make_const(DataType::Float(32), -7.37332916720468e-03f); + auto beta_4 = make_const(DataType::Float(32), -1.68282697438203e-03f); + auto beta_6 = make_const(DataType::Float(32), -2.13374055278905e-04f); + auto beta_8 = make_const(DataType::Float(32), -1.45660718464996e-05f); + + return compute(data->shape, [&](const Array &i) { + // clamp x + auto x = tvm::max(tvm::min(data(i), plus_4), minus_4); + auto x2 = x * x; + + // Evaluate the numerator polynomial p. + auto p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x * p; + + // Evaluate the denominator polynomial p. + auto q = x2 * beta_8 + beta_6; + q = x2 * q + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + + return p / q; + }, name, tag); +} + +/*! + * \brief Fast erf implementation + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is erf operation + */ +inline Tensor fast_erf(const Tensor& x, + std::string name = "T_fast_erf", + std::string tag = kElementWise) { + if (x->dtype == DataType::Float(32)) { + auto ret = fast_erf_float32(x, name, tag); + return ret; + } else { + return topi::erf(x); + } +} + } // namespace topi #endif // TOPI_ELEMWISE_H_ diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 6f71ae9..6f31cca 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -534,3 +534,19 @@ def fast_tanh(x): The result. """ return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE) + + +def fast_erf(x): + """Take gauss error function of input x using fast_erf implementation. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return cpp.fast_erf(x, x.dtype, tag.ELEMWISE) diff --git a/topi/src/elemwise.cc b/topi/src/elemwise.cc index ab9f6fd..71764cd 100644 --- a/topi/src/elemwise.cc +++ b/topi/src/elemwise.cc @@ -46,6 +46,11 @@ TVM_REGISTER_GLOBAL("topi.erf") *rv = erf(args[0]); }); +TVM_REGISTER_GLOBAL("topi.fast_erf") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = fast_erf(args[0]); + }); + TVM_REGISTER_GLOBAL("topi.tan") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = tan(args[0]); diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index 94b78a9..ea98083 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -16,6 +16,7 @@ # under the License. import numpy as np import scipy +from scipy import special import tvm from tvm import te import topi @@ -238,11 +239,11 @@ def test_fastmath(): test_apply(topi.fast_exp, "fast_exp", np.exp, - low=-88, high=88, - step = 0.01) + low=-88, high=88, step=0.01) + test_apply(topi.fast_erf, "fast_erf", scipy.special.erf, + low=-10, high=10, step=0.01) test_apply(topi.fast_tanh, "fast_tanh", np.tanh, - low=-10, high=10, - step = 0.01) + low=-10, high=10, step=0.01) if __name__ == "__main__": test_util()