Make OpBuilder::insert virtual instead of OpBuilder::createOperation.
authorRiver Riddle <riverriddle@google.com>
Thu, 12 Dec 2019 00:26:08 +0000 (16:26 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 12 Dec 2019 00:26:45 +0000 (16:26 -0800)
It is sometimes useful to create operations separately from the builder before insertion as it may be easier to erase them in isolation if necessary. One example use case for this is folding, as we will only want to insert newly generated constant operations on success. This has the added benefit of fixing some silent PatternRewriter failures related to cloning, as the OpBuilder 'clone' methods don't call createOperation.

PiperOrigin-RevId: 285086242

mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/IR/Builders.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index c5ed7b1..9c787c1 100644 (file)
@@ -281,6 +281,9 @@ public:
   /// Returns the current insertion point of the builder.
   Block::iterator getInsertionPoint() const { return insertPoint; }
 
+  /// Insert the given operation at the current insertion point and return it.
+  virtual Operation *insert(Operation *op);
+
   /// Add new block and set the insertion point to the end of it. The block is
   /// inserted at the provided insertion point of 'parent'.
   Block *createBlock(Region *parent, Region::iterator insertPt = {});
@@ -293,7 +296,7 @@ public:
   Block *getBlock() const { return block; }
 
   /// Creates an operation given the fields represented as an OperationState.
-  virtual Operation *createOperation(const OperationState &state);
+  Operation *createOperation(const OperationState &state);
 
   /// Create an operation of specific op type at the current insertion point.
   template <typename OpTy, typename... Args>
@@ -346,28 +349,21 @@ public:
   /// cloned sub-operations to the corresponding operation that is copied,
   /// and adds those mappings to the map.
   Operation *clone(Operation &op, BlockAndValueMapping &mapper) {
-    Operation *cloneOp = op.clone(mapper);
-    insert(cloneOp);
-    return cloneOp;
-  }
-  Operation *clone(Operation &op) {
-    Operation *cloneOp = op.clone();
-    insert(cloneOp);
-    return cloneOp;
+    return insert(op.clone(mapper));
   }
+  Operation *clone(Operation &op) { return insert(op.clone()); }
 
   /// Creates a deep copy of this operation but keep the operation regions
   /// empty. Operands are remapped using `mapper` (if present), and `mapper` is
   /// updated to contain the results.
   Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) {
-    Operation *cloneOp = op.cloneWithoutRegions(mapper);
-    insert(cloneOp);
-    return cloneOp;
+    return insert(op.cloneWithoutRegions(mapper));
   }
   Operation *cloneWithoutRegions(Operation &op) {
-    Operation *cloneOp = op.cloneWithoutRegions();
-    insert(cloneOp);
-    return cloneOp;
+    return insert(op.cloneWithoutRegions());
+  }
+  template <typename OpT> OpT cloneWithoutRegions(OpT op) {
+    return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
   }
 
 private:
@@ -375,9 +371,6 @@ private:
   /// 'results'.
   void tryFold(Operation *op, SmallVectorImpl<Value *> &results);
 
-  /// Insert the given operation at the current insertion point.
-  void insert(Operation *op);
-
   Block *block = nullptr;
   Block::iterator insertPoint;
 };
index 366d2b8..4805152 100644 (file)
@@ -302,9 +302,9 @@ public:
     return OpTy();
   }
 
-  /// This is implemented to create the specified operations and serves as a
+  /// This is implemented to insert the specified operation and serves as a
   /// notification hook for rewriters that want to know about new operations.
-  virtual Operation *createOperation(const OperationState &state) = 0;
+  virtual Operation *insert(Operation *op) = 0;
 
   /// Move the blocks that belong to "region" before the given position in
   /// another region "parent". The two regions must be different. The caller
index fee58a4..249b4c1 100644 (file)
@@ -332,12 +332,6 @@ public:
   /// Replace all the uses of the block argument `from` with value `to`.
   void replaceUsesOfBlockArgument(BlockArgument *from, Value *to);
 
-  /// Clone the given operation without cloning its regions.
-  Operation *cloneWithoutRegions(Operation *op);
-  template <typename OpT> OpT cloneWithoutRegions(OpT op) {
-    return cast<OpT>(cloneWithoutRegions(op.getOperation()));
-  }
-
   /// Return the converted value that replaces 'key'. Return 'key' if there is
   /// no such a converted value.
   Value *getRemappedValue(Value *key);
@@ -376,8 +370,8 @@ public:
                          BlockAndValueMapping &mapping) override;
   using PatternRewriter::cloneRegionBefore;
 
-  /// PatternRewriter hook for creating a new operation.
-  Operation *createOperation(const OperationState &state) override;
+  /// PatternRewriter hook for inserting a new operation.
+  Operation *insert(Operation *op) override;
 
   /// PatternRewriter hook for updating the root operation in-place.
   void notifyRootUpdated(Operation *op) override;
index 4d6cd35..8c54df4 100644 (file)
@@ -306,6 +306,13 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
 
 OpBuilder::~OpBuilder() {}
 
+/// Insert the given operation at the current insertion point and return it.
+Operation *OpBuilder::insert(Operation *op) {
+  if (block)
+    block->getOperations().insert(insertPoint, op);
+  return op;
+}
+
 /// Add new block and set the insertion point to the end of it. The block is
 /// inserted at the provided insertion point of 'parent'.
 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
@@ -328,10 +335,7 @@ Block *OpBuilder::createBlock(Block *insertBefore) {
 
 /// Create an operation given the fields represented as an OperationState.
 Operation *OpBuilder::createOperation(const OperationState &state) {
-  assert(block && "createOperation() called without setting builder's block");
-  auto *op = Operation::create(state);
-  insert(op);
-  return op;
+  return insert(Operation::create(state));
 }
 
 /// Attempts to fold the given operation and places new results within
@@ -359,9 +363,3 @@ void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
                   [](OpFoldResult result) { return result.get<Value *>(); });
   op->erase();
 }
-
-/// Insert the given operation at the current insertion point.
-void OpBuilder::insert(Operation *op) {
-  if (block)
-    block->getOperations().insert(insertPoint, op);
-}
index 6d34db9..ea4ad68 100644 (file)
@@ -802,13 +802,6 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument *from,
   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
 }
 
-/// Clone the given operation without cloning its regions.
-Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
-  Operation *newOp = OpBuilder::cloneWithoutRegions(*op);
-  impl->createdOps.push_back(newOp);
-  return newOp;
-}
-
 /// Return the converted value that replaces 'key'. Return 'key' if there is
 /// no such a converted value.
 Value *ConversionPatternRewriter::getRemappedValue(Value *key) {
@@ -854,12 +847,11 @@ void ConversionPatternRewriter::cloneRegionBefore(
 }
 
 /// PatternRewriter hook for creating a new operation.
-Operation *
-ConversionPatternRewriter::createOperation(const OperationState &state) {
-  LLVM_DEBUG(llvm::dbgs() << "** Creating operation : " << state.name << "\n");
-  auto *result = OpBuilder::createOperation(state);
-  impl->createdOps.push_back(result);
-  return result;
+Operation *ConversionPatternRewriter::insert(Operation *op) {
+  LLVM_DEBUG(llvm::dbgs() << "** Inserting operation : " << op->getName()
+                          << "\n");
+  impl->createdOps.push_back(op);
+  return OpBuilder::insert(op);
 }
 
 /// PatternRewriter hook for updating the root operation in-place.
index aa4563c..e2ca3f8 100644 (file)
@@ -86,12 +86,11 @@ public:
 
   // These are hooks implemented for PatternRewriter.
 protected:
-  // Implement the hook for creating operations, and make sure that newly
-  // created ops are added to the worklist for processing.
-  Operation *createOperation(const OperationState &state) override {
-    auto *result = OpBuilder::createOperation(state);
-    addToWorklist(result);
-    return result;
+  // Implement the hook for inserting operations, and make sure that newly
+  // inserted ops are added to the worklist for processing.
+  Operation *insert(Operation *op) override {
+    addToWorklist(op);
+    return OpBuilder::insert(op);
   }
 
   // If an operation is about to be removed, make sure it is not in our