[Relay][FastMath] Relay pass to use fast exp/tanh (#4873)
authorAnimesh Jain <anijain@umich.edu>
Sun, 1 Mar 2020 21:57:24 +0000 (13:57 -0800)
committerGitHub <noreply@github.com>
Sun, 1 Mar 2020 21:57:24 +0000 (13:57 -0800)
* [Relay][FastMath] Relay pass to use fast exp/tanh

* Adding required_pass to the tests.

* FastMath test changes.

include/tvm/relay/transform.h
python/tvm/relay/transform.py
src/relay/backend/build_module.cc
src/relay/op/tensor/unary.cc
src/relay/pass/fast_math.cc [new file with mode: 0644]
src/relay/pass/pattern_util.h
tests/python/relay/test_pass_fast_math.py [new file with mode: 0644]
topi/include/topi/elemwise.h
topi/python/topi/math.py
topi/src/topi.cc

index 8d886aa..2862800 100644 (file)
@@ -164,6 +164,13 @@ TVM_DLL Pass PartialEval();
 TVM_DLL Pass SimplifyInference();
 
 /*!
+ * \brief Replaces non linear activation functions with their fast but approximate counterparts.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass FastMath();
+
+/*!
  * \brief Infer the type of an expression.
  *
  * The result of type checking is a new expression with unambigous
index 45535af..f773835 100644 (file)
@@ -57,7 +57,8 @@ def build_config(opt_level=2,
                 "CanonicalizeCast": 3,
                 "EliminateCommonSubexpr": 3,
                 "CombineParallelConv2D": 4,
-                "CombineParallelDense": 4
+                "CombineParallelDense": 4,
+                "FastMath": 4
             }
 
     fallback_device : int, str, or tvmContext, optional
@@ -175,11 +176,22 @@ def SimplifyInference():
     Returns
     -------
     ret: tvm.relay.Pass
-        The registered to perform operator simplification.
+        The registered pass to perform operator simplification.
     """
     return _transform.SimplifyInference()
 
 
+def FastMath():
+    """ Converts the expensive non linear functions to their fast but approximate counterparts.
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered pass to perform fast math operations.
+    """
+    return _transform.FastMath()
+
+
 def CanonicalizeOps():
     """Canonicalize special operators to basic operators.
     This can simplify followed analysis, e.g. expanding bias_add to
index ff64d4a..0c0a8b8 100644 (file)
@@ -305,6 +305,9 @@ class RelayBuildModule : public runtime::ModuleNode {
     if (targets.size() == 1) {
       pass_seqs.push_back(transform::AlterOpLayout());
     }
+
+    // Fast math optimizations.
+    pass_seqs.push_back(transform::FastMath());
     pass_seqs.push_back(transform::FoldConstant());
 
     // Create a sequential pass and perform optimizations.
index 2c73458..1169fa8 100644 (file)
@@ -95,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("exp")
 .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
 
 
+RELAY_REGISTER_UNARY_OP("fast_exp")
+.describe(R"code(Returns the fast_exp input array, computed element-wise.
+
+.. math::
+   \fast_exp(x)
+
+)code" TVM_ADD_FILELINE)
+.set_support_level(1)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp));
+
+
 RELAY_REGISTER_UNARY_OP("erf")
 .describe(R"code(Returns the error function value for input array, computed element-wise.
 
@@ -250,6 +261,17 @@ RELAY_REGISTER_UNARY_OP("tanh")
 .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
 
 
+RELAY_REGISTER_UNARY_OP("fast_tanh")
+.describe(R"code(Returns the fast_tanh of input array, computed element-wise.
+
+.. math::
+   Y = sinh(X) / cosh(X)
+
+)code" TVM_ADD_FILELINE)
+.set_support_level(1)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh));
+
+
 RELAY_REGISTER_UNARY_OP("negative")
 .describe(R"code(Returns the numeric negative of input array, computed element-wise.
 
diff --git a/src/relay/pass/fast_math.cc b/src/relay/pass/fast_math.cc
new file mode 100644 (file)
index 0000000..898f760
--- /dev/null
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fast_math.cc
+ * \brief Replaces non linear activation functions with their fast but approximate counterparts.
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/op.h>
+#include "pattern_util.h"
+
+namespace tvm {
+namespace relay {
+
+class FastMathMutator : public ExprMutator {
+ public:
+  FastMathMutator()
+      : exp_op_(Op::Get("exp")),
+        tanh_op_(Op::Get("tanh")) {}
+
+  Expr VisitExpr_(const CallNode* n) {
+    auto new_n = ExprMutator::VisitExpr_(n);
+    if (n->op == exp_op_) {
+      return FastExp(new_n.as<CallNode>()->args[0]);
+    } else if (n->op == tanh_op_) {
+      return FastTanh(new_n.as<CallNode>()->args[0]);
+    }
+    return new_n;
+  }
+
+ private:
+  // Cache the following ops. They will be used in the passes repeatedly for
+  // operator equivalence checking so that the registry lookup overhead can be
+  // reduced.
+  const Op& exp_op_;
+  const Op& tanh_op_;
+};
+
+Expr FastMath(const Expr& e) {
+  return FastMathMutator().Mutate(e);
+}
+
+namespace transform {
+
+Pass FastMath() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+    [=](Function f, IRModule m, PassContext pc) {
+    return Downcast<Function>(FastMath(f));
+  };
+  return CreateFunctionPass(pass_func, 4, "FastMath",
+                            {tir::StringImmNode::make("InferType")});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.FastMath")
+.set_body_typed(FastMath);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
index f7d8f9c..85750f5 100644 (file)
@@ -316,6 +316,16 @@ inline Expr Exp(Expr e) {
   return CallNode::make(op, {e});
 }
 
+inline Expr FastExp(Expr e) {
+  static const Op& op = Op::Get("fast_exp");
+  return CallNode::make(op, {e});
+}
+
+inline Expr FastTanh(Expr e) {
+  static const Op& op = Op::Get("fast_tanh");
+  return CallNode::make(op, {e});
+}
+
 inline Expr Log(Expr e) {
   static const Op& op = Op::Get("log");
   return CallNode::make(op, {e});
diff --git a/tests/python/relay/test_pass_fast_math.py b/tests/python/relay/test_pass_fast_math.py
new file mode 100644 (file)
index 0000000..e75316f
--- /dev/null
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.ir import IRModule
+from tvm import relay
+from tvm.relay.transform import FastMath
+
+def test_exp():
+    x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
+    y = relay.exp(x)
+    func = relay.Function([x], y)
+    mod = tvm.IRModule.from_expr(func)
+
+    fast_mod = FastMath()(mod)
+    assert "fast_exp" in fast_mod.astext()
+
+    # Check that FastMath option works for relay.build.
+    with relay.build_config(opt_level=3, required_pass=['FastMath']):
+        fast_mod = relay.optimize(mod, target='llvm', params=None)
+    assert "fast_exp" in fast_mod[0].astext()
+
+def test_tanh():
+    x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
+    y = relay.tanh(x)
+    func = relay.Function([x], y)
+    mod = tvm.IRModule.from_expr(func)
+
+    fast_mod = FastMath()(mod)
+    assert "fast_tanh" in fast_mod.astext()
+
+    # Check that FastMath option works for relay.build.
+    with relay.build_config(opt_level=3, required_pass=['FastMath']):
+        fast_mod = relay.optimize(mod, target='llvm', params=None)
+    assert "fast_tanh" in fast_mod[0].astext()
+
+if __name__ == "__main__":
+    test_exp()
+    test_tanh()
index e35e3e4..3c0822f 100644 (file)
@@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(cos);
 TOPI_DECLARE_UNARY_OP(sin);
 TOPI_DECLARE_UNARY_OP(atan);
 TOPI_DECLARE_UNARY_OP(isnan);
+TOPI_DECLARE_UNARY_OP(tanh);
 
 /*
  * \brief Fast_tanh_float implementation from Eigen
@@ -113,9 +114,9 @@ inline Tensor fast_tanh_float(const Tensor& in,
 *
 * \return A Tensor whose op member is tanh
 */
-inline Tensor tanh(const Tensor& x,
-                   std::string name = "T_tanh",
-                   std::string tag = kElementWise) {
+inline Tensor fast_tanh(const Tensor& x,
+                        std::string name = "T_fast_tanh",
+                        std::string tag = kElementWise) {
   if (x->dtype == DataType::Float(32)) {
     // invoke fast_tanh_float implementation
     return fast_tanh_float(x, name, tag);
index 5b6b9ab..4a63c45 100644 (file)
@@ -467,3 +467,19 @@ def fast_exp(x):
         The result.
     """
     return cpp.fast_exp(x, x.dtype, tag.ELEMWISE)
+
+
+def fast_tanh(x):
+    """Take tanhonential of input x using fast_tanh implementation
+
+    Parameters
+    ----------
+    x : tvm.Tensor
+        Input argument.
+
+    Returns
+    -------
+    y : tvm.Tensor
+        The result.
+    """
+    return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE)
index 79e223c..75517b8 100644 (file)
@@ -188,7 +188,10 @@ TVM_REGISTER_GLOBAL("topi.tanh")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   *rv = tanh(args[0]);
   });
-
+TVM_REGISTER_GLOBAL("topi.fast_tanh")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = fast_tanh(args[0]);
+  });
 TVM_REGISTER_GLOBAL("topi.atan")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   *rv = atan(args[0]);