From 8a583dd22012294e253f735bf274628791fecb33 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Tue, 21 Feb 2023 10:10:15 +0530 Subject: [PATCH] [MLIR] Add replaceUsesWithIf on Operation Add replaceUsesWithIf on Operation along the lines of Value::replaceUsesWithIf. This had been missing on Operation and is convenient to replace multi-result operations' results conditionally. Reviewed By: lattner Differential Revision: https://reviews.llvm.org/D144348 --- mlir/include/mlir/IR/Operation.h | 9 +++++++++ mlir/include/mlir/IR/ValueRange.h | 20 ++++++++++++++++++++ mlir/lib/IR/OperationSupport.cpp | 5 +++++ mlir/lib/Transforms/CSE.cpp | 9 +++------ 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index e7ef01e..ac6bdfc 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -257,6 +257,15 @@ public: getResults().replaceAllUsesWith(std::forward(values)); } + /// Replace uses of results of this operation with the provided `values` if + /// the given callback returns true. + template + void replaceUsesWithIf(ValuesT &&values, + function_ref shouldReplace) { + getResults().replaceUsesWithIf(std::forward(values), + shouldReplace); + } + /// Destroys this operation and its subclass data. void destroy(); diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h index 8873260..0f7354b 100644 --- a/mlir/include/mlir/IR/ValueRange.h +++ b/mlir/include/mlir/IR/ValueRange.h @@ -279,6 +279,26 @@ public: /// Replace all uses of results of this range with results of 'op'. void replaceAllUsesWith(Operation *op); + /// Replace uses of results of this range with the provided 'values' if the + /// given callback returns true. The size of `values` must match the size of + /// this range. + template + std::enable_if_t::value> + replaceUsesWithIf(ValuesT &&values, + function_ref shouldReplace) { + assert(static_cast(std::distance(values.begin(), values.end())) == + size() && + "expected 'values' to correspond 1-1 with the number of results"); + + for (auto it : llvm::zip(*this, values)) + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), shouldReplace); + } + + /// Replace uses of results of this range with results of `op` if the given + /// callback returns true. + void replaceUsesWithIf(Operation *op, + function_ref shouldReplace); + //===--------------------------------------------------------------------===// // Users //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 20ce9b3..a38a12d 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -589,6 +589,11 @@ void ResultRange::replaceAllUsesWith(Operation *op) { replaceAllUsesWith(op->getResults()); } +void ResultRange::replaceUsesWithIf( + Operation *op, function_ref shouldReplace) { + replaceUsesWithIf(op->getResults(), shouldReplace); +} + //===----------------------------------------------------------------------===// // ValueRange diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 93e5c95c..e98cccc 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -124,12 +124,9 @@ void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, } else { // When the region does not have SSA dominance, we need to check if we // have visited a use before replacing any use. - for (auto it : llvm::zip(op->getResults(), existing->getResults())) { - std::get<0>(it).replaceUsesWithIf( - std::get<1>(it), [&](OpOperand &operand) { - return !knownValues.count(operand.getOwner()); - }); - } + op->replaceUsesWithIf(existing->getResults(), [&](OpOperand &operand) { + return !knownValues.count(operand.getOwner()); + }); // There may be some remaining uses of the operation. if (op->use_empty()) -- 2.7.4