GPU dialect: introduce custom syntax for gpu.launch
authorAlex Zinenko <zinenko@google.com>
Mon, 6 May 2019 09:30:50 +0000 (02:30 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:29:57 +0000 (08:29 -0700)
    This syntax removes boilerplate and verbose list of region arguments in the
    header of the entry block.  It groups operands into segments related to GPU
    blocks, GPU threads as well as the operands that are forwarded to the kernel.
    The two former segments are also used to give names to the region arguments
    that are used for GPU blocks and threads inside the kernel body region.

--

PiperOrigin-RevId: 246792329

mlir/g3doc/Dialects/GPU.md
mlir/include/mlir/GPU/GPUDialect.h
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/GPU/IR/GPUDialect.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/GPU/ops.mlir

index 8f572ee..ddf3c3e 100644 (file)
@@ -38,11 +38,28 @@ body region. Nested regions inside the kernel body are allowed to use values
 defined in their ancestor regions as long as they don't cross the kernel body
 region boundary.
 
-Custom syntax for this operation is currently not available.
+Syntax:
+
+``` {.ebnf}
+operation ::= `gpu.launch` `block` `(` ssa-id-list `)` `in` ssa-reassignment
+                         `threads` `(` ssa-id-list `)` `in` ssa-reassignment
+                           (`args` ssa-reassignment `:` type-list)?
+                           region attr-dict?
+ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
+```
 
 Example:
 
 ```mlir {.mlir}
+gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %0, %sz_by = %1, %sz_bz = %2)
+           threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5)
+           args(%arg0 = %6, %arg1 = 7) : f32, memref<?xf32, 1> {
+  // Block and thread identifiers, as well as block/grid sizes are
+  // immediately usable inside body region.
+  "some_op"(%bx, %tx) : (index, index) -> ()
+  %42 = load %arg1[%bx] : memref<?xf32, 1>
+}
+
 // Generic syntax explains how the pretty syntax maps to the IR structure.
 "gpu.launch"(%cst, %cst, %c1,  // Grid sizes.
                     %cst, %c1, %c1,   // Block sizes.
index 2229b2a..4a9b0f6 100644 (file)
@@ -39,12 +39,18 @@ public:
   static StringRef getDialectName();
 };
 
+/// Utility class for the GPU dialect to represent triples of `Value`s
+/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
 struct KernelDim3 {
   Value *x;
   Value *y;
   Value *z;
 };
 
+/// GPU kernel launch operation.  Takes a 3D grid of thread blocks as leading
+/// operands, followed by kernel data operands.  Has one region representing
+/// the kernel to be executed.  This region is not allowed to use values defined
+/// outside it.
 class LaunchOp : public Op<LaunchOp, OpTrait::AtLeastNOperands<6>::Impl,
                            OpTrait::ZeroResult,
                            OpTrait::NthRegionIsIsolatedAbove<0>::Impl> {
@@ -57,15 +63,38 @@ public:
                     Value *blockSizeY, Value *blockSizeZ,
                     ArrayRef<Value *> operands);
 
+  /// Get the kernel region.
   Region &getBody();
+
+  /// Get the SSA values corresponding to kernel block identifiers.
   KernelDim3 getBlockIds();
+  /// Get the SSA values corresponding to kernel thread identifiers.
   KernelDim3 getThreadIds();
+  /// Get the SSA values corresponding to kernel grid size.
   KernelDim3 getGridSize();
+  /// Get the SSA values corresponding to kernel block size.
   KernelDim3 getBlockSize();
 
   LogicalResult verify();
 
+  /// Custom syntax support.
+  void print(OpAsmPrinter *p);
+  static bool parse(OpAsmParser *parser, OperationState *result);
+
   static StringRef getOperationName() { return "gpu.launch"; }
+
+private:
+  static StringRef getBlocksKeyword() { return "blocks"; }
+  static StringRef getThreadsKeyword() { return "threads"; }
+  static StringRef getArgsKeyword() { return "args"; }
+
+  /// The number of launch configuration operands, placed at the leading
+  /// positions of the operand list.
+  static constexpr unsigned kNumConfigOperands = 6;
+
+  /// The number of region attributes containing the launch configuration,
+  /// placed in the leading positions of the argument list.
+  static constexpr unsigned kNumConfigRegionAttributes = 12;
 };
 
 } // end namespace mlir
index f402f74..a818c01 100644 (file)
@@ -159,6 +159,9 @@ public:
   /// This parses... a comma!
   virtual bool parseComma() = 0;
 
+  /// Parses a comma if present.
+  virtual bool parseOptionalComma() = 0;
+
   /// Parse a `:` token.
   virtual bool parseColon() = 0;
 
index 825ea5b..cc440b7 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "mlir/GPU/GPUDialect.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/StandardTypes.h"
 
 using namespace mlir;
@@ -53,12 +54,13 @@ void LaunchOp::build(Builder *builder, OperationState *result, Value *gridSizeX,
       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
   result->addOperands(operands);
 
-  // Create a kernel body region with 12 + N arguments, where the first 12
-  // arguments have `index` type and the rest have the same types as the data
-  // operands.
+  // Create a kernel body region with kNumConfigRegionAttributes + N arguments,
+  // where the first kNumConfigRegionAttributes arguments have `index` type and
+  // the rest have the same types as the data operands.
   Region *kernelRegion = result->addRegion();
   Block *body = new Block();
-  body->addArguments(std::vector<Type>(12, builder->getIndexType()));
+  body->addArguments(
+      std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType()));
   body->addArguments(getValueTypes(operands));
   kernelRegion->push_back(body);
 }
@@ -86,14 +88,172 @@ KernelDim3 LaunchOp::getBlockSize() {
 }
 
 LogicalResult LaunchOp::verify() {
-  // Kernel launch takes 6 leading operands for grid/block sizes and transforms
-  // them into 12 region arguments for block/thread identifiers and grid/block
-  // sizes.
+  // Kernel launch takes kNumConfigOperands leading operands for grid/block
+  // sizes and transforms them into kNumConfigRegionAttributes region arguments
+  // for block/thread identifiers and grid/block sizes.
   if (!getBody().empty()) {
     Block &entryBlock = getBody().front();
-    if (entryBlock.getNumArguments() != 6 + getNumOperands())
+    if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands())
       return emitError("unexpected number of region arguments");
   }
 
   return success();
 }
+
+// Pretty-print the kernel grid/block size assignment as
+//   (%iter-x, %iter-y, %iter-z) in
+//   (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
+// where %size-* and %iter-* will correspond to the body region arguments.
+static void printSizeAssignment(OpAsmPrinter *p, KernelDim3 size,
+                                ArrayRef<Value *> operands, KernelDim3 ids) {
+  *p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in (";
+  *p << *size.x << " = " << *operands[0] << ", ";
+  *p << *size.y << " = " << *operands[1] << ", ";
+  *p << *size.z << " = " << *operands[2] << ')';
+}
+
+void LaunchOp::print(OpAsmPrinter *p) {
+  SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end());
+  ArrayRef<Value *> operands(operandContainer);
+
+  // Print the launch configuration.
+  *p << getOperationName() << ' ' << getBlocksKeyword();
+  printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds());
+  *p << ' ' << getThreadsKeyword();
+  printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds());
+
+  // From now on, the first kNumConfigOperands operands corresponding to grid
+  // and block sizes are irrelevant, so we can drop them.
+  operands = operands.drop_front(kNumConfigOperands);
+
+  // Print the data argument remapping.
+  if (!getBody().empty() && !operands.empty()) {
+    *p << ' ' << getArgsKeyword() << '(';
+    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+      if (i != 0)
+        *p << ", ";
+      *p << *getBody().front().getArgument(kNumConfigRegionAttributes + i)
+         << " = " << *operands[i];
+    }
+    *p << ") ";
+  }
+
+  // Print the types of data arguments.
+  if (!operands.empty()) {
+    *p << ": ";
+    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+      if (i != 0)
+        *p << ", ";
+      *p << operands[i]->getType();
+    }
+  }
+
+  p->printRegion(getBody(), /*printEntryBlockArgs=*/false);
+  p->printOptionalAttrDict(getAttrs());
+}
+
+// Parse the size assignment blocks for blocks and threads.  These have the form
+//   (%region_arg, %region_arg, %region_arg) in
+//   (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
+// where %region_arg are percent-identifiers for the region arguments to be
+// introduced futher (SSA defs), and %operand are percent-identifiers for the
+// SSA value uses.
+static bool
+parseSizeAssignment(OpAsmParser *parser,
+                    MutableArrayRef<OpAsmParser::OperandType> sizes,
+                    MutableArrayRef<OpAsmParser::OperandType> regionSizes,
+                    MutableArrayRef<OpAsmParser::OperandType> indices) {
+  if (parser->parseLParen() || parser->parseRegionArgument(indices[0]) ||
+      parser->parseComma() || parser->parseRegionArgument(indices[1]) ||
+      parser->parseComma() || parser->parseRegionArgument(indices[2]) ||
+      parser->parseRParen() || parser->parseKeyword("in") ||
+      parser->parseLParen())
+    return true;
+
+  for (int i = 0; i < 3; ++i) {
+    if (i != 0 && parser->parseComma())
+      return true;
+    if (parser->parseRegionArgument(regionSizes[i]) || parser->parseEqual() ||
+        parser->parseOperand(sizes[i]))
+      return true;
+  }
+
+  return parser->parseRParen();
+}
+
+// Parses a Launch operation.
+// operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
+//                           `threads` `(` ssa-id-list `)` `in` ssa-reassignment
+//                             (`args` ssa-reassignment `:` type-list)?
+//                             region attr-dict?
+// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
+bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
+  // Sizes of the grid and block.
+  SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes(
+      kNumConfigOperands);
+  MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes);
+
+  // Actual (data) operands passed to the kernel.
+  SmallVector<OpAsmParser::OperandType, 4> dataOperands;
+
+  // Region arguments to be created.
+  SmallVector<OpAsmParser::OperandType, 16> regionArgs(
+      kNumConfigRegionAttributes);
+  MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs);
+
+  // Parse the size assignment segments: the first segment assigns grid siezs
+  // and defines values for block identifiers; the second segment assigns block
+  // sies and defines values for thread identifiers.  In the region argument
+  // list, identifiers preceed sizes, and block-related values preceed
+  // thread-related values.
+  if (parser->parseKeyword(getBlocksKeyword().data()) ||
+      parseSizeAssignment(parser, sizesRef.take_front(3),
+                          regionArgsRef.slice(6, 3),
+                          regionArgsRef.slice(0, 3)) ||
+      parser->parseKeyword(getThreadsKeyword().data()) ||
+      parseSizeAssignment(parser, sizesRef.drop_front(3),
+                          regionArgsRef.slice(9, 3),
+                          regionArgsRef.slice(3, 3)) ||
+      parser->resolveOperands(sizes, parser->getBuilder().getIndexType(),
+                              result->operands))
+    return true;
+
+  // If kernel argument renaming segment is present, parse it.  When present,
+  // the segment should have at least one element.  If this segment is present,
+  // so is the trailing type list.  Parse it as well and use the parsed types
+  // to resolve the operands passed to the kernel arguments.
+  SmallVector<Type, 4> dataTypes;
+  if (!parser->parseOptionalKeyword(getArgsKeyword().data())) {
+    llvm::SMLoc argsLoc;
+
+    regionArgs.push_back({});
+    dataOperands.push_back({});
+    if (parser->getCurrentLocation(&argsLoc) || parser->parseLParen() ||
+        parser->parseRegionArgument(regionArgs.back()) ||
+        parser->parseEqual() || parser->parseOperand(dataOperands.back()))
+      return true;
+
+    while (!parser->parseOptionalComma()) {
+      regionArgs.push_back({});
+      dataOperands.push_back({});
+      if (parser->parseRegionArgument(regionArgs.back()) ||
+          parser->parseEqual() || parser->parseOperand(dataOperands.back()))
+        return true;
+    }
+
+    if (parser->parseRParen() || parser->parseColonTypeList(dataTypes) ||
+        parser->resolveOperands(dataOperands, dataTypes, argsLoc,
+                                result->operands))
+      return true;
+  }
+
+  // Introduce the body region and parse it.  The region has
+  // kNumConfigRegionAttributes leading arguments that correspond to
+  // block/thread identifiers and grid/block sizes, all of the `index` type.
+  // Follow the actual kernel arguments.
+  Type index = parser->getBuilder().getIndexType();
+  dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index);
+  Region *body = result->addRegion();
+  return parser->parseRegion(*body, regionArgs, dataTypes) ||
+         parser->parseOptionalAttributeDict(result->attributes);
+}
index 9ba043b..15807f6 100644 (file)
@@ -3216,6 +3216,8 @@ public:
     return false;
   }
 
+  bool parseOptionalComma() override { return !parser.consumeIf(Token::comma); }
+
   /// Parse an optional keyword.
   bool parseOptionalKeyword(const char *keyword) override {
     // Check that the current token is a bare identifier or keyword.
index 9d094ef..74b3705 100644 (file)
@@ -2,51 +2,36 @@
 
 // CHECK-LABEL:func @no_args(%arg0: index)
 func @no_args(%sz : index) {
-// CHECK:  "gpu.launch"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0)
-// CHECK-SAME: {
-  "gpu.launch"(%sz, %sz, %sz, %sz, %sz, %sz) ({
-  ^bb1(%bx: index, %by: index, %bz: index,
-       %tx: index, %ty: index, %tz: index,
-       %szbx: index, %szby: index, %szbz: index,
-       %sztx: index, %szty: index, %sztz: index):
+// CHECK: gpu.launch blocks(%i0, %i1, %i2) in (%i6 = %arg0, %i7 = %arg0, %i8 = %arg0) threads(%i3, %i4, %i5) in (%i9 = %arg0, %i10 = %arg0, %i11 = %arg0)
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+             threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
     return
-// CHECK: (index, index, index, index, index, index) -> ()
-  }) : (index, index, index, index, index, index) -> ()
+  }
   return
 }
 
 // CHECK-LABEL:func @args(%arg0: index, %arg1: index, %arg2: f32, %arg3: memref<?xf32, 1>) {
 func @args(%blk : index, %thrd : index, %float : f32, %data : memref<?xf32,1>) {
-// CHECK:  "gpu.launch"(%arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg2, %arg3)
-// CHECK-SAME: {
-  "gpu.launch"(%blk, %blk, %blk, %thrd, %thrd, %thrd, %float, %data) ({
-  ^bb1(%bx: index, %by: index, %bz: index,
-       %tx: index, %ty: index, %tz: index,
-       %szbx: index, %szby: index, %szbz: index,
-       %sztx: index, %szty: index, %sztz: index,
-       %data0: f32, %data1: memref<?xf32,1>):
+// CHECK: gpu.launch blocks(%i0, %i1, %i2) in (%i6 = %arg0, %i7 = %arg0, %i8 = %arg0) threads(%i3, %i4, %i5) in (%i9 = %arg1, %i10 = %arg1, %i11 = %arg1) args(%i12 = %arg2, %i13 = %arg3) : f32, memref<?xf32, 1>
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %blk, %grid_y = %blk, %grid_z = %blk)
+             threads(%tx, %ty, %tz) in (%block_x = %thrd, %block_y = %thrd, %block_z = %thrd)
+            args(%kernel_arg0 = %float, %kernel_arg1 = %data) : f32, memref<?xf32, 1> {
     return
-// CHECK: (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
-  }) : (index, index, index, index, index, index, f32, memref<?xf32,1>) -> ()
+  }
   return
 }
 
 // It is possible to use values passed into the region as arguments.
 // CHECK-LABEL: func @passing_values
 func @passing_values(%blk : index, %thrd : index, %float : f32, %data : memref<?xf32,1>) {
-// CHECK:  "gpu.launch"(%arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg2, %arg3)
-// CHECK-SAME: {
-  "gpu.launch"(%blk, %blk, %blk, %thrd, %thrd, %thrd, %float, %data) ({
-// CHECK: ^bb1(%i0: index, %i1: index, %i2: index, %i3: index, %i4: index, %i5: index, %i6: index, %i7: index, %i8: index, %i9: index, %i10: index, %i11: index, %i12: f32, %i13: memref<?xf32, 1>)
-  ^bb1(%bx: index, %by: index, %bz: index,
-       %tx: index, %ty: index, %tz: index,
-       %szbx: index, %szby: index, %szbz: index,
-       %sztx: index, %szty: index, %sztz: index,
-       %data0: f32, %data1: memref<?xf32,1>):
+// CHECK: gpu.launch blocks(%i0, %i1, %i2) in (%i6 = %arg0, %i7 = %arg0, %i8 = %arg0) threads(%i3, %i4, %i5) in (%i9 = %arg1, %i10 = %arg1, %i11 = %arg1) args(%i12 = %arg2, %i13 = %arg3) : f32, memref<?xf32, 1>
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %blk, %grid_y = %blk, %grid_z = %blk)
+             threads(%tx, %ty, %tz) in (%block_x = %thrd, %block_y = %thrd, %block_z = %thrd)
+            args(%kernel_arg0 = %float, %kernel_arg1 = %data) : f32, memref<?xf32, 1> {
 // CHECK: "use"(%i12)
-    "use"(%data0): (f32) -> ()
+    "use"(%kernel_arg0): (f32) -> ()
     return
-  }) : (index, index, index, index, index, index, f32, memref<?xf32,1>) -> ()
+  }
   return
 }
 
@@ -54,11 +39,8 @@ func @passing_values(%blk : index, %thrd : index, %float : f32, %data : memref<?
 // cross kernel launch region boundaries.
 // CHECK-LABEL: func @nested_isolation
 func @nested_isolation(%sz : index) {
-  "gpu.launch"(%sz, %sz, %sz, %sz, %sz, %sz) ({
-  ^bb1(%bx: index, %by: index, %bz: index,
-       %tx: index, %ty: index, %tz: index,
-       %szbx: index, %szby: index, %szbz: index,
-       %sztx: index, %szty: index, %sztz: index):
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+             threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
     "region"() ({
 // CHECK: %0 = "produce"()
       %val = "produce"() : () -> (index)
@@ -67,6 +49,6 @@ func @nested_isolation(%sz : index) {
         "use"(%val) : (index) -> ()
       }) : () -> ()
     }) : () -> ()
-  }) : (index, index, index, index, index, index) -> ()
+  }
   return
 }