[mlir][vector] Add patterns for vector distribution
authorThomas Raoux <thomasraoux@google.com>
Mon, 6 Jun 2022 20:56:50 +0000 (20:56 +0000)
committerThomas Raoux <thomasraoux@google.com>
Fri, 10 Jun 2022 17:46:51 +0000 (17:46 +0000)
Add pattern to hoist scalar code outside of warp distribute region as
those cannot be distributed and we would want to execute them on all
the lanes.
Add patterns to distribute transfer_write ops. Those operations can be
distributed in different ways and it is control by user.

Differential Revision: https://reviews.llvm.org/D127152

mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

index 06ca024..b95b527 100644 (file)
@@ -39,6 +39,32 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
     RewritePatternSet &patterns,
     const WarpExecuteOnLane0LoweringOptions &options);
 
+using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
+
+/// Distribute transfer_write ops based on the affine map returned by
+/// `distributionMapFn`.
+/// Example:
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%id){
+///   ...
+///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
+///   vector.yield
+/// }
+/// ```
+/// To
+/// ```
+/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
+///   ...
+///   vector.yield %v : vector<32xf32>
+/// }
+/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
+void populateDistributeTransferWriteOpPatterns(
+    RewritePatternSet &patterns, DistributionMapFn distributionMapFn);
+
+/// Move scalar operations with no dependency on the warp op outside of the
+/// region.
+void moveScalarUniformCode(WarpExecuteOnLane0Op op);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
index 933d572..586604f 100644 (file)
@@ -6,10 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
+#include "mlir/Transforms/SideEffectUtils.h"
 
 using namespace mlir;
 using namespace mlir::vector;
@@ -93,8 +95,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
     if (resultType == val.getType()) {
       // Result type and yielded value type are the same. This is a broadcast.
       // E.g.:
-      // %r = vector_ext.warp_execute_on_lane_0(...) -> (f32) {
-      //   vector_ext.yield %cst : f32
+      // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
+      //   vector.yield %cst : f32
       // }
       // Both types are f32. The constant %cst is broadcasted to all lanes.
       // This is described in more detail in the documentation of the op.
@@ -131,6 +133,54 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
   return success();
 }
 
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  // Create a new op before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(warpOp);
+  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+  Region &opBody = warpOp.getBodyRegion();
+  Region &newOpBody = newWarpOp.getBodyRegion();
+  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  auto yield =
+      cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+  rewriter.updateRootInPlace(
+      yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
+  return newWarpOp;
+}
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  SmallVector<Type> types(warpOp.getResultTypes().begin(),
+                          warpOp.getResultTypes().end());
+  types.append(newReturnTypes.begin(), newReturnTypes.end());
+  auto yield = cast<vector::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  SmallVector<Value> yieldValues(yield.getOperands().begin(),
+                                 yield.getOperands().end());
+  yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
+  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+      rewriter, warpOp, yieldValues, types);
+  rewriter.replaceOp(warpOp,
+                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
+  return newWarpOp;
+}
+
+/// Helper to know if an op can be hoisted out of the region.
+static bool canBeHoisted(Operation *op,
+                         function_ref<bool(Value)> definedOutside) {
+  return llvm::all_of(op->getOperands(), definedOutside) &&
+         isSideEffectFree(op) && op->getNumRegions() == 0;
+}
+
 namespace {
 
 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -149,6 +199,157 @@ private:
   const WarpExecuteOnLane0LoweringOptions &options;
 };
 
+/// Distribute transfer_write ops based on the affine map returned by
+/// `distributionMapFn`.
+/// Example:
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%id){
+///   ...
+///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
+///   vector.yield
+/// }
+/// ```
+/// To
+/// ```
+/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
+///   ...
+///   vector.yield %v : vector<32xf32>
+/// }
+/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
+struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
+  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
+                      PatternBenefit b = 1)
+      : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
+        distributionMapFn(fn) {}
+
+  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
+  /// are multiples of the distribution ratio are supported at the moment.
+  LogicalResult tryDistributeOp(RewriterBase &rewriter,
+                                vector::TransferWriteOp writeOp,
+                                WarpExecuteOnLane0Op warpOp) const {
+    AffineMap map = distributionMapFn(writeOp);
+    SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
+                                     writeOp.getVectorType().getShape().end());
+    assert(map.getNumResults() == 1 &&
+           "multi-dim distribution not implemented yet");
+    for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+      unsigned position = map.getDimPosition(i);
+      if (targetShape[position] % warpOp.getWarpSize() != 0)
+        return failure();
+      targetShape[position] = targetShape[position] / warpOp.getWarpSize();
+    }
+    VectorType targetType =
+        VectorType::get(targetShape, writeOp.getVectorType().getElementType());
+
+    SmallVector<Value> yieldValues = {writeOp.getVector()};
+    SmallVector<Type> retTypes = {targetType};
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, yieldValues, retTypes);
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    // Move op outside of region: Insert clone at the insertion point and delete
+    // the old op.
+    auto newWriteOp =
+        cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
+    rewriter.eraseOp(writeOp);
+
+    rewriter.setInsertionPoint(newWriteOp);
+    AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
+    Location loc = newWriteOp.getLoc();
+    SmallVector<Value> indices(newWriteOp.getIndices().begin(),
+                               newWriteOp.getIndices().end());
+    for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
+      AffineExpr d0, d1;
+      bindDims(newWarpOp.getContext(), d0, d1);
+      auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+      if (!indexExpr)
+        continue;
+      unsigned indexPos = indexExpr.getPosition();
+      unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
+      auto scale =
+          getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
+      indices[indexPos] =
+          makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
+                                  {indices[indexPos], newWarpOp.getLaneid()});
+    }
+    newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
+    newWriteOp.getIndicesMutable().assign(indices);
+
+    return success();
+  }
+
+  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
+  LogicalResult tryExtractOp(RewriterBase &rewriter,
+                             vector::TransferWriteOp writeOp,
+                             WarpExecuteOnLane0Op warpOp) const {
+    Location loc = writeOp.getLoc();
+    VectorType vecType = writeOp.getVectorType();
+
+    // Only vector<1x> is supported at the moment.
+    if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1)
+      return failure();
+
+    // Do not process warp ops that contain only TransferWriteOps.
+    if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
+          return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
+        }))
+      return failure();
+
+    SmallVector<Value> yieldValues = {writeOp.getVector()};
+    SmallVector<Type> retTypes = {vecType};
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, yieldValues, retTypes);
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    // Create a second warp op that contains only writeOp.
+    auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+        loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
+    Block &body = secondWarpOp.getBodyRegion().front();
+    rewriter.setInsertionPointToStart(&body);
+    auto newWriteOp =
+        cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
+    newWriteOp.getVectorMutable().assign(
+        newWarpOp.getResult(newWarpOp.getNumResults() - 1));
+    rewriter.eraseOp(writeOp);
+    rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
+    return success();
+  }
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    // Ops with mask not supported yet.
+    if (writeOp.getMask())
+      return failure();
+
+    auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
+    if (!warpOp)
+      return failure();
+
+    // There must be no op with a side effect after writeOp.
+    Operation *nextOp = writeOp.getOperation();
+    while ((nextOp = nextOp->getNextNode()))
+      if (!isSideEffectFree(nextOp))
+        return failure();
+
+    if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
+          return writeOp.getVector() == value ||
+                 warpOp.isDefinedOutsideOfRegion(value);
+        }))
+      return failure();
+
+    if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
+      return success();
+
+    if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
+      return success();
+
+    return failure();
+  }
+
+private:
+  DistributionMapFn distributionMapFn;
+};
+
 } // namespace
 
 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
@@ -156,3 +357,36 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
     const WarpExecuteOnLane0LoweringOptions &options) {
   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
 }
+
+void mlir::vector::populateDistributeTransferWriteOpPatterns(
+    RewritePatternSet &patterns, DistributionMapFn distributionMapFn) {
+  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
+}
+
+void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
+  Block *body = warpOp.getBody();
+
+  // Keep track of the ops we want to hoist.
+  llvm::SmallSetVector<Operation *, 8> opsToMove;
+
+  // Helper to check if a value is or will be defined outside of the region.
+  auto isDefinedOutsideOfBody = [&](Value value) {
+    auto *definingOp = value.getDefiningOp();
+    return (definingOp && opsToMove.count(definingOp)) ||
+           warpOp.isDefinedOutsideOfRegion(value);
+  };
+
+  // Do not use walk here, as we do not want to go into nested regions and hoist
+  // operations from there.
+  for (auto &op : body->without_terminator()) {
+    bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
+      return result.getType().isa<VectorType>();
+    });
+    if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
+      opsToMove.insert(&op);
+  }
+
+  // Move all the ops marked as uniform outside of the region.
+  for (Operation *op : opsToMove)
+    op->moveBefore(warpOp);
+}
index cba8f05..dc4dfee 100644 (file)
@@ -1,4 +1,6 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
 
 // CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3>
 // CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3>
@@ -52,3 +54,76 @@ func.func @rewrite_warp_op_to_scf_if(%laneid: index,
   "some_use"(%r#1) : (vector<2xf32>) -> ()
   return
 }
+
+// -----
+
+// CHECK-D-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 * 2 + 32)>
+
+// CHECK-ALL-LABEL: func @warp(
+// CHECK-HOIST: memref.subview
+// CHECK-HOIST: memref.subview
+// CHECK-HOIST: memref.subview
+// CHECK-HOIST: vector.warp_execute_on_lane_0
+
+//     CHECK-D: %[[R:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>, vector<1xf32>) {
+//     CHECK-D:   arith.addf {{.*}} : vector<32xf32>
+//     CHECK-D:   arith.addf {{.*}} : vector<64xf32>
+//     CHECK-D:   vector.yield %{{.*}}, %{{.*}} : vector<64xf32>, vector<32xf32>
+// CHECK-D-DAG: vector.transfer_write %[[R]]#1, %{{.*}}[%{{.*}}] {in_bounds = [true]} : vector<1xf32>, memref<128xf32
+// CHECK-D-DAG: %[[ID1:.*]] = affine.apply #[[MAP1]]()[%{{.*}}]
+// CHECK-D-DAG: vector.transfer_write %[[R]]#0, %2[%[[ID1]]] {in_bounds = [true]} : vector<2xf32>, memref<128xf32
+
+// CHECK-ALL-NOT: vector.warp_execute_on_lane_0
+// CHECK-ALL: vector.transfer_read {{.*}} vector<1xf32>
+// CHECK-ALL: vector.transfer_read {{.*}} vector<1xf32>
+// CHECK-ALL: vector.transfer_read {{.*}} vector<2xf32>
+// CHECK-ALL: vector.transfer_read {{.*}} vector<2xf32>
+// CHECK-ALL: arith.addf {{.*}} : vector<1xf32>
+// CHECK-ALL: arith.addf {{.*}} : vector<2xf32>
+// CHECK-ALL: vector.transfer_write {{.*}} : vector<1xf32>
+// CHECK-ALL: vector.transfer_write {{.*}} : vector<2xf32>
+
+#map0 =  affine_map<(d0)[s0] -> (d0 + s0)>
+func.func @warp(%laneid: index, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>,
+           %arg3: memref<1024xf32>, %gid : index) {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
+    %sb = memref.subview %arg2[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
+    %sc = memref.subview %arg3[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %2 = vector.transfer_read %sa[%c0], %cst : memref<128xf32, #map0>, vector<32xf32>
+    %3 = vector.transfer_read %sa[%c32], %cst : memref<128xf32, #map0>, vector<32xf32>
+    %4 = vector.transfer_read %sb[%c0], %cst : memref<128xf32, #map0>, vector<64xf32>
+    %5 = vector.transfer_read %sb[%c32], %cst : memref<128xf32, #map0>, vector<64xf32>
+    %6 = arith.addf %2, %3 : vector<32xf32>
+    %7 = arith.addf %4, %5 : vector<64xf32>
+    vector.transfer_write %6, %sc[%c0] : vector<32xf32>, memref<128xf32, #map0>
+    vector.transfer_write %7, %sc[%c32] : vector<64xf32>, memref<128xf32, #map0>
+  }
+  return
+}
+
+// -----
+
+// CHECK-D-LABEL: func @warp_extract(
+//       CHECK-D:   %[[WARPOP:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
+//       CHECK-D:     "test.dummy_op"
+//       CHECK-D:     vector.yield %{{.*}} : vector<1xf32>
+//       CHECK-D:   }
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     vector.transfer_write %[[WARPOP]], %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32>
+//       CHECK-D:   }
+
+#map2 =  affine_map<(d0)[s0] -> (d0 + s0)>
+
+func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map2>
+    %c0 = arith.constant 0 : index
+    %v = "test.dummy_op"() : () -> (vector<1xf32>)
+    vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2>
+  }
+  return
+}
\ No newline at end of file
index f216ba8..e1ffddc 100644 (file)
@@ -809,7 +809,8 @@ struct TestVectorDistribution
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect>();
+    registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
+                    AffineDialect>();
   }
 
   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
@@ -825,8 +826,43 @@ struct TestVectorDistribution
       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
       llvm::cl::init(false)};
 
+  Option<bool> distributeTransferWriteOps{
+      *this, "distribute-transfer-write",
+      llvm::cl::desc("Test distribution of transfer write"),
+      llvm::cl::init(false)};
+
+  Option<bool> hoistUniform{*this, "hoist-uniform",
+                            llvm::cl::desc("Test hoist uniform"),
+                            llvm::cl::init(false)};
+
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
+
+    getOperation().walk([&](Operation *op) {
+      if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
+        if (hoistUniform) {
+          moveScalarUniformCode(warpOp);
+        }
+        WalkResult::interrupt();
+      }
+    });
+    MLIRContext *ctx = &getContext();
+    if (distributeTransferWriteOps) {
+      auto distributionFn = [](vector::TransferWriteOp writeOp) {
+        // Create a map (d0, d1) -> (d1) to distribute along the inner
+        // dimension. Once we support n-d distribution we can add more
+        // complex cases.
+        int64_t vecRank = writeOp.getVectorType().getRank();
+        OpBuilder builder(writeOp.getContext());
+        auto map =
+            AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+        return map;
+      };
+      RewritePatternSet patterns(ctx);
+      populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
+      (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    }
+
     WarpExecuteOnLane0LoweringOptions options;
     options.warpAllocationFn = allocateGlobalSharedMemory;
     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,