From 91f62f0e352a4f5c755f1cbec6f27e40a60ff109 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Tue, 1 Nov 2022 06:25:47 +0000 Subject: [PATCH] [mlir][vector] Fix distribution of scf.for with value coming from above When a value used in the forOp is defined outside the region but within the parent warpOp we need to return and distribute the value to pass it to new operations created within the loop. Also simplify the lambda interface. Differential Revision: https://reviews.llvm.org/D137146 --- .../Dialect/Vector/Transforms/VectorDistribution.h | 9 +- .../Dialect/Vector/Transforms/VectorDistribute.cpp | 134 ++++++++++++++++----- .../Dialect/Vector/vector-warp-distribute.mlir | 34 ++++++ .../lib/Dialect/Vector/TestVectorTransforms.cpp | 24 ++-- 4 files changed, 159 insertions(+), 42 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h index 204b322..49e3427 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h @@ -40,7 +40,7 @@ void populateWarpExecuteOnLane0OpToScfForPattern( const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit = 1); -using DistributionMapFn = std::function; +using DistributionMapFn = std::function; /// Distribute transfer_write ops based on the affine map returned by /// `distributionMapFn`. @@ -67,9 +67,12 @@ void populateDistributeTransferWriteOpPatterns( /// region. void moveScalarUniformCode(WarpExecuteOnLane0Op op); -/// Collect patterns to propagate warp distribution. +/// Collect patterns to propagate warp distribution. `distributionMapFn` is used +/// to decide how a value should be distributed when this cannot be inferred +/// from its uses. void populatePropagateWarpVectorDistributionPatterns( - RewritePatternSet &pattern, PatternBenefit benefit = 1); + RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn, + PatternBenefit benefit = 1); /// Lambda signature to compute a reduction of a distributed value for the given /// reduction kind and size. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index f730044..6dfdf766 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/Transforms/RegionUtils.h" #include "mlir/Transforms/SideEffectUtils.h" #include "llvm/ADT/SetVector.h" #include @@ -421,6 +422,31 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, return newWriteOp; } +/// Return the distributed vector type based on the original type and the +/// distribution map. The map is expected to have a dimension equal to the +/// original type rank and should be a projection where the results are the +/// distributed dimensions. The number of results should be equal to the number +/// of warp sizes which is currently limited to 1. +/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) +/// and a warp size of 16 would distribute the second dimension (associated to +/// d1) and return vector<16x2x64> +static VectorType getDistributedType(VectorType originalType, AffineMap map, + int64_t warpSize) { + if (map.getNumResults() != 1) + return VectorType(); + SmallVector targetShape(originalType.getShape().begin(), + originalType.getShape().end()); + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + unsigned position = map.getDimPosition(i); + if (targetShape[position] % warpSize != 0) + return VectorType(); + targetShape[position] = targetShape[position] / warpSize; + } + VectorType targetType = + VectorType::get(targetShape, originalType.getElementType()); + return targetType; +} + /// Distribute transfer_write ops based on the affine map returned by /// `distributionMapFn`. /// Example: @@ -456,29 +482,19 @@ struct WarpOpTransferWrite : public OpRewritePattern { if (writtenVectorType.getRank() == 0) return failure(); - // 2. Compute the distribution map. - AffineMap map = distributionMapFn(writeOp); - if (map.getNumResults() != 1) - return writeOp->emitError("multi-dim distribution not implemented yet"); - - // 3. Compute the targetType using the distribution map. - SmallVector targetShape(writtenVectorType.getShape().begin(), - writtenVectorType.getShape().end()); - for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - unsigned position = map.getDimPosition(i); - if (targetShape[position] % warpOp.getWarpSize() != 0) - return failure(); - targetShape[position] = targetShape[position] / warpOp.getWarpSize(); - } + // 2. Compute the distributed type. + AffineMap map = distributionMapFn(writeOp.getVector()); VectorType targetType = - VectorType::get(targetShape, writtenVectorType.getElementType()); + getDistributedType(writtenVectorType, map, warpOp.getWarpSize()); + if (!targetType) + return failure(); - // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from + // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from // the rest. vector::TransferWriteOp newWriteOp = cloneWriteOp(rewriter, warpOp, writeOp, targetType); - // 5. Reindex the write using the distribution map. + // 4. Reindex the write using the distribution map. auto newWarpOp = newWriteOp.getVector().getDefiningOp(); rewriter.setInsertionPoint(newWriteOp); @@ -494,7 +510,8 @@ struct WarpOpTransferWrite : public OpRewritePattern { continue; unsigned indexPos = indexExpr.getPosition(); unsigned vectorPos = std::get<1>(it).cast().getPosition(); - auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]); + auto scale = + rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos)); indices[indexPos] = makeComposedAffineApply(rewriter, loc, d0 + scale * d1, {indices[indexPos], newWarpOp.getLaneid()}); @@ -956,6 +973,10 @@ struct WarpOpExtractElement : public OpRewritePattern { /// } /// ``` struct WarpOpScfForOp : public OpRewritePattern { + + WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : OpRewritePattern(ctx, b), + distributionMapFn(std::move(fn)) {} using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { @@ -966,6 +987,35 @@ struct WarpOpScfForOp : public OpRewritePattern { auto forOp = dyn_cast_or_null(lastNode); if (!forOp) return failure(); + // Collect Values that come from the warp op but are outside the forOp. + // Those Value needs to be returned by the original warpOp and passed to the + // new op. + llvm::SmallSetVector escapingValues; + SmallVector inputTypes; + SmallVector distTypes; + mlir::visitUsedValuesDefinedAbove( + forOp.getBodyRegion(), [&](OpOperand *operand) { + Operation *parent = operand->get().getParentRegion()->getParentOp(); + if (warpOp->isAncestor(parent)) { + if (!escapingValues.insert(operand->get())) + return; + Type distType = operand->get().getType(); + if (auto vecType = distType.cast()) { + AffineMap map = distributionMapFn(operand->get()); + distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + } + inputTypes.push_back(operand->get().getType()); + distTypes.push_back(distType); + } + }); + + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, escapingValues.getArrayRef(), distTypes, + newRetIndices); + yield = cast( + newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + SmallVector newOperands; SmallVector resultIdx; // Collect all the outputs coming from the forOp. @@ -973,28 +1023,42 @@ struct WarpOpScfForOp : public OpRewritePattern { if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) continue; auto forResult = yieldOperand.get().cast(); - newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber())); + newOperands.push_back( + newWarpOp.getResult(yieldOperand.getOperandNumber())); yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); resultIdx.push_back(yieldOperand.getOperandNumber()); } + OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(warpOp); + rewriter.setInsertionPointAfter(newWarpOp); + // Create a new for op outside the region with a WarpExecuteOnLane0Op region // inside. auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newOperands); rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + + SmallVector warpInput(newForOp.getRegionIterArgs().begin(), + newForOp.getRegionIterArgs().end()); + SmallVector warpInputType(forOp.getResultTypes().begin(), + forOp.getResultTypes().end()); + llvm::SmallDenseMap argIndexMapping; + for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { + warpInput.push_back(newWarpOp.getResult(retIdx)); + argIndexMapping[escapingValues[i]] = warpInputType.size(); + warpInputType.push_back(inputTypes[i]); + } auto innerWarp = rewriter.create( - warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(), - warpOp.getWarpSize(), newForOp.getRegionIterArgs(), - forOp.getResultTypes()); + newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), + newWarpOp.getWarpSize(), warpInput, warpInputType); SmallVector argMapping; argMapping.push_back(newForOp.getInductionVar()); for (Value args : innerWarp.getBody()->getArguments()) { argMapping.push_back(args); } + argMapping.resize(forOp.getBody()->getNumArguments()); SmallVector yieldOperands; for (Value operand : forOp.getBody()->getTerminator()->getOperands()) yieldOperands.push_back(operand); @@ -1008,12 +1072,23 @@ struct WarpOpScfForOp : public OpRewritePattern { rewriter.eraseOp(forOp); // Replace the warpOp result coming from the original ForOp. for (const auto &res : llvm::enumerate(resultIdx)) { - warpOp.getResult(res.value()) + newWarpOp.getResult(res.value()) .replaceAllUsesWith(newForOp.getResult(res.index())); - newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value())); + newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); } + newForOp.walk([&](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) { + auto it = argIndexMapping.find(operand.get()); + if (it == argIndexMapping.end()) + continue; + operand.set(innerWarp.getBodyRegion().getArgument(it->second)); + } + }); return success(); } + +private: + DistributionMapFn distributionMapFn; }; /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. @@ -1119,11 +1194,14 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns( } void mlir::vector::populatePropagateWarpVectorDistributionPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, + PatternBenefit benefit) { patterns.add( - patterns.getContext(), benefit); + WarpOpForwardOperand, WarpOpConstant>(patterns.getContext(), + benefit); + patterns.add(patterns.getContext(), distributionMapFn, + benefit); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 49c36fe..daebccd 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -349,6 +349,40 @@ func.func @warp_scf_for(%arg0: index) { // ----- +// CHECK-PROP-LABEL: func @warp_scf_for_use_from_above( +// CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) { +// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP: %[[USE:.*]] = "some_def_above"() : () -> vector<128xf32> +// CHECK-PROP: vector.yield %[[INI1]], %[[USE]] : vector<128xf32>, vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[F:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG:.*]] = %[[INI]]#0) -> (vector<4xf32>) { +// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG]], %[[INI]]#1 : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>) { +// CHECK-PROP: ^bb0(%[[ARG0:.*]]: vector<128xf32>, %[[ARG1:.*]]: vector<128xf32>): +// CHECK-PROP: %[[ACC:.*]] = "some_def"(%[[ARG0]], %[[ARG1]]) : (vector<128xf32>, vector<128xf32>) -> vector<128xf32> +// CHECK-PROP: vector.yield %[[ACC]] : vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[W]] : vector<4xf32> +// CHECK-PROP: } +// CHECK-PROP: "some_use"(%[[F]]) : (vector<4xf32>) -> () +func.func @warp_scf_for_use_from_above(%arg0: index) { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) { + %ini = "some_def"() : () -> (vector<128xf32>) + %use_from_above = "some_def_above"() : () -> (vector<128xf32>) + %3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) { + %acc = "some_def"(%arg4, %use_from_above) : (vector<128xf32>, vector<128xf32>) -> (vector<128xf32>) + scf.yield %acc : vector<128xf32> + } + vector.yield %3 : vector<128xf32> + } + "some_use"(%0) : (vector<4xf32>) -> () + return +} + +// ----- + // CHECK-PROP-LABEL: func @warp_scf_for_swap( // CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) { // CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 5547a96..b66b2fe 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -746,24 +746,26 @@ struct TestVectorDistribution } }); MLIRContext *ctx = &getContext(); + auto distributionFn = [](Value val) { + // Create a map (d0, d1) -> (d1) to distribute along the inner + // dimension. Once we support n-d distribution we can add more + // complex cases. + VectorType vecType = val.getType().dyn_cast(); + int64_t vecRank = vecType ? vecType.getRank() : 0; + OpBuilder builder(val.getContext()); + if (vecRank == 0) + return AffineMap::get(val.getContext()); + return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); + }; if (distributeTransferWriteOps) { - auto distributionFn = [](vector::TransferWriteOp writeOp) { - // Create a map (d0, d1) -> (d1) to distribute along the inner - // dimension. Once we support n-d distribution we can add more - // complex cases. - int64_t vecRank = writeOp.getVectorType().getRank(); - OpBuilder builder(writeOp.getContext()); - auto map = - AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); - return map; - }; RewritePatternSet patterns(ctx); populateDistributeTransferWriteOpPatterns(patterns, distributionFn); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } if (propagateDistribution) { RewritePatternSet patterns(ctx); - vector::populatePropagateWarpVectorDistributionPatterns(patterns); + vector::populatePropagateWarpVectorDistributionPatterns(patterns, + distributionFn); vector::populateDistributeReduction(patterns, warpReduction); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } -- 2.7.4