string llvmBuilder = "$res = createIntrinsicCall(builder,"
# "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");";
let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }];
- let printer = [{ printNVVMSpecialRegisterOp(p, this->getOperation()); }];
+ let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
}
def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">;
def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
+def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
+ }];
+ let parser = [{ return success(); }];
+ let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
+}
+
+def NVVM_ShflBflyOp :
+ NVVM_Op<"shfl.sync.bfly">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_Type:$dst,
+ LLVM_Type:$val,
+ LLVM_Type:$offset,
+ LLVM_Type:$mask_and_clamp)> {
+ string llvmBuilder = [{
+ auto intId = $val->getType()->isFloatTy() ?
+ llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 :
+ llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
+ $res = createIntrinsicCall(builder,
+ intId, {$dst, $val, $offset, $mask_and_clamp});
+ }];
+ let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
+ let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
+}
+
+def NVVM_VoteBallotOp :
+ NVVM_Op<"vote.ballot.sync">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+ string llvmBuilder = [{
+ $res = createIntrinsicCall(builder,
+ llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
+ }];
+ let parser = [{ return parseNVVMVoteBallotOp(parser, result); }];
+ let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
+}
+
#endif // NVVMIR_OPS
// Printing/parsing for NVVM ops
//===----------------------------------------------------------------------===//
-static void printNVVMSpecialRegisterOp(OpAsmPrinter *p, Operation *op) {
- *p << op->getName() << " : ";
- if (op->getNumResults() == 1) {
- *p << op->getResult(0)->getType();
- } else {
- *p << "###invalid type###";
- }
+static void printNVVMIntrinsicOp(OpAsmPrinter *p, Operation *op) {
+ *p << op->getName() << " ";
+ p->printOperands(op->getOperands());
+ if (op->getNumResults() > 0)
+ interleaveComma(op->getResultTypes(), *p << " : ");
}
// <operation> ::= `llvm.nvvm.XYZ` : type
return success();
}
+static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser *parser) {
+ return parser->getBuilder()
+ .getContext()
+ ->getRegisteredDialect<LLVM::LLVMDialect>();
+}
+
+// <operation> ::=
+// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
+// : result_type
+static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser *parser,
+ OperationState *result) {
+ auto llvmDialect = getLlvmDialect(parser);
+ auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
+
+ SmallVector<OpAsmParser::OperandType, 8> ops;
+ Type type;
+ return failure(parser->parseOperandList(ops) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->addTypeToList(type, result->types) ||
+ parser->resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
+ parser->getNameLoc(),
+ result->operands));
+}
+
+// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
+static ParseResult parseNVVMVoteBallotOp(OpAsmParser *parser,
+ OperationState *result) {
+ auto llvmDialect = getLlvmDialect(parser);
+ auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
+ auto int1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
+
+ SmallVector<OpAsmParser::OperandType, 8> ops;
+ Type type;
+ return failure(parser->parseOperandList(ops) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->addTypeToList(type, result->types) ||
+ parser->resolveOperands(ops, {int32Ty, int1Ty},
+ parser->getNameLoc(),
+ result->operands));
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
namespace {
static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
- llvm::Intrinsic::ID intrinsic) {
+ llvm::Intrinsic::ID intrinsic,
+ ArrayRef<llvm::Value *> args = {}) {
llvm::Module *module = builder.GetInsertBlock()->getModule();
- llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {});
- return builder.CreateCall(fn);
+ llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic);
+ return builder.CreateCall(fn, args);
}
class ModuleTranslation : public LLVM::ModuleTranslation {
// RUN: mlir-opt %s | FileCheck %s
func @nvvm_special_regs() -> !llvm.i32 {
- // CHECK: %0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.tid.x : !llvm.i32
%0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
- // CHECK: %1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.tid.y : !llvm.i32
%1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
- // CHECK: %2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.tid.z : !llvm.i32
%2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
- // CHECK: %3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.ntid.x : !llvm.i32
%3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
- // CHECK: %4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.ntid.y : !llvm.i32
%4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
- // CHECK: %5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.ntid.z : !llvm.i32
%5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
- // CHECK: %6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
%6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
- // CHECK: %7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
%7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
- // CHECK: %8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
%8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
- // CHECK: %9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
%9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
- // CHECK: %10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
%10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
- // CHECK: %11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
+ // CHECK: nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
%11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
llvm.return %0 : !llvm.i32
}
+
+func @llvm.nvvm.barrier0() {
+ // CHECK: nvvm.barrier0
+ nvvm.barrier0
+ llvm.return
+}
+
+func @nvvm_shfl(
+ %arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
+ %arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm.i32 {
+ // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : !llvm.i32
+ // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float
+ %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : !llvm.float
+ llvm.return %0 : !llvm.i32
+}
+
+func @nvvm_vote(%arg0 : !llvm.i32, %arg1 : !llvm.i1) -> !llvm.i32 {
+ // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = nvvm.vote.ballot.sync %arg0, %arg1 : !llvm.i32
+ llvm.return %0 : !llvm.i32
+}
// RUN: mlir-translate -mlir-to-nvvmir %s | FileCheck %s
func @nvvm_special_regs() -> !llvm.i32 {
- // CHECK: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%1 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
- // CHECK: %2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
%2 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
- // CHECK: %3 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
%3 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
- // CHECK: %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%4 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
- // CHECK: %5 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
%5 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
- // CHECK: %6 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
%6 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
- // CHECK: %7 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%7 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
- // CHECK: %8 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
%8 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
- // CHECK: %9 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
%9 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
- // CHECK: %10 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
%10 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
- // CHECK: %11 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
%11 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
- // CHECK: %12 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
+ // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
%12 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
llvm.return %1 : !llvm.i32
}
+func @llvm.nvvm.barrier0() {
+ // CHECK: call void @llvm.nvvm.barrier0()
+ nvvm.barrier0
+ llvm.return
+}
+
+func @nvvm_shfl(
+ %0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
+ %3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 {
+ // CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : !llvm.i32
+ // CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : !llvm.float
+ llvm.return %6 : !llvm.i32
+}
+
+func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
+ // CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
+ %3 = nvvm.vote.ballot.sync %0, %1 : !llvm.i32
+ llvm.return %3 : !llvm.i32
+}
+
// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
func @kernel_func() attributes {gpu.kernel} {