[mlir][Linalg] Change LinalgDependenceGraph to use LinalgOp.
authorMaheshRavishankar <ravishankarm@google.com>
Fri, 13 Nov 2020 20:21:43 +0000 (12:21 -0800)
committerMaheshRavishankar <ravishankarm@google.com>
Fri, 13 Nov 2020 20:34:38 +0000 (12:34 -0800)
Using LinalgOp will reduce the repeated conversion from Operation <->
LinalgOp.

Differential Revision: https://reviews.llvm.org/D91101

mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp

index 372f6c4..f27b929 100644 (file)
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
 #define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
 
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpDefinition.h"
 
@@ -67,7 +68,7 @@ public:
 
   // Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
   static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f);
-  LinalgDependenceGraph(Aliases &aliases, ArrayRef<Operation *> ops);
+  LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
 
   /// Returns the X such that op -> X is a dependence of type dt.
   dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
@@ -168,7 +169,7 @@ private:
                                         ArrayRef<DependenceType> types) const;
 
   Aliases &aliases;
-  SmallVector<Operation *, 8> linalgOps;
+  SmallVector<LinalgOp, 8> linalgOps;
   DenseMap<Operation *, unsigned> linalgOpPositions;
 };
 } // namespace linalg
index 01e167d..96da933 100644 (file)
@@ -86,21 +86,21 @@ StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
 
 LinalgDependenceGraph
 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
-  SmallVector<Operation *, 8> linalgOps;
+  SmallVector<LinalgOp, 8> linalgOps;
   f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
   return LinalgDependenceGraph(aliases, linalgOps);
 }
 
 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
-                                             ArrayRef<Operation *> ops)
+                                             ArrayRef<LinalgOp> ops)
     : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
   for (auto en : llvm::enumerate(linalgOps)) {
-    assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp");
-    linalgOpPositions.insert(std::make_pair(en.value(), en.index()));
+    linalgOpPositions.insert(
+        std::make_pair(en.value().getOperation(), en.index()));
   }
   for (unsigned i = 0, e = ops.size(); i < e; ++i) {
     for (unsigned j = i + 1; j < e; ++j) {
-      addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j]));
+      addDependencesBetween(ops[i], ops[j]);
     }
   }
 }
index e6e150b..eb9e3a5 100644 (file)
@@ -124,7 +124,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
   DenseSet<Operation *> eraseSet;
 
   // Save original Linalg ops, we only want to make a pass over those.
-  SmallVector<Operation *, 8> linalgOps;
+  SmallVector<LinalgOp, 8> linalgOps;
   f.walk([&](LinalgOp op) {
     // TODO: support multi-results.
     if (op.getOperation()->getNumResults() <= 1)
@@ -133,8 +133,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
 
   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
   bool changed = false;
-  for (auto *op : llvm::reverse(linalgOps)) {
-    LinalgOp linalgOp = cast<LinalgOp>(op);
+  for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
     for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
       if (en.value().getType().isa<MemRefType>()) {
         // TODO: LinalgDependenceGraph should be able to update itself.
@@ -142,7 +141,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
         // removed.
         linalg::Aliases aliases;
         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
-        if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) {
+        if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
           auto *originalOp = info->originalProducer.getOperation();
           eraseSet.insert(originalOp);
           auto *originalOpInLinalgOpsVector =
@@ -155,7 +154,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
         // Tile and Fuse tensor input (TODO: init_tensors too).
         if (en.index() >= linalgOp.getNumInputs())
           continue;
-        if (auto info = fuseProducerOfTensor(b, op, en.index())) {
+        if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
           auto *originalOp = info->originalProducer.getOperation();
           auto *originalOpInLinalgOpsVector =
               std::find(linalgOps.begin(), linalgOps.end(), originalOp);