return call;
}
+llvm::Function* CodeGenLLVM::GetIntrinsicDecl(
+ llvm::Intrinsic::ID id, llvm::Type* ret_type,
+ llvm::ArrayRef<llvm::Type*> arg_types) {
+ llvm::Module* module = module_.get();
+
+ if (!llvm::Intrinsic::isOverloaded(id)) {
+ return llvm::Intrinsic::getDeclaration(module, id, {});
+ }
+
+ llvm::SmallVector<llvm::Intrinsic::IITDescriptor, 4> infos;
+ llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos);
+ llvm::SmallVector<llvm::Type*, 4> overload_types;
+
+#if TVM_LLVM_VERSION >= 90
+ auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) {
+ overload_types.clear();
+ llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
+ auto match =
+ llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types);
+ if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
+ bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref);
+ if (error) {
+ return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg;
+ }
+ }
+ return match;
+ };
+
+ // First, try matching the signature assuming non-vararg case.
+ auto* fn_ty = llvm::FunctionType::get(ret_type, arg_types, false);
+ switch (try_match(fn_ty, false)) {
+ case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet:
+ // The return type doesn't match, there is nothing else to do.
+ return nullptr;
+ case llvm::Intrinsic::MatchIntrinsicTypes_Match:
+ return llvm::Intrinsic::getDeclaration(module, id, overload_types);
+ case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg:
+ break;
+ }
+
+ // Keep adding one type at a time (starting from empty list), and
+ // try matching the vararg signature.
+ llvm::SmallVector<llvm::Type*, 4> var_types;
+ for (int i = 0, e = arg_types.size(); i <= e; ++i) {
+ if (i > 0) var_types.push_back(arg_types[i - 1]);
+ auto* ft = llvm::FunctionType::get(ret_type, var_types, true);
+ if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
+ return llvm::Intrinsic::getDeclaration(module, id, overload_types);
+ }
+ }
+ // Failed to identify the type.
+ return nullptr;
+
+#else // TVM_LLVM_VERSION
+ llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
+ // matchIntrinsicType returns true on error.
+ if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) {
+ return nullptr;
+ }
+ for (llvm::Type* t : arg_types) {
+ if (llvm::Intrinsic::matchIntrinsicType(t, ref, overload_types)) {
+ return nullptr;
+ }
+ }
+ return llvm::Intrinsic::getDeclaration(module, id, overload_types);
+#endif // TVM_LLVM_VERSION
+}
+
llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
Downcast<IntImm>(op->args[0])->value);
int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
std::vector<llvm::Value*> arg_value;
- std::vector<llvm::Type*> sig_type;
+ std::vector<llvm::Type*> arg_type;
for (size_t i = 2; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i]));
if (i - 2 < static_cast<size_t>(num_signature)) {
- sig_type.push_back(arg_value.back()->getType());
+ arg_type.push_back(arg_value.back()->getType());
}
}
- llvm::Type *return_type = GetLLVMType(GetRef<PrimExpr>(op));
- if (sig_type.size() > 0 && return_type != sig_type[0]) {
- sig_type.insert(sig_type.begin(), return_type);
- }
- llvm::Function* f = llvm::Intrinsic::getDeclaration(
- module_.get(), id, sig_type);
+ // LLVM's prefetch intrinsic returns "void", while TVM's prefetch
+ // returns int32. This causes problems because prefetch is one of
+ // those intrinsics that is generated automatically via the
+ // tvm.intrin.rule mechanism. Any other intrinsic with a type
+ // mismatch will have to be treated specially here.
+ // TODO(kparzysz-quic): fix this once TVM prefetch uses the same
+ // type as LLVM.
+ llvm::Type *return_type = (id != llvm::Intrinsic::prefetch)
+ ? GetLLVMType(GetRef<PrimExpr>(op))
+ : llvm::Type::getVoidTy(*ctx_);
+
+ llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
+ CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
+ << llvm::Intrinsic::getName(id, {});
return builder_->CreateCall(f, arg_value);
} else if (op->is_intrinsic(CallNode::bitwise_and)) {
return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
* \param type The corresponding TVM Type.
*/
llvm::Type* GetLLVMType(const PrimExpr& expr) const;
+ /*!
+ * \brief Get the declaration of the LLVM intrinsic based on the intrinsic
+ * id, and the type of the actual call.
+ *
+ * \param id The intrinsic id.
+ * \param ret_type The call return type.
+ * \param arg_types The types of the call arguments.
+ *
+ * \return Return the llvm::Function pointer, or nullptr if the declaration
+ * could not be generated (e.g. if the argument/return types do not
+ * match).
+ */
+ llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id,
+ llvm::Type* ret_type,
+ llvm::ArrayRef<llvm::Type*> arg_types);
// initialize the function state.
void InitFuncState();
// Get alignment given index.
namespace llvm {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
-.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 0>);
+.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>);
+.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>);
+.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
fcode = tvm.build(func, None, "llvm")
+def test_llvm_overloaded_intrin():
+ # Name lookup for overloaded intrinsics in LLVM 4- requires a name
+ # that includes the overloaded types.
+ if tvm.target.codegen.llvm_version_major() < 5:
+ return
+
+ def use_llvm_intrinsic(A, C):
+ ib = tvm.tir.ir_builder.create()
+ L = A.vload((0,0))
+ I = tvm.tir.call_llvm_intrin('int32', 'llvm.ctlz',
+ tvm.tir.const(2, 'uint32'), L, tvm.tir.const(0, 'int1'))
+ S = C.vstore((0,0), I)
+ ib.emit(S)
+ return ib.get()
+
+ A = tvm.te.placeholder((1,1), dtype = 'int32', name = 'A')
+ C = tvm.te.extern((1,1), [A],
+ lambda ins, outs: use_llvm_intrinsic(ins[0], outs[0]),
+ name = 'C' , dtype = 'int32')
+
+ s = tvm.te.create_schedule(C.op)
+ f = tvm.build(s, [A, C], target = 'llvm')
+
+
def test_llvm_import():
# extern "C" is necessary to get the correct signature
cc_code = """
def test_llvm_lookup_intrin():
ib = tvm.tir.ir_builder.create()
- m = te.size_var("m")
A = ib.pointer("uint8x8", name="A")
- x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'), A)
+ z = tvm.tir.const(0, 'int32')
+ x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
ib.emit(x)
body = ib.get()
func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True)
test_llvm_vadd_pipeline()
test_llvm_add_pipeline()
test_llvm_intrin()
+ test_llvm_overloaded_intrin()
test_llvm_flip_pipeline()
test_llvm_madd_pipeline()
test_llvm_temp_space()
ww, xx = ins
zz = outs[0]
- args_1 = tvm.tir.const(1, 'uint32')
args_2 = tvm.tir.const(2, 'uint32')
if unipolar:
cnts8[i] = upper_half + lower_half
for i in range(m//2):
cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
- args_1, cnts8[i*2], cnts8[i*2+1])
+ args_2, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4):
cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
- args_1, cnts4[i*2], cnts4[i*2+1])
+ args_2, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.tir.call_pure_intrin(
full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
cnts8[i] = tvm.tir.popcount(w_ & x_)
for i in range(m//2):
cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
- args_1, cnts8[i*2], cnts8[i*2+1])
+ args_2, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4):
cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
- args_1, cnts4[i*2], cnts4[i*2+1])
+ args_2, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.tir.call_pure_intrin(
full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)