let printer = [{ p << getOperationName(); }];
}
-def gpu_AllReduce : GPU_Op<"all_reduce", [SameOperandsAndResultType]>,
- Arguments<(ins AnyType)>, Results<(outs AnyType)> {
+def gpu_Yield : GPU_Op<"yield", [Terminator]>,
+ Arguments<(ins Variadic<AnyType>:$values)> {
+ let summary = "GPU yield operation";
+ let description = [{
+ "gpu.yield" is a special terminator operation for blocks inside regions
+ in gpu ops. It returns values to the immediately enclosing gpu op.
+
+ Example:
+
+ gpu.yield %f0, %f1 : f32, f32
+ }];
+}
+
+
+def gpu_AllReduce : GPU_Op<"all_reduce",
+ [SameOperandsAndResultType, IsolatedFromAbove]>,
+ Arguments<(ins AnyType:$value, OptionalAttr<StrAttr>:$op)>,
+ Results<(outs AnyType)> {
let summary = "Reduce values among workgroup.";
let description = [{
The "all_reduce" op reduces the value of every invocation across a local
- workgroup.
+ workgroup. The result is equal for all invocations of a local workgroup.
- For example,
+ For example, both
```
- %1 = gpu.all_reduce %0 : f32
+ %1 = "gpu.all_reduce"(%0) ({}) { op = "add" } : (f32) -> (f32)
+ %2 = "gpu.all_reduce"(%0) ({
+ ^bb(%lhs : f32, %rhs : f32):
+ %sum = addf %lhs, %rhs : f32
+ "gpu.yield"(%sum) : (f32) -> ()
+ }) : (f32) -> (f32)
```
- computes the sum of each invocation's %0 value. The value of %1 is always
- equal for all invocations of a local workgroup.
+ compute the sum of each invocation's %0 value. The first version specifies
+ the accumulation as operation, whereas the second version specifies the
+ accumulation as code region. The accumulation operation must either be
+ `add` or `mul`.
Either none or all invocations of a local workgroup need to execute this op
in convergence.
}];
+ let regions = (region AnyRegion:$body);
+ let verifier = [{ return ::verifyAllReduce(*this); }];
}
#endif // GPU_OPS
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
/// Converts all_reduce op to LLVM/NVVM ops.
struct GPUAllReduceOpLowering : public LLVMOpLowering {
+ using AccumulatorFactory = std::function<Value *(
+ Location, Value *, Value *, ConversionPatternRewriter &)>;
+
explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_)
: LLVMOpLowering(gpu::AllReduce::getOperationName(),
lowering_.getDialect()->getContext(), lowering_),
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
- Value *result = createBlockReduce(op->getLoc(), operands.front(), rewriter);
+ Location loc = op->getLoc();
+ Value *operand = operands.front();
+
+ // TODO(csigg): Generalize to other types of accumulation.
+ assert(op->getOperand(0)->getType().isIntOrFloat());
+
+ // Create the reduction using an accumulator factory.
+ AccumulatorFactory factory = getFactory(cast<gpu::AllReduce>(op), operand);
+ assert(factory && "failed to create accumulator factory");
+ Value *result = createBlockReduce(loc, operand, factory, rewriter);
+
rewriter.replaceOp(op, {result});
return matchSuccess();
}
private:
+ /// Returns an accumulator factory using either the op attribute or the body
+ /// region.
+ AccumulatorFactory getFactory(gpu::AllReduce allReduce,
+ Value *operand) const {
+ if (!allReduce.body().empty()) {
+ return getFactory(allReduce.body());
+ }
+ if (allReduce.op()) {
+ auto type = operand->getType().cast<LLVM::LLVMType>();
+ return getFactory(*allReduce.op(), type.getUnderlyingType());
+ }
+ return AccumulatorFactory();
+ }
+
+ /// Returns an accumulator factory that clones the body. The body's entry
+ /// block is expected to have 2 arguments. The gpu.yield return the
+ /// accumulated value of the same type.
+ AccumulatorFactory getFactory(Region &body) const {
+ return AccumulatorFactory([&](Location loc, Value *lhs, Value *rhs,
+ ConversionPatternRewriter &rewriter) {
+ Block *block = rewriter.getInsertionBlock();
+ Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
+
+ // Insert accumulator body between split block.
+ BlockAndValueMapping mapping;
+ mapping.map(body.front().getArgument(0), lhs);
+ mapping.map(body.front().getArgument(1), rhs);
+ rewriter.cloneRegionBefore(body, *split->getParent(),
+ split->getIterator(), mapping);
+
+ // Add branch before inserted body, into body.
+ block = block->getNextNode();
+ rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>{},
+ llvm::makeArrayRef(block),
+ llvm::ArrayRef<Value *>());
+
+ // Replace all gpu.yield ops with branch out of body.
+ for (; block != split; block = block->getNextNode()) {
+ Operation *terminator = block->getTerminator();
+ if (!llvm::isa<gpu::Yield>(terminator))
+ continue;
+ rewriter.setInsertionPointToEnd(block);
+ rewriter.replaceOpWithNewOp<LLVM::BrOp>(
+ terminator, ArrayRef<Value *>{}, llvm::makeArrayRef(split),
+ llvm::makeArrayRef(terminator->getOperand(0)));
+ }
+
+ // Return accumulator result.
+ rewriter.setInsertionPointToStart(split);
+ return split->addArgument(lhs->getType());
+ });
+ }
+
+ /// Returns an accumulator factory that creates an op specified by opName.
+ AccumulatorFactory getFactory(StringRef opName, llvm::Type *type) const {
+ if (type->isVectorTy() || type->isArrayTy())
+ return getFactory(opName, type->getSequentialElementType());
+
+ bool isFloatingPoint = type->isFloatingPointTy();
+
+ if (opName == "add") {
+ return isFloatingPoint ? getFactory<LLVM::FAddOp>()
+ : getFactory<LLVM::AddOp>();
+ }
+ if (opName == "mul") {
+ return isFloatingPoint ? getFactory<LLVM::FMulOp>()
+ : getFactory<LLVM::MulOp>();
+ }
+
+ return AccumulatorFactory();
+ }
+
+ /// Returns an accumulator factory that creates an op of type T.
+ template <typename T> AccumulatorFactory getFactory() const {
+ return [](Location loc, Value *lhs, Value *rhs,
+ ConversionPatternRewriter &rewriter) {
+ return rewriter.create<T>(loc, lhs->getType(), lhs, rhs);
+ };
+ }
+
/// Creates an all_reduce across the block.
///
/// First reduce the elements within a warp. The first thread of each warp
/// return %result
///
Value *createBlockReduce(Location loc, Value *operand,
+ AccumulatorFactory &accumFactory,
ConversionPatternRewriter &rewriter) const {
auto type = operand->getType().cast<LLVM::LLVMType>();
Value *activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter);
// Reduce elements within each warp to produce the intermediate results.
- Value *warpReduce =
- createWarpReduce(loc, activeWidth, laneId, operand, rewriter);
+ Value *warpReduce = createWarpReduce(loc, activeWidth, laneId, operand,
+ accumFactory, rewriter);
// Write the intermediate results to shared memory, using the first lane of
// each warp.
Value *loadSrc = rewriter.create<LLVM::GEPOp>(
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, threadIdx}));
Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
- Value *result = createWarpReduce(loc, numWarps, laneId, value, rewriter);
+ Value *result = createWarpReduce(loc, numWarps, laneId, value,
+ accumFactory, rewriter);
rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
});
rewriter.create<NVVM::Barrier0Op>(loc);
/// Creates a reduction across the first activeWidth lanes of a warp.
/// The first lane returns the result, all others return values are undefined.
Value *createWarpReduce(Location loc, Value *activeWidth, Value *laneId,
- Value *operand,
+ Value *operand, AccumulatorFactory accumFactory,
ConversionPatternRewriter &rewriter) const {
- // TODO(csigg): Generalize to other types of accumulation.
Value *warpSize = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
Value *maskAndClamp = rewriter.create<LLVM::ConstantOp>(
loc, rewriter, isActiveSrcLane,
[&] {
return llvm::SmallVector<Value *, 1>{
- rewriter.create<LLVM::FAddOp>(loc, type, value, shfl)};
+ accumFactory(loc, value, shfl, rewriter)};
},
[&] { return llvm::makeArrayRef(value); });
value = rewriter.getInsertionBlock()->getArgument(0);
loc, int32Type, rewriter.getI32IntegerAttr(i));
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
loc, type, activeMask, value, offset, maskAndClamp);
- value = rewriter.create<LLVM::FAddOp>(loc, type, value, shfl);
+ value = accumFactory(loc, value, shfl, rewriter);
}
return llvm::SmallVector<Value *, 1>{value};
});
target.addIllegalDialect<gpu::GPUDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalDialect<NVVM::NVVMDialect>();
+ // TODO(csigg): Remove once we support replacing non-root ops.
+ target.addLegalOp<gpu::Yield>();
if (failed(applyPartialConversion(m, target, patterns, &converter)))
signalPassFailure();
}
return
}
+
+// -----
+
+func @reduce_no_op_no_body(%arg0 : f32) {
+ // expected-error@+1 {{expected either an op attribute or a non-empty body}}
+ %res = "gpu.all_reduce"(%arg0) ({}) : (f32) -> (f32)
+ return
+}
+
+// -----
+
+func @reduce_op_and_body(%arg0 : f32) {
+ // expected-error@+1 {{expected either an op attribute or a non-empty body}}
+ %res = "gpu.all_reduce"(%arg0) ({
+ ^bb(%lhs : f32, %rhs : f32):
+ "gpu.yield"(%lhs) : (f32) -> ()
+ }) {op = "add"} : (f32) -> (f32)
+}
+
+// -----
+
+func @reduce_invalid_op(%arg0 : f32) {
+ // expected-error@+1 {{op "foo" is invalid}}
+ %res = "gpu.all_reduce"(%arg0) ({}) {op = "foo"} : (f32) -> (f32)
+ return
+}
+
+// -----
+
+func @reduce_incorrect_region_arguments(%arg0 : f32) {
+ // expected-error@+1 {{expected two region arguments}}
+ %res = "gpu.all_reduce"(%arg0) ({
+ ^bb(%lhs : f32):
+ "gpu.yield"(%lhs) : (f32) -> ()
+ }) : (f32) -> (f32)
+}
+
+// -----
+
+func @reduce_incorrect_region_arguments(%arg0 : f32) {
+ // expected-error@+1 {{incorrect region argument type}}
+ %res = "gpu.all_reduce"(%arg0) ({
+ ^bb(%lhs : f32, %rhs : i32):
+ "gpu.yield"(%lhs) : (f32) -> ()
+ }) : (f32) -> (f32)
+}
+
+// -----
+
+func @reduce_incorrect_yield(%arg0 : f32) {
+ // expected-error@+1 {{expected one gpu.yield operand}}
+ %res = "gpu.all_reduce"(%arg0) ({
+ ^bb(%lhs : f32, %rhs : f32):
+ "gpu.yield"(%lhs, %rhs) : (f32, f32) -> ()
+ }) : (f32) -> (f32)
+}
+
+// -----
+
+func @reduce_incorrect_yield(%arg0 : f32) {
+ // expected-error@+1 {{incorrect gpu.yield type}}
+ %res = "gpu.all_reduce"(%arg0) ({
+ ^bb(%lhs : f32, %rhs : f32):
+ %one = constant 1 : i32
+ "gpu.yield"(%one) : (i32) -> ()
+ }) : (f32) -> (f32)
+}
+
+// -----
+
+func @reduce_incorrect_yield(%arg0 : f32) {
+ // expected-error@+1 {{expected gpu.yield op in region}}
+ %res = "gpu.all_reduce"(%arg0) ({
+ ^bb(%lhs : f32, %rhs : f32):
+ return
+ }) : (f32) -> (f32)
+}
+