[mlir][gpu] Add `subgroup_reduce` operation
authorIvan Butygin <ivan.butygin@gmail.com>
Wed, 5 Oct 2022 21:57:33 +0000 (23:57 +0200)
committerIvan Butygin <ivan.butygin@gmail.com>
Tue, 11 Oct 2022 09:47:15 +0000 (11:47 +0200)
Introduce `subgroup_reduce` operation, similar to `all_reduce`, but operating on subgroup scope instead of workgroup.
It is intended as low-level building block for more high level abstractions (e.g for workgroup-wide `all_reduce` ops).
Only introduce version taking reduce operation enum for simplicity sake.

Differential Revision: https://reviews.llvm.org/D135323

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/ops.mlir

index f3d1046..f1d894a 100644 (file)
@@ -717,6 +717,30 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
   let hasRegionVerifier = 1;
 }
 
+def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
+    [SameOperandsAndResultType]>,
+    Arguments<(ins AnyType:$value,
+               GPU_AllReduceOperationAttr:$op)>,
+    Results<(outs AnyType)> {
+  let summary = "Reduce values among subgroup.";
+  let description = [{
+    The `subgroup_reduce` op reduces the value of every work item across a
+    subgroup. The result is equal for all work items of a subgroup.
+
+    Example:
+
+    ```mlir
+    %1 = gpu.subgroup_reduce add %0 : (f32) -> (f32)
+    ```
+
+    Either none or all work items of a subgroup need to execute this op
+    in convergence.
+  }];
+  let assemblyFormat = [{ custom<AllReduceOperation>($op) $value attr-dict
+                          `:` functional-type(operands, results) }];
+  let hasVerifier = 1;
+}
+
 def GPU_ShuffleOpXor  : I32EnumAttrCase<"XOR",  0, "xor">;
 def GPU_ShuffleOpDown : I32EnumAttrCase<"DOWN", 1, "down">;
 def GPU_ShuffleOpUp   : I32EnumAttrCase<"UP",   2, "up">;
index 83ee9bb..bfdcedf 100644 (file)
@@ -309,6 +309,17 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
 // AllReduceOp
 //===----------------------------------------------------------------------===//
 
+static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
+                                  Type resType) {
+  if ((opName == gpu::AllReduceOperation::AND ||
+       opName == gpu::AllReduceOperation::OR ||
+       opName == gpu::AllReduceOperation::XOR) &&
+      !resType.isa<IntegerType>())
+    return false;
+
+  return true;
+}
+
 LogicalResult gpu::AllReduceOp::verifyRegions() {
   if (getBody().empty() != getOp().has_value())
     return emitError("expected either an op attribute or a non-empty body");
@@ -333,10 +344,7 @@ LogicalResult gpu::AllReduceOp::verifyRegions() {
       return emitError("expected gpu.yield op in region");
   } else {
     gpu::AllReduceOperation opName = *getOp();
-    if ((opName == gpu::AllReduceOperation::AND ||
-         opName == gpu::AllReduceOperation::OR ||
-         opName == gpu::AllReduceOperation::XOR) &&
-        !getType().isa<IntegerType>()) {
+    if (!verifyReduceOpAndType(opName, getType())) {
       return emitError()
              << '`' << gpu::stringifyAllReduceOperation(opName)
              << "` accumulator is only compatible with Integer type";
@@ -365,6 +373,19 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
 }
 
 //===----------------------------------------------------------------------===//
+// SubgroupReduceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult gpu::SubgroupReduceOp::verify() {
+  gpu::AllReduceOperation opName = getOp();
+  if (!verifyReduceOpAndType(opName, getType())) {
+    return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
+                       << "` accumulator is only compatible with Integer type";
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // AsyncOpInterface
 //===----------------------------------------------------------------------===//
 
index f3c8123..b029d2f 100644 (file)
@@ -245,6 +245,14 @@ func.func @reduce_invalid_op_type(%arg0 : f32) {
 
 // -----
 
+func.func @subgroup_reduce_invalid_op_type(%arg0 : f32) {
+  // expected-error@+1 {{`and` accumulator is only compatible with Integer type}}
+  %res = gpu.subgroup_reduce and %arg0 : (f32) -> (f32)
+  return
+}
+
+// -----
+
 func.func @reduce_incorrect_region_arguments(%arg0 : f32) {
   // expected-error@+1 {{expected two region arguments}}
   %res = gpu.all_reduce %arg0 {
index 5232074..9b31a32 100644 (file)
@@ -85,6 +85,9 @@ module attributes {gpu.container_module} {
       %one = arith.constant 1.0 : f32
       %sum = gpu.all_reduce add %one {} : (f32) -> (f32)
 
+      // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (f32) -> f32
+      %sum_subgroup = gpu.subgroup_reduce add %one : (f32) -> f32
+
       %width = arith.constant 7 : i32
       %offset = arith.constant 3 : i32
       // CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : f32