register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
# default pattern for exp
register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
-
-# default pattern for sigmoid
-register_intrin_rule("default", "sigmoid", lambda op: 1.0 / (1.0 + exp(-op.args[0])))
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
.set_body(DispatchExtern<FloatSuffix>);
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
+.set_body([](const TVMArgs& args, TVMRetValue* rv){
+ Expr e = args[0];
+ const Call* call = e.as<Call>();
+ CHECK(call != nullptr);
+
+ auto one = make_const(call->args[0].type(), 1);
+ *rv = one / (one + exp(-call->args[0]));
+ });
+
} // namespace intrin
} // namespace codegen
} // namespace tvm