[mlir][gpu] Add `uniform` flag to gpu reduction ops
authorIvan Butygin <ivan.butygin@gmail.com>
Sun, 27 Nov 2022 14:19:56 +0000 (15:19 +0100)
committerIvan Butygin <ivan.butygin@gmail.com>
Wed, 14 Dec 2022 12:15:58 +0000 (13:15 +0100)
Differential Revision: https://reviews.llvm.org/D138758

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
mlir/test/Dialect/GPU/all-reduce-max.mlir
mlir/test/Dialect/GPU/all-reduce.mlir
mlir/test/Dialect/GPU/multiple-all-reduce.mlir
mlir/test/Dialect/GPU/ops.mlir

index f9fff78..baf9540 100644 (file)
@@ -688,7 +688,8 @@ def GPU_AllReduceOperationAttr : EnumAttr<GPU_Dialect, GPU_AllReduceOperation,
 def GPU_AllReduceOp : GPU_Op<"all_reduce",
     [SameOperandsAndResultType, IsolatedFromAbove]>,
     Arguments<(ins AnyType:$value,
-               OptionalAttr<GPU_AllReduceOperationAttr>:$op)>,
+               OptionalAttr<GPU_AllReduceOperationAttr>:$op,
+               UnitAttr:$uniform)>,
     Results<(outs AnyType)> {
   let summary = "Reduce values among workgroup.";
   let description = [{
@@ -711,11 +712,12 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
     accumulation as code region. The accumulation operation must be one of:
     `add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
 
-    Either none or all work items of a workgroup need to execute this op
-    in convergence.
+    If `uniform` flag is set either none or all work items of a workgroup
+    need to execute this op in convergence.
   }];
   let regions = (region AnyRegion:$body);
-  let assemblyFormat = [{ custom<AllReduceOperation>($op) $value $body attr-dict
+  let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
+                          (`uniform` $uniform^)? $body attr-dict
                           `:` functional-type(operands, results) }];
   let hasRegionVerifier = 1;
 }
@@ -723,7 +725,8 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
 def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
     [SameOperandsAndResultType]>,
     Arguments<(ins AnyType:$value,
-               GPU_AllReduceOperationAttr:$op)>,
+               GPU_AllReduceOperationAttr:$op,
+               UnitAttr:$uniform)>,
     Results<(outs AnyType)> {
   let summary = "Reduce values among subgroup.";
   let description = [{
@@ -736,10 +739,11 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
     %1 = gpu.subgroup_reduce add %0 : (f32) -> (f32)
     ```
 
-    Either none or all work items of a subgroup need to execute this op
-    in convergence.
+    If `uniform` flag is set either none or all work items of a subgroup
+    need to execute this op in convergence.
   }];
-  let assemblyFormat = [{ custom<AllReduceOperation>($op) $value attr-dict
+  let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
+                          (`uniform` $uniform^)? attr-dict
                           `:` functional-type(operands, results) }];
   let hasVerifier = 1;
 }
index 32bf6e3..87bcf33 100644 (file)
@@ -394,14 +394,23 @@ struct GpuAllReduceConversion : public RewritePattern {
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     auto funcOp = cast<gpu::GPUFuncOp>(op);
-    auto callback = [&](gpu::AllReduceOp reduceOp) {
-      GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
-      // Performing a rewrite invalidates the walk iterator. Report interrupt
-      // so that we can start a new walk until all all_reduce ops are replaced.
-      return WalkResult::interrupt();
+
+    SmallVector<gpu::AllReduceOp> reduceOps;
+    auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
+      if (!reduceOp.getUniform())
+        return WalkResult::interrupt();
+
+      reduceOps.emplace_back(reduceOp);
+      return WalkResult::advance();
     };
-    while (funcOp.walk(callback).wasInterrupted()) {
-    }
+
+    if (funcOp.walk(callback).wasInterrupted())
+      return rewriter.notifyMatchFailure(
+          op, "Non uniform reductions are not supported yet.");
+
+    for (gpu::AllReduceOp reduceOp : reduceOps)
+      GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
+
     return success();
   }
 };
index 4014a19..fc25972 100644 (file)
@@ -89,7 +89,7 @@ gpu.module @test_module {
     // CHECK: nvvm.shfl.sync bfly {{.*}}
     // CHECK: nvvm.barrier0
     // CHECK: llvm.fadd
-    %result = gpu.all_reduce add %arg0 {} : (f32) -> (f32)
+    %result = gpu.all_reduce add %arg0 uniform {} : (f32) -> (f32)
 
     gpu.return
   }
@@ -104,7 +104,7 @@ gpu.module @test_module {
     // TODO: Check full IR expansion once lowering has settled.
     // CHECK: nvvm.shfl.sync bfly {{.*}}
     // CHECK: nvvm.barrier0
-    %result = gpu.all_reduce %arg0 {
+    %result = gpu.all_reduce %arg0 uniform {
     ^bb(%lhs : i32, %rhs : i32):
       %xor = arith.xori %lhs, %rhs : i32
       "gpu.yield"(%xor) : (i32) -> ()
index a1dcdb4..d39b961 100644 (file)
@@ -195,7 +195,7 @@ gpu.module @kernels {
     // CHECK:   cf.br ^bb42
     // CHECK: ^bb42:
     // CHECK:   gpu.barrier
-    %sum = gpu.all_reduce max %arg0 {} : (f32) -> (f32)
+    %sum = gpu.all_reduce max %arg0 uniform {} : (f32) -> (f32)
     gpu.return
   }
 
index 4d8654d..67d8335 100644 (file)
@@ -175,7 +175,7 @@ gpu.module @kernels {
     // CHECK:   cf.br ^bb42
     // CHECK: ^bb42:
     // CHECK:   gpu.barrier
-    %sum = gpu.all_reduce add %arg0 {} : (f32) -> (f32)
+    %sum = gpu.all_reduce add %arg0 uniform {} : (f32) -> (f32)
     gpu.return
   }
 
index 9b8d1c9..4153bfb 100644 (file)
@@ -10,9 +10,9 @@ func.func @main() {
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
              threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
     %val = memref.load %data[%bx, %tx] : memref<2x6xf32>
-    %reduced0 = gpu.all_reduce add %val {} : (f32) -> (f32)
+    %reduced0 = gpu.all_reduce add %val uniform {} : (f32) -> (f32)
     memref.store %reduced0, %sum[%bx] : memref<2xf32>
-    %reduced1 = gpu.all_reduce mul %val {} : (f32) -> (f32)
+    %reduced1 = gpu.all_reduce mul %val uniform {} : (f32) -> (f32)
     memref.store %reduced1, %mul[%bx] : memref<2xf32>
     gpu.terminator
   }
index b68a109..301ab91 100644 (file)
@@ -83,11 +83,21 @@ module attributes {gpu.container_module} {
       %SgSi = gpu.subgroup_size : index
 
       %one = arith.constant 1.0 : f32
+
+      // CHECK: %{{.*}} = gpu.all_reduce add %{{.*}} {
+      // CHECK-NEXT: } : (f32) -> f32
       %sum = gpu.all_reduce add %one {} : (f32) -> (f32)
 
+      // CHECK: %{{.*}} = gpu.all_reduce add %{{.*}} uniform {
+      // CHECK-NEXT: } : (f32) -> f32
+      %sum1 = gpu.all_reduce add %one uniform {} : (f32) -> f32
+
       // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (f32) -> f32
       %sum_subgroup = gpu.subgroup_reduce add %one : (f32) -> f32
 
+      // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} uniform : (f32) -> f32
+      %sum_subgroup1 = gpu.subgroup_reduce add %one uniform : (f32) -> f32
+
       %width = arith.constant 7 : i32
       %offset = arith.constant 3 : i32
       // CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : f32