// loopToOperandRangesMaps are permutations-only.
unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
viewRanges[d] = loopRanges[loopPos];
- LLVM_DEBUG(dbgs() << "i,j: " << en.index() << ", " << en2.index() << "\t"
+ LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
+ << "\t"
<< "loopPos: " << loopPos << "\t" << viewRanges[d]);
}
// TODO(ntv) opportunities for folding/CSE here rather than build new IR.
for (auto en : llvm::enumerate(ios)) {
unsigned idx = en.index();
auto map = maps[idx];
- LLVM_DEBUG(dbgs() << "map: " << map << "\n");
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
Value *view = en.value();
SmallVector<Value *, 8> viewRanges(map.getNumResults(), nullptr);
for (auto en2 : llvm::enumerate(map.getResults())) {
- if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition())
+ if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
+ << "\n");
+ LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view
+ << "\n");
return ViewDimension{view, static_cast<unsigned>(en2.index())};
+ }
}
}
llvm_unreachable("Expect to be able to extract a view defining loop range");
return llvm::None;
unsigned producerIdx = maybeProducerIdx.getValue();
- auto sliceOp = dyn_cast_or_null<SubViewOp>(
+ // If the view is the same between consumer and tiledConsumer, this means we
+ // don't have loops and the producer cannot be fused at this level.
+ if (consumer.getInput(consumerIdx) == tiledConsumer.getInput(consumerIdx))
+ return llvm::None;
+
+ auto tiledConsumerSubView = dyn_cast_or_null<SubViewOp>(
tiledConsumer.getInput(consumerIdx)->getDefiningOp());
+
// If we don't have a slice, this also means we don't have loops and the
// producer cannot be fused at this level.
- if (!sliceOp)
+ if (!tiledConsumerSubView)
return llvm::None;
+ // loopToOperandRangesMaps are permutations-only by construction:
+ // we can always identify a data dimension with a (at least one) loop
+ // dimension.
AffineMap producerMap =
loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx];
- LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << "\tmap: "
+ LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << ", consumer map: "
<< loopToOperandRangesMaps(consumer)[consumerIdx] << "\n");
LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
- << "\tmap: " << producerMap << "\n");
+ << ", producer map: " << producerMap << "\n");
unsigned nPar = producer.getNumParallelLoops();
unsigned nRed = producer.getNumReductionLoops();
unsigned nWin = producer.getNumWindowLoops();
SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
- DenseSet<unsigned> fromSlice;
+
+ // Iterate over dimensions identified by the producer map for `producerIdx`.
+ // This defines a subset of the loop ranges that we need to complete later.
for (auto en : llvm::enumerate(producerMap.getResults())) {
- // loopToOperandRangesMaps are permutations-only.
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
- loopRanges[posInProducerLoop] = sliceOp.getRange(en.index());
- fromSlice.insert(posInProducerLoop);
+ loopRanges[posInProducerLoop] = tiledConsumerSubView.getRange(en.index());
}
OpBuilder b(tiledConsumer.getOperation());
auto loc = tiledConsumer.getLoc();
- for (unsigned i = 0; i < loopRanges.size(); ++i) {
- if (fromSlice.count(i))
- LLVM_DEBUG(llvm::dbgs() << "LR: " << loopRanges[i] << "\n");
+ // Iterate over all dimensions. For the dimensions not identified by the
+ // producer map for `producerIdx`, we need to explicitly compute the view that
+ // defines the loop ranges using the `producer`.
+ for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
+ if (loopRanges[i].min)
+ LLVM_DEBUG(llvm::dbgs()
+ << "existing LoopRange: " << loopRanges[i] << "\n");
else {
auto viewDim = getViewDefiningLoopRange(producer, i);
loopRanges[i] = SubViewOp::Range{
state.create<ConstantIndexOp>(b, loc, 0),
linalg::intrinsics::dim(viewDim.view, viewDim.dimension),
state.create<ConstantIndexOp>(b, loc, 1)};
- LLVM_DEBUG(llvm::dbgs() << "new LR: " << loopRanges[i] << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
}
}
OperationFolder state;
DenseSet<Operation *> eraseSet;
+ LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
+
// 1. Record the linalg ops so we can traverse them in reverse order.
SmallVector<Operation *, 8> linalgOps;
f.walk<LinalgOp>(
consumer, LinalgDependenceGraph::DependenceType::RAW)) {
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
- << *producer.getOperation());
+ << *producer.getOperation() << "\n");
// a. For now we require fusion on identical SSA values, this allows us to
// not worry about partial writes etc.
continue;
// 6. Try to fuse `producer` just before `tiledOp`.
+ LLVM_DEBUG(f.print(dbgs() << "\nBefore tiledOp-fusion: \n"));
+
auto tOp = tiledOp->op;
OpBuilder builder(tOp.getOperation());
ScopedContext scope(builder, tOp.getLoc());
+ LLVM_DEBUG(dbgs() << "Try fuse into tiled consumer: " << *tOp << "\n");
auto maybeFusedProducer = fuse(view, producer, op, tOp, state);
if (!maybeFusedProducer) {
LLVM_DEBUG(dbgs() << "\nFusion did not do anything, skip.");
namespace {
struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
- LinalgFusionPass();
+ LinalgFusionPass() = default;
LinalgFusionPass(ArrayRef<int64_t> sizes);
void runOnFunction() { fuseLinalgOps(getFunction(), tileSizes); }
};
} // namespace
-LinalgFusionPass::LinalgFusionPass()
- : tileSizes(clTileSizes.begin(), clTileSizes.end()) {}
-
LinalgFusionPass::LinalgFusionPass(ArrayRef<int64_t> sizes)
: LinalgFusionPass() {
if (!sizes.empty())
}
static PassRegistration<LinalgFusionPass>
- pass("linalg-fusion", "Fuse operations in the linalg dialect");
+ pass("linalg-fusion", "Fuse operations in the linalg dialect", [] {
+ auto *pass = new LinalgFusionPass();
+ pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
+ return pass;
+ });