//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Transforms/SideEffectUtils.h"
using namespace mlir;
using namespace mlir::vector;
if (resultType == val.getType()) {
// Result type and yielded value type are the same. This is a broadcast.
// E.g.:
- // %r = vector_ext.warp_execute_on_lane_0(...) -> (f32) {
- // vector_ext.yield %cst : f32
+ // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
+ // vector.yield %cst : f32
// }
// Both types are f32. The constant %cst is broadcasted to all lanes.
// This is described in more detail in the documentation of the op.
return success();
}
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+ RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes) {
+ // Create a new op before the existing one, with the extra operands.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(warpOp);
+ auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+ warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+ warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+ Region &opBody = warpOp.getBodyRegion();
+ Region &newOpBody = newWarpOp.getBodyRegion();
+ rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+ auto yield =
+ cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+ rewriter.updateRootInPlace(
+ yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
+ return newWarpOp;
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+ RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+ ValueRange newYieldedValues, TypeRange newReturnTypes) {
+ SmallVector<Type> types(warpOp.getResultTypes().begin(),
+ warpOp.getResultTypes().end());
+ types.append(newReturnTypes.begin(), newReturnTypes.end());
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ SmallVector<Value> yieldValues(yield.getOperands().begin(),
+ yield.getOperands().end());
+ yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+ rewriter, warpOp, yieldValues, types);
+ rewriter.replaceOp(warpOp,
+ newWarpOp.getResults().take_front(warpOp.getNumResults()));
+ return newWarpOp;
+}
+
+/// Helper to know if an op can be hoisted out of the region.
+static bool canBeHoisted(Operation *op,
+ function_ref<bool(Value)> definedOutside) {
+ return llvm::all_of(op->getOperands(), definedOutside) &&
+ isSideEffectFree(op) && op->getNumRegions() == 0;
+}
+
namespace {
struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
const WarpExecuteOnLane0LoweringOptions &options;
};
+/// Distribute transfer_write ops based on the affine map returned by
+/// `distributionMapFn`.
+/// Example:
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%id){
+/// ...
+/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
+/// vector.yield
+/// }
+/// ```
+/// To
+/// ```
+/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
+/// ...
+/// vector.yield %v : vector<32xf32>
+/// }
+/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
+struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
+ WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
+ PatternBenefit b = 1)
+ : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
+ distributionMapFn(fn) {}
+
+ /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
+ /// are multiples of the distribution ratio are supported at the moment.
+ LogicalResult tryDistributeOp(RewriterBase &rewriter,
+ vector::TransferWriteOp writeOp,
+ WarpExecuteOnLane0Op warpOp) const {
+ AffineMap map = distributionMapFn(writeOp);
+ SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
+ writeOp.getVectorType().getShape().end());
+ assert(map.getNumResults() == 1 &&
+ "multi-dim distribution not implemented yet");
+ for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+ unsigned position = map.getDimPosition(i);
+ if (targetShape[position] % warpOp.getWarpSize() != 0)
+ return failure();
+ targetShape[position] = targetShape[position] / warpOp.getWarpSize();
+ }
+ VectorType targetType =
+ VectorType::get(targetShape, writeOp.getVectorType().getElementType());
+
+ SmallVector<Value> yieldValues = {writeOp.getVector()};
+ SmallVector<Type> retTypes = {targetType};
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, yieldValues, retTypes);
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ // Move op outside of region: Insert clone at the insertion point and delete
+ // the old op.
+ auto newWriteOp =
+ cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
+ rewriter.eraseOp(writeOp);
+
+ rewriter.setInsertionPoint(newWriteOp);
+ AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
+ Location loc = newWriteOp.getLoc();
+ SmallVector<Value> indices(newWriteOp.getIndices().begin(),
+ newWriteOp.getIndices().end());
+ for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
+ AffineExpr d0, d1;
+ bindDims(newWarpOp.getContext(), d0, d1);
+ auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+ if (!indexExpr)
+ continue;
+ unsigned indexPos = indexExpr.getPosition();
+ unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
+ auto scale =
+ getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
+ indices[indexPos] =
+ makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
+ {indices[indexPos], newWarpOp.getLaneid()});
+ }
+ newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
+ newWriteOp.getIndicesMutable().assign(indices);
+
+ return success();
+ }
+
+ /// Extract TransferWriteOps of vector<1x> into a separate warp op.
+ LogicalResult tryExtractOp(RewriterBase &rewriter,
+ vector::TransferWriteOp writeOp,
+ WarpExecuteOnLane0Op warpOp) const {
+ Location loc = writeOp.getLoc();
+ VectorType vecType = writeOp.getVectorType();
+
+ // Only vector<1x> is supported at the moment.
+ if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1)
+ return failure();
+
+ // Do not process warp ops that contain only TransferWriteOps.
+ if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
+ return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
+ }))
+ return failure();
+
+ SmallVector<Value> yieldValues = {writeOp.getVector()};
+ SmallVector<Type> retTypes = {vecType};
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, yieldValues, retTypes);
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ // Create a second warp op that contains only writeOp.
+ auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+ loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
+ Block &body = secondWarpOp.getBodyRegion().front();
+ rewriter.setInsertionPointToStart(&body);
+ auto newWriteOp =
+ cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
+ newWriteOp.getVectorMutable().assign(
+ newWarpOp.getResult(newWarpOp.getNumResults() - 1));
+ rewriter.eraseOp(writeOp);
+ rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
+ return success();
+ }
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ // Ops with mask not supported yet.
+ if (writeOp.getMask())
+ return failure();
+
+ auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
+ if (!warpOp)
+ return failure();
+
+ // There must be no op with a side effect after writeOp.
+ Operation *nextOp = writeOp.getOperation();
+ while ((nextOp = nextOp->getNextNode()))
+ if (!isSideEffectFree(nextOp))
+ return failure();
+
+ if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
+ return writeOp.getVector() == value ||
+ warpOp.isDefinedOutsideOfRegion(value);
+ }))
+ return failure();
+
+ if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
+ return success();
+
+ if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
+ return success();
+
+ return failure();
+ }
+
+private:
+ DistributionMapFn distributionMapFn;
+};
+
} // namespace
void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
const WarpExecuteOnLane0LoweringOptions &options) {
patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
}
+
+void mlir::vector::populateDistributeTransferWriteOpPatterns(
+ RewritePatternSet &patterns, DistributionMapFn distributionMapFn) {
+ patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
+}
+
+void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
+ Block *body = warpOp.getBody();
+
+ // Keep track of the ops we want to hoist.
+ llvm::SmallSetVector<Operation *, 8> opsToMove;
+
+ // Helper to check if a value is or will be defined outside of the region.
+ auto isDefinedOutsideOfBody = [&](Value value) {
+ auto *definingOp = value.getDefiningOp();
+ return (definingOp && opsToMove.count(definingOp)) ||
+ warpOp.isDefinedOutsideOfRegion(value);
+ };
+
+ // Do not use walk here, as we do not want to go into nested regions and hoist
+ // operations from there.
+ for (auto &op : body->without_terminator()) {
+ bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
+ return result.getType().isa<VectorType>();
+ });
+ if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
+ opsToMove.insert(&op);
+ }
+
+ // Move all the ops marked as uniform outside of the region.
+ for (Operation *op : opsToMove)
+ op->moveBefore(warpOp);
+}
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
// CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3>
// CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3>
"some_use"(%r#1) : (vector<2xf32>) -> ()
return
}
+
+// -----
+
+// CHECK-D-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 * 2 + 32)>
+
+// CHECK-ALL-LABEL: func @warp(
+// CHECK-HOIST: memref.subview
+// CHECK-HOIST: memref.subview
+// CHECK-HOIST: memref.subview
+// CHECK-HOIST: vector.warp_execute_on_lane_0
+
+// CHECK-D: %[[R:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>, vector<1xf32>) {
+// CHECK-D: arith.addf {{.*}} : vector<32xf32>
+// CHECK-D: arith.addf {{.*}} : vector<64xf32>
+// CHECK-D: vector.yield %{{.*}}, %{{.*}} : vector<64xf32>, vector<32xf32>
+// CHECK-D-DAG: vector.transfer_write %[[R]]#1, %{{.*}}[%{{.*}}] {in_bounds = [true]} : vector<1xf32>, memref<128xf32
+// CHECK-D-DAG: %[[ID1:.*]] = affine.apply #[[MAP1]]()[%{{.*}}]
+// CHECK-D-DAG: vector.transfer_write %[[R]]#0, %2[%[[ID1]]] {in_bounds = [true]} : vector<2xf32>, memref<128xf32
+
+// CHECK-ALL-NOT: vector.warp_execute_on_lane_0
+// CHECK-ALL: vector.transfer_read {{.*}} vector<1xf32>
+// CHECK-ALL: vector.transfer_read {{.*}} vector<1xf32>
+// CHECK-ALL: vector.transfer_read {{.*}} vector<2xf32>
+// CHECK-ALL: vector.transfer_read {{.*}} vector<2xf32>
+// CHECK-ALL: arith.addf {{.*}} : vector<1xf32>
+// CHECK-ALL: arith.addf {{.*}} : vector<2xf32>
+// CHECK-ALL: vector.transfer_write {{.*}} : vector<1xf32>
+// CHECK-ALL: vector.transfer_write {{.*}} : vector<2xf32>
+
+#map0 = affine_map<(d0)[s0] -> (d0 + s0)>
+func.func @warp(%laneid: index, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>,
+ %arg3: memref<1024xf32>, %gid : index) {
+ vector.warp_execute_on_lane_0(%laneid)[32] {
+ %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
+ %sb = memref.subview %arg2[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
+ %sc = memref.subview %arg3[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %2 = vector.transfer_read %sa[%c0], %cst : memref<128xf32, #map0>, vector<32xf32>
+ %3 = vector.transfer_read %sa[%c32], %cst : memref<128xf32, #map0>, vector<32xf32>
+ %4 = vector.transfer_read %sb[%c0], %cst : memref<128xf32, #map0>, vector<64xf32>
+ %5 = vector.transfer_read %sb[%c32], %cst : memref<128xf32, #map0>, vector<64xf32>
+ %6 = arith.addf %2, %3 : vector<32xf32>
+ %7 = arith.addf %4, %5 : vector<64xf32>
+ vector.transfer_write %6, %sc[%c0] : vector<32xf32>, memref<128xf32, #map0>
+ vector.transfer_write %7, %sc[%c32] : vector<64xf32>, memref<128xf32, #map0>
+ }
+ return
+}
+
+// -----
+
+// CHECK-D-LABEL: func @warp_extract(
+// CHECK-D: %[[WARPOP:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
+// CHECK-D: "test.dummy_op"
+// CHECK-D: vector.yield %{{.*}} : vector<1xf32>
+// CHECK-D: }
+// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
+// CHECK-D: vector.transfer_write %[[WARPOP]], %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32>
+// CHECK-D: }
+
+#map2 = affine_map<(d0)[s0] -> (d0 + s0)>
+
+func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) {
+ vector.warp_execute_on_lane_0(%laneid)[32] {
+ %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map2>
+ %c0 = arith.constant 0 : index
+ %v = "test.dummy_op"() : () -> (vector<1xf32>)
+ vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2>
+ }
+ return
+}
\ No newline at end of file
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect>();
+ registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
+ AffineDialect>();
}
StringRef getArgument() const final { return "test-vector-warp-distribute"; }
llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
llvm::cl::init(false)};
+ Option<bool> distributeTransferWriteOps{
+ *this, "distribute-transfer-write",
+ llvm::cl::desc("Test distribution of transfer write"),
+ llvm::cl::init(false)};
+
+ Option<bool> hoistUniform{*this, "hoist-uniform",
+ llvm::cl::desc("Test hoist uniform"),
+ llvm::cl::init(false)};
+
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
+
+ getOperation().walk([&](Operation *op) {
+ if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
+ if (hoistUniform) {
+ moveScalarUniformCode(warpOp);
+ }
+ WalkResult::interrupt();
+ }
+ });
+ MLIRContext *ctx = &getContext();
+ if (distributeTransferWriteOps) {
+ auto distributionFn = [](vector::TransferWriteOp writeOp) {
+ // Create a map (d0, d1) -> (d1) to distribute along the inner
+ // dimension. Once we support n-d distribution we can add more
+ // complex cases.
+ int64_t vecRank = writeOp.getVectorType().getRank();
+ OpBuilder builder(writeOp.getContext());
+ auto map =
+ AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+ return map;
+ };
+ RewritePatternSet patterns(ctx);
+ populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+
WarpExecuteOnLane0LoweringOptions options;
options.warpAllocationFn = allocateGlobalSharedMemory;
options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,