Deduplicate constant folding logic in ConstantFold and GreedyPatternRewriteDriver
authorLei Zhang <antiagainst@google.com>
Thu, 4 Apr 2019 18:40:57 +0000 (11:40 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 5 Apr 2019 14:41:32 +0000 (07:41 -0700)
    There are two places containing constant folding logic right now: the ConstantFold
    pass and the GreedyPatternRewriteDriver. The logic was not shared and started to
    drift apart. We were testing constant folding logic using the ConstantFold pass,
    but lagged behind the GreedyPatternRewriteDriver, where we really want the constant
    folding to happen.

    This CL pulled the logic into utility functions and classes for sharing between
    these two places. A new ConstantFoldHelper class is created to help constant fold
    and de-duplication.

    Also, renamed the ConstantFold pass to TestConstantFold to make it clear that it is
    intended for testing purpose.

--

PiperOrigin-RevId: 241971681

mlir/include/mlir/IR/Matchers.h
mlir/include/mlir/Transforms/ConstantFoldUtils.h [new file with mode: 0644]
mlir/include/mlir/Transforms/Passes.h
mlir/lib/ExecutionEngine/ExecutionEngine.cpp
mlir/lib/Transforms/ConstantFold.cpp [deleted file]
mlir/lib/Transforms/TestConstantFold.cpp [new file with mode: 0644]
mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp [new file with mode: 0644]
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Transforms/constant-fold.mlir

index 2ac334b..1b8b982 100644 (file)
@@ -60,8 +60,8 @@ struct attr_value_binder {
   }
 };
 
-/// The matcher that matches a constant foldable operation that has no operands
-/// and produces a single result.
+/// The matcher that matches a constant foldable operation that has no side
+/// effect, no operands and produces a single result.
 struct constant_op_binder {
   Attribute *bind_value;
 
@@ -72,6 +72,9 @@ struct constant_op_binder {
   bool match(Operation *op) {
     if (op->getNumOperands() > 0 || op->getNumResults() != 1)
       return false;
+    if (!op->hasNoSideEffect())
+      return false;
+
     SmallVector<Attribute, 1> foldedAttr;
     if (succeeded(op->constantFold(/*operands=*/llvm::None, foldedAttr))) {
       *bind_value = foldedAttr.front();
@@ -134,6 +137,12 @@ inline bool matchPattern(Value *value, const Pattern &pattern) {
   return false;
 }
 
+/// Entry point for matching a pattern over an Operation.
+template <typename Pattern>
+inline bool matchPattern(Operation *op, const Pattern &pattern) {
+  return const_cast<Pattern &>(pattern).match(op);
+}
+
 /// Matches a constant holding a scalar/vector/tensor integer (splat) and
 /// writes the integer value to bind_value.
 inline detail::constant_int_op_binder
diff --git a/mlir/include/mlir/Transforms/ConstantFoldUtils.h b/mlir/include/mlir/Transforms/ConstantFoldUtils.h
new file mode 100644 (file)
index 0000000..325f6be
--- /dev/null
@@ -0,0 +1,94 @@
+//===- ConstantFoldUtils.h - Constant Fold Utilities ------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file declares various constant fold utilities. These utilities
+// are intended to be used by passes to unify and simply their logic.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_CONSTANT_UTILS_H
+#define MLIR_TRANSFORMS_CONSTANT_UTILS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+
+namespace mlir {
+class Function;
+class Operation;
+
+/// A helper class for constant folding operations, and unifying duplicated
+/// constants along the way.
+///
+/// To make sure constants' proper dominance of all their uses, constants are
+/// moved to the beginning of the entry block of the function when tracked by
+/// this class.
+class ConstantFoldHelper {
+public:
+  /// Constructs an instance for managing constants in the given function `f`.
+  /// Constants tracked by this instance will be moved to the entry block of
+  /// `f`. If `insertAtHead` is true, the insertion always happen at the very
+  /// top of the entry block; otherwise, the insertion happens after the last
+  /// one of consecutive constant ops at the beginning of the entry block.
+  ///
+  /// This instance does not proactively walk the operations inside `f`;
+  /// instead, users must invoke the following methods to manually handle each
+  /// operation of interest.
+  ConstantFoldHelper(Function *f, bool insertAtHead = true);
+
+  /// Tries to perform constant folding on the given `op`, including unifying
+  /// deplicated constants. If successful, calls `preReplaceAction` (if
+  /// provided) by passing in `op`, then replaces `op`'s uses with folded
+  /// constants, and returns true.
+  ///
+  /// Note: `op` will *not* be erased to avoid invalidating potential walkers in
+  /// the caller.
+  bool
+  tryToConstantFold(Operation *op,
+                    std::function<void(Operation *)> preReplaceAction = {});
+
+  /// Notifies that the given constant `op` should be remove from this
+  /// ConstantFoldHelper's internal bookkeeping.
+  ///
+  /// Note: this method must be called if a constant op is to be deleted
+  /// externally to this ConstantFoldHelper. `op` must be a constant op.
+  void notifyRemoval(Operation *op);
+
+private:
+  /// Tries to deduplicate the given constant and returns true if that can be
+  /// done. This moves the given constant to the top of the entry block if it
+  /// is first seen. If there is already an existing constant that is the same,
+  /// this does *not* erases the given constant.
+  bool tryToUnify(Operation *op);
+
+  /// Moves the given constant `op` to entry block to guarantee dominance.
+  void moveConstantToEntryBlock(Operation *op);
+
+  /// The function where we are managing constant.
+  Function *function;
+
+  /// Whether to always insert constants at the very top of the entry block.
+  bool isInsertAtHead;
+
+  /// This map keeps track of uniqued constants.
+  DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_CONSTANT_UTILS_H
index bbfd43a..845e398 100644 (file)
@@ -33,8 +33,11 @@ class AffineForOp;
 class FunctionPassBase;
 class ModulePassBase;
 
-/// Creates a constant folding pass.
-FunctionPassBase *createConstantFoldPass();
+/// Creates a constant folding pass. Note that this pass solely provides simple
+/// top-down constant folding functionality; it is intended to be used for
+/// testing purpose. Use Canonicalizer pass, which exploits more simplification
+/// opportunties exposed by constant folding, for the general cases.
+FunctionPassBase *createTestConstantFoldPass();
 
 /// Creates an instance of the Canonicalizer pass.
 FunctionPassBase *createCanonicalizerPass();
index 7e1241d..e63a3e8 100644 (file)
@@ -179,7 +179,7 @@ static void getDefaultPasses(
     passEntry->addToPipeline(manager);
 
   // Append the extra passes for lowering to MLIR.
-  manager.addPass(mlir::createConstantFoldPass());
+  manager.addPass(mlir::createCanonicalizerPass());
   manager.addPass(mlir::createCSEPass());
   manager.addPass(mlir::createCanonicalizerPass());
   manager.addPass(mlir::createLowerAffinePass());
diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp
deleted file mode 100644 (file)
index 364c3dc..0000000
+++ /dev/null
@@ -1,123 +0,0 @@
-//===- ConstantFold.cpp - Pass that does constant folding -----------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/AffineOps/AffineOps.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/StandardOps/Ops.h"
-#include "mlir/Transforms/Passes.h"
-#include "mlir/Transforms/Utils.h"
-
-using namespace mlir;
-
-namespace {
-/// Simple constant folding pass.
-struct ConstantFold : public FunctionPass<ConstantFold> {
-  // All constants in the function post folding.
-  SmallVector<Value *, 8> existingConstants;
-  // Operations that were folded and that need to be erased.
-  std::vector<Operation *> opInstsToErase;
-
-  void foldOperation(Operation *op);
-  void runOnFunction() override;
-};
-} // end anonymous namespace
-
-/// Attempt to fold the specified operation, updating the IR to match.  If
-/// constants are found, we keep track of them in the existingConstants list.
-///
-void ConstantFold::foldOperation(Operation *op) {
-  // If this operation is already a constant, just remember it for cleanup
-  // later, and don't try to fold it.
-  if (auto constant = op->dyn_cast<ConstantOp>()) {
-    existingConstants.push_back(constant);
-    return;
-  }
-
-  // Get values for operands that are trivial constants. nullptr is used as
-  // placeholder for non-constant operands.
-  SmallVector<Attribute, 8> operandConstants;
-  for (auto *operand : op->getOperands()) {
-    Attribute operandCst = nullptr;
-    if (auto *operandOp = operand->getDefiningOp()) {
-      if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
-        operandCst = operandConstantOp.getValue();
-    }
-    operandConstants.push_back(operandCst);
-  }
-
-  // Attempt to constant fold the operation.
-  SmallVector<Attribute, 8> resultConstants;
-  if (failed(op->constantFold(operandConstants, resultConstants)))
-    return;
-
-  // Ok, if everything succeeded, then we can create constants corresponding
-  // to the result of the call.
-  // TODO: We can try to reuse existing constants if we see them laying
-  // around.
-  assert(resultConstants.size() == op->getNumResults() &&
-         "constant folding produced the wrong number of results");
-
-  FuncBuilder builder(op);
-  for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
-    auto *res = op->getResult(i);
-    if (res->use_empty()) // ignore dead uses.
-      continue;
-
-    auto cst = builder.create<ConstantOp>(op->getLoc(), res->getType(),
-                                          resultConstants[i]);
-    existingConstants.push_back(cst);
-    res->replaceAllUsesWith(cst);
-  }
-
-  // At this point the operation is dead, so we can remove it.  We add it to
-  // a vector to avoid invalidating our walker.
-  opInstsToErase.push_back(op);
-}
-
-// For now, we do a simple top-down pass over a function folding constants.  We
-// don't handle conditional control flow, block arguments, folding
-// conditional branches, or anything else fancy.
-void ConstantFold::runOnFunction() {
-  existingConstants.clear();
-  opInstsToErase.clear();
-
-  getFunction().walk([&](Operation *op) { foldOperation(op); });
-
-  // At this point, these operations are dead, remove them.
-  // TODO: This is assuming that all constant foldable operations have no
-  // side effects.  When we have side effect modeling, we should verify that
-  // the operation is effect-free before we remove it.  Until then this is
-  // close enough.
-  for (auto *op : opInstsToErase) {
-    op->erase();
-  }
-
-  // By the time we are done, we may have simplified a bunch of code, leaving
-  // around dead constants.  Check for them now and remove them.
-  for (auto *cst : existingConstants) {
-    if (cst->use_empty())
-      cst->getDefiningOp()->erase();
-  }
-}
-
-/// Creates a constant folding pass.
-FunctionPassBase *mlir::createConstantFoldPass() { return new ConstantFold(); }
-
-static PassRegistration<ConstantFold>
-    pass("constant-fold", "Constant fold operations in functions");
diff --git a/mlir/lib/Transforms/TestConstantFold.cpp b/mlir/lib/Transforms/TestConstantFold.cpp
new file mode 100644 (file)
index 0000000..60407cd
--- /dev/null
@@ -0,0 +1,89 @@
+//===- TestConstantFold.cpp - Pass to test constant folding ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/ConstantFoldUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+
+using namespace mlir;
+
+namespace {
+/// Simple constant folding pass.
+struct TestConstantFold : public FunctionPass<TestConstantFold> {
+  // All constants in the function post folding.
+  SmallVector<Operation *, 8> existingConstants;
+  // Operations that were folded and that need to be erased.
+  std::vector<Operation *> opsToErase;
+
+  void foldOperation(Operation *op, ConstantFoldHelper &helper);
+  void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void TestConstantFold::foldOperation(Operation *op,
+                                     ConstantFoldHelper &helper) {
+  // Attempt to fold the specified operation, including handling unused or
+  // duplicated constants.
+  if (helper.tryToConstantFold(op)) {
+    opsToErase.push_back(op);
+  }
+  // If this op is a constant that are used and cannot be de-duplicated,
+  // remember it for cleanup later.
+  else if (auto constant = op->dyn_cast<ConstantOp>()) {
+    existingConstants.push_back(op);
+  }
+}
+
+// For now, we do a simple top-down pass over a function folding constants.  We
+// don't handle conditional control flow, block arguments, folding conditional
+// branches, or anything else fancy.
+void TestConstantFold::runOnFunction() {
+  existingConstants.clear();
+  opsToErase.clear();
+
+  auto &f = getFunction();
+
+  ConstantFoldHelper helper(&f, /*insertAtHead=*/false);
+
+  f.walk([&](Operation *op) { foldOperation(op, helper); });
+
+  // At this point, these operations are dead, remove them.
+  for (auto *op : opsToErase) {
+    assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
+    op->erase();
+  }
+
+  // By the time we are done, we may have simplified a bunch of code, leaving
+  // around dead constants.  Check for them now and remove them.
+  for (auto *cst : existingConstants) {
+    if (cst->use_empty())
+      cst->erase();
+  }
+}
+
+/// Creates a constant folding pass.
+FunctionPassBase *mlir::createTestConstantFoldPass() {
+  return new TestConstantFold();
+}
+
+static PassRegistration<TestConstantFold>
+    pass("test-constant-fold", "Test operation constant folding");
diff --git a/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp b/mlir/lib/Transforms/Utils/ConstantFoldUtils.cpp
new file mode 100644 (file)
index 0000000..5908ec2
--- /dev/null
@@ -0,0 +1,163 @@
+//===- ConstantFoldUtils.cpp ---- Constant Fold Utilities -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines various constant fold utilities. These utilities are
+// intended to be used by passes to unify and simply their logic.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/ConstantFoldUtils.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+
+ConstantFoldHelper::ConstantFoldHelper(Function *f, bool insertAtHead)
+    : function(f), isInsertAtHead(insertAtHead) {}
+
+bool ConstantFoldHelper::tryToConstantFold(
+    Operation *op, std::function<void(Operation *)> preReplaceAction) {
+  assert(op->getFunction() == function &&
+         "cannot constant fold op from another function");
+
+  // The constant op also implements the constant fold hook; it can be folded
+  // into the value it contains. We need to consider constants before the
+  // constant folding logic to avoid re-creating the same constant later.
+  // TODO: Extend to support dialect-specific constant ops.
+  if (auto constant = op->dyn_cast<ConstantOp>()) {
+    // If this constant is dead, update bookkeeping and signal the caller.
+    if (constant.use_empty()) {
+      notifyRemoval(op);
+      return true;
+    }
+    // Otherwise, try to see if we can de-duplicate it.
+    return tryToUnify(op);
+  }
+
+  SmallVector<Attribute, 8> operandConstants, resultConstants;
+
+  // Check to see if any operands to the operation is constant and whether
+  // the operation knows how to constant fold itself.
+  operandConstants.assign(op->getNumOperands(), Attribute());
+  for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+    matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
+
+  // If this is a commutative binary operation with a constant on the left
+  // side move it to the right side.
+  if (operandConstants.size() == 2 && operandConstants[0] &&
+      !operandConstants[1] && op->isCommutative()) {
+    std::swap(op->getOpOperand(0), op->getOpOperand(1));
+    std::swap(operandConstants[0], operandConstants[1]);
+  }
+
+  // Attempt to constant fold the operation.
+  if (failed(op->constantFold(operandConstants, resultConstants)))
+    return false;
+
+  // Constant folding succeeded. We will start replacing this op's uses and
+  // eventually erase this op. Invoke the callback provided by the caller to
+  // perform any pre-replacement action.
+  if (preReplaceAction)
+    preReplaceAction(op);
+
+  // Create the result constants and replace the results.
+  FuncBuilder builder(op);
+  for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
+    auto *res = op->getResult(i);
+    if (res->use_empty()) // Ignore dead uses.
+      continue;
+
+    // If we already have a canonicalized version of this constant, just reuse
+    // it.  Otherwise create a new one.
+    auto &constInst =
+        uniquedConstants[std::make_pair(resultConstants[i], res->getType())];
+    if (!constInst) {
+      // TODO: Extend to support dialect-specific constant ops.
+      auto newOp = builder.create<ConstantOp>(op->getLoc(), res->getType(),
+                                              resultConstants[i]);
+      // Register to the constant map and also move up to entry block to
+      // guarantee dominance.
+      constInst = newOp.getOperation();
+      moveConstantToEntryBlock(constInst);
+    }
+    res->replaceAllUsesWith(constInst->getResult(0));
+  }
+
+  return true;
+}
+
+void ConstantFoldHelper::notifyRemoval(Operation *op) {
+  assert(op->getFunction() == function &&
+         "cannot remove constant from another function");
+
+  Attribute constValue;
+  matchPattern(op, m_Constant(&constValue));
+  assert(constValue);
+
+  // This constant is dead. keep uniquedConstants up to date.
+  auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()});
+  if (it != uniquedConstants.end() && it->second == op)
+    uniquedConstants.erase(it);
+}
+
+bool ConstantFoldHelper::tryToUnify(Operation *op) {
+  Attribute constValue;
+  matchPattern(op, m_Constant(&constValue));
+  assert(constValue);
+
+  // Check to see if we already have a constant with this type and value:
+  auto &constInst =
+      uniquedConstants[std::make_pair(constValue, op->getResult(0)->getType())];
+  if (constInst) {
+    // If this constant is already our uniqued one, then leave it alone.
+    if (constInst == op)
+      return false;
+
+    // Otherwise replace this redundant constant with the uniqued one.  We know
+    // this is safe because we move constants to the top of the function when
+    // they are uniqued, so we know they dominate all uses.
+    op->getResult(0)->replaceAllUsesWith(constInst->getResult(0));
+    return true;
+  }
+
+  // If we have no entry, then we should unique this constant as the
+  // canonical version.  To ensure safe dominance, move the operation to the
+  // entry block of the function.
+  constInst = op;
+  moveConstantToEntryBlock(op);
+  return false;
+}
+
+void ConstantFoldHelper::moveConstantToEntryBlock(Operation *op) {
+  auto &entryBB = function->front();
+  if (isInsertAtHead || entryBB.empty()) {
+    // Insert at the very top of the entry block.
+    op->moveBefore(&entryBB, entryBB.begin());
+  } else {
+    // TODO: This is only used by TestConstantFold and not very clean. We should
+    // figure out a better way to work around this.
+
+    // Move to be ahead of the first non-constant op.
+    auto it = entryBB.begin();
+    while (it != entryBB.end() && it->isa<ConstantOp>())
+      ++it;
+    op->moveBefore(&entryBB, it);
+  }
+}
index 3505ed1..dae0bf4 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/StandardOps/Ops.h"
+#include "mlir/Transforms/ConstantFoldUtils.h"
 #include "llvm/ADT/DenseMap.h"
+
 using namespace mlir;
 
 namespace {
@@ -133,18 +134,14 @@ private:
   /// the function, even if they aren't the root of a pattern.
   std::vector<Operation *> worklist;
   DenseMap<Operation *, unsigned> worklistMap;
-
-  /// As part of canonicalization, we move constants to the top of the entry
-  /// block of the current function and de-duplicate them.  This keeps track of
-  /// constants we have done this for.
-  DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
 };
 }; // end anonymous namespace
 
 /// Perform the rewrites.
 void GreedyPatternRewriteDriver::simplifyFunction() {
-  // These are scratch vectors used in the constant folding loop below.
-  SmallVector<Attribute, 8> operandConstants, resultConstants;
+  ConstantFoldHelper helper(builder.getFunction());
+
+  // These are scratch vectors used in the folding loop below.
   SmallVector<Value *, 8> originalOperands, resultValues;
 
   while (!worklist.empty()) {
@@ -154,99 +151,34 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
     if (op == nullptr)
       continue;
 
-    // If we have a constant op, unique it into the entry block.
-    if (auto constant = op->dyn_cast<ConstantOp>()) {
-      // If this constant is dead, remove it, being careful to keep
-      // uniquedConstants up to date.
-      if (constant.use_empty()) {
-        auto it =
-            uniquedConstants.find({constant.getValue(), constant.getType()});
-        if (it != uniquedConstants.end() && it->second == op)
-          uniquedConstants.erase(it);
-        constant.erase();
-        continue;
-      }
-
-      // Check to see if we already have a constant with this type and value:
-      auto &entry = uniquedConstants[std::make_pair(constant.getValue(),
-                                                    constant.getType())];
-      if (entry) {
-        // If this constant is already our uniqued one, then leave it alone.
-        if (entry == op)
-          continue;
-
-        // Otherwise replace this redundant constant with the uniqued one.  We
-        // know this is safe because we move constants to the top of the
-        // function when they are uniqued, so we know they dominate all uses.
-        constant.replaceAllUsesWith(entry->getResult(0));
-        constant.erase();
-        continue;
-      }
-
-      // If we have no entry, then we should unique this constant as the
-      // canonical version.  To ensure safe dominance, move the operation to the
-      // top of the function.
-      entry = op;
-      auto &entryBB = builder.getInsertionBlock()->getFunction()->front();
-      op->moveBefore(&entryBB, entryBB.begin());
-      continue;
-    }
-
     // If the operation has no side effects, and no users, then it is trivially
     // dead - remove it.
     if (op->hasNoSideEffect() && op->use_empty()) {
+      // Be careful to update bookkeeping in ConstantHelper to keep consistency
+      // if this is a constant op.
+      if (op->isa<ConstantOp>())
+        helper.notifyRemoval(op);
       op->erase();
       continue;
     }
 
-    // Check to see if any operands to the operation is constant and whether
-    // the operation knows how to constant fold itself.
-    operandConstants.assign(op->getNumOperands(), Attribute());
-    for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
-      matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
-
-    // If this is a commutative binary operation with a constant on the left
-    // side move it to the right side.
-    if (operandConstants.size() == 2 && operandConstants[0] &&
-        !operandConstants[1] && op->isCommutative()) {
-      std::swap(op->getOpOperand(0), op->getOpOperand(1));
-      std::swap(operandConstants[0], operandConstants[1]);
-    }
-
-    // If constant folding was successful, create the result constants, RAUW the
-    // operation and remove it.
-    resultConstants.clear();
-    if (succeeded(op->constantFold(operandConstants, resultConstants))) {
-      builder.setInsertionPoint(op);
-
+    // Collects all the operands and result uses of the given `op` into work
+    // list.
+    auto collectOperandsAndUses = [this](Operation *op) {
       // Add the operands to the worklist for visitation.
       addToWorklist(op->getOperands());
-
+      // Add all the users of the result to the worklist so we make sure
+      // to revisit them.
+      //
+      // TODO: Add a result->getUsers() iterator.
       for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
-        auto *res = op->getResult(i);
-        if (res->use_empty()) // ignore dead uses.
-          continue;
-
-        // If we already have a canonicalized version of this constant, just
-        // reuse it.  Otherwise create a new one.
-        Value *cstValue;
-        auto it = uniquedConstants.find({resultConstants[i], res->getType()});
-        if (it != uniquedConstants.end())
-          cstValue = it->second->getResult(0);
-        else
-          cstValue = create<ConstantOp>(op->getLoc(), res->getType(),
-                                        resultConstants[i]);
-
-        // Add all the users of the result to the worklist so we make sure to
-        // revisit them.
-        //
-        // TODO: Add a result->getUsers() iterator.
         for (auto &operand : op->getResult(i)->getUses())
           addToWorklist(operand.getOwner());
-
-        res->replaceAllUsesWith(cstValue);
       }
+    };
 
+    // Try to constant fold this op.
+    if (helper.tryToConstantFold(op, collectOperandsAndUses)) {
       assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
       op->erase();
       continue;
@@ -292,8 +224,6 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
     // to do here.
     matcher.matchAndRewrite(op);
   }
-
-  uniquedConstants.clear();
 }
 
 /// Rewrite the specified function by repeatedly applying the highest benefit
index b40daa1..257323b 100644 (file)
@@ -1,13 +1,13 @@
-// RUN: mlir-opt %s -constant-fold | FileCheck %s
+// RUN: mlir-opt %s -test-constant-fold | FileCheck %s
 
 // CHECK-LABEL: @test(%arg0: memref<f32>) {
 func @test(%p : memref<f32>) {
+  // CHECK: %cst = constant 6.000000e+00 : f32
   affine.for %i0 = 0 to 128 {
     affine.for %i1 = 0 to 8 { // CHECK: affine.for %i1 = 0 to 8 {
       %0 = constant 4.5 : f32
       %1 = constant 1.5 : f32
 
-      // CHECK-NEXT: %cst = constant 6.000000e+00 : f32
       %2 = addf %0, %1 : f32
 
       // CHECK-NEXT: store %cst, %arg0[]
@@ -201,12 +201,11 @@ func @simple_remis(%a : i32) -> (i32, i32, i32) {
 
   // CHECK-NEXT: %c1_i32 = constant 1 : i32
   %4 = remis %0, %1 : i32
-  // CHECK-NEXT: %c1_i32_0 = constant 1 : i32
   %5 = remis %0, %3 : i32
   // CHECK-NEXT: %c0_i32 = constant 0 : i32
   %6 = remis %a, %2 : i32
 
-  // CHECK-NEXT: return %c1_i32, %c1_i32_0, %c0_i32 : i32, i32, i32
+  // CHECK-NEXT: return %c1_i32, %c1_i32, %c0_i32 : i32, i32, i32
   return %4, %5, %6 : i32, i32, i32
 }
 
@@ -217,11 +216,11 @@ func @simple_remiu(%a : i32) -> (i32, i32, i32) {
   %2 = constant 1 : i32
   %3 = constant -2 : i32
 
-  // CHECK-NEXT: %c1_i32 = constant 1 : i32
+  // CHECK-DAG: %c1_i32 = constant 1 : i32
   %4 = remiu %0, %1 : i32
-  // CHECK-NEXT: %c5_i32 = constant 5 : i32
+  // CHECK-DAG: %c5_i32 = constant 5 : i32
   %5 = remiu %0, %3 : i32
-  // CHECK-NEXT: %c0_i32 = constant 0 : i32
+  // CHECK-DAG: %c0_i32 = constant 0 : i32
   %6 = remiu %a, %2 : i32
 
   // CHECK-NEXT: return %c1_i32, %c5_i32, %c0_i32 : i32, i32, i32
@@ -267,24 +266,26 @@ func @cmpi() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
   %c42 = constant 42 : i32
   %cm1 = constant -1 : i32
 // CHECK-NEXT: %false = constant 0 : i1
-  %0 = cmpi "eq", %c42, %cm1 : i32
 // CHECK-NEXT: %true = constant 1 : i1
+// CHECK-NEXT: return %false,
+  %0 = cmpi "eq", %c42, %cm1 : i32
+// CHECK-SAME: %true,
   %1 = cmpi "ne", %c42, %cm1 : i32
-// CHECK-NEXT: %false_0 = constant 0 : i1
+// CHECK-SAME: %false,
   %2 = cmpi "slt", %c42, %cm1 : i32
-// CHECK-NEXT: %false_1 = constant 0 : i1
+// CHECK-SAME: %false,
   %3 = cmpi "sle", %c42, %cm1 : i32
-// CHECK-NEXT: %true_2 = constant 1 : i1
+// CHECK-SAME: %true,
   %4 = cmpi "sgt", %c42, %cm1 : i32
-// CHECK-NEXT: %true_3 = constant 1 : i1
+// CHECK-SAME: %true,
   %5 = cmpi "sge", %c42, %cm1 : i32
-// CHECK-NEXT: %true_4 = constant 1 : i1
+// CHECK-SAME: %true,
   %6 = cmpi "ult", %c42, %cm1 : i32
-// CHECK-NEXT: %true_5 = constant 1 : i1
+// CHECK-SAME: %true,
   %7 = cmpi "ule", %c42, %cm1 : i32
-// CHECK-NEXT: %false_6 = constant 0 : i1
+// CHECK-SAME: %false,
   %8 = cmpi "ugt", %c42, %cm1 : i32
-// CHECK-NEXT: %false_7 = constant 0 : i1
+// CHECK-SAME: %false
   %9 = cmpi "uge", %c42, %cm1 : i32
   return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
 }