[mlir][transform] TrackingListener: Allow existing ops as replacements
authorMatthias Springer <me@m-sp.org>
Fri, 12 May 2023 08:04:13 +0000 (10:04 +0200)
committerMatthias Springer <me@m-sp.org>
Fri, 12 May 2023 13:07:20 +0000 (15:07 +0200)
The TrackingListener was unnecessarily strict. Existing ops are now allowed when updating payload ops mappings due to `replaceOp` in the TrackingListener.

Differential Revision: https://reviews.llvm.org/D150429

mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/lib/Dialect/Transform/IR/TransformOps.cpp

index b6bc094..7a0f802 100644 (file)
@@ -48,8 +48,8 @@ public:
 protected:
   /// Return a replacement payload op for the given op, which is going to be
   /// replaced with the given values. By default, if all values are defined by
-  /// the same newly-created op, which also has the same type as the given op,
-  /// that defining op is used as a replacement.
+  /// the same op, which also has the same type as the given op, that defining
+  /// op is used as a replacement.
   virtual Operation *findReplacementOp(Operation *op,
                                        ValueRange newValues) const;
 
@@ -66,22 +66,14 @@ protected:
   virtual void notifyPayloadReplacementNotFound(Operation *op,
                                                 ValueRange values) {}
 
-  /// Return "true" if the given op is a new op.
-  bool isNewOp(Operation *op) const;
-
   /// Return the single op that defines all given values (if any).
   static Operation *getCommonDefiningOp(ValueRange values);
 
 private:
-  void notifyOperationInserted(Operation *op) override;
-
   void notifyOperationRemoved(Operation *op) override;
 
   void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
 
-  /// Ops that were newly created during the transform.
-  DenseMap<OperationName, DenseSet<Operation *>> newOps;
-
   /// The transform op in which this TrackingListener is used.
   TransformOpInterface transformOp;
 };
index 7451193..ecd5d2a 100644 (file)
@@ -180,20 +180,9 @@ transform::TrackingListener::findReplacementOp(Operation *op,
   if (op->getName() != defOp->getName())
     return nullptr;
 
-  // If the replacement op is not a new op, drop the mapping.
-  if (!isNewOp(defOp))
-    return nullptr;
-
   return defOp;
 }
 
-bool transform::TrackingListener::isNewOp(Operation *op) const {
-  auto it = newOps.find(op->getName());
-  if (it == newOps.end())
-    return false;
-  return it->second.contains(op);
-}
-
 LogicalResult transform::TrackingListener::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
@@ -204,17 +193,9 @@ LogicalResult transform::TrackingListener::notifyMatchFailure(
   return failure();
 }
 
-void transform::TrackingListener::notifyOperationInserted(Operation *op) {
-  newOps[op->getName()].insert(op);
-}
-
 void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
   // TODO: Walk can be removed when D144193 has landed.
   op->walk([&](Operation *op) {
-    // Keep set of new ops up-to-date.
-    auto it = newOps.find(op->getName());
-    if (it != newOps.end())
-      it->second.erase(op);
     // Remove mappings for result values.
     for (OpResult value : op->getResults())
       (void)replacePayloadValue(value, nullptr);