NFC: Rename FoldHelper to OperationFolder and split a large function in two.
authorRiver Riddle <riverriddle@google.com>
Tue, 4 Jun 2019 18:56:43 +0000 (11:56 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:17:11 +0000 (16:17 -0700)
PiperOrigin-RevId: 251485843

mlir/include/mlir/Transforms/FoldUtils.h
mlir/lib/Transforms/TestConstantFold.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index 6264bd1..b6a43c5 100644 (file)
 namespace mlir {
 class Function;
 class Operation;
+class Value;
 
-/// A helper class for folding operations, and unifying duplicated constants
+/// A utility class for folding operations, and unifying duplicated constants
 /// generated along the way.
 ///
 /// To make sure constants properly dominate all their uses, constants are
 /// moved to the beginning of the entry block of the function when tracked by
 /// this class.
-class FoldHelper {
+class OperationFolder {
 public:
   /// Constructs an instance for managing constants in the given function `f`.
   /// Constants tracked by this instance will be moved to the entry block of
@@ -47,7 +48,7 @@ public:
   /// This instance does not proactively walk the operations inside `f`;
   /// instead, users must invoke the following methods to manually handle each
   /// operation of interest.
-  FoldHelper(Function *f);
+  OperationFolder(Function *f) : function(f) {}
 
   /// Tries to perform folding on the given `op`, including unifying
   /// deduplicated constants. If successful, calls `preReplaceAction` (if
@@ -59,13 +60,17 @@ public:
             std::function<void(Operation *)> preReplaceAction = {});
 
   /// Notifies that the given constant `op` should be remove from this
-  /// FoldHelper's internal bookkeeping.
+  /// OperationFolder's internal bookkeeping.
   ///
   /// Note: this method must be called if a constant op is to be deleted
-  /// externally to this FoldHelper. `op` must be a constant op.
+  /// externally to this OperationFolder. `op` must be a constant op.
   void notifyRemoval(Operation *op);
 
 private:
+  /// Tries to perform folding on the given `op`. If successful, populates
+  /// `results` with the results of the foldin.
+  LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results);
+
   /// Tries to deduplicate the given constant and returns success if that can be
   /// done. This moves the given constant to the top of the entry block if it
   /// is first seen. If there is already an existing constant that is the same,
index 1169607..f360441 100644 (file)
@@ -32,12 +32,12 @@ struct TestConstantFold : public FunctionPass<TestConstantFold> {
   // All constants in the function post folding.
   SmallVector<Operation *, 8> existingConstants;
 
-  void foldOperation(Operation *op, FoldHelper &helper);
+  void foldOperation(Operation *op, OperationFolder &helper);
   void runOnFunction() override;
 };
 } // end anonymous namespace
 
-void TestConstantFold::foldOperation(Operation *op, FoldHelper &helper) {
+void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
   // Attempt to fold the specified operation, including handling unused or
   // duplicated constants.
   if (succeeded(helper.tryToFold(op)))
@@ -56,7 +56,7 @@ void TestConstantFold::runOnFunction() {
   existingConstants.clear();
 
   auto &f = getFunction();
-  FoldHelper helper(&f);
+  OperationFolder helper(&f);
 
   // Collect and fold the operations within the function.
   SmallVector<Operation *, 8> ops;
index 578b822..fbf1a2a 100644 (file)
 
 using namespace mlir;
 
-FoldHelper::FoldHelper(Function *f) : function(f) {}
+//===----------------------------------------------------------------------===//
+// OperationFolder
+//===----------------------------------------------------------------------===//
 
 LogicalResult
-FoldHelper::tryToFold(Operation *op,
-                      std::function<void(Operation *)> preReplaceAction) {
+OperationFolder::tryToFold(Operation *op,
+                           std::function<void(Operation *)> preReplaceAction) {
   assert(op->getFunction() == function &&
          "cannot constant fold op from another function");
 
@@ -52,8 +54,37 @@ FoldHelper::tryToFold(Operation *op,
     return tryToUnify(op);
   }
 
+  // Try to fold the operation.
+  SmallVector<Value *, 8> results;
+  if (failed(tryToFold(op, results)))
+    return failure();
+
+  // Constant folding succeeded. We will start replacing this op's uses and
+  // eventually erase this op. Invoke the callback provided by the caller to
+  // perform any pre-replacement action.
+  if (preReplaceAction)
+    preReplaceAction(op);
+
+  // Check to see if the operation was just updated in place.
+  if (results.empty())
+    return success();
+
+  // Otherwise, replace all of the result values and erase the operation.
+  for (unsigned i = 0, e = results.size(); i != e; ++i)
+    op->getResult(i)->replaceAllUsesWith(results[i]);
+  op->erase();
+  return success();
+}
+
+/// Tries to perform folding on the given `op`. If successful, populates
+/// `results` with the results of the foldin.
+LogicalResult OperationFolder::tryToFold(Operation *op,
+                                         SmallVectorImpl<Value *> &results) {
+  assert(op->getFunction() == function &&
+         "cannot constant fold op from another function");
+
   SmallVector<Attribute, 8> operandConstants;
-  SmallVector<OpFoldResult, 8> results;
+  SmallVector<OpFoldResult, 8> foldResults;
 
   // Check to see if any operands to the operation is constant and whether
   // the operation knows how to constant fold itself.
@@ -70,38 +101,29 @@ FoldHelper::tryToFold(Operation *op,
   }
 
   // Attempt to constant fold the operation.
-  if (failed(op->fold(operandConstants, results)))
+  if (failed(op->fold(operandConstants, foldResults)))
     return failure();
 
-  // Constant folding succeeded. We will start replacing this op's uses and
-  // eventually erase this op. Invoke the callback provided by the caller to
-  // perform any pre-replacement action.
-  if (preReplaceAction)
-    preReplaceAction(op);
-
   // Check to see if the operation was just updated in place.
-  if (results.empty())
+  if (foldResults.empty())
     return success();
-  assert(results.size() == op->getNumResults());
+  assert(foldResults.size() == op->getNumResults());
 
   // Create the result constants and replace the results.
   FuncBuilder builder(op);
   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
-    auto *res = op->getResult(i);
-    if (res->use_empty()) // Ignore dead uses.
-      continue;
-    assert(!results[i].isNull() && "expected valid OpFoldResult");
+    assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
 
     // Check if the result was an SSA value.
-    if (auto *repl = results[i].dyn_cast<Value *>()) {
-      if (repl != res)
-        res->replaceAllUsesWith(repl);
+    if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
+      results.emplace_back(repl);
       continue;
     }
 
     // If we already have a canonicalized version of this constant, just reuse
-    // it.  Otherwise create a new one.
-    Attribute attrRepl = results[i].get<Attribute>();
+    // it. Otherwise create a new one.
+    Attribute attrRepl = foldResults[i].get<Attribute>();
+    auto *res = op->getResult(i);
     auto &constInst =
         uniquedConstants[std::make_pair(attrRepl, res->getType())];
     if (!constInst) {
@@ -113,14 +135,13 @@ FoldHelper::tryToFold(Operation *op,
       constInst = newOp.getOperation();
       moveConstantToEntryBlock(constInst);
     }
-    res->replaceAllUsesWith(constInst->getResult(0));
+    results.push_back(constInst->getResult(0));
   }
-  op->erase();
 
   return success();
 }
 
-void FoldHelper::notifyRemoval(Operation *op) {
+void OperationFolder::notifyRemoval(Operation *op) {
   assert(op->getFunction() == function &&
          "cannot remove constant from another function");
 
@@ -134,7 +155,7 @@ void FoldHelper::notifyRemoval(Operation *op) {
     uniquedConstants.erase(it);
 }
 
-LogicalResult FoldHelper::tryToUnify(Operation *op) {
+LogicalResult OperationFolder::tryToUnify(Operation *op) {
   Attribute constValue;
   matchPattern(op, m_Constant(&constValue));
   assert(constValue);
@@ -163,7 +184,7 @@ LogicalResult FoldHelper::tryToUnify(Operation *op) {
   return failure();
 }
 
-void FoldHelper::moveConstantToEntryBlock(Operation *op) {
+void OperationFolder::moveConstantToEntryBlock(Operation *op) {
   // Insert at the very top of the entry block.
   auto &entryBB = function->front();
   op->moveBefore(&entryBB, entryBB.begin());
index a2d2d03..a2e6427 100644 (file)
@@ -143,7 +143,7 @@ private:
 /// Perform the rewrites.
 bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
   Function *fn = getFunction();
-  FoldHelper helper(fn);
+  OperationFolder helper(fn);
 
   bool changed = false;
   int i = 0;
@@ -166,8 +166,8 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
       // If the operation has no side effects, and no users, then it is
       // trivially dead - remove it.
       if (op->hasNoSideEffect() && op->use_empty()) {
-        // Be careful to update bookkeeping in FoldHelper to keep consistency if
-        // this is a constant op.
+        // Be careful to update bookkeeping in OperationFolder to keep
+        // consistency if this is a constant op.
         helper.notifyRemoval(op);
         op->erase();
         continue;