Linalg "tile-and-fuse" is currently exposed as a Linalg pass "-linalg-fusion" but only the mechanics of the transformation are currently relevant.
Instead turn it into a "-test-linalg-greedy-fusion" pass which performs canonicalizations to enable more fusions to compose.
This allows dropping the OperationFolder which is not meant to be used with the pattern rewrite infrastructure.
Differential Revision: https://reviews.llvm.org/D90394
namespace mlir {
std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
-std::unique_ptr<OperationPass<FuncOp>> createLinalgFusionPass();
std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
let dependentDialects = ["linalg::LinalgDialect"];
}
-def LinalgFusion : FunctionPass<"linalg-fusion"> {
- let summary = "Fuse operations in the linalg dialect";
- let constructor = "mlir::createLinalgFusionPass()";
- let dependentDialects = ["linalg::LinalgDialect"];
-}
-
def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
let summary = "Fuse operations on RankedTensorType in linalg dialect";
let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
/// to be a `subview` op (generally obtained by applying the tiling
/// transformation).
-/// When non-null, the optional pointer `folder` is used to call into the
-/// `createAndFold` builder method. If `folder` is null, the regular `create`
-/// method is called.
Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
unsigned consumerIdx,
- const LinalgDependenceGraph &graph,
- OperationFolder *folder = nullptr);
+ const LinalgDependenceGraph &graph);
/// Tensor counterpart of `fuseProducerOfBuffer`.
/// This implements the fusion part of the "tileAndFuse on tensors"
/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
/// to be the result of a `subtensor` op (generally obtained by applying the
/// tiling transformation).
Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
- unsigned consumerIdx,
- OperationFolder *folder);
+ unsigned consumerIdx);
/// Fuse linalg operation on tensors, with the producer of the operand at
/// position `consumerIdx` of the consumer.
-Optional<SmallVector<Value, 1>>
-fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
- unsigned consumerIdx, OperationFolder *folder = nullptr);
+Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
+ Operation *consumer,
+ unsigned consumerIdx);
/// Returns the linearized list of all shape dimensions in a `linalgOp`.
/// Applying the inverse, concatenated loopToOperandRangeMaps to this list
/// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
/// concatenated indexing maps to the result of `getShape`. Returns None if
/// the bounds computation fails.
-Optional<SmallVector<Value, 4>>
-getLoopRanges(OpBuilder &builder, LinalgOp linalgOp,
- OperationFolder *folder = nullptr);
+Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
+ LinalgOp linalgOp);
/// Returns the values obtained by applying `map` to the list of values.
-/// When non-null, the optional pointer `folder` is used to call into the
-/// `createAndFold` builder method. If `folder` is null, the regular `create`
-/// method is called.
SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
- AffineMap map, ValueRange values,
- OperationFolder *folder = nullptr);
+ AffineMap map, ValueRange values);
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
#include "PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
-#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
-using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
-
using llvm::dbgs;
/// Implements a simple high-level fusion pass on linalg structured operations.
/// 2. Tensor case: `producerIdx` is the index of the tensor in
/// `producer.getResults()`.
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
- LinalgOp consumer, unsigned consumerIdx,
- OperationFolder *folder = nullptr) {
+ LinalgOp consumer, unsigned consumerIdx) {
Operation *shapeProducingOp =
consumer.getShapedOperand(consumerIdx).getDefiningOp();
assert((isa<SubViewOp>(shapeProducingOp) ||
<< "existing LoopRange: " << loopRanges[i] << "\n");
else {
auto shapeDim = getShapeDefiningLoopRange(producer, i);
- loopRanges[i] = Range{folded_std_constant_index(folder, 0),
+ loopRanges[i] = Range{std_constant_index(0),
std_dim(shapeDim.shape, shapeDim.dimension),
- folded_std_constant_index(folder, 1)};
+ std_constant_index(1)};
LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
}
}
return {};
}
-Optional<FusionInfo> mlir::linalg::fuseProducerOfBuffer(
- OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
- const LinalgDependenceGraph &graph, OperationFolder *folder) {
+Optional<FusionInfo>
+mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
+ unsigned consumerIdx,
+ const LinalgDependenceGraph &graph) {
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
findFusableProducer(consumer, consumerIdx, graph);
if (!fusableDependence)
return {};
LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+ // If producer is already in the same block as consumer, we are done.
+ if (consumer.getOperation()->getBlock() ==
+ producerOp.getOperation()->getBlock())
+ return {};
+
Value producerView = fusableDependence->dependentOpView.view;
Value consumerView = fusableDependence->indexingView;
assert(producerIdxOpt.hasValue() && "incorrect operand index");
unsigned producerIdx = producerIdxOpt.getValue();
- auto fusedProducer =
- fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
+ auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
return FusionInfo{producerOp, fusedProducer};
}
}
}
-Optional<FusionInfo>
-mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
- unsigned consumerIdx,
- OperationFolder *folder) {
+Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
+ LinalgOp consumer,
+ unsigned consumerIdx) {
Value inputTensor = consumer.getInput(consumerIdx);
LinalgOp producerOp;
unsigned producerIdx;
return {};
}
+ // If producer is already in the same block as consumer, we are done.
+ if (consumer.getOperation()->getBlock() ==
+ producerOp.getOperation()->getBlock())
+ return {};
+
// Insert fused `producer` just before `consumer`.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(consumer.getOperation());
ScopedContext scope(b, consumer.getLoc());
LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
LinalgOp fusedProducer =
- fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
+ fuse(b, producerOp, producerIdx, consumer, consumerIdx);
// Replace use.
// Canonicalizations are not guaranteed to have happened before constructing
}
return llvm::None;
}
-
-static void fuseLinalgOpsGreedily(FuncOp f) {
- LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
-
- OpBuilder b(f);
- OperationFolder folder(f.getContext());
- DenseSet<Operation *> eraseSet;
-
- // Save original Linalg ops, we only want to make a pass over those.
- SmallVector<Operation *, 8> linalgOps;
- f.walk([&](LinalgOp op) {
- // TODO: support multi-results.
- if (op.getOperation()->getNumResults() <= 1)
- linalgOps.push_back(op);
- });
-
- // Tile and Fuse for tensors inputs (TODO: all tensor operands).
- for (auto *op : llvm::reverse(linalgOps)) {
- LinalgOp linalgOp = cast<LinalgOp>(op);
- for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
- if (en.value().getType().isa<MemRefType>()) {
- // TODO: LinalgDependenceGraph should be able to update itself.
- // The current naive and expensive reconstruction of the graph should be
- // removed.
- linalg::Aliases aliases;
- linalg::LinalgDependenceGraph graph(aliases, linalgOps);
- if (auto info =
- fuseProducerOfBuffer(b, op, en.index(), graph, &folder)) {
- auto *originalOp = info->originalProducer.getOperation();
- eraseSet.insert(originalOp);
- auto *originalOpInLinalgOpsVector =
- std::find(linalgOps.begin(), linalgOps.end(), originalOp);
- *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
- }
- } else {
- assert(en.value().getType().isa<RankedTensorType>());
- // Tile and Fuse tensor input (TODO: init_tensors too).
- if (en.index() >= linalgOp.getNumInputs())
- continue;
- if (auto info = fuseProducerOfTensor(b, op, en.index(), &folder)) {
- auto *originalOp = info->originalProducer.getOperation();
- auto *originalOpInLinalgOpsVector =
- std::find(linalgOps.begin(), linalgOps.end(), originalOp);
- *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
- // Don't mark for erasure in the tensor case, let DCE handle this.
- }
- }
- }
- }
- // The `fuseProducerOfBuffer` function performs structural checks and in
- // particular that no covering read or write exist between the consumer and
- // the producer. As a consequence, the only fusions that may occur preserve
- // subsequent dependences and are guaranteed by construction to produce the
- // whole view. We may thus erase the producer once it is fused.
- for (auto *e : eraseSet)
- e->erase();
-
- LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
-}
-
-namespace {
-struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
- void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
-};
-} // namespace
-
-std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
- return std::make_unique<LinalgFusionPass>();
-}
static Optional<SmallVector<Value, 1>>
fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
- PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
+ PatternRewriter &rewriter) {
if (!areTensorOpsFusable(producer, consumer, consumerIdx))
return llvm::None;
/// conditions have been satisfied.
static Optional<SmallVector<Value, 1>>
fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
- unsigned fusedTensorIndex, PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
+ unsigned fusedTensorIndex,
+ PatternRewriter &rewriter) {
assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) &&
"preconditions for fuse operation failed");
// Check if reshape is expanding or collapsing.
Optional<SmallVector<Value, 1>>
mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
- unsigned consumerIdx, OperationFolder *folder) {
+ unsigned consumerIdx) {
if (consumerIdx >= consumer->getNumOperands())
return llvm::None;
Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
return llvm::None;
return fuseTensorOpsImpl(cast<LinalgOp>(producer), cast<LinalgOp>(consumer),
- consumerIdx, rewriter, folder);
+ consumerIdx, rewriter);
}
namespace {
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/FoldUtils.h"
using namespace mlir;
using namespace mlir::linalg;
return llvm::None;
}
-static Value emitOrFoldComposedAffineApply(OpBuilder &b, Location loc,
- AffineMap map,
- ValueRange operandsRef,
- OperationFolder *folder) {
+static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
+ AffineMap map,
+ ValueRange operandsRef) {
SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
fullyComposeAffineMapAndOperands(&map, &operands);
canonicalizeMapAndOperands(&map, &operands);
- return folder ? folder->create<AffineApplyOp>(b, loc, map, operands)
- : b.create<AffineApplyOp>(loc, map, operands);
+ return b.createOrFold<AffineApplyOp>(loc, map, operands);
}
SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
AffineMap map,
- ValueRange values,
- OperationFolder *folder) {
+ ValueRange values) {
SmallVector<Value, 4> res;
res.reserve(map.getNumResults());
unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
// For each `expr` in `map`, applies the `expr` to the values extracted from
// ranges. If the resulting application can be folded into a Value, the
- // folding occurs eagerly. Otherwise, an affine.apply operation is emitted.
+ // folding occurs eagerly.
for (auto expr : map.getResults()) {
AffineMap map = AffineMap::get(numDims, numSym, expr);
- res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, folder));
+ res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
}
return res;
}
return res;
}
-Optional<SmallVector<Value, 4>>
-getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
+Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
+ LinalgOp linalgOp) {
SmallVector<Value, 8> viewSizes = getShape(builder, linalgOp);
AffineMap invertedMap =
inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
if (!invertedMap)
return {};
- return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes,
- folder);
+ return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes);
}
/// Specialization to build an scf "for" nest.
-// RUN: mlir-opt %s -linalg-fusion | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-greedy-fusion | FileCheck %s
func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, %C: memref<?x?xf32, offset: ?, strides: [?, 1]>, %D: memref<?x?xf32, offset: ?, strides: [?, 1]>, %E: memref<?x?xf32, offset: ?, strides: [?, 1]>) -> memref<?x?xf32, offset: ?, strides: [?, 1]> {
%c1 = constant 1 : index
-// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#id_2d = affine_map<(d0, d1) -> (d0, d1)>
^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors
%i_int = index_cast %i: index to i32
%i_float = sitofp %i_int : i32 to f32
+ %j_int = index_cast %j: index to i32
+ %j_float = sitofp %j_int : i32 to f32
%ab = addf %a, %b : f32
- %out = addf %ab, %i_float : f32
+ %tmp = addf %ab, %i_float : f32
+ %out = addf %tmp, %j_float : f32
linalg.yield %out : f32
}
%C_X = dim %C, %c0 : memref<?x?xf32>
// CHECK: [[i_new:%.*]] = addi [[i]], [[I]] : index
// CHECK: [[j_new:%.*]] = addi [[j]], [[J]] : index
// CHECK: {{.*}} = index_cast [[i_new]] : index to i32
+// CHECK: {{.*}} = index_cast [[j_new]] : index to i32
// CHECK: linalg.generic
// CHECK: addf
ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C : memref<?x?xf32>) {
^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors
+ %i_int = index_cast %i: index to i32
+ %i_float = sitofp %i_int : i32 to f32
%j_int = index_cast %j: index to i32
%j_float = sitofp %j_int : i32 to f32
%ab = addf %a, %b : f32
- %out = addf %ab, %j_float : f32
+ %tmp = addf %ab, %i_float : f32
+ %out = addf %tmp, %j_float : f32
linalg.yield %out : f32
}
%C_X = dim %C, %c0 : memref<?x?xf32>
// CHECK-NOT: scf.parallel
// CHECK: linalg.indexed_generic
// CHECK: ^bb0([[i:%.*]]: index, [[j:%.*]]: index
-// CHECK: [[i_new:%.*]] = addi [[i]], [[C0]] : index
// CHECK: [[j_new:%.*]] = addi [[j]], [[J]] : index
+// CHECK: {{.*}} = index_cast [[i]] : index to i32
// CHECK: {{.*}} = index_cast [[j_new]] : index to i32
// CHECK: linalg.generic
// CHECK: addf
-// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
%B: memref<?x?xf32, offset: 0, strides: [?, 1]>,
// -----
+// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
+
func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%B: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%C: memref<?x?xf32, offset: 0, strides: [?, ?]>,
}
// CHECK-LABEL: func @f3
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK: %[[D_0:.*]] = dim %[[D]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[D_0:.*]] = dim %[[D]], %[[C0]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[D_1:.*]] = dim %[[D]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[C_1:.*]] = dim %[[C]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
// -----
+// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
+
func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%B: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%C: memref<?x?xf32, offset: 0, strides: [?, ?]>,
}
// CHECK-LABEL: func @f4
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK: %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[C_0:.*]] = dim %[[C]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[C_1:.*]] = dim %[[C]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
}
// CHECK-LABEL: func @f5
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK-DAG: %[[B_1:.*]] = dim %[[B]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK-DAG: %[[D_0:.*]] = dim %[[D]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK-DAG: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
-// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
-// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
-// CHECK-DAG: %[[D_IK:.*]] = subview %[[D]][%[[I]], %[[K]]]
-// CHECK-DAG: %[[B_KJ:.*]] = subview %[[B]][%[[K]], %[[J]]]
-// CHECK-DAG: %[[E_IJ:.*]] = subview %[[E]][%[[I]], %[[J]]]
-// CHECK: dim
-// CHECK-DAG: %[[C_I0:.*]] = subview %[[C]][%[[I]], %{{.*}}]
-// CHECK-DAG: %[[B_0K:.*]] = subview %[[B]][%{{.*}}, %[[K]]]
-// CHECK-DAG: %[[D_IK_:.*]] = subview %[[D]][%[[I]], %[[K]]]
-// CHECK: dim
-// CHECK-DAG: %[[A_I0:.*]] = subview %[[A]][%[[I]], %{{.*}}]
-// CHECK-DAG: %[[B_00:.*]] = subview %[[B]][%{{.*}}, %{{.*}}]
-// CHECK-DAG: %[[C_I0_:.*]] = subview %[[C]][%[[I]], %{{.*}}]
-// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_]]
-// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_]]
-// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[B_1:.*]] = dim %[[B]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[D_0:.*]] = dim %[[D]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[B_00:.*]] = subview %[[B]][0, 0]{{.*}}
+// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
+// CHECK-DAG: %[[A_I0:.*]] = subview %[[A]][%[[I]], 0]
+// CHECK-DAG: %[[C_I0:.*]] = subview %[[C]][%[[I]], 0]
+// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
+// CHECK: %[[E_IJ:.*]] = subview %[[E]][%[[I]], %[[J]]]
+// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
+// CHECK-DAG: %[[D_IK:.*]] = subview %[[D]][%[[I]], %[[K]]]
+// CHECK-DAG: %[[B_0K:.*]] = subview %[[B]][0, %[[K]]]
+// CHECK-DAG: %[[B_KJ:.*]] = subview %[[B]][%[[K]], %[[J]]]
+// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0]]
+// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK]]
+// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
// -----
}
// CHECK-LABEL: func @f7
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK: %[[A_0:.*]] = dim %[[A]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[A_1:.*]] = dim %[[A]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[A_0:.*]] = dim %[[A]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[A_1:.*]] = dim %[[A]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[C_1:.*]] = dim %[[C]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[C_0:.*]] = dim %[[C]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK: linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]]
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
-// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -linalg-fusion -canonicalize -cse -split-input-file | FileCheck %s --check-prefix=CANONICALIZED
+// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
// CHECK-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
-// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[dA1:.*]] = dim %[[A]], %[[C1]] : tensor<?x?xf32>
// CHECK: scf.for %[[I:[0-9a-z]*]]
+// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<2x?xf32>
// CHECK-NEXT: scf.for %[[J:[0-9a-z]*]]
-// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]]
-//
-// subtensor of the original program, first one refers to the unfused matmul and becomes a dead SSA value.
-// CHECK: subtensor %{{.*}}[%[[I]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x4xf32>
-// CHECK: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] {{.*}} : tensor<?x?xf32> to tensor<4x?xf32>
-// CHECK: %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-//
-// subtensors of the producing matmul.
-// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], %[[C0]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK-NEXT: %[[stB2:.*]] = subtensor %[[B]][%[[C0]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK-NEXT: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK-NEXT: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) init(%[[stC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-NEXT: %[[stD2:.*]] = tensor_cast %[[stD]] : tensor<?x?xf32> to tensor<?x4xf32>
-// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD2]], %[[stB1]] : tensor<?x4xf32>, tensor<4x?xf32>) init(%[[stF]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-NEXT: subtensor_insert %[[stG]]
-
-
-// CANONICALIZED-LABEL: func @matmul_tensors(
-// CANONICALIZED-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
-// CANONICALIZED-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
-// CANONICALIZED-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
-// CANONICALIZED: %[[C0:.*]] = constant 0 : index
-// CANONICALIZED: %[[C1:.*]] = constant 1 : index
-// CANONICALIZED: scf.for %[[I:[0-9a-z]*]]
-// CANONICALIZED-NEXT: scf.for %[[J:[0-9a-z]*]]
-// CANONICALIZED-NEXT: scf.for %[[K:[0-9a-z]*]]
-//
-// CANONICALIZED: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor<?x?xf32> to tensor<4x3xf32>
-// CANONICALIZED: %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1] : tensor<?x?xf32> to tensor<2x3xf32>
+// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
+// CHECK-DAG: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor<?x?xf32> to tensor<4x3xf32>
+// CHECK-DAG: %[[stF:.*]] = subtensor %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1] : tensor<?x?xf32> to tensor<2x3xf32>
//
// subtensors of the producing matmul.
-// CANONICALIZED: %[[dA1:.*]] = dim %[[A]], %[[C1]] : tensor<?x?xf32>
-// CANONICALIZED: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<2x?xf32>
-// CANONICALIZED-NEXT: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
-// CANONICALIZED-NEXT: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor<?x?xf32> to tensor<2x4xf32>
-// CANONICALIZED-NEXT: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) init(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
-// CANONICALIZED-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
-// CANONICALIZED-NEXT: subtensor_insert %[[stG]]
+// CHECK-DAG: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
+// CHECK-DAG: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor<?x?xf32> to tensor<2x4xf32>
+// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) init(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
+// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
+// CHECK-NEXT: subtensor_insert %[[stG]] into %[[RES]][%[[I]], %[[J]]]
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::linalg;
applyFusionPatterns(&getContext(), getFunction());
}
+static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
+ OpBuilder b(f);
+ DenseSet<Operation *> eraseSet;
+
+ // Save original Linalg ops, we only want to make a pass over those.
+ SmallVector<Operation *, 8> linalgOps;
+ f.walk([&](LinalgOp op) {
+ // TODO: support multi-results.
+ if (op.getOperation()->getNumResults() <= 1)
+ linalgOps.push_back(op);
+ });
+
+ // Tile and Fuse for tensors inputs (TODO: all tensor operands).
+ bool changed = false;
+ for (auto *op : llvm::reverse(linalgOps)) {
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
+ if (en.value().getType().isa<MemRefType>()) {
+ // TODO: LinalgDependenceGraph should be able to update itself.
+ // The current naive and expensive reconstruction of the graph should be
+ // removed.
+ linalg::Aliases aliases;
+ linalg::LinalgDependenceGraph graph(aliases, linalgOps);
+ if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) {
+ auto *originalOp = info->originalProducer.getOperation();
+ eraseSet.insert(originalOp);
+ auto *originalOpInLinalgOpsVector =
+ std::find(linalgOps.begin(), linalgOps.end(), originalOp);
+ *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+ changed = true;
+ }
+ } else {
+ assert(en.value().getType().isa<RankedTensorType>());
+ // Tile and Fuse tensor input (TODO: init_tensors too).
+ if (en.index() >= linalgOp.getNumInputs())
+ continue;
+ if (auto info = fuseProducerOfTensor(b, op, en.index())) {
+ auto *originalOp = info->originalProducer.getOperation();
+ auto *originalOpInLinalgOpsVector =
+ std::find(linalgOps.begin(), linalgOps.end(), originalOp);
+ *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+ // Don't mark for erasure in the tensor case, let DCE handle this.
+ changed = true;
+ }
+ }
+ }
+ }
+ // The `fuseProducerOfBuffer` function performs structural checks and in
+ // particular that no covering read or write exist between the consumer and
+ // the producer. As a consequence, the only fusions that may occur preserve
+ // subsequent dependences and are guaranteed by construction to produce the
+ // whole view. We may thus erase the producer once it is fused.
+ for (auto *e : eraseSet)
+ e->erase();
+
+ return changed ? success() : failure();
+}
+
+namespace {
+struct TestLinalgGreedyFusion
+ : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
+ void runOnFunction() override {
+ MLIRContext *context = &getContext();
+ OwningRewritePatternList patterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
+ FrozenRewritePatternList frozenPatterns(std::move(patterns));
+ while (succeeded(fuseLinalgOpsGreedily(getFunction()))) {
+ applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
+ PassManager pm(context);
+ pm.addPass(createLoopInvariantCodeMotionPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ LogicalResult res = pm.run(getFunction().getParentOfType<ModuleOp>());
+ if (failed(res))
+ this->signalPassFailure();
+ }
+ }
+};
+} // namespace
+
namespace mlir {
void registerTestLinalgFusionTransforms() {
PassRegistration<TestLinalgFusionTransforms> testFusionTransformsPass(
"test-linalg-fusion-transform-patterns",
"Test Linalg fusion transformation patterns by applying them greedily.");
}
+void registerTestLinalgGreedyFusion() {
+ PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
+ "test-linalg-greedy-fusion",
+ "Test Linalg fusion by applying a greedy test transformation.");
+}
} // namespace mlir
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgFusionTransforms();
+void registerTestLinalgGreedyFusion();
void registerTestLinalgHoisting();
void registerTestLinalgTransforms();
void registerTestLivenessPass();
registerTestInterfaces();
registerTestLinalgCodegenStrategy();
registerTestLinalgFusionTransforms();
+ registerTestLinalgGreedyFusion();
registerTestLinalgHoisting();
registerTestLinalgTransforms();
registerTestLivenessPass();