protected:
/// Return a replacement payload op for the given op, which is going to be
/// replaced with the given values. By default, if all values are defined by
- /// the same newly-created op, which also has the same type as the given op,
- /// that defining op is used as a replacement.
+ /// the same op, which also has the same type as the given op, that defining
+ /// op is used as a replacement.
virtual Operation *findReplacementOp(Operation *op,
ValueRange newValues) const;
virtual void notifyPayloadReplacementNotFound(Operation *op,
ValueRange values) {}
- /// Return "true" if the given op is a new op.
- bool isNewOp(Operation *op) const;
-
/// Return the single op that defines all given values (if any).
static Operation *getCommonDefiningOp(ValueRange values);
private:
- void notifyOperationInserted(Operation *op) override;
-
void notifyOperationRemoved(Operation *op) override;
void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
- /// Ops that were newly created during the transform.
- DenseMap<OperationName, DenseSet<Operation *>> newOps;
-
/// The transform op in which this TrackingListener is used.
TransformOpInterface transformOp;
};
if (op->getName() != defOp->getName())
return nullptr;
- // If the replacement op is not a new op, drop the mapping.
- if (!isNewOp(defOp))
- return nullptr;
-
return defOp;
}
-bool transform::TrackingListener::isNewOp(Operation *op) const {
- auto it = newOps.find(op->getName());
- if (it == newOps.end())
- return false;
- return it->second.contains(op);
-}
-
LogicalResult transform::TrackingListener::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
return failure();
}
-void transform::TrackingListener::notifyOperationInserted(Operation *op) {
- newOps[op->getName()].insert(op);
-}
-
void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
// TODO: Walk can be removed when D144193 has landed.
op->walk([&](Operation *op) {
- // Keep set of new ops up-to-date.
- auto it = newOps.find(op->getName());
- if (it != newOps.end())
- it->second.erase(op);
// Remove mappings for result values.
for (OpResult value : op->getResults())
(void)replacePayloadValue(value, nullptr);