This extends TileAndFuse to handle distribution on tensors.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D120441
struct LinalgTilingAndFusionOptions {
/// Tile sizes used to tile the root operation.
SmallVector<int64_t> tileSizes;
+ LinalgTilingAndFusionOptions &setTileSizes(ArrayRef<int64_t> ts) {
+ tileSizes.assign(ts.begin(), ts.end());
+ return *this;
+ }
/// Tile interchange used to permute the tile loops.
SmallVector<int64_t> tileInterchange;
+ /// When specified, specifies distribution of generated tile loops to
+ /// processors.
+ Optional<LinalgLoopDistributionOptions> tileDistribution = None;
+ LinalgTilingAndFusionOptions &
+ setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
+ tileDistribution = std::move(distributionOptions);
+ return *this;
+ }
};
struct LinalgTilingOptions {
OpOperand &consumerOpOperand);
//===----------------------------------------------------------------------===//
-// Fusion on tensor utilities
-//===----------------------------------------------------------------------===//
-
-/// A struct to manage the tile loop nest specific information.
-class TileLoopNest {
-public:
- TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
-
- /// Tile the root operation using the given `tileSizes` and `tileInterchange`.
- LogicalResult tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> tileInterchange);
-
- /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
- /// the fused producer or fails if fusion is not possible.
- FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
-
- /// Returns the replacement results for the original untiled root operation.
- ValueRange getRootOpReplacementResults();
-
- /// Returns the tiled root operation.
- LinalgOp getRootOp() { return rootOp; }
-
- /// Returns the tiled root operation and the fused producers.
- SmallVector<LinalgOp> getAllTiledAndFusedOps();
-
- /// Returns the loop ops generated from tiling.
- ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; }
-
- /// Returns true if the tile loop nest has no tile loops.
- bool isEmpty();
-
-private:
- /// Returns true if the tile loop nest invariants are satisfied:
- /// - The `rootOp` has been tiled at least once.
- /// - The number of tile loop operations and dimensions match.
- /// - The innermost tile loop is the parent of `tiledOp`.
- /// - The tile loops are directly nested.
- // TODO: relax to support additional control flow, e.g., IfOp.
- bool isValid();
-
- /// Searches the block arguments tied to a block argument `bbArg` of the
- /// innermost tile loop. Returns the block argument from outermost to
- /// innermost or an empty vector if none are found.
- SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg);
-
- /// Returns the iteration argument of the outermost tile loop mapped to a
- /// block argument `bbArg` of the innermost tile loop.
- OpOperand *getTiedIterArg(BlockArgument bbArg);
-
- /// Returns true if `bbArg` has other used than `sliceOp` and its
- /// dependencies. Only if there are no other uses, the producer output
- /// iteration argument may reused to pass the producer result after fusion.
- bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
-
- LinalgOp rootOp;
- SmallVector<scf::ForOp> tileLoopOps;
- DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
-};
-
-/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
-/// `tileSizes` and `tileInterchange` parameters to control the tiling.
-FailureOr<TileLoopNest>
-tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> tileInterchange);
-
-//===----------------------------------------------------------------------===//
// Distribution utilities
//===----------------------------------------------------------------------===//
Value &ub, Value &step);
//===----------------------------------------------------------------------===//
+// Fusion on tensor utilities
+//===----------------------------------------------------------------------===//
+
+/// A struct to manage the tile loop nest specific information.
+class TileLoopNest {
+public:
+ TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
+
+ /// Tile the root operation using the given `tileSizes` and `tileInterchange`,
+ /// and `tileDistribution`.
+ LogicalResult
+ tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> tileInterchange,
+ Optional<LinalgLoopDistributionOptions> tileDistribution);
+
+ /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
+ /// the fused producer or fails if fusion is not possible.
+ FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
+
+ /// Returns the replacement results for the original untiled root operation.
+ ValueRange getRootOpReplacementResults();
+
+ /// Returns the tiled root operation.
+ LinalgOp getRootOp() { return rootOp; }
+
+ /// Returns the tiled root operation and the fused producers.
+ SmallVector<LinalgOp> getAllTiledAndFusedOps();
+
+ /// Returns the loop ops generated from tiling.
+ ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; }
+
+ /// Returns true if the tile loop nest has no tile loops.
+ bool isEmpty();
+
+private:
+ /// Returns true if the tile loop nest invariants are satisfied:
+ /// - The `rootOp` has been tiled at least once.
+ /// - The number of tile loop operations and dimensions match.
+ /// - The innermost tile loop is the parent of `tiledOp`.
+ /// - The tile loops are directly nested.
+ // TODO: relax to support additional control flow, e.g., IfOp.
+ bool isValid();
+
+ /// Searches the block arguments tied to a block argument `bbArg` of the
+ /// innermost tile loop. Returns the block argument from outermost to
+ /// innermost or an empty vector if none are found.
+ SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg);
+
+ /// Returns the iteration argument of the outermost tile loop mapped to a
+ /// block argument `bbArg` of the innermost tile loop.
+ OpOperand *getTiedIterArg(BlockArgument bbArg);
+
+ /// Returns true if `bbArg` has other used than `sliceOp` and its
+ /// dependencies. Only if there are no other uses, the producer output
+ /// iteration argument may reused to pass the producer result after fusion.
+ bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
+
+ LinalgOp rootOp;
+ SmallVector<scf::ForOp> tileLoopOps;
+ DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
+};
+
+/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
+/// `tileSizes`, `tileInterchange`, and `tileDistribution` parameters to control
+/// the tiling.
+FailureOr<TileLoopNest> tileConsumerAndFuseProducers(
+ OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> tileInterchange,
+ Optional<LinalgLoopDistributionOptions> tileDistribution);
+
+//===----------------------------------------------------------------------===//
// Generic op region utilities
//===----------------------------------------------------------------------===//
});
}
-LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> tileInterchange) {
+LogicalResult TileLoopNest::tileRootOp(
+ OpBuilder &b, ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> tileInterchange,
+ Optional<LinalgLoopDistributionOptions> tileDistribution) {
// Exit if all tile sizes are zero.
if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
return success();
tileInterchange.begin(), tileInterchange.end()))
.setTileSizes(tileSizes)
.setLoopType(LinalgTilingLoopType::Loops);
+ if (tileDistribution)
+ tilingOptions =
+ tilingOptions.setDistributionOptions(tileDistribution.getValue());
// TODO: Propagate RewriterBase everywhere.
IRRewriter rewriter(b);
// Tile and fuse entry-points.
//===----------------------------------------------------------------------===//
-FailureOr<TileLoopNest>
-mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> tileInterchange) {
+FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
+ OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> tileInterchange,
+ Optional<LinalgLoopDistributionOptions> tileDistribution) {
assert(tileSizes.size() == tileInterchange.size() &&
"expect the number of tile sizes and interchange dims to match");
assert(isPermutation(tileInterchange) &&
SmallVector<int64_t> outerTileSizes;
outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
outerTileSizes.append(tileSizes.size() - split, 0);
- if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
+ if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
+ tileDistribution)))
return failure();
fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
SmallVector<int64_t> innerTileSizes;
innerTileSizes.append(split, 0);
innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
- if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
+ if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
+ tileDistribution)))
return failure();
fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
op, "expect the tile interchange permutes the root loops");
// Tile `rootOp` and fuse its producers.
- FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
- rewriter, rootOp, rootTileSizes, rootInterchange);
+ FailureOr<TileLoopNest> tileLoopNest =
+ tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
+ rootInterchange, options.tileDistribution);
if (failed(tileLoopNest))
return rewriter.notifyMatchFailure(
op, "tileConsumerAndFuseProducers failed unexpectedly");
--- /dev/null
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-fuse-and-distribute-options -split-input-file | FileCheck %s
+
+// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK: #[[ADDMAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fill_matmul_tensors(
+// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @fill_matmul_tensors(
+ %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
+// CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y
+// CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
+// CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor
+// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]]
+// CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
+// CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]]
+// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDX]], %[[C8]]]
+// CHECK: %[[LBX:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
+// CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]]
+// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TC1]]
+// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[SLICE]])
+// CHECK: %[[sTD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[FILL]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
+// CHECK: %[[TD2:.*]] = tensor.insert_slice %[[sTD2]] into %[[TC1]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
+// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %3 = linalg.fill(%cst, %2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+ %4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"}
+ ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%3: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+
+// CHECK: return %[[TD0]] : tensor<?x?xf32>
+ return %4 : tensor<?x?xf32>
+}
*this, "test-tile-and-distribute-options",
llvm::cl::desc("Test tile and distribute options"),
llvm::cl::init(false)};
+ Option<bool> testTileFuseAndDistributionOptions{
+ *this, "test-tile-fuse-and-distribute-options",
+ llvm::cl::desc("Test tile, fuse and distribute options"),
+ llvm::cl::init(false)};
Option<bool> testVectorTransferForwardingPatterns{
*this, "test-vector-transfer-forwarding-patterns",
llvm::cl::desc(
}
}
+static void fillTileFuseAndDistributePatterns(MLIRContext *context,
+ RewritePatternSet &patterns) {
+ LinalgLoopDistributionOptions cyclicNprocsEqNiters;
+ cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
+ cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
+ patterns.add<LinalgTileAndFuseTensorOpsPattern>(
+ MatmulOp::getOperationName(), context,
+ LinalgTilingAndFusionOptions()
+ .setTileSizes({8, 8, 4})
+ .setDistributionOptions(cyclicNprocsEqNiters),
+ LinalgTransformationFilter(
+ StringAttr::get(context, "tensors_fuse_distribute1"),
+ StringAttr::get(context, "tensors_after_fuse_distribute1")));
+}
+
static void
applyMatmulToVectorPatterns(FuncOp funcOp,
bool testMatmulToVectorPatterns1dTiling,
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
+ if (testTileFuseAndDistributionOptions) {
+ RewritePatternSet patterns(&getContext());
+ fillTileFuseAndDistributePatterns(&getContext(), patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ return;
+ }
if (testPatterns)
return applyPatterns(getOperation());
if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)