update rocm intrin rule (#4499)
authorPeter Yeh <petrex@users.noreply.github.com>
Wed, 11 Dec 2019 09:14:36 +0000 (01:14 -0800)
committermasahi <masahi129@gmail.com>
Wed, 11 Dec 2019 09:14:36 +0000 (18:14 +0900)
src/codegen/llvm/intrin_rule_rocm.cc

index aaaf0ca..5ad5261 100644 (file)
@@ -22,7 +22,6 @@
  */
 #ifdef TVM_LLVM_VERSION
 
-#include "intrin_rule_llvm.h"
 #include <tvm/ir.h>
 #include <tvm/expr.h>
 #include <tvm/api_registry.h>
@@ -45,27 +44,28 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
 namespace llvm {
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
+.set_body(DispatchExternOCML);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
+.set_body(DispatchExternOCML);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
+.set_body(DispatchExternOCML);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
+.set_body(DispatchExternOCML);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
+.set_body(DispatchExternOCML);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
 .set_body(DispatchExternOCML);
 
-// On AMD GPU, fma is slower than mac
-// removing fma dispatch allows backend to generate faster mac instruction
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf")
+.set_body(DispatchExternOCML);
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>);
+.set_body(DispatchExternOCML);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log")
 .set_body(DispatchExternOCML);
@@ -78,6 +78,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow")
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh")
 .set_body(DispatchExternOCML);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos")
+.set_body(DispatchExternOCML);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin")
+.set_body(DispatchExternOCML);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan")
+.set_body(DispatchExternOCML);
+
 }  // namespace llvm
 }  // namespace codegen
 }  // namespace tvm