Add support for walking the use list of an SSAValue and converting owners to
authorChris Lattner <clattner@google.com>
Sun, 28 Oct 2018 17:03:19 +0000 (10:03 -0700)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 20:43:01 +0000 (13:43 -0700)
Operation*'s, simplifying some code in GreedyPatternRewriteDriver.cpp.

Also add print/dump methods on Operation.

PiperOrigin-RevId: 219045764

mlir/include/mlir/IR/Instructions.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/Statements.h
mlir/include/mlir/IR/UseDefLists.h
mlir/lib/IR/Operation.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index 64507a6f50666c89d239b1a550dd5c9198c3992e..e74c5616a451614e9a3d6ef1ac1f89202474e07c 100644 (file)
@@ -206,11 +206,13 @@ public:
                                ArrayRef<NamedAttribute> attributes,
                                MLIRContext *context);
 
+  using Instruction::dump;
   using Instruction::emitError;
   using Instruction::emitNote;
   using Instruction::emitWarning;
   using Instruction::getContext;
   using Instruction::getLoc;
+  using Instruction::print;
 
   OperationInst *clone() const;
 
@@ -341,8 +343,8 @@ public:
                   llvm::iplist<OperationInst>::iterator iterator);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Instruction *inst) {
-    return inst->getKind() == Kind::Operation;
+  static bool classof(const IROperandOwner *ptr) {
+    return ptr->getKind() == IROperandOwner::Kind::OperationInst;
   }
   static bool classof(const Operation *op) {
     return op->getOperationKind() == OperationKind::Instruction;
@@ -433,8 +435,8 @@ public:
   ArrayRef<BasicBlockOperand> getBasicBlockOperands() const { return dest; }
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Instruction *inst) {
-    return inst->getKind() == Kind::Branch;
+  static bool classof(const IROperandOwner *ptr) {
+    return ptr->getKind() == IROperandOwner::Kind::BranchInst;
   }
 
 private:
@@ -479,10 +481,7 @@ public:
 
   unsigned getNumOperands() const { return operands.size(); }
 
-  //
-  // Accessors for operands to the 'true' destination
-  //
-
+  // Accessors for operands to the 'true' destination.
   CFGValue *getTrueOperand(unsigned idx) {
     return getTrueInstOperand(idx).get();
   }
@@ -530,10 +529,7 @@ public:
   /// Add a list of values to the operand list.
   void addTrueOperands(ArrayRef<CFGValue *> values);
 
-  //
-  // Accessors for operands to the 'false' destination
-  //
-
+  // Accessors for operands to the 'false' destination.
   CFGValue *getFalseOperand(unsigned idx) {
     return getFalseInstOperand(idx).get();
   }
@@ -592,8 +588,8 @@ public:
   ArrayRef<BasicBlockOperand> getBasicBlockOperands() const { return dests; }
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Instruction *inst) {
-    return inst->getKind() == Kind::CondBranch;
+  static bool classof(const IROperandOwner *ptr) {
+    return ptr->getKind() == IROperandOwner::Kind::CondBranchInst;
   }
 
 private:
@@ -631,8 +627,8 @@ public:
   void destroy();
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Instruction *inst) {
-    return inst->getKind() == Kind::Return;
+  static bool classof(const IROperandOwner *ptr) {
+    return ptr->getKind() == IROperandOwner::Kind::ReturnInst;
   }
 
 private:
index a0e820966c56e094f5fd9f29069714144e11fa31..2d294e1bb1a9599744b99db193113d278feeb21b 100644 (file)
@@ -29,6 +29,7 @@ template <typename OpType> class OpPointer;
 template <typename ObjectType, typename ElementType> class OperandIterator;
 template <typename ObjectType, typename ElementType> class ResultIterator;
 class Function;
+class IROperandOwner;
 class Instruction;
 class Statement;
 
@@ -232,12 +233,16 @@ public:
 
   // Returns whether the operation is commutative.
   bool isCommutative() const {
-    return getAbstractOperation()->hasProperty(OperationProperty::Commutative);
+    if (auto *absOp = getAbstractOperation())
+      return absOp->hasProperty(OperationProperty::Commutative);
+    return false;
   }
 
   // Returns whether the operation has side-effects.
   bool hasNoSideEffect() const {
-    return getAbstractOperation()->hasProperty(OperationProperty::NoSideEffect);
+    if (auto *absOp = getAbstractOperation())
+      return absOp->hasProperty(OperationProperty::NoSideEffect);
+    return false;
   }
 
   /// Remove this operation from its parent block and delete it.
@@ -251,9 +256,13 @@ public:
   bool constantFold(ArrayRef<Attribute> operands,
                     SmallVectorImpl<Attribute> &results) const;
 
+  void print(raw_ostream &os) const;
+  void dump() const;
+
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Instruction *inst);
   static bool classof(const Statement *stmt);
+  static bool classof(const IROperandOwner *ptr);
 
 protected:
   Operation(bool isInstruction, OperationName name,
@@ -410,4 +419,16 @@ inline auto Operation::getResults() const
 }
 } // end namespace mlir
 
+/// We need to teach the LLVM cast/dyn_cast etc logic how to cast from an
+/// IROperandOwner* to Operation*.  This can't be done with a simple pointer to
+/// pointer cast because the pointer adjustment depends on whether the Owner is
+/// dynamically an Instruction or Statement, because of multiple inheritance.
+namespace llvm {
+template <>
+struct cast_convert_val<mlir::Operation, mlir::IROperandOwner *,
+                        mlir::IROperandOwner *> {
+  static mlir::Operation *doit(const mlir::IROperandOwner *value);
+};
+} // namespace llvm
+
 #endif
index 2ad815f1c5fc30d97fbfcaa6ce61544ca9a91071..7e7a49ffa15ff966027afb198585cb7530b40dab 100644 (file)
@@ -51,10 +51,12 @@ public:
   /// Return the context this operation is associated with.
   MLIRContext *getContext() const;
 
+  using Statement::dump;
   using Statement::emitError;
   using Statement::emitNote;
   using Statement::emitWarning;
   using Statement::getLoc;
+  using Statement::print;
 
   /// Check if this statement is a return statement.
   bool isReturn() const;
index 3f191e0e3f21ff536c6d887ff22b50e70d4fbcee..4a2594774a38a3b025dfdc8d4f707e22842b2b32 100644 (file)
@@ -29,6 +29,7 @@
 namespace mlir {
 
 class IROperand;
+class IROperandOwner;
 template <typename OperandType, typename OwnerType> class SSAValueUseIterator;
 
 class IRObjectWithUseList {
@@ -43,7 +44,7 @@ public:
   /// Returns true if this value has exactly one use.
   inline bool hasOneUse() const;
 
-  using use_iterator = SSAValueUseIterator<IROperand, void>;
+  using use_iterator = SSAValueUseIterator<IROperand, IROperandOwner>;
   using use_range = llvm::iterator_range<use_iterator>;
 
   inline use_iterator use_begin() const;
index 99da81df99016dec7ae26c559b0f3516c781091e..2ed09b83b53c403ae99ae474f6c7827a39b23331 100644 (file)
@@ -237,6 +237,18 @@ bool Operation::constantFold(ArrayRef<Attribute> operands,
   return true;
 }
 
+void Operation::print(raw_ostream &os) const {
+  if (auto *inst = llvm::dyn_cast<OperationInst>(this))
+    return inst->print(os);
+  return llvm::cast<OperationStmt>(this)->print(os);
+}
+
+void Operation::dump() const {
+  if (auto *inst = llvm::dyn_cast<OperationInst>(this))
+    return inst->dump();
+  return llvm::cast<OperationStmt>(this)->dump();
+}
+
 /// Methods for support type inquiry through isa, cast, and dyn_cast.
 bool Operation::classof(const Instruction *inst) {
   return inst->getKind() == Instruction::Kind::Operation;
@@ -244,6 +256,26 @@ bool Operation::classof(const Instruction *inst) {
 bool Operation::classof(const Statement *stmt) {
   return stmt->getKind() == Statement::Kind::Operation;
 }
+bool Operation::classof(const IROperandOwner *ptr) {
+  return ptr->getKind() == IROperandOwner::Kind::OperationInst ||
+         ptr->getKind() == IROperandOwner::Kind::OperationStmt;
+}
+
+/// We need to teach the LLVM cast/dyn_cast etc logic how to cast from an
+/// IROperandOwner* to Operation*.  This can't be done with a simple pointer to
+/// pointer cast because the pointer adjustment depends on whether the Owner is
+/// dynamically an Instruction or Statement, because of multiple inheritance.
+Operation *
+llvm::cast_convert_val<mlir::Operation, mlir::IROperandOwner *,
+                       mlir::IROperandOwner *>::doit(const mlir::IROperandOwner
+                                                         *value) {
+  const Operation *op;
+  if (auto *ptr = dyn_cast<OperationStmt>(value))
+    op = ptr;
+  else
+    op = cast<OperationInst>(value);
+  return const_cast<Operation *>(op);
+}
 
 //===----------------------------------------------------------------------===//
 // OpState trait class.
index 30034b6fce512e1667d91a26be35953b6f423e07..cdf5b7166a084d755a64c6afce0c1de692a29672 100644 (file)
@@ -98,6 +98,22 @@ public:
     driver.removeFromWorklist(op);
   }
 
+  // When the root of a pattern is about to be replaced, it can trigger
+  // simplifications to its users - make sure to add them to the worklist
+  // before the root is changed.
+  void notifyRootReplaced(Operation *op) override {
+    for (auto *result : op->getResults())
+      // TODO: Add a result->getUsers() iterator.
+      for (auto &user : result->getUses()) {
+        if (auto *op = dyn_cast<Operation>(user.getOwner()))
+          driver.addToWorklist(op);
+      }
+
+    // TODO: Walk the operand list dropping them as we go.  If any of them
+    // drop to zero uses, then add them to the worklist to allow them to be
+    // deleted as dead.
+  }
+
   GreedyPatternRewriteDriver &driver;
 };
 
@@ -206,22 +222,10 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
         // Add all the users of the result to the worklist so we make sure to
         // revisit them.
         //
-        // TODO: This is super gross. SSAValue use iterators should have an
-        // "owner" that can be downcasted to operation and other things.  This
-        // will require a rejiggering of the class hierarchies.
-        if (auto *stmt = dyn_cast<OperationStmt>(op)) {
-          // TODO: Add a result->getUsers() iterator.
-          for (auto &operand : stmt->getResult(i)->getUses()) {
-            if (auto *op = dyn_cast<OperationStmt>(operand.getOwner()))
-              addToWorklist(op);
-          }
-        } else {
-          auto *inst = cast<OperationInst>(op);
-          // TODO: Add a result->getUsers() iterator.
-          for (auto &operand : inst->getResult(i)->getUses()) {
-            if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
-              addToWorklist(op);
-          }
+        // TODO: Add a result->getUsers() iterator.
+        for (auto &operand : op->getResult(i)->getUses()) {
+          if (auto *op = dyn_cast<Operation>(operand.getOwner()))
+            addToWorklist(op);
         }
 
         res->replaceAllUsesWith(cstValue);
@@ -268,23 +272,6 @@ static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) {
       return result;
     }
 
-    // When the root of a pattern is about to be replaced, it can trigger
-    // simplifications to its users - make sure to add them to the worklist
-    // before the root is changed.
-    void notifyRootReplaced(Operation *op) override {
-      auto *opStmt = cast<OperationStmt>(op);
-      for (auto *result : opStmt->getResults())
-        // TODO: Add a result->getUsers() iterator.
-        for (auto &user : result->getUses()) {
-          if (auto *op = dyn_cast<OperationStmt>(user.getOwner()))
-            driver.addToWorklist(op);
-        }
-
-      // TODO: Walk the operand list dropping them as we go.  If any of them
-      // drop to zero uses, then add them to the worklist to allow them to be
-      // deleted as dead.
-    }
-
     void setInsertionPoint(Operation *op) override {
       // Any new operations should be added before this statement.
       builder.setInsertionPoint(cast<OperationStmt>(op));
@@ -316,23 +303,6 @@ static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) {
       return result;
     }
 
-    // When the root of a pattern is about to be replaced, it can trigger
-    // simplifications to its users - make sure to add them to the worklist
-    // before the root is changed.
-    void notifyRootReplaced(Operation *op) override {
-      auto *opStmt = cast<OperationInst>(op);
-      for (auto *result : opStmt->getResults())
-        // TODO: Add a result->getUsers() iterator.
-        for (auto &user : result->getUses()) {
-          if (auto *op = dyn_cast<OperationInst>(user.getOwner()))
-            driver.addToWorklist(op);
-        }
-
-      // TODO: Walk the operand list dropping them as we go.  If any of them
-      // drop to zero uses, then add them to the worklist to allow them to be
-      // deleted as dead.
-    }
-
     void setInsertionPoint(Operation *op) override {
       // Any new operations should be added before this instruction.
       builder.setInsertionPoint(cast<OperationInst>(op));