From 7d436d56b60b36508b94e39d08761f1405a9c770 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 12 May 2023 10:04:13 +0200 Subject: [PATCH] [mlir][transform] TrackingListener: Allow existing ops as replacements The TrackingListener was unnecessarily strict. Existing ops are now allowed when updating payload ops mappings due to `replaceOp` in the TrackingListener. Differential Revision: https://reviews.llvm.org/D150429 --- mlir/include/mlir/Dialect/Transform/IR/TransformOps.h | 12 ++---------- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 19 ------------------- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h index b6bc094..7a0f802 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -48,8 +48,8 @@ public: protected: /// Return a replacement payload op for the given op, which is going to be /// replaced with the given values. By default, if all values are defined by - /// the same newly-created op, which also has the same type as the given op, - /// that defining op is used as a replacement. + /// the same op, which also has the same type as the given op, that defining + /// op is used as a replacement. virtual Operation *findReplacementOp(Operation *op, ValueRange newValues) const; @@ -66,22 +66,14 @@ protected: virtual void notifyPayloadReplacementNotFound(Operation *op, ValueRange values) {} - /// Return "true" if the given op is a new op. - bool isNewOp(Operation *op) const; - /// Return the single op that defines all given values (if any). static Operation *getCommonDefiningOp(ValueRange values); private: - void notifyOperationInserted(Operation *op) override; - void notifyOperationRemoved(Operation *op) override; void notifyOperationReplaced(Operation *op, ValueRange newValues) override; - /// Ops that were newly created during the transform. - DenseMap> newOps; - /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; }; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 7451193..ecd5d2a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -180,20 +180,9 @@ transform::TrackingListener::findReplacementOp(Operation *op, if (op->getName() != defOp->getName()) return nullptr; - // If the replacement op is not a new op, drop the mapping. - if (!isNewOp(defOp)) - return nullptr; - return defOp; } -bool transform::TrackingListener::isNewOp(Operation *op) const { - auto it = newOps.find(op->getName()); - if (it == newOps.end()) - return false; - return it->second.contains(op); -} - LogicalResult transform::TrackingListener::notifyMatchFailure( Location loc, function_ref reasonCallback) { LLVM_DEBUG({ @@ -204,17 +193,9 @@ LogicalResult transform::TrackingListener::notifyMatchFailure( return failure(); } -void transform::TrackingListener::notifyOperationInserted(Operation *op) { - newOps[op->getName()].insert(op); -} - void transform::TrackingListener::notifyOperationRemoved(Operation *op) { // TODO: Walk can be removed when D144193 has landed. op->walk([&](Operation *op) { - // Keep set of new ops up-to-date. - auto it = newOps.find(op->getName()); - if (it != newOps.end()) - it->second.erase(op); // Remove mappings for result values. for (OpResult value : op->getResults()) (void)replacePayloadValue(value, nullptr); -- 2.7.4