def GPU_AllReduceOp : GPU_Op<"all_reduce",
[SameOperandsAndResultType, IsolatedFromAbove]>,
Arguments<(ins AnyType:$value,
- OptionalAttr<GPU_AllReduceOperationAttr>:$op)>,
+ OptionalAttr<GPU_AllReduceOperationAttr>:$op,
+ UnitAttr:$uniform)>,
Results<(outs AnyType)> {
let summary = "Reduce values among workgroup.";
let description = [{
accumulation as code region. The accumulation operation must be one of:
`add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
- Either none or all work items of a workgroup need to execute this op
- in convergence.
+ If `uniform` flag is set either none or all work items of a workgroup
+ need to execute this op in convergence.
}];
let regions = (region AnyRegion:$body);
- let assemblyFormat = [{ custom<AllReduceOperation>($op) $value $body attr-dict
+ let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
+ (`uniform` $uniform^)? $body attr-dict
`:` functional-type(operands, results) }];
let hasRegionVerifier = 1;
}
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
[SameOperandsAndResultType]>,
Arguments<(ins AnyType:$value,
- GPU_AllReduceOperationAttr:$op)>,
+ GPU_AllReduceOperationAttr:$op,
+ UnitAttr:$uniform)>,
Results<(outs AnyType)> {
let summary = "Reduce values among subgroup.";
let description = [{
%1 = gpu.subgroup_reduce add %0 : (f32) -> (f32)
```
- Either none or all work items of a subgroup need to execute this op
- in convergence.
+ If `uniform` flag is set either none or all work items of a subgroup
+ need to execute this op in convergence.
}];
- let assemblyFormat = [{ custom<AllReduceOperation>($op) $value attr-dict
+ let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
+ (`uniform` $uniform^)? attr-dict
`:` functional-type(operands, results) }];
let hasVerifier = 1;
}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto funcOp = cast<gpu::GPUFuncOp>(op);
- auto callback = [&](gpu::AllReduceOp reduceOp) {
- GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
- // Performing a rewrite invalidates the walk iterator. Report interrupt
- // so that we can start a new walk until all all_reduce ops are replaced.
- return WalkResult::interrupt();
+
+ SmallVector<gpu::AllReduceOp> reduceOps;
+ auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
+ if (!reduceOp.getUniform())
+ return WalkResult::interrupt();
+
+ reduceOps.emplace_back(reduceOp);
+ return WalkResult::advance();
};
- while (funcOp.walk(callback).wasInterrupted()) {
- }
+
+ if (funcOp.walk(callback).wasInterrupted())
+ return rewriter.notifyMatchFailure(
+ op, "Non uniform reductions are not supported yet.");
+
+ for (gpu::AllReduceOp reduceOp : reduceOps)
+ GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
+
return success();
}
};
// CHECK: nvvm.shfl.sync bfly {{.*}}
// CHECK: nvvm.barrier0
// CHECK: llvm.fadd
- %result = gpu.all_reduce add %arg0 {} : (f32) -> (f32)
+ %result = gpu.all_reduce add %arg0 uniform {} : (f32) -> (f32)
gpu.return
}
// TODO: Check full IR expansion once lowering has settled.
// CHECK: nvvm.shfl.sync bfly {{.*}}
// CHECK: nvvm.barrier0
- %result = gpu.all_reduce %arg0 {
+ %result = gpu.all_reduce %arg0 uniform {
^bb(%lhs : i32, %rhs : i32):
%xor = arith.xori %lhs, %rhs : i32
"gpu.yield"(%xor) : (i32) -> ()
// CHECK: cf.br ^bb42
// CHECK: ^bb42:
// CHECK: gpu.barrier
- %sum = gpu.all_reduce max %arg0 {} : (f32) -> (f32)
+ %sum = gpu.all_reduce max %arg0 uniform {} : (f32) -> (f32)
gpu.return
}
// CHECK: cf.br ^bb42
// CHECK: ^bb42:
// CHECK: gpu.barrier
- %sum = gpu.all_reduce add %arg0 {} : (f32) -> (f32)
+ %sum = gpu.all_reduce add %arg0 uniform {} : (f32) -> (f32)
gpu.return
}
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
%val = memref.load %data[%bx, %tx] : memref<2x6xf32>
- %reduced0 = gpu.all_reduce add %val {} : (f32) -> (f32)
+ %reduced0 = gpu.all_reduce add %val uniform {} : (f32) -> (f32)
memref.store %reduced0, %sum[%bx] : memref<2xf32>
- %reduced1 = gpu.all_reduce mul %val {} : (f32) -> (f32)
+ %reduced1 = gpu.all_reduce mul %val uniform {} : (f32) -> (f32)
memref.store %reduced1, %mul[%bx] : memref<2xf32>
gpu.terminator
}
%SgSi = gpu.subgroup_size : index
%one = arith.constant 1.0 : f32
+
+ // CHECK: %{{.*}} = gpu.all_reduce add %{{.*}} {
+ // CHECK-NEXT: } : (f32) -> f32
%sum = gpu.all_reduce add %one {} : (f32) -> (f32)
+ // CHECK: %{{.*}} = gpu.all_reduce add %{{.*}} uniform {
+ // CHECK-NEXT: } : (f32) -> f32
+ %sum1 = gpu.all_reduce add %one uniform {} : (f32) -> f32
+
// CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (f32) -> f32
%sum_subgroup = gpu.subgroup_reduce add %one : (f32) -> f32
+ // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} uniform : (f32) -> f32
+ %sum_subgroup1 = gpu.subgroup_reduce add %one uniform : (f32) -> f32
+
%width = arith.constant 7 : i32
%offset = arith.constant 3 : i32
// CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : f32