From 490510d463bd760fc474f956a3098b88ebf7260a Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 4 Jun 2020 17:46:55 +0200 Subject: [PATCH] codegen llvm: move nvptx-specific intrinsic handling into codegen_nvptx (#5726) See discussion in #5600. I'm also throwing in a pointer lifetime fix for the context held by NVPTX because otherwise topi/tests/python/test_topi_softmax.py would sefault for me. With the test, I can also run resnet-18 on the nvptx target in gpu_imagenet_bench.py. --- src/target/llvm/codegen_llvm.cc | 52 --------------------------------- src/target/llvm/codegen_nvptx.cc | 62 +++++++++++++++++++++++++++++++++++++++- topi/python/topi/cuda/softmax.py | 3 ++ 3 files changed, 64 insertions(+), 53 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 5c2c41a..b43e988 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -736,40 +736,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type #endif // TVM_LLVM_VERSION } -// Check if this is a warp shuffle intrinsic call and match its -// corresponding nvvm intrinsic. Return true if the match is successful. -static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) { - // Only 32 bit data type is supported. - if (op->dtype.is_vector() || op->dtype.bits() != 32) { - return false; - } - - // Intrinsic lookup table. - // It is difficult to emit _sync verion that works on Pascal. - // We ignore the mask and only emit the non-sync version for nvptx. - llvm::Intrinsic::ID ids[] = { - llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32, - llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32, - llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32}; - - int offset = 0; - if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) { - offset = 0; - } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) { - offset = 2; - } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) { - offset = 4; - } else { - return false; - } - - *id = ids[offset + op->dtype.is_float()]; - return true; -} - llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { - llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic; - if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); @@ -814,25 +781,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return CreateStorageSync(op); - } else if (GetWarpShuffleIntrinsic(op, &id)) { - std::vector arg_value; - std::vector arg_type; - // Ignore the first mask operand and remove the last - // redundant warp_size.. - size_t n_args = op->args.size() - 1; - for (size_t i = 1; i < n_args; ++i) { - arg_value.push_back(MakeValue(op->args[i])); - arg_type.push_back(arg_value.back()->getType()); - } - llvm::Type* return_type = arg_type[0]; - llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type); - return builder_->CreateCall(func, arg_value); - } else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) { - // Only nvptx target may keep this intrinsic at this point. - // PTX assembly: asm "activemask.b32 r1;" - auto fty = llvm::FunctionType::get(t_int32_, false); - auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true); - return builder_->CreateCall(val); } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index a0687b9..353f322 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -170,6 +170,8 @@ class CodeGenNVPTX : public CodeGenLLVM { CodeGenLLVM::Optimize(); } + llvm::Value* CreateIntrinsic(const CallNode* op) override; + protected: void InitTarget(llvm::TargetMachine* tm) final { // Maximum vector lane = float4 @@ -178,6 +180,62 @@ class CodeGenNVPTX : public CodeGenLLVM { } }; +// Check if this is a warp shuffle intrinsic call and match its +// corresponding nvvm intrinsic. Return true if the match is successful. +static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) { + // Only 32 bit data type is supported. + if (op->dtype.is_vector() || op->dtype.bits() != 32) { + return false; + } + + // Intrinsic lookup table. + // It is difficult to emit _sync verion that works on Pascal. + // We ignore the mask and only emit the non-sync version for nvptx. + llvm::Intrinsic::ID ids[] = { + llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32, + llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32, + llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32}; + + int offset = 0; + if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) { + offset = 0; + } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) { + offset = 2; + } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) { + offset = 4; + } else { + return false; + } + + *id = ids[offset + op->dtype.is_float()]; + return true; +} + +llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { + llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic; + if (GetWarpShuffleIntrinsic(op, &id)) { + std::vector arg_value; + std::vector arg_type; + // Ignore the first mask operand and remove the last + // redundant warp_size.. + size_t n_args = op->args.size() - 1; + for (size_t i = 1; i < n_args; ++i) { + arg_value.push_back(MakeValue(op->args[i])); + arg_type.push_back(arg_value.back()->getType()); + } + llvm::Type* return_type = arg_type[0]; + llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type); + return builder_->CreateCall(func, arg_value); + } else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) { + // Only nvptx target may keep this intrinsic at this point. + // PTX assembly: asm "activemask.b32 r1;" + auto fty = llvm::FunctionType::get(t_int32_, false); + auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true); + return builder_->CreateCall(val); + } + return CodeGenLLVM::CreateIntrinsic(op); +} + inline int DetectCUDAComputeVersion() { TVMContext tvm_ctx; tvm_ctx.device_type = kDLGPU; @@ -204,8 +262,10 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) { config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver << target.substr(5, target.length() - 5); std::unique_ptr tm = GetLLVMTargetMachine(config.str()); - std::unique_ptr cg(new CodeGenNVPTX()); std::unique_ptr ctx(new llvm::LLVMContext()); + // careful: cg will hold a naked pointer reference to ctx, so it should + // have a shorter lifetime than the ctx. + std::unique_ptr cg(new CodeGenNVPTX()); cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false); diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py index 6142f48..50e2b0d 100644 --- a/topi/python/topi/cuda/softmax.py +++ b/topi/python/topi/cuda/softmax.py @@ -60,6 +60,9 @@ def schedule_softmax(outs): def sched_warp_softmax(): if tgt.target_name == "nvptx": return softmax.dtype == "float32" or softmax.dtype == "int32" + if tgt.target_name != "cuda": + # this is used as the gpu schedule for other arches which may not have warp reductions + return False return True if len(softmax.shape) > 2: -- 2.7.4