[Relay][OP] Add fast_erf implementation (#5241)
authorHaichen Shen <shenhaichen@gmail.com>
Tue, 7 Apr 2020 19:05:33 +0000 (12:05 -0700)
committerGitHub <noreply@github.com>
Tue, 7 Apr 2020 19:05:33 +0000 (12:05 -0700)
* add fast erf

* doc

* lint

* fix

* fix indent

include/tvm/target/generic_func.h
python/tvm/relay/op/_tensor.py
src/relay/op/tensor/unary.cc
src/relay/transforms/fast_math.cc
src/relay/transforms/pattern_util.h
tests/python/relay/test_op_fast_math.py
topi/include/topi/elemwise.h
topi/python/topi/math.py
topi/src/elemwise.cc
topi/tests/python/test_topi_math.py

index 89a7f57..f2a361b 100644 (file)
@@ -72,7 +72,7 @@ class GenericFunc : public ObjectRef {
    *
    * \code
    *   // Example code on how to call generic function
-   *   void CallGeneirc(GenericFunc f) {
+   *   void CallGeneric(GenericFunc f) {
    *     // call like normal functions by pass in arguments
    *     // return value is automatically converted back
    *     int rvalue = f(1, 2.0);
index f24da05..a607a47 100644 (file)
@@ -76,6 +76,7 @@ register_injective_schedule("shape_of")
 register_injective_schedule("ndarray_size")
 register_broadcast_schedule("fast_exp")
 register_broadcast_schedule("fast_tanh")
+register_broadcast_schedule("fast_erf")
 
 
 # zeros
@@ -222,3 +223,4 @@ register_shape_func("exp", False, elemwise_shape_func)
 register_shape_func("tan", False, elemwise_shape_func)
 register_shape_func("fast_exp", False, elemwise_shape_func)
 register_shape_func("fast_tanh", False, elemwise_shape_func)
+register_shape_func("fast_erf", False, elemwise_shape_func)
index 3da77e9..4cca8b0 100644 (file)
@@ -128,6 +128,17 @@ RELAY_REGISTER_UNARY_OP("erf")
 .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf));
 
 
+RELAY_REGISTER_UNARY_OP("fast_erf")
+.describe(R"code(Returns the error function value for input array, computed element-wise.
+
+.. math::
+   \fast_erf(x)
+
+)code" TVM_ADD_FILELINE)
+.set_support_level(1)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf));
+
+
 RELAY_REGISTER_UNARY_OP("sqrt")
 .describe(R"code(Returns the sqrt input array, computed element-wise.
 
index 861566f..cf00a89 100644 (file)
@@ -35,11 +35,14 @@ class FastMathMutator : public ExprRewriter {
  public:
   FastMathMutator()
       : exp_op_(Op::Get("exp")),
+        erf_op_(Op::Get("erf")),
         tanh_op_(Op::Get("tanh")) {}
 
   Expr Rewrite_(const CallNode* pre, const Expr& post) override {
     if (pre->op == exp_op_) {
       return FastExp(post.as<CallNode>()->args[0]);
+    } else if (pre->op == erf_op_) {
+      return FastErf(post.as<CallNode>()->args[0]);
     } else if (pre->op == tanh_op_) {
       return FastTanh(post.as<CallNode>()->args[0]);
     }
@@ -51,6 +54,7 @@ class FastMathMutator : public ExprRewriter {
   // operator equivalence checking so that the registry lookup overhead can be
   // reduced.
   const Op& exp_op_;
+  const Op& erf_op_;
   const Op& tanh_op_;
 };
 
index 350d9e1..cd2af9f 100644 (file)
@@ -322,6 +322,11 @@ inline Expr FastExp(Expr e) {
   return Call(op, {e});
 }
 
+inline Expr FastErf(Expr e) {
+  static const Op& op = Op::Get("fast_erf");
+  return Call(op, {e});
+}
+
 inline Expr FastTanh(Expr e) {
   static const Op& op = Op::Get("fast_tanh");
   return Call(op, {e});
index 1d661c3..215b83e 100644 (file)
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
+import scipy
+from scipy import special
 import tvm
 import tvm.relay as relay
 import topi
@@ -52,6 +54,7 @@ def test_fastmath():
                                     rtol=1e-5, atol=1e-5)
 
     test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01)
+    test_apply(relay.erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01)
     test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01)
 
 
index 88d5732..49eb088 100644 (file)
@@ -27,6 +27,7 @@
 #include <tvm/tir/expr.h>
 #include <tvm/tir/ir_pass.h>
 #include <topi/tags.h>
+#include <algorithm>
 #include <string>
 #include "broadcast.h"
 
@@ -63,7 +64,7 @@ TOPI_DECLARE_UNARY_OP(tanh);
 TOPI_DECLARE_UNARY_OP(isfinite);
 TOPI_DECLARE_UNARY_OP(isinf);
 
-/*
+/*!
  * \brief Fast_tanh_float implementation from Eigen
  * https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26
  */
@@ -461,5 +462,75 @@ inline Tensor fast_exp(const Tensor& x,
   }
 }
 
+/*!
+ * \brief Fast_tanh_float implementation from Eigen
+ * https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290
+ */
+inline Tensor fast_erf_float32(const Tensor& data,
+                               std::string name,
+                               std::string tag) {
+  auto plus_4 = make_const(DataType::Float(32), 4.f);
+  auto minus_4 = make_const(DataType::Float(32), -4.f);
+
+  // The monomial coefficients of the numerator polynomial (odd).
+  auto alpha_1 = make_const(DataType::Float(32), -1.60960333262415e-02f);
+  auto alpha_3 = make_const(DataType::Float(32), -2.95459980854025e-03f);
+  auto alpha_5 = make_const(DataType::Float(32), -7.34990630326855e-04f);
+  auto alpha_7 = make_const(DataType::Float(32), -5.69250639462346e-05f);
+  auto alpha_9 = make_const(DataType::Float(32), -2.10102402082508e-06f);
+  auto alpha_11 = make_const(DataType::Float(32), 2.77068142495902e-08f);
+  auto alpha_13 = make_const(DataType::Float(32), -2.72614225801306e-10f);
+
+  // The monomial coefficients of the denominator polynomial (even).
+  auto beta_0 = make_const(DataType::Float(32), -1.42647390514189e-02f);
+  auto beta_2 = make_const(DataType::Float(32), -7.37332916720468e-03f);
+  auto beta_4 = make_const(DataType::Float(32), -1.68282697438203e-03f);
+  auto beta_6 = make_const(DataType::Float(32), -2.13374055278905e-04f);
+  auto beta_8 = make_const(DataType::Float(32), -1.45660718464996e-05f);
+
+  return compute(data->shape, [&](const Array<Var> &i) {
+    // clamp x
+    auto x = tvm::max(tvm::min(data(i), plus_4), minus_4);
+    auto x2 = x * x;
+
+    // Evaluate the numerator polynomial p.
+    auto p = x2 * alpha_13 + alpha_11;
+    p = x2 * p + alpha_9;
+    p = x2 * p + alpha_7;
+    p = x2 * p + alpha_5;
+    p = x2 * p + alpha_3;
+    p = x2 * p + alpha_1;
+    p = x * p;
+
+    // Evaluate the denominator polynomial p.
+    auto q = x2 * beta_8 + beta_6;
+    q = x2 * q + beta_4;
+    q = x2 * q + beta_2;
+    q = x2 * q + beta_0;
+
+    return p / q;
+  }, name, tag);
+}
+
+/*!
+ * \brief Fast erf implementation
+ *
+ * \param x The input tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is erf operation
+ */
+inline Tensor fast_erf(const Tensor& x,
+                       std::string name = "T_fast_erf",
+                       std::string tag = kElementWise) {
+  if (x->dtype == DataType::Float(32)) {
+    auto ret = fast_erf_float32(x, name, tag);
+    return ret;
+  } else {
+    return topi::erf(x);
+  }
+}
+
 }  // namespace topi
 #endif  // TOPI_ELEMWISE_H_
index 6f71ae9..6f31cca 100644 (file)
@@ -534,3 +534,19 @@ def fast_tanh(x):
         The result.
     """
     return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE)
+
+
+def fast_erf(x):
+    """Take gauss error function of input x using fast_erf implementation.
+
+    Parameters
+    ----------
+    x : tvm.te.Tensor
+        Input argument.
+
+    Returns
+    -------
+    y : tvm.te.Tensor
+        The result.
+    """
+    return cpp.fast_erf(x, x.dtype, tag.ELEMWISE)
index ab9f6fd..71764cd 100644 (file)
@@ -46,6 +46,11 @@ TVM_REGISTER_GLOBAL("topi.erf")
   *rv = erf(args[0]);
   });
 
+TVM_REGISTER_GLOBAL("topi.fast_erf")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = fast_erf(args[0]);
+  });
+
 TVM_REGISTER_GLOBAL("topi.tan")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   *rv = tan(args[0]);
index 94b78a9..ea98083 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 import numpy as np
 import scipy
+from scipy import special
 import tvm
 from tvm import te
 import topi
@@ -238,11 +239,11 @@ def test_fastmath():
 
 
     test_apply(topi.fast_exp, "fast_exp", np.exp,
-               low=-88, high=88,
-               step = 0.01)
+               low=-88, high=88, step=0.01)
+    test_apply(topi.fast_erf, "fast_erf", scipy.special.erf,
+               low=-10, high=10, step=0.01)
     test_apply(topi.fast_tanh, "fast_tanh", np.tanh,
-               low=-10, high=10,
-               step = 0.01)
+               low=-10, high=10, step=0.01)
 
 if __name__ == "__main__":
     test_util()