[mlir][vector] Fix distribution of scf.for with value coming from above
authorThomas Raoux <thomasraoux@google.com>
Tue, 1 Nov 2022 06:25:47 +0000 (06:25 +0000)
committerThomas Raoux <thomasraoux@google.com>
Wed, 2 Nov 2022 04:15:18 +0000 (04:15 +0000)
When a value used in the forOp is defined outside the region but within
the parent warpOp we need to return and distribute the value to pass it
to new operations created within the loop.
Also simplify the lambda interface.

Differential Revision: https://reviews.llvm.org/D137146

mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

index 204b322..49e3427 100644 (file)
@@ -40,7 +40,7 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
     const WarpExecuteOnLane0LoweringOptions &options,
     PatternBenefit benefit = 1);
 
-using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
+using DistributionMapFn = std::function<AffineMap(Value)>;
 
 /// Distribute transfer_write ops based on the affine map returned by
 /// `distributionMapFn`.
@@ -67,9 +67,12 @@ void populateDistributeTransferWriteOpPatterns(
 /// region.
 void moveScalarUniformCode(WarpExecuteOnLane0Op op);
 
-/// Collect patterns to propagate warp distribution.
+/// Collect patterns to propagate warp distribution. `distributionMapFn` is used
+/// to decide how a value should be distributed when this cannot be inferred
+/// from its uses.
 void populatePropagateWarpVectorDistributionPatterns(
-    RewritePatternSet &pattern, PatternBenefit benefit = 1);
+    RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
+    PatternBenefit benefit = 1);
 
 /// Lambda signature to compute a reduction of a distributed value for the given
 /// reduction kind and size.
index f730044..6dfdf76 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Transforms/SideEffectUtils.h"
 #include "llvm/ADT/SetVector.h"
 #include <utility>
@@ -421,6 +422,31 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
   return newWriteOp;
 }
 
+/// Return the distributed vector type based on the original type and the
+/// distribution map. The map is expected to have a dimension equal to the
+/// original type rank and should be a projection where the results are the
+/// distributed dimensions. The number of results should be equal to the number
+/// of warp sizes which is currently limited to 1.
+/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
+/// and a warp size of 16 would distribute the second dimension (associated to
+/// d1) and return vector<16x2x64>
+static VectorType getDistributedType(VectorType originalType, AffineMap map,
+                                     int64_t warpSize) {
+  if (map.getNumResults() != 1)
+    return VectorType();
+  SmallVector<int64_t> targetShape(originalType.getShape().begin(),
+                                   originalType.getShape().end());
+  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+    unsigned position = map.getDimPosition(i);
+    if (targetShape[position] % warpSize != 0)
+      return VectorType();
+    targetShape[position] = targetShape[position] / warpSize;
+  }
+  VectorType targetType =
+      VectorType::get(targetShape, originalType.getElementType());
+  return targetType;
+}
+
 /// Distribute transfer_write ops based on the affine map returned by
 /// `distributionMapFn`.
 /// Example:
@@ -456,29 +482,19 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     if (writtenVectorType.getRank() == 0)
       return failure();
 
-    // 2. Compute the distribution map.
-    AffineMap map = distributionMapFn(writeOp);
-    if (map.getNumResults() != 1)
-      return writeOp->emitError("multi-dim distribution not implemented yet");
-
-    // 3. Compute the targetType using the distribution map.
-    SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
-                                     writtenVectorType.getShape().end());
-    for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
-      unsigned position = map.getDimPosition(i);
-      if (targetShape[position] % warpOp.getWarpSize() != 0)
-        return failure();
-      targetShape[position] = targetShape[position] / warpOp.getWarpSize();
-    }
+    // 2. Compute the distributed type.
+    AffineMap map = distributionMapFn(writeOp.getVector());
     VectorType targetType =
-        VectorType::get(targetShape, writtenVectorType.getElementType());
+        getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
+    if (!targetType)
+      return failure();
 
-    // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
+    // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
     // the rest.
     vector::TransferWriteOp newWriteOp =
         cloneWriteOp(rewriter, warpOp, writeOp, targetType);
 
-    // 5. Reindex the write using the distribution map.
+    // 4. Reindex the write using the distribution map.
     auto newWarpOp =
         newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
     rewriter.setInsertionPoint(newWriteOp);
@@ -494,7 +510,8 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
         continue;
       unsigned indexPos = indexExpr.getPosition();
       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
-      auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
+      auto scale =
+          rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
       indices[indexPos] =
           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
                                   {indices[indexPos], newWarpOp.getLaneid()});
@@ -956,6 +973,10 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
 ///  }
 /// ```
 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
+
+  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
+        distributionMapFn(std::move(fn)) {}
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -966,6 +987,35 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
     if (!forOp)
       return failure();
+    // Collect Values that come from the warp op but are outside the forOp.
+    // Those Value needs to be returned by the original warpOp and passed to the
+    // new op.
+    llvm::SmallSetVector<Value, 32> escapingValues;
+    SmallVector<Type> inputTypes;
+    SmallVector<Type> distTypes;
+    mlir::visitUsedValuesDefinedAbove(
+        forOp.getBodyRegion(), [&](OpOperand *operand) {
+          Operation *parent = operand->get().getParentRegion()->getParentOp();
+          if (warpOp->isAncestor(parent)) {
+            if (!escapingValues.insert(operand->get()))
+              return;
+            Type distType = operand->get().getType();
+            if (auto vecType = distType.cast<VectorType>()) {
+              AffineMap map = distributionMapFn(operand->get());
+              distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+            }
+            inputTypes.push_back(operand->get().getType());
+            distTypes.push_back(distType);
+          }
+        });
+
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
+        newRetIndices);
+    yield = cast<vector::YieldOp>(
+        newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+
     SmallVector<Value> newOperands;
     SmallVector<unsigned> resultIdx;
     // Collect all the outputs coming from the forOp.
@@ -973,28 +1023,42 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
         continue;
       auto forResult = yieldOperand.get().cast<OpResult>();
-      newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
+      newOperands.push_back(
+          newWarpOp.getResult(yieldOperand.getOperandNumber()));
       yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
       resultIdx.push_back(yieldOperand.getOperandNumber());
     }
+
     OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointAfter(warpOp);
+    rewriter.setInsertionPointAfter(newWarpOp);
+
     // Create a new for op outside the region with a WarpExecuteOnLane0Op region
     // inside.
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), newOperands);
     rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
+
+    SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
+                                 newForOp.getRegionIterArgs().end());
+    SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
+                                    forOp.getResultTypes().end());
+    llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
+    for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
+      warpInput.push_back(newWarpOp.getResult(retIdx));
+      argIndexMapping[escapingValues[i]] = warpInputType.size();
+      warpInputType.push_back(inputTypes[i]);
+    }
     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
-        warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
-        warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
-        forOp.getResultTypes());
+        newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
+        newWarpOp.getWarpSize(), warpInput, warpInputType);
 
     SmallVector<Value> argMapping;
     argMapping.push_back(newForOp.getInductionVar());
     for (Value args : innerWarp.getBody()->getArguments()) {
       argMapping.push_back(args);
     }
+    argMapping.resize(forOp.getBody()->getNumArguments());
     SmallVector<Value> yieldOperands;
     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
       yieldOperands.push_back(operand);
@@ -1008,12 +1072,23 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
     rewriter.eraseOp(forOp);
     // Replace the warpOp result coming from the original ForOp.
     for (const auto &res : llvm::enumerate(resultIdx)) {
-      warpOp.getResult(res.value())
+      newWarpOp.getResult(res.value())
           .replaceAllUsesWith(newForOp.getResult(res.index()));
-      newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
+      newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
     }
+    newForOp.walk([&](Operation *op) {
+      for (OpOperand &operand : op->getOpOperands()) {
+        auto it = argIndexMapping.find(operand.get());
+        if (it == argIndexMapping.end())
+          continue;
+        operand.set(innerWarp.getBodyRegion().getArgument(it->second));
+      }
+    });
     return success();
   }
+
+private:
+  DistributionMapFn distributionMapFn;
 };
 
 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
@@ -1119,11 +1194,14 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
 }
 
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
+    RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
+    PatternBenefit benefit) {
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
                WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement,
-               WarpOpForwardOperand, WarpOpScfForOp, WarpOpConstant>(
-      patterns.getContext(), benefit);
+               WarpOpForwardOperand, WarpOpConstant>(patterns.getContext(),
+                                                     benefit);
+  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
+                               benefit);
 }
 
 void mlir::vector::populateDistributeReduction(
index 49c36fe..daebccd 100644 (file)
@@ -349,6 +349,40 @@ func.func @warp_scf_for(%arg0: index) {
 
 // -----
 
+// CHECK-PROP-LABEL:   func @warp_scf_for_use_from_above(
+// CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP:   %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP:   %[[USE:.*]] = "some_def_above"() : () -> vector<128xf32>
+// CHECK-PROP:   vector.yield %[[INI1]], %[[USE]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[F:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG:.*]] = %[[INI]]#0) -> (vector<4xf32>) {
+// CHECK-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG]], %[[INI]]#1 : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP:    ^bb0(%[[ARG0:.*]]: vector<128xf32>, %[[ARG1:.*]]: vector<128xf32>):
+// CHECK-PROP:      %[[ACC:.*]] = "some_def"(%[[ARG0]], %[[ARG1]]) : (vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP:      vector.yield %[[ACC]] : vector<128xf32>
+// CHECK-PROP:   }
+// CHECK-PROP:   scf.yield %[[W]] : vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[F]]) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_use_from_above(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %use_from_above = "some_def_above"() : () -> (vector<128xf32>)
+    %3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+      %acc = "some_def"(%arg4, %use_from_above) : (vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc : vector<128xf32>
+    }
+    vector.yield %3 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
+// -----
+
 // CHECK-PROP-LABEL:   func @warp_scf_for_swap(
 // CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
 // CHECK-PROP:   %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
index 5547a96..b66b2fe 100644 (file)
@@ -746,24 +746,26 @@ struct TestVectorDistribution
       }
     });
     MLIRContext *ctx = &getContext();
+    auto distributionFn = [](Value val) {
+      // Create a map (d0, d1) -> (d1) to distribute along the inner
+      // dimension. Once we support n-d distribution we can add more
+      // complex cases.
+      VectorType vecType = val.getType().dyn_cast<VectorType>();
+      int64_t vecRank = vecType ? vecType.getRank() : 0;
+      OpBuilder builder(val.getContext());
+      if (vecRank == 0)
+        return AffineMap::get(val.getContext());
+      return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+    };
     if (distributeTransferWriteOps) {
-      auto distributionFn = [](vector::TransferWriteOp writeOp) {
-        // Create a map (d0, d1) -> (d1) to distribute along the inner
-        // dimension. Once we support n-d distribution we can add more
-        // complex cases.
-        int64_t vecRank = writeOp.getVectorType().getRank();
-        OpBuilder builder(writeOp.getContext());
-        auto map =
-            AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
-        return map;
-      };
       RewritePatternSet patterns(ctx);
       populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     }
     if (propagateDistribution) {
       RewritePatternSet patterns(ctx);
-      vector::populatePropagateWarpVectorDistributionPatterns(patterns);
+      vector::populatePropagateWarpVectorDistributionPatterns(patterns,
+                                                              distributionFn);
       vector::populateDistributeReduction(patterns, warpReduction);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     }