Add 3 additional intrinsic ops to NVVM dialect, in preparation to implement block...
authorMLIR Team <no-reply@google.com>
Tue, 27 Aug 2019 17:55:47 +0000 (10:55 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Aug 2019 17:56:18 +0000 (10:56 -0700)
PiperOrigin-RevId: 265720077

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
mlir/test/LLVMIR/nvvm.mlir
mlir/test/Target/nvvmir.mlir

index 72bbb13..224a580 100644 (file)
@@ -41,7 +41,7 @@ class NVVM_SpecialRegisterOp<string mnemonic,
   string llvmBuilder = "$res = createIntrinsicCall(builder,"
     # "llvm::Intrinsic::nvvm_" # !subst(".","_", mnemonic) # ");";
   let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }];
-  let printer = [{ printNVVMSpecialRegisterOp(p, this->getOperation()); }];
+  let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
 }
 
 def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">;
@@ -57,4 +57,42 @@ def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
 def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
 def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
 
+def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
+  string llvmBuilder = [{
+      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
+  }];
+  let parser = [{ return success(); }];
+  let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
+}
+
+def NVVM_ShflBflyOp :
+  NVVM_Op<"shfl.sync.bfly">,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_Type:$dst,
+                 LLVM_Type:$val,
+                 LLVM_Type:$offset,
+                 LLVM_Type:$mask_and_clamp)> {
+  string llvmBuilder = [{
+      auto intId = $val->getType()->isFloatTy() ?
+          llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 :
+          llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
+      $res = createIntrinsicCall(builder,
+          intId, {$dst, $val, $offset, $mask_and_clamp});
+  }];
+  let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
+  let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
+}
+
+def NVVM_VoteBallotOp :
+  NVVM_Op<"vote.ballot.sync">,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
+  string llvmBuilder = [{
+      $res = createIntrinsicCall(builder,
+            llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
+  }];
+  let parser = [{ return parseNVVMVoteBallotOp(parser, result); }];
+  let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
+}
+
 #endif // NVVMIR_OPS
index 8d6f308..90d285e 100644 (file)
@@ -43,13 +43,11 @@ namespace NVVM {
 // Printing/parsing for NVVM ops
 //===----------------------------------------------------------------------===//
 
-static void printNVVMSpecialRegisterOp(OpAsmPrinter *p, Operation *op) {
-  *p << op->getName() << " : ";
-  if (op->getNumResults() == 1) {
-    *p << op->getResult(0)->getType();
-  } else {
-    *p << "###invalid type###";
-  }
+static void printNVVMIntrinsicOp(OpAsmPrinter *p, Operation *op) {
+  *p << op->getName() << " ";
+  p->printOperands(op->getOperands());
+  if (op->getNumResults() > 0)
+    interleaveComma(op->getResultTypes(), *p << " : ");
 }
 
 // <operation> ::= `llvm.nvvm.XYZ` : type
@@ -64,6 +62,49 @@ static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser *parser,
   return success();
 }
 
+static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser *parser) {
+  return parser->getBuilder()
+      .getContext()
+      ->getRegisteredDialect<LLVM::LLVMDialect>();
+}
+
+// <operation> ::=
+//     `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
+//     : result_type
+static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser *parser,
+                                           OperationState *result) {
+  auto llvmDialect = getLlvmDialect(parser);
+  auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
+
+  SmallVector<OpAsmParser::OperandType, 8> ops;
+  Type type;
+  return failure(parser->parseOperandList(ops) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(type) ||
+                 parser->addTypeToList(type, result->types) ||
+                 parser->resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
+                                         parser->getNameLoc(),
+                                         result->operands));
+}
+
+// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
+static ParseResult parseNVVMVoteBallotOp(OpAsmParser *parser,
+                                         OperationState *result) {
+  auto llvmDialect = getLlvmDialect(parser);
+  auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
+  auto int1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
+
+  SmallVector<OpAsmParser::OperandType, 8> ops;
+  Type type;
+  return failure(parser->parseOperandList(ops) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(type) ||
+                 parser->addTypeToList(type, result->types) ||
+                 parser->resolveOperands(ops, {int32Ty, int1Ty},
+                                         parser->getNameLoc(),
+                                         result->operands));
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
index 98dc43c..32fa167 100644 (file)
@@ -38,10 +38,11 @@ using namespace mlir;
 
 namespace {
 static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
-                                        llvm::Intrinsic::ID intrinsic) {
+                                        llvm::Intrinsic::ID intrinsic,
+                                        ArrayRef<llvm::Value *> args = {}) {
   llvm::Module *module = builder.GetInsertBlock()->getModule();
-  llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {});
-  return builder.CreateCall(fn);
+  llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic);
+  return builder.CreateCall(fn, args);
 }
 
 class ModuleTranslation : public LLVM::ModuleTranslation {
index 8716df9..8ca439d 100644 (file)
@@ -1,29 +1,51 @@
 // RUN: mlir-opt %s | FileCheck %s
 
 func @nvvm_special_regs() -> !llvm.i32 {
-  // CHECK: %0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.tid.x : !llvm.i32
   %0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
-  // CHECK: %1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.tid.y : !llvm.i32
   %1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
-  // CHECK: %2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.tid.z : !llvm.i32
   %2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
-  // CHECK: %3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.ntid.x : !llvm.i32
   %3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
-  // CHECK: %4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.ntid.y : !llvm.i32
   %4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
-  // CHECK: %5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.ntid.z : !llvm.i32
   %5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
-  // CHECK: %6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
   %6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
-  // CHECK: %7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
   %7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
-  // CHECK: %8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
   %8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
-  // CHECK: %9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
   %9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
-  // CHECK: %10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
   %10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
-  // CHECK: %11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
+  // CHECK: nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
   %11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
   llvm.return %0 : !llvm.i32
 }
+
+func @llvm.nvvm.barrier0() {
+  // CHECK: nvvm.barrier0
+  nvvm.barrier0
+  llvm.return
+}
+
+func @nvvm_shfl(
+    %arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
+    %arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm.i32 {
+  // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32
+  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : !llvm.i32
+  // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float
+  %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : !llvm.float
+  llvm.return %0 : !llvm.i32
+}
+
+func @nvvm_vote(%arg0 : !llvm.i32, %arg1 : !llvm.i1) -> !llvm.i32 {
+  // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : !llvm.i32
+  %0 = nvvm.vote.ballot.sync %arg0, %arg1 : !llvm.i32
+  llvm.return %0 : !llvm.i32
+}
index 85d7ef2..74a1ebc 100644 (file)
@@ -1,33 +1,55 @@
 // RUN: mlir-translate -mlir-to-nvvmir %s | FileCheck %s
 
 func @nvvm_special_regs() -> !llvm.i32 {
-  // CHECK: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
   %1 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
-  // CHECK: %2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
   %2 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
-  // CHECK: %3 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
   %3 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
-  // CHECK: %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
   %4 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
-  // CHECK: %5 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
   %5 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
-  // CHECK: %6 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
   %6 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
-  // CHECK: %7 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
   %7 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
-  // CHECK: %8 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
   %8 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
-  // CHECK: %9 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
   %9 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
-  // CHECK: %10 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
   %10 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
-  // CHECK: %11 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
   %11 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
-  // CHECK: %12 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
   %12 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
   llvm.return %1 : !llvm.i32
 }
 
+func @llvm.nvvm.barrier0() {
+  // CHECK: call void @llvm.nvvm.barrier0()
+  nvvm.barrier0
+  llvm.return
+}
+
+func @nvvm_shfl(
+    %0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
+    %3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 {
+  // CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : !llvm.i32
+  // CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : !llvm.float
+  llvm.return %6 : !llvm.i32
+}
+
+func @nvvm_vote(%0 : !llvm.i32, %1 : !llvm.i1) -> !llvm.i32 {
+  // CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
+  %3 = nvvm.vote.ballot.sync %0, %1 : !llvm.i32
+  llvm.return %3 : !llvm.i32
+}
+
 // This function has the "kernel" attribute attached and should appear in the
 // NVVM annotations after conversion.
 func @kernel_func() attributes {gpu.kernel} {