From 696fcb7520e13b8c712146b6bcb6a55815af660c Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Tue, 27 Aug 2019 10:55:47 -0700 Subject: [PATCH] Add 3 additional intrinsic ops to NVVM dialect, in preparation to implement block-wide reduce. PiperOrigin-RevId: 265720077 --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 40 ++++++++++++++++++++- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 55 +++++++++++++++++++++++++---- mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 7 ++-- mlir/test/LLVMIR/nvvm.mlir | 46 +++++++++++++++++------- mlir/test/Target/nvvmir.mlir | 46 +++++++++++++++++------- 5 files changed, 159 insertions(+), 35 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 72bbb13..224a580 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -41,7 +41,7 @@ class NVVM_SpecialRegisterOpgetOperation()); }]; + let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; } def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">; @@ -57,4 +57,42 @@ def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">; def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; +def NVVM_Barrier0Op : NVVM_Op<"barrier0"> { + string llvmBuilder = [{ + createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0); + }]; + let parser = [{ return success(); }]; + let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; +} + +def NVVM_ShflBflyOp : + NVVM_Op<"shfl.sync.bfly">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type:$dst, + LLVM_Type:$val, + LLVM_Type:$offset, + LLVM_Type:$mask_and_clamp)> { + string llvmBuilder = [{ + auto intId = $val->getType()->isFloatTy() ? + llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 : + llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; + $res = createIntrinsicCall(builder, + intId, {$dst, $val, $offset, $mask_and_clamp}); + }]; + let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }]; + let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; +} + +def NVVM_VoteBallotOp : + NVVM_Op<"vote.ballot.sync">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> { + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, + llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred}); + }]; + let parser = [{ return parseNVVMVoteBallotOp(parser, result); }]; + let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; +} + #endif // NVVMIR_OPS diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 8d6f308..90d285e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -43,13 +43,11 @@ namespace NVVM { // Printing/parsing for NVVM ops //===----------------------------------------------------------------------===// -static void printNVVMSpecialRegisterOp(OpAsmPrinter *p, Operation *op) { - *p << op->getName() << " : "; - if (op->getNumResults() == 1) { - *p << op->getResult(0)->getType(); - } else { - *p << "###invalid type###"; - } +static void printNVVMIntrinsicOp(OpAsmPrinter *p, Operation *op) { + *p << op->getName() << " "; + p->printOperands(op->getOperands()); + if (op->getNumResults() > 0) + interleaveComma(op->getResultTypes(), *p << " : "); } // ::= `llvm.nvvm.XYZ` : type @@ -64,6 +62,49 @@ static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser *parser, return success(); } +static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser *parser) { + return parser->getBuilder() + .getContext() + ->getRegisteredDialect(); +} + +// ::= +// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask` +// : result_type +static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser *parser, + OperationState *result) { + auto llvmDialect = getLlvmDialect(parser); + auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); + + SmallVector ops; + Type type; + return failure(parser->parseOperandList(ops) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->addTypeToList(type, result->types) || + parser->resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty}, + parser->getNameLoc(), + result->operands)); +} + +// ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type +static ParseResult parseNVVMVoteBallotOp(OpAsmParser *parser, + OperationState *result) { + auto llvmDialect = getLlvmDialect(parser); + auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); + auto int1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect); + + SmallVector ops; + Type type; + return failure(parser->parseOperandList(ops) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->addTypeToList(type, result->types) || + parser->resolveOperands(ops, {int32Ty, int1Ty}, + parser->getNameLoc(), + result->operands)); +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 98dc43c..32fa167 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -38,10 +38,11 @@ using namespace mlir; namespace { static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder, - llvm::Intrinsic::ID intrinsic) { + llvm::Intrinsic::ID intrinsic, + ArrayRef args = {}) { llvm::Module *module = builder.GetInsertBlock()->getModule(); - llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {}); - return builder.CreateCall(fn); + llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic); + return builder.CreateCall(fn, args); } class ModuleTranslation : public LLVM::ModuleTranslation { diff --git a/mlir/test/LLVMIR/nvvm.mlir b/mlir/test/LLVMIR/nvvm.mlir index 8716df9..8ca439d 100644 --- a/mlir/test/LLVMIR/nvvm.mlir +++ b/mlir/test/LLVMIR/nvvm.mlir @@ -1,29 +1,51 @@ // RUN: mlir-opt %s | FileCheck %s func @nvvm_special_regs() -> !llvm.i32 { - // CHECK: %0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.tid.x : !llvm.i32 %0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32 - // CHECK: %1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.tid.y : !llvm.i32 %1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32 - // CHECK: %2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.tid.z : !llvm.i32 %2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32 - // CHECK: %3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.ntid.x : !llvm.i32 %3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32 - // CHECK: %4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.ntid.y : !llvm.i32 %4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32 - // CHECK: %5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.ntid.z : !llvm.i32 %5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32 - // CHECK: %6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.ctaid.x : !llvm.i32 %6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32 - // CHECK: %7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.ctaid.y : !llvm.i32 %7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32 - // CHECK: %8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.ctaid.z : !llvm.i32 %8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32 - // CHECK: %9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.nctaid.x : !llvm.i32 %9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32 - // CHECK: %10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.nctaid.y : !llvm.i32 %10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32 - // CHECK: %11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 + // CHECK: nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 %11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 llvm.return %0 : !llvm.i32 } + +func @llvm.nvvm.barrier0() { + // CHECK: nvvm.barrier0 + nvvm.barrier0 + llvm.return +} + +func @nvvm_shfl( + %arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, + %arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm.i32 { + // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32 + %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : !llvm.i32 + // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float + %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : !llvm.float + llvm.return %0 : !llvm.i32 +} + +func @nvvm_vote(%arg0 : !llvm.i32, %arg1 : !llvm.i1) -> !llvm.i32 { + // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : !llvm.i32 + %0 = nvvm.vote.ballot.sync %arg0, %arg1 : !llvm.i32 + llvm.return %0 : !llvm.i32 +} diff --git a/mlir/test/Target/nvvmir.mlir b/mlir/test/Target/nvvmir.mlir index 85d7ef2..74a1ebc 100644 --- a/mlir/test/Target/nvvmir.mlir +++ b/mlir/test/Target/nvvmir.mlir @@ -1,33 +1,55 @@ // RUN: mlir-translate -mlir-to-nvvmir %s | FileCheck %s func @nvvm_special_regs() -> !llvm.i32 { - // CHECK: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.x() %1 = nvvm.read.ptx.sreg.tid.x : !llvm.i32 - // CHECK: %2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.y() %2 = nvvm.read.ptx.sreg.tid.y : !llvm.i32 - // CHECK: %3 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.z() %3 = nvvm.read.ptx.sreg.tid.z : !llvm.i32 - // CHECK: %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() %4 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32 - // CHECK: %5 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ntid.y() %5 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32 - // CHECK: %6 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ntid.z() %6 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32 - // CHECK: %7 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() %7 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32 - // CHECK: %8 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() %8 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32 - // CHECK: %9 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() %9 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32 - // CHECK: %10 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x() %10 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32 - // CHECK: %11 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y() %11 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32 - // CHECK: %12 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z() + // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z() %12 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 llvm.return %1 : !llvm.i32 } +func @llvm.nvvm.barrier0() { + // CHECK: call void @llvm.nvvm.barrier0() + nvvm.barrier0 + llvm.return +} + +func @nvvm_shfl( + %0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32, + %3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 { + // CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : !llvm.i32 + // CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : !llvm.float + llvm.return %6 : !llvm.i32 +} + +func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 { + // CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}}) + %3 = nvvm.vote.ballot.sync %0, %1 : !llvm.i32 + llvm.return %3 : !llvm.i32 +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. func @kernel_func() attributes {gpu.kernel} { -- 2.7.4