Fix missing sigmoid intrinsic in C++ (#2231)
authorSergey Mironov <grrwlf@gmail.com>
Fri, 7 Dec 2018 14:43:28 +0000 (17:43 +0300)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 7 Dec 2018 14:43:28 +0000 (09:43 -0500)
python/tvm/intrin.py
src/codegen/intrin_rule.cc

index 3207b6112b1d40b92b01501183776691a54a25be..cd9a108c546aedabf825f15d4d3ba25b566f14be 100644 (file)
@@ -492,6 +492,3 @@ def _rule_float_direct(op):
 register_intrin_rule("opencl", "exp", _rule_float_direct, override=True)
 # default pattern for exp
 register_intrin_rule("default", "exp", _rule_float_suffix, override=True)
-
-# default pattern for sigmoid
-register_intrin_rule("default", "sigmoid", lambda op: 1.0 / (1.0 + exp(-op.args[0])))
index 822d515fb8a54569ae45abd8d4e2045f7580b7fd..f326fceb6ee834907a8b7940c5708143f3ab309f 100644 (file)
@@ -24,6 +24,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
 .set_body(DispatchExtern<FloatSuffix>);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
+.set_body([](const TVMArgs& args, TVMRetValue* rv){
+    Expr e = args[0];
+    const Call* call = e.as<Call>();
+    CHECK(call != nullptr);
+
+    auto one = make_const(call->args[0].type(), 1);
+    *rv = one / (one + exp(-call->args[0]));
+  });
+
 }  // namespace intrin
 }  // namespace codegen
 }  // namespace tvm