Fast exponent (#4790)
authorAlex Gladkov <gladkova@lab126.com>
Mon, 17 Feb 2020 17:22:11 +0000 (09:22 -0800)
committerGitHub <noreply@github.com>
Mon, 17 Feb 2020 17:22:11 +0000 (09:22 -0800)
topi/include/topi/elemwise.h
topi/python/topi/math.py
topi/src/topi.cc
topi/tests/python/test_topi_math.py

index e3f4678..e35e3e4 100644 (file)
@@ -377,5 +377,85 @@ inline Tensor full_like(const Tensor& x,
   }, name, tag);
 }
 
+/*!
+ * \brief Fast exponential function implementation
+ *
+ * \param _x The input tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is exponent operation
+ *
+ * \note Function computes:
+ * log2(e^x) = x * log2(e) * log2(2) =>
+ * log2(e^x) = log2(2^(x*log2(e))) =>
+ * e^x = 2^(x*log2(e))
+ * Splitting power x*log2(e) into integer and fractional parts:
+ * e^(n+f) = e^n * e^f
+ * n = floor(x*log2(e) + 1/2)
+ * f = x - n * ln(2)
+ * exp(x) = 2^n * exp(y)
+ * Approximation for fractional part:
+ * y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2))
+ */
+inline Tensor fast_exp_float32(const Tensor& _x,
+                               std::string name,
+                               std::string tag) {
+  auto x_hi = make_const(DataType::Float(32), 88.3762626647950f);
+  auto x_lo = make_const(DataType::Float(32), -88.3762626647949f);
+  auto log2e = make_const(DataType::Float(32), 1.44269504088896341f);
+  auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f);
+  PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f),
+                   make_const(DataType::Float(32), 1.3981999507E-3f),
+                   make_const(DataType::Float(32), 8.3334519073E-3f),
+                   make_const(DataType::Float(32), 4.1665795894E-2f),
+                   make_const(DataType::Float(32), 1.6666665459E-1f),
+                   make_const(DataType::Float(32), 5.0000001201E-1f)};
+  auto one = make_const(DataType::Float(32), 1.0f);
+  auto one_half = make_const(DataType::Float(32), 0.5f);
+  auto b = make_const(DataType::Float(32), 127.0f);
+
+  return compute(_x->shape,
+                 [&](const Array<Var>& i) {
+                   // clamp x
+                   auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo);
+                   // integer part
+                   auto n = ::tvm::floor(x * log2e + one_half);
+                   // fractional part
+                   auto f = x - n * ln2;
+                   auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3])* f+ p[4]) * f
+                             + p[5]) * f * f + f + one;
+                   // Return 2^m * exp(r).
+                   auto ef = tvm::reinterpret(DataType::Float(32),
+                                              ::tvm::cast(DataType::Int(32), n + b) << 23);
+                   return ::tvm::max(ef * y, _x(i)); // NOLINT(*)
+                 },
+                 name, tag);
+}
+
+
+/*!
+ * \brief Fast exponential function implementation
+ *
+ * \param x The input tensor
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is exponent operation
+ *
+ */
+inline Tensor fast_exp(const Tensor& x,
+                  std::string name = "T_fast_exp",
+                  std::string tag = kElementWise) {
+  if (x->dtype == DataType::Float(32)) {
+    auto ret = fast_exp_float32(x, name, tag);
+    return ret;
+  } else {
+    return compute(x->shape, [&](const Array<Var>& i) {
+        return ::tvm::exp(x(i));
+      }, name, tag);
+  }
+}
+
 }  // namespace topi
 #endif  // TOPI_ELEMWISE_H_
index c3e1a10..148d53a 100644 (file)
@@ -451,3 +451,19 @@ def reinterpret(x, dtype):
         The result.
     """
     return cpp.reinterpret(x, dtype)
+
+
+def fast_exp(x):
+    """Take exponential of input x using fast_exp implementation
+
+    Parameters
+    ----------
+    x : tvm.Tensor
+        Input argument.
+
+    Returns
+    -------
+    y : tvm.Tensor
+        The result.
+    """
+    return cpp.fast_exp(x, x.dtype, tag.ELEMWISE)
index 2b2142b..a7b9160 100644 (file)
@@ -165,6 +165,11 @@ TVM_REGISTER_GLOBAL("topi.exp")
   *rv = exp(args[0]);
   });
 
+TVM_REGISTER_GLOBAL("topi.fast_exp")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = fast_exp(args[0]);
+  });
+
 TVM_REGISTER_GLOBAL("topi.erf")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   *rv = erf(args[0]);
index bb67436..5bb95ba 100644 (file)
@@ -185,7 +185,45 @@ def test_cast():
     verify("bool", "int32")
 
 
+def test_fastmath():
+    def test_apply(
+        func,
+        name,
+        f_numpy,
+        low,
+        high,
+        step,
+        dtype=tvm.float32
+    ):
+        a_np = np.arange(low, high, step).astype(dtype)
+        b_np = f_numpy(a_np)
+        A = tvm.placeholder(a_np.shape, dtype=dtype, name="A")
+        B = func(A)
+        assert tuple(B.shape) == tuple(A.shape)
+
+        def check_device(device):
+            ctx = tvm.context(device, 0)
+            if not ctx.exist:
+                print("Skip because %s is not enabled" % device)
+                return
+            with tvm.target.create(device):
+                s = topi.generic.schedule_injective(B)
+            func = tvm.build(s, [A, B], device, name=name)
+            a = tvm.nd.array(a_np, ctx)
+            b = tvm.nd.array(np.zeros_like(b_np), ctx)
+            func(a, b)
+            tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
+
+        check_device('llvm')
+        check_device('llvm -device=arm-cpu')
+
+
+    test_apply(topi.fast_exp, "fast_exp", np.exp,
+               low=-88, high=88,
+               step = 0.01)
+
 if __name__ == "__main__":
     test_util()
     test_ewise()
     test_cast()
+    test_fastmath()