Add gpu.shuffle op.
authorChristian Sigg <csigg@google.com>
Fri, 20 Dec 2019 10:52:21 +0000 (02:52 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 20 Dec 2019 10:52:52 +0000 (02:52 -0800)
This will allow us to lower most of gpu.all_reduce (when all_reduce
doesn't exist in the target dialect) within the GPU dialect, and only do
target-specific lowering for the shuffle op.

PiperOrigin-RevId: 286548256

mlir/include/mlir/Dialect/GPU/GPUDialect.h
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/ops.mlir
mlir/test/mlir-cuda-runner/shuffle.mlir [new file with mode: 0644]

index 495238f..93c0b13 100644 (file)
@@ -26,6 +26,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/FunctionSupport.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/SymbolTable.h"
 
 namespace mlir {
index 46433c6..6751f0a 100644 (file)
@@ -536,6 +536,41 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
   let verifier = [{ return ::verifyAllReduce(*this); }];
 }
 
+def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">;
+
+def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr",
+    "Indexing modes supported by gpu.shuffle.",
+    [
+      GPU_ShuffleOpXor,
+    ]>;
+
+def GPU_ShuffleOp : GPU_Op<"shuffle", [NoSideEffect]>,
+    Arguments<(ins AnyType:$value, I32:$offset, I32:$width,
+               GPU_ShuffleModeAttr:$mode)>,
+    Results<(outs AnyType:$result, I1:$valid)> {
+  let summary = "Shuffles values within a subgroup.";
+  let description = [{
+    The "shuffle" op moves values to a different invocation within the same
+    subgroup.
+
+    For example
+    ```
+      %1, %2 = gpu.shuffle %0, %offset, %width xor : f32
+    ```
+    for lane k returns the value from lane `k ^ offset` and `true` if that lane
+    is smaller than %width. Otherwise it returns an unspecified value and
+    `false`. A lane is the index of an invocation relative to its subgroup.
+
+    The width specifies the number of invocations that participate in the
+    shuffle. The width needs to be the same for all invocations that participate
+    in the shuffle. Exactly the first `width` invocations of a subgroup need to
+    execute this op in convergence.
+  }];
+  let verifier = [{ return ::verifyShuffleOp(*this); }];
+  let printer = [{ printShuffleOp(p, *this); }];
+  let parser = [{ return parseShuffleOp(parser, result); }];
+}
+
 def GPU_BarrierOp : GPU_Op<"barrier"> {
   let summary = "Synchronizes all work items of a workgroup.";
   let description = [{
index 78fe15d..220df53 100644 (file)
@@ -473,6 +473,64 @@ private:
   static constexpr int kWarpSize = 32;
 };
 
+struct GPUShuffleOpLowering : public LLVMOpLowering {
+  explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
+      : LLVMOpLowering(gpu::ShuffleOp::getOperationName(),
+                       lowering_.getDialect()->getContext(), lowering_) {}
+
+  /// Lowers a shuffle to the corresponding NVVM op.
+  ///
+  /// Convert the `width` argument into an activeMask (a bitmask which specifies
+  /// which threads participate in the shuffle) and a maskAndClamp (specifying
+  /// the highest lane which participates in the shuffle).
+  ///
+  ///     %one = llvm.constant(1 : i32) : !llvm.i32
+  ///     %shl = llvm.shl %one, %width : !llvm.i32
+  ///     %active_mask = llvm.sub %shl, %one : !llvm.i32
+  ///     %mask_and_clamp = llvm.sub %width, %one : !llvm.i32
+  ///     %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
+  ///         %mask_and_clamp : !llvm<"{ float, i1 }">
+  ///     %shfl_value = llvm.extractvalue %shfl[0 : index] :
+  ///         !llvm<"{ float, i1 }">
+  ///     %shfl_pred = llvm.extractvalue %shfl[1 : index] :
+  ///         !llvm<"{ float, i1 }">
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    gpu::ShuffleOpOperandAdaptor adaptor(operands);
+
+    auto dialect = lowering.getDialect();
+    auto valueTy = adaptor.value()->getType().cast<LLVM::LLVMType>();
+    auto int32Type = LLVM::LLVMType::getInt32Ty(dialect);
+    auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
+    auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy});
+
+    Value *one = rewriter.create<LLVM::ConstantOp>(
+        loc, int32Type, rewriter.getI32IntegerAttr(1));
+    // Bit mask of active lanes: `(1 << activeWidth) - 1`.
+    Value *activeMask = rewriter.create<LLVM::SubOp>(
+        loc, int32Type,
+        rewriter.create<LLVM::ShlOp>(loc, int32Type, one, adaptor.width()),
+        one);
+    // Clamp lane: `activeWidth - 1`
+    Value *maskAndClamp =
+        rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.width(), one);
+
+    auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
+    Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
+        loc, resultTy, activeMask, adaptor.value(), adaptor.offset(),
+        maskAndClamp, returnValueAndIsValidAttr);
+    Value *shflValue = rewriter.create<LLVM::ExtractValueOp>(
+        loc, valueTy, shfl, rewriter.getIndexArrayAttr(0));
+    Value *isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
+        loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
+
+    rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
+    return matchSuccess();
+  }
+};
+
 struct GPUFuncOpLowering : LLVMOpLowering {
   explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter)
       : LLVMOpLowering(gpu::GPUFuncOp::getOperationName(),
@@ -688,8 +746,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
                                           NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
               GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
                                           NVVM::GridDimYOp, NVVM::GridDimZOp>,
-              GPUAllReduceOpLowering, GPUFuncOpLowering, GPUReturnOpLowering>(
-          converter);
+              GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering,
+              GPUReturnOpLowering>(converter);
   patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf",
                                                "__nv_exp");
 }
index 7324b96..9c0183e 100644 (file)
@@ -165,6 +165,47 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
   return success();
 }
 
+static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
+  auto type = shuffleOp.value()->getType();
+  if (shuffleOp.result()->getType() != type) {
+    return shuffleOp.emitOpError()
+           << "requires the same type for value operand and result";
+  }
+  if (!type.isIntOrFloat() || type.getIntOrFloatBitWidth() != 32) {
+    return shuffleOp.emitOpError()
+           << "requires value operand type to be f32 or i32";
+  }
+  return success();
+}
+
+static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) {
+  p << ShuffleOp::getOperationName() << ' ';
+  p.printOperands(op.getOperands());
+  p << ' ' << op.mode() << " : ";
+  p.printType(op.value()->getType());
+}
+
+static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) {
+  SmallVector<OpAsmParser::OperandType, 3> operandInfo;
+  if (parser.parseOperandList(operandInfo, 3))
+    return failure();
+
+  StringRef mode;
+  if (parser.parseKeyword(&mode))
+    return failure();
+  state.addAttribute("mode", parser.getBuilder().getStringAttr(mode));
+
+  Type valueType;
+  Type int32Type = parser.getBuilder().getIntegerType(32);
+  Type int1Type = parser.getBuilder().getI1Type();
+  if (parser.parseColonType(valueType) ||
+      parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type},
+                             parser.getCurrentLocation(), state.operands) ||
+      parser.addTypesToList({valueType, int1Type}, state.types))
+    return failure();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LaunchOp
 //===----------------------------------------------------------------------===//
index 525016b..b1820cb 100644 (file)
@@ -75,6 +75,31 @@ module attributes {gpu.kernel_module} {
 // -----
 
 module attributes {gpu.kernel_module} {
+  // CHECK-LABEL: func @gpu_shuffle()
+  func @gpu_shuffle()
+      attributes { gpu.kernel } {
+    // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+    %arg0 = constant 1.0 : f32
+    // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : !llvm.i32
+    %arg1 = constant 4 : i32
+    // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : !llvm.i32
+    %arg2 = constant 23 : i32
+    // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+    // CHECK: %[[#SHL:]] = llvm.shl %[[#ONE]], %[[#WIDTH]] : !llvm.i32
+    // CHECK: %[[#MASK:]] = llvm.sub %[[#SHL]], %[[#ONE]] : !llvm.i32
+    // CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : !llvm.i32
+    // CHECK: %[[#SHFL:]] = nvvm.shfl.sync.bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : !llvm<"{ float, i1 }">
+    // CHECK: llvm.extractvalue %[[#SHFL]][0 : index] : !llvm<"{ float, i1 }">
+    // CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm<"{ float, i1 }">
+    %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (f32, i1)
+
+    std.return
+  }
+}
+
+// -----
+
+module attributes {gpu.kernel_module} {
   // CHECK-LABEL: func @gpu_sync()
   func @gpu_sync()
       attributes { gpu.kernel } {
index f8ed1a9..8323fdf 100644 (file)
@@ -362,6 +362,20 @@ func @reduce_incorrect_yield(%arg0 : f32) {
 
 // -----
 
+func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
+  // expected-error@+1 {{'gpu.shuffle' op requires the same type for value operand and result}}
+  %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (i32, i1)
+}
+
+// -----
+
+func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
+  // expected-error@+1 {{'gpu.shuffle' op requires value operand type to be f32 or i32}}
+  %shfl, %pred = gpu.shuffle %arg0, %arg1, %arg2 xor : index
+}
+
+// -----
+
 module {
   module @gpu_funcs attributes {gpu.kernel_module} {
     // expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}
index ff5a40d..1dd08ce 100644 (file)
@@ -81,6 +81,11 @@ module attributes {gpu.container_module} {
       %one = constant 1.0 : f32
       %sum = "gpu.all_reduce"(%one) ({}) {op = "add"} : (f32) -> (f32)
 
+      %width = constant 7 : i32
+      %offset = constant 3 : i32
+      // CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} xor : f32
+      %shfl, %pred = gpu.shuffle %arg0, %offset, %width xor : f32
+
       "gpu.barrier"() : () -> ()
 
       "some_op"(%bIdX, %tIdX) : (index, index) -> ()
diff --git a/mlir/test/mlir-cuda-runner/shuffle.mlir b/mlir/test/mlir-cuda-runner/shuffle.mlir
new file mode 100644 (file)
index 0000000..1b01399
--- /dev/null
@@ -0,0 +1,32 @@
+// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+// CHECK: [4, 5, 6, 7, 0, 1, 2, 3, 12, -1, -1, -1, 8]
+func @main() {
+  %arg = alloc() : memref<13xf32>
+  %dst = memref_cast %arg : memref<13xf32> to memref<?xf32>
+  %one = constant 1 : index
+  %sx = dim %dst, 0 : memref<?xf32>
+  call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> ()
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
+             threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one)
+             args(%kernel_dst = %dst) : memref<?xf32> {
+    %t0 = index_cast %tx : index to i32
+    %val = sitofp %t0 : i32 to f32
+    %width = index_cast %block_x : index to i32
+    %offset = constant 4 : i32
+    %shfl, %valid = gpu.shuffle %val, %offset, %width xor : f32
+    cond_br %valid, ^bb1(%shfl : f32), ^bb0
+  ^bb0:
+    %m1 = constant -1.0 : f32
+    br ^bb1(%m1 : f32)
+  ^bb1(%value : f32):
+    store %value, %kernel_dst[%tx] : memref<?xf32>
+    gpu.return
+  }
+  %U = memref_cast %dst : memref<?xf32> to memref<*xf32>
+  call @print_memref_f32(%U) : (memref<*xf32>) -> ()
+  return
+}
+
+func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref<?xf32>)
+func @print_memref_f32(%ptr : memref<*xf32>)