TVM_DLL Pass SimplifyInference();
/*!
+ * \brief Replaces non linear activation functions with their fast but approximate counterparts.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass FastMath();
+
+/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
- "CombineParallelDense": 4
+ "CombineParallelDense": 4,
+ "FastMath": 4
}
fallback_device : int, str, or tvmContext, optional
Returns
-------
ret: tvm.relay.Pass
- The registered to perform operator simplification.
+ The registered pass to perform operator simplification.
"""
return _transform.SimplifyInference()
+def FastMath():
+ """ Converts the expensive non linear functions to their fast but approximate counterparts.
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered pass to perform fast math operations.
+ """
+ return _transform.FastMath()
+
+
def CanonicalizeOps():
"""Canonicalize special operators to basic operators.
This can simplify followed analysis, e.g. expanding bias_add to
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
}
+
+ // Fast math optimizations.
+ pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());
// Create a sequential pass and perform optimizations.
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
+RELAY_REGISTER_UNARY_OP("fast_exp")
+.describe(R"code(Returns the fast_exp input array, computed element-wise.
+
+.. math::
+ \fast_exp(x)
+
+)code" TVM_ADD_FILELINE)
+.set_support_level(1)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp));
+
+
RELAY_REGISTER_UNARY_OP("erf")
.describe(R"code(Returns the error function value for input array, computed element-wise.
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
+RELAY_REGISTER_UNARY_OP("fast_tanh")
+.describe(R"code(Returns the fast_tanh of input array, computed element-wise.
+
+.. math::
+ Y = sinh(X) / cosh(X)
+
+)code" TVM_ADD_FILELINE)
+.set_support_level(1)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh));
+
+
RELAY_REGISTER_UNARY_OP("negative")
.describe(R"code(Returns the numeric negative of input array, computed element-wise.
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fast_math.cc
+ * \brief Replaces non linear activation functions with their fast but approximate counterparts.
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/op.h>
+#include "pattern_util.h"
+
+namespace tvm {
+namespace relay {
+
+class FastMathMutator : public ExprMutator {
+ public:
+ FastMathMutator()
+ : exp_op_(Op::Get("exp")),
+ tanh_op_(Op::Get("tanh")) {}
+
+ Expr VisitExpr_(const CallNode* n) {
+ auto new_n = ExprMutator::VisitExpr_(n);
+ if (n->op == exp_op_) {
+ return FastExp(new_n.as<CallNode>()->args[0]);
+ } else if (n->op == tanh_op_) {
+ return FastTanh(new_n.as<CallNode>()->args[0]);
+ }
+ return new_n;
+ }
+
+ private:
+ // Cache the following ops. They will be used in the passes repeatedly for
+ // operator equivalence checking so that the registry lookup overhead can be
+ // reduced.
+ const Op& exp_op_;
+ const Op& tanh_op_;
+};
+
+Expr FastMath(const Expr& e) {
+ return FastMathMutator().Mutate(e);
+}
+
+namespace transform {
+
+Pass FastMath() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(FastMath(f));
+ };
+ return CreateFunctionPass(pass_func, 4, "FastMath",
+ {tir::StringImmNode::make("InferType")});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.FastMath")
+.set_body_typed(FastMath);
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
return CallNode::make(op, {e});
}
+inline Expr FastExp(Expr e) {
+ static const Op& op = Op::Get("fast_exp");
+ return CallNode::make(op, {e});
+}
+
+inline Expr FastTanh(Expr e) {
+ static const Op& op = Op::Get("fast_tanh");
+ return CallNode::make(op, {e});
+}
+
inline Expr Log(Expr e) {
static const Op& op = Op::Get("log");
return CallNode::make(op, {e});
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.ir import IRModule
+from tvm import relay
+from tvm.relay.transform import FastMath
+
+def test_exp():
+ x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
+ y = relay.exp(x)
+ func = relay.Function([x], y)
+ mod = tvm.IRModule.from_expr(func)
+
+ fast_mod = FastMath()(mod)
+ assert "fast_exp" in fast_mod.astext()
+
+ # Check that FastMath option works for relay.build.
+ with relay.build_config(opt_level=3, required_pass=['FastMath']):
+ fast_mod = relay.optimize(mod, target='llvm', params=None)
+ assert "fast_exp" in fast_mod[0].astext()
+
+def test_tanh():
+ x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
+ y = relay.tanh(x)
+ func = relay.Function([x], y)
+ mod = tvm.IRModule.from_expr(func)
+
+ fast_mod = FastMath()(mod)
+ assert "fast_tanh" in fast_mod.astext()
+
+ # Check that FastMath option works for relay.build.
+ with relay.build_config(opt_level=3, required_pass=['FastMath']):
+ fast_mod = relay.optimize(mod, target='llvm', params=None)
+ assert "fast_tanh" in fast_mod[0].astext()
+
+if __name__ == "__main__":
+ test_exp()
+ test_tanh()
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
+TOPI_DECLARE_UNARY_OP(tanh);
/*
* \brief Fast_tanh_float implementation from Eigen
*
* \return A Tensor whose op member is tanh
*/
-inline Tensor tanh(const Tensor& x,
- std::string name = "T_tanh",
- std::string tag = kElementWise) {
+inline Tensor fast_tanh(const Tensor& x,
+ std::string name = "T_fast_tanh",
+ std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
// invoke fast_tanh_float implementation
return fast_tanh_float(x, name, tag);
The result.
"""
return cpp.fast_exp(x, x.dtype, tag.ELEMWISE)
+
+
+def fast_tanh(x):
+ """Take tanhonential of input x using fast_tanh implementation
+
+ Parameters
+ ----------
+ x : tvm.Tensor
+ Input argument.
+
+ Returns
+ -------
+ y : tvm.Tensor
+ The result.
+ """
+ return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE)
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
});
-
+TVM_REGISTER_GLOBAL("topi.fast_tanh")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = fast_tanh(args[0]);
+ });
TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]);