[mlir][Linalg] Add support for tileFuseAndDistribute on tensors.
authorHanhan Wang <hanchung@google.com>
Fri, 25 Feb 2022 18:52:08 +0000 (10:52 -0800)
committerHanhan Wang <hanchung@google.com>
Fri, 25 Feb 2022 19:51:11 +0000 (11:51 -0800)
This extends TileAndFuse to handle distribution on tensors.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D120441

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

index a8cd374..50e6191 100644 (file)
@@ -638,8 +638,20 @@ struct LinalgPaddingOptions {
 struct LinalgTilingAndFusionOptions {
   /// Tile sizes used to tile the root operation.
   SmallVector<int64_t> tileSizes;
+  LinalgTilingAndFusionOptions &setTileSizes(ArrayRef<int64_t> ts) {
+    tileSizes.assign(ts.begin(), ts.end());
+    return *this;
+  }
   /// Tile interchange used to permute the tile loops.
   SmallVector<int64_t> tileInterchange;
+  /// When specified, specifies distribution of generated tile loops to
+  /// processors.
+  Optional<LinalgLoopDistributionOptions> tileDistribution = None;
+  LinalgTilingAndFusionOptions &
+  setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
+    tileDistribution = std::move(distributionOptions);
+    return *this;
+  }
 };
 
 struct LinalgTilingOptions {
index 673c27e..dfc2a30 100644 (file)
@@ -246,73 +246,6 @@ FailureOr<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
                                            OpOperand &consumerOpOperand);
 
 //===----------------------------------------------------------------------===//
-// Fusion on tensor utilities
-//===----------------------------------------------------------------------===//
-
-/// A struct to manage the tile loop nest specific information.
-class TileLoopNest {
-public:
-  TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
-
-  /// Tile the root operation using the given `tileSizes` and `tileInterchange`.
-  LogicalResult tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
-                           ArrayRef<int64_t> tileInterchange);
-
-  /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
-  /// the fused producer or fails if fusion is not possible.
-  FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
-
-  /// Returns the replacement results for the original untiled root operation.
-  ValueRange getRootOpReplacementResults();
-
-  /// Returns the tiled root operation.
-  LinalgOp getRootOp() { return rootOp; }
-
-  /// Returns the tiled root operation and the fused producers.
-  SmallVector<LinalgOp> getAllTiledAndFusedOps();
-
-  /// Returns the loop ops generated from tiling.
-  ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; }
-
-  /// Returns true if the tile loop nest has no tile loops.
-  bool isEmpty();
-
-private:
-  /// Returns true if the tile loop nest invariants are satisfied:
-  /// - The `rootOp` has been tiled at least once.
-  /// - The number of tile loop operations and dimensions match.
-  /// - The innermost tile loop is the parent of `tiledOp`.
-  /// - The tile loops are directly nested.
-  // TODO: relax to support additional control flow, e.g., IfOp.
-  bool isValid();
-
-  /// Searches the block arguments tied to a block argument `bbArg` of the
-  /// innermost tile loop. Returns the block argument from outermost to
-  /// innermost or an empty vector if none are found.
-  SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg);
-
-  /// Returns the iteration argument of the outermost tile loop mapped to a
-  /// block argument `bbArg` of the innermost tile loop.
-  OpOperand *getTiedIterArg(BlockArgument bbArg);
-
-  /// Returns true if `bbArg` has other used than `sliceOp` and its
-  /// dependencies. Only if there are no other uses, the producer output
-  /// iteration argument may reused to pass the producer result after fusion.
-  bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
-
-  LinalgOp rootOp;
-  SmallVector<scf::ForOp> tileLoopOps;
-  DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
-};
-
-/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
-/// `tileSizes` and `tileInterchange` parameters to control the tiling.
-FailureOr<TileLoopNest>
-tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
-                             ArrayRef<int64_t> tileSizes,
-                             ArrayRef<int64_t> tileInterchange);
-
-//===----------------------------------------------------------------------===//
 // Distribution utilities
 //===----------------------------------------------------------------------===//
 
@@ -397,6 +330,77 @@ void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc,
                                        Value &ub, Value &step);
 
 //===----------------------------------------------------------------------===//
+// Fusion on tensor utilities
+//===----------------------------------------------------------------------===//
+
+/// A struct to manage the tile loop nest specific information.
+class TileLoopNest {
+public:
+  TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
+
+  /// Tile the root operation using the given `tileSizes` and `tileInterchange`,
+  /// and `tileDistribution`.
+  LogicalResult
+  tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
+             ArrayRef<int64_t> tileInterchange,
+             Optional<LinalgLoopDistributionOptions> tileDistribution);
+
+  /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
+  /// the fused producer or fails if fusion is not possible.
+  FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
+
+  /// Returns the replacement results for the original untiled root operation.
+  ValueRange getRootOpReplacementResults();
+
+  /// Returns the tiled root operation.
+  LinalgOp getRootOp() { return rootOp; }
+
+  /// Returns the tiled root operation and the fused producers.
+  SmallVector<LinalgOp> getAllTiledAndFusedOps();
+
+  /// Returns the loop ops generated from tiling.
+  ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; }
+
+  /// Returns true if the tile loop nest has no tile loops.
+  bool isEmpty();
+
+private:
+  /// Returns true if the tile loop nest invariants are satisfied:
+  /// - The `rootOp` has been tiled at least once.
+  /// - The number of tile loop operations and dimensions match.
+  /// - The innermost tile loop is the parent of `tiledOp`.
+  /// - The tile loops are directly nested.
+  // TODO: relax to support additional control flow, e.g., IfOp.
+  bool isValid();
+
+  /// Searches the block arguments tied to a block argument `bbArg` of the
+  /// innermost tile loop. Returns the block argument from outermost to
+  /// innermost or an empty vector if none are found.
+  SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg);
+
+  /// Returns the iteration argument of the outermost tile loop mapped to a
+  /// block argument `bbArg` of the innermost tile loop.
+  OpOperand *getTiedIterArg(BlockArgument bbArg);
+
+  /// Returns true if `bbArg` has other used than `sliceOp` and its
+  /// dependencies. Only if there are no other uses, the producer output
+  /// iteration argument may reused to pass the producer result after fusion.
+  bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
+
+  LinalgOp rootOp;
+  SmallVector<scf::ForOp> tileLoopOps;
+  DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
+};
+
+/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
+/// `tileSizes`, `tileInterchange`, and `tileDistribution` parameters to control
+/// the tiling.
+FailureOr<TileLoopNest> tileConsumerAndFuseProducers(
+    OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
+    ArrayRef<int64_t> tileInterchange,
+    Optional<LinalgLoopDistributionOptions> tileDistribution);
+
+//===----------------------------------------------------------------------===//
 // Generic op region utilities
 //===----------------------------------------------------------------------===//
 
index 154b8f4..25808ed 100644 (file)
@@ -269,9 +269,10 @@ bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
   });
 }
 
-LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
-                                       ArrayRef<int64_t> tileSizes,
-                                       ArrayRef<int64_t> tileInterchange) {
+LogicalResult TileLoopNest::tileRootOp(
+    OpBuilder &b, ArrayRef<int64_t> tileSizes,
+    ArrayRef<int64_t> tileInterchange,
+    Optional<LinalgLoopDistributionOptions> tileDistribution) {
   // Exit if all tile sizes are zero.
   if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
     return success();
@@ -283,6 +284,9 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
                           tileInterchange.begin(), tileInterchange.end()))
                       .setTileSizes(tileSizes)
                       .setLoopType(LinalgTilingLoopType::Loops);
+  if (tileDistribution)
+    tilingOptions =
+        tilingOptions.setDistributionOptions(tileDistribution.getValue());
 
   // TODO: Propagate RewriterBase everywhere.
   IRRewriter rewriter(b);
@@ -408,10 +412,10 @@ SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
 // Tile and fuse entry-points.
 //===----------------------------------------------------------------------===//
 
-FailureOr<TileLoopNest>
-mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
-                                           ArrayRef<int64_t> tileSizes,
-                                           ArrayRef<int64_t> tileInterchange) {
+FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
+    OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
+    ArrayRef<int64_t> tileInterchange,
+    Optional<LinalgLoopDistributionOptions> tileDistribution) {
   assert(tileSizes.size() == tileInterchange.size() &&
          "expect the number of tile sizes and interchange dims to match");
   assert(isPermutation(tileInterchange) &&
@@ -446,7 +450,8 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
   SmallVector<int64_t> outerTileSizes;
   outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
   outerTileSizes.append(tileSizes.size() - split, 0);
-  if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
+  if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
+                                     tileDistribution)))
     return failure();
   fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
 
@@ -454,7 +459,8 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
   SmallVector<int64_t> innerTileSizes;
   innerTileSizes.append(split, 0);
   innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
-  if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
+  if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
+                                     tileDistribution)))
     return failure();
   fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
 
index 486e8b6..9f177f0 100644 (file)
@@ -613,8 +613,9 @@ LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
         op, "expect the tile interchange permutes the root loops");
 
   // Tile `rootOp` and fuse its producers.
-  FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
-      rewriter, rootOp, rootTileSizes, rootInterchange);
+  FailureOr<TileLoopNest> tileLoopNest =
+      tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
+                                   rootInterchange, options.tileDistribution);
   if (failed(tileLoopNest))
     return rewriter.notifyMatchFailure(
         op, "tileConsumerAndFuseProducers failed unexpectedly");
diff --git a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir
new file mode 100644 (file)
index 0000000..58f8297
--- /dev/null
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-fuse-and-distribute-options -split-input-file | FileCheck %s
+
+//      CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+//      CHECK: #[[ADDMAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//      CHECK: func @fill_matmul_tensors(
+// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @fill_matmul_tensors(
+  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+//  CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+//  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
+//  CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y
+//  CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
+//  CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x
+//  CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor
+//      CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]]
+//      CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
+//      CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]]
+//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+//      CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDX]], %[[C8]]]
+//      CHECK: %[[LBX:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
+//      CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]]
+//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
+//      CHECK:     %[[SLICE:.+]] = tensor.extract_slice %[[TC1]]
+//      CHECK:     %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[SLICE]])
+//      CHECK:     %[[sTD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[FILL]]) -> (tensor<?x?xf32>) {
+//      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME:                                  outs(%[[sTC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
+//      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
+//      CHECK:       scf.yield %[[TD]] : tensor<?x?xf32>
+//      CHECK:     %[[TD2:.*]] = tensor.insert_slice %[[sTD2]] into %[[TC1]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
+//      CHECK:     scf.yield %[[TD2]] : tensor<?x?xf32>
+//      CHECK:   scf.yield %[[TD1]] : tensor<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  %3 = linalg.fill(%cst, %2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+  %4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"}
+       ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%3: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+//      CHECK: return %[[TD0]] : tensor<?x?xf32>
+  return %4 : tensor<?x?xf32>
+}
index 32f4538..8af5c46 100644 (file)
@@ -78,6 +78,10 @@ struct TestLinalgTransforms
       *this, "test-tile-and-distribute-options",
       llvm::cl::desc("Test tile and distribute options"),
       llvm::cl::init(false)};
+  Option<bool> testTileFuseAndDistributionOptions{
+      *this, "test-tile-fuse-and-distribute-options",
+      llvm::cl::desc("Test tile, fuse and distribute options"),
+      llvm::cl::init(false)};
   Option<bool> testVectorTransferForwardingPatterns{
       *this, "test-vector-transfer-forwarding-patterns",
       llvm::cl::desc(
@@ -505,6 +509,21 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
   }
 }
 
+static void fillTileFuseAndDistributePatterns(MLIRContext *context,
+                                              RewritePatternSet &patterns) {
+  LinalgLoopDistributionOptions cyclicNprocsEqNiters;
+  cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
+  cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
+  patterns.add<LinalgTileAndFuseTensorOpsPattern>(
+      MatmulOp::getOperationName(), context,
+      LinalgTilingAndFusionOptions()
+          .setTileSizes({8, 8, 4})
+          .setDistributionOptions(cyclicNprocsEqNiters),
+      LinalgTransformationFilter(
+          StringAttr::get(context, "tensors_fuse_distribute1"),
+          StringAttr::get(context, "tensors_after_fuse_distribute1")));
+}
+
 static void
 applyMatmulToVectorPatterns(FuncOp funcOp,
                             bool testMatmulToVectorPatterns1dTiling,
@@ -698,6 +717,12 @@ void TestLinalgTransforms::runOnOperation() {
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     return;
   }
+  if (testTileFuseAndDistributionOptions) {
+    RewritePatternSet patterns(&getContext());
+    fillTileFuseAndDistributePatterns(&getContext(), patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    return;
+  }
   if (testPatterns)
     return applyPatterns(getOperation());
   if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)