#ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
#define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
// Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f);
- LinalgDependenceGraph(Aliases &aliases, ArrayRef<Operation *> ops);
+ LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
/// Returns the X such that op -> X is a dependence of type dt.
dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
ArrayRef<DependenceType> types) const;
Aliases &aliases;
- SmallVector<Operation *, 8> linalgOps;
+ SmallVector<LinalgOp, 8> linalgOps;
DenseMap<Operation *, unsigned> linalgOpPositions;
};
} // namespace linalg
LinalgDependenceGraph
LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
- SmallVector<Operation *, 8> linalgOps;
+ SmallVector<LinalgOp, 8> linalgOps;
f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
return LinalgDependenceGraph(aliases, linalgOps);
}
LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
- ArrayRef<Operation *> ops)
+ ArrayRef<LinalgOp> ops)
: aliases(aliases), linalgOps(ops.begin(), ops.end()) {
for (auto en : llvm::enumerate(linalgOps)) {
- assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp");
- linalgOpPositions.insert(std::make_pair(en.value(), en.index()));
+ linalgOpPositions.insert(
+ std::make_pair(en.value().getOperation(), en.index()));
}
for (unsigned i = 0, e = ops.size(); i < e; ++i) {
for (unsigned j = i + 1; j < e; ++j) {
- addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j]));
+ addDependencesBetween(ops[i], ops[j]);
}
}
}
DenseSet<Operation *> eraseSet;
// Save original Linalg ops, we only want to make a pass over those.
- SmallVector<Operation *, 8> linalgOps;
+ SmallVector<LinalgOp, 8> linalgOps;
f.walk([&](LinalgOp op) {
// TODO: support multi-results.
if (op.getOperation()->getNumResults() <= 1)
// Tile and Fuse for tensors inputs (TODO: all tensor operands).
bool changed = false;
- for (auto *op : llvm::reverse(linalgOps)) {
- LinalgOp linalgOp = cast<LinalgOp>(op);
+ for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
if (en.value().getType().isa<MemRefType>()) {
// TODO: LinalgDependenceGraph should be able to update itself.
// removed.
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalgOps);
- if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) {
+ if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
auto *originalOp = info->originalProducer.getOperation();
eraseSet.insert(originalOp);
auto *originalOpInLinalgOpsVector =
// Tile and Fuse tensor input (TODO: init_tensors too).
if (en.index() >= linalgOp.getNumInputs())
continue;
- if (auto info = fuseProducerOfTensor(b, op, en.index())) {
+ if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
auto *originalOp = info->originalProducer.getOperation();
auto *originalOpInLinalgOpsVector =
std::find(linalgOps.begin(), linalgOps.end(), originalOp);