Add a new dialect hook 'materializeConstant' to create a constant operation that...
authorRiver Riddle <riverriddle@google.com>
Sat, 22 Jun 2019 18:48:43 +0000 (11:48 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 22 Jun 2019 20:05:27 +0000 (13:05 -0700)
PiperOrigin-RevId: 254570153

mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/StandardOps/Ops.td
mlir/include/mlir/Transforms/FoldUtils.h
mlir/lib/StandardOps/Ops.cpp
mlir/lib/Transforms/TestConstantFold.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index 67ba2ea..56d0661 100644 (file)
@@ -25,6 +25,7 @@
 #include "mlir/IR/OperationSupport.h"
 
 namespace mlir {
+class OpBuilder;
 class Type;
 
 using DialectConstantDecodeHook =
@@ -43,6 +44,12 @@ using DialectExtractElementHook =
 ///
 class Dialect {
 public:
+  virtual ~Dialect();
+
+  /// Utility function that returns if the given string is a valid dialect
+  /// namespace.
+  static bool isValidNamespace(StringRef str);
+
   MLIRContext *getContext() const { return context; }
 
   StringRef getNamespace() const { return name; }
@@ -52,6 +59,10 @@ public:
   /// addOperation.
   bool allowsUnknownOperations() const { return allowUnknownOps; }
 
+  //===--------------------------------------------------------------------===//
+  // Constant Hooks
+  //===--------------------------------------------------------------------===//
+
   /// Registered fallback constant fold hook for the dialect. Like the constant
   /// fold hook of each operation, it attempts to constant fold the operation
   /// with the specified constant operand values - the elements in "operands"
@@ -80,6 +91,22 @@ public:
         return Attribute();
       };
 
+  /// Registered hook to materialize a single constant operation from a given
+  /// attribute value with the desired resultant type. This method should use
+  /// the provided builder to create the operation without changing the
+  /// insertion position. The generated operation is expected to be constant
+  /// like, i.e. single result, zero operands, non side-effecting, etc. On
+  /// success, this hook should return the value generated to represent the
+  /// constant value. Otherwise, it should return null on failure.
+  virtual Operation *materializeConstant(OpBuilder &builder, Attribute value,
+                                         Type type, Location loc) {
+    return nullptr;
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Parsing Hooks
+  //===--------------------------------------------------------------------===//
+
   /// Parse an attribute registered to this dialect.
   virtual Attribute parseAttribute(StringRef attrData, Location loc) const;
 
@@ -112,6 +139,10 @@ public:
   virtual void
   getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) {}
 
+  //===--------------------------------------------------------------------===//
+  // Verification Hooks
+  //===--------------------------------------------------------------------===//
+
   /// Verify an attribute from this dialect on the given function. Returns
   /// failure if the verification failed, success otherwise.
   virtual LogicalResult verifyFunctionAttribute(Function *, NamedAttribute) {
@@ -132,12 +163,6 @@ public:
     return success();
   }
 
-  virtual ~Dialect();
-
-  /// Utility function that returns if the given string is a valid dialect
-  /// namespace.
-  static bool isValidNamespace(StringRef str);
-
 protected:
   /// The constructor takes a unique namespace for this dialect as well as the
   /// context to bind to.
index a7fc094..1b14e2a 100644 (file)
@@ -499,6 +499,10 @@ def ConstantOp : Std_Op<"constant", [NoSideEffect]> {
 
   let extraClassDeclaration = [{
     Attribute getValue() { return getAttr("value"); }
+
+    /// Returns true if a constant operation can be built with the given value
+    /// and result type.
+    static bool isBuildableWith(Attribute value, Type type);
   }];
 
   let hasFolder = 1;
index 1882ae2..7c50c04 100644 (file)
@@ -24,6 +24,7 @@
 #define MLIR_TRANSFORMS_FOLDUTILS_H
 
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
 
 namespace mlir {
 class Function;
@@ -48,13 +49,15 @@ public:
   OperationFolder(Function *f) : function(f) {}
 
   /// Tries to perform folding on the given `op`, including unifying
-  /// deduplicated constants. If successful, calls `preReplaceAction` (if
-  /// provided) by passing in `op`, then replaces `op`'s uses with folded
-  /// results, and returns success. If the op was completely folded it is
+  /// deduplicated constants. If successful, replaces `op`'s uses with
+  /// folded results, and returns success. `preReplaceAction` is invoked on `op`
+  /// before it is replaced. 'processGeneratedConstants' is invoked for any new
+  /// operations generated when folding. If the op was completely folded it is
   /// erased.
-  LogicalResult
-  tryToFold(Operation *op,
-            std::function<void(Operation *)> preReplaceAction = {});
+  LogicalResult tryToFold(
+      Operation *op,
+      llvm::function_ref<void(Operation *)> processGeneratedConstants = nullptr,
+      llvm::function_ref<void(Operation *)> preReplaceAction = nullptr);
 
   /// Notifies that the given constant `op` should be remove from this
   /// OperationFolder's internal bookkeeping.
@@ -103,22 +106,28 @@ public:
 private:
   /// Tries to perform folding on the given `op`. If successful, populates
   /// `results` with the results of the folding.
-  LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results);
+  LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results,
+                          llvm::function_ref<void(Operation *)>
+                              processGeneratedConstants = nullptr);
 
-  /// Tries to deduplicate the given constant and returns success if that can be
-  /// done. This moves the given constant to the top of the entry block if it
-  /// is first seen. If there is already an existing constant that is the same,
-  /// this does *not* erases the given constant.
-  LogicalResult tryToUnify(Operation *op);
-
-  /// Moves the given constant `op` to entry block to guarantee dominance.
-  void moveConstantToEntryBlock(Operation *op);
+  /// Try to get or create a new constant entry. On success this returns the
+  /// constant operation, nullptr otherwise.
+  Operation *tryGetOrCreateConstant(Dialect *dialect, OpBuilder &builder,
+                                    Attribute value, Type type, Location loc);
 
   /// The function where we are managing constant.
   Function *function;
 
-  /// This map keeps track of uniqued constants.
-  DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
+  /// This map keeps track of uniqued constants by dialect, attribute, and type.
+  /// A constant operation materializes an attribute with a type. Dialects may
+  /// generate different constants with the same input attribute and type, so we
+  /// also need to track per-dialect.
+  DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>
+      uniquedConstants;
+
+  /// This map tracks all of the dialects that an operation is referenced by;
+  /// given that many dialects may generate the same constant.
+  DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects;
 };
 
 } // end namespace mlir
index 9a4a3f2..202896b 100644 (file)
@@ -1131,6 +1131,20 @@ OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
   return getValue();
 }
 
+/// Returns true if a constant operation can be built with the given value and
+/// result type.
+bool ConstantOp::isBuildableWith(Attribute value, Type type) {
+  // FunctionAttr can only be used with a function type.
+  if (value.isa<FunctionAttr>())
+    return type.isa<FunctionType>();
+  // Otherwise, the attribute must have the same type as 'type'.
+  if (value.getType() != type)
+    return false;
+  // Finally, check that the attribute kind is handled.
+  return value.isa<IntegerAttr>() || value.isa<FloatAttr>() ||
+         value.isa<ElementsAttr>() || value.isa<UnitAttr>();
+}
+
 void ConstantFloatOp::build(Builder *builder, OperationState *result,
                             const APFloat &value, FloatType type) {
   ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value));
index f360441..dd87d98 100644 (file)
@@ -38,15 +38,13 @@ struct TestConstantFold : public FunctionPass<TestConstantFold> {
 } // end anonymous namespace
 
 void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
+  auto processGeneratedConstants = [this](Operation *op) {
+    existingConstants.push_back(op);
+  };
+
   // Attempt to fold the specified operation, including handling unused or
   // duplicated constants.
-  if (succeeded(helper.tryToFold(op)))
-    return;
-
-  // If this op is a constant that are used and cannot be de-duplicated,
-  // remember it for cleanup later.
-  if (auto constant = dyn_cast<ConstantOp>(op))
-    existingConstants.push_back(op);
+  (void)helper.tryToFold(op, processGeneratedConstants);
 }
 
 // For now, we do a simple top-down pass over a function folding constants.  We
index e06756d..e25215d 100644 (file)
@@ -33,30 +33,21 @@ using namespace mlir;
 // OperationFolder
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-OperationFolder::tryToFold(Operation *op,
-                           std::function<void(Operation *)> preReplaceAction) {
+LogicalResult OperationFolder::tryToFold(
+    Operation *op,
+    llvm::function_ref<void(Operation *)> processGeneratedConstants,
+    llvm::function_ref<void(Operation *)> preReplaceAction) {
   assert(op->getFunction() == function &&
          "cannot constant fold op from another function");
 
-  // The constant op also implements the constant fold hook; it can be folded
-  // into the value it contains. We need to consider constants before the
-  // constant folding logic to avoid re-creating the same constant later.
-  // TODO: Extend to support dialect-specific constant ops.
-  if (auto constant = dyn_cast<ConstantOp>(op)) {
-    // If this constant is dead, update bookkeeping and signal the caller.
-    if (constant.use_empty()) {
-      notifyRemoval(op);
-      op->erase();
-      return success();
-    }
-    // Otherwise, try to see if we can de-duplicate it.
-    return tryToUnify(op);
-  }
+  // If this is a unique'd constant, return failure as we know that it has
+  // already been folded.
+  if (referencedDialects.count(op))
+    return failure();
 
   // Try to fold the operation.
   SmallVector<Value *, 8> results;
-  if (failed(tryToFold(op, results)))
+  if (failed(tryToFold(op, results, processGeneratedConstants)))
     return failure();
 
   // Constant folding succeeded. We will start replacing this op's uses and
@@ -76,10 +67,58 @@ OperationFolder::tryToFold(Operation *op,
   return success();
 }
 
+/// Notifies that the given constant `op` should be remove from this
+/// OperationFolder's internal bookkeeping.
+void OperationFolder::notifyRemoval(Operation *op) {
+  assert(op->getFunction() == function &&
+         "cannot remove constant from another function");
+
+  // Check to see if this operation is uniqued within the folder.
+  auto it = referencedDialects.find(op);
+  if (it == referencedDialects.end())
+    return;
+
+  // Get the constant value for this operation, this is the value that was used
+  // to unique the operation internally.
+  Attribute constValue;
+  matchPattern(op, m_Constant(&constValue));
+  assert(constValue);
+
+  // Erase all of the references to this operation.
+  auto type = op->getResult(0)->getType();
+  for (auto *dialect : it->second)
+    uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
+  referencedDialects.erase(it);
+}
+
+/// A utility function used to materialize a constant for a given attribute and
+/// type. On success, a valid constant value is returned. Otherwise, null is
+/// returned
+static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
+                                      Attribute value, Type type,
+                                      Location loc) {
+  auto insertPt = builder.getInsertionPoint();
+  (void)insertPt;
+
+  // Ask the dialect to materialize a constant operation for this value.
+  if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
+    assert(insertPt == builder.getInsertionPoint());
+    assert(matchPattern(constOp, m_Constant(&value)));
+    return constOp;
+  }
+
+  // If the dialect is unable to materialize a constant, check to see if the
+  // standard constant can be used.
+  if (ConstantOp::isBuildableWith(value, type))
+    return builder.create<ConstantOp>(loc, type, value);
+  return nullptr;
+}
+
 /// Tries to perform folding on the given `op`. If successful, populates
 /// `results` with the results of the folding.
-LogicalResult OperationFolder::tryToFold(Operation *op,
-                                         SmallVectorImpl<Value *> &results) {
+LogicalResult OperationFolder::tryToFold(
+    Operation *op, SmallVectorImpl<Value *> &results,
+    llvm::function_ref<void(Operation *)> processGeneratedConstants) {
   assert(op->getFunction() == function &&
          "cannot constant fold op from another function");
 
@@ -109,8 +148,12 @@ LogicalResult OperationFolder::tryToFold(Operation *op,
     return success();
   assert(foldResults.size() == op->getNumResults());
 
+  // Create a builder to insert new operations into the entry block.
+  auto &entry = function->getBody().front();
+  OpBuilder builder(&entry, entry.empty() ? entry.end() : entry.begin());
+
   // Create the result constants and replace the results.
-  OpBuilder builder(op);
+  auto *dialect = op->getDialect();
   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
 
@@ -120,72 +163,70 @@ LogicalResult OperationFolder::tryToFold(Operation *op,
       continue;
     }
 
-    // If we already have a canonicalized version of this constant, just reuse
-    // it. Otherwise create a new one.
-    Attribute attrRepl = foldResults[i].get<Attribute>();
+    // Check to see if there is a canonicalized version of this constant.
     auto *res = op->getResult(i);
-    auto &constInst =
-        uniquedConstants[std::make_pair(attrRepl, res->getType())];
-    if (!constInst) {
-      // TODO: Extend to support dialect-specific constant ops.
-      auto newOp =
-          builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl);
-      // Register to the constant map and also move up to entry block to
-      // guarantee dominance.
-      constInst = newOp.getOperation();
-      moveConstantToEntryBlock(constInst);
+    Attribute attrRepl = foldResults[i].get<Attribute>();
+    if (auto *constOp = tryGetOrCreateConstant(dialect, builder, attrRepl,
+                                               res->getType(), op->getLoc())) {
+      results.push_back(constOp->getResult(0));
+      continue;
+    }
+    // If materialization fails, cleanup any operations generated for the
+    // previous results and return failure.
+    for (Operation &op : llvm::make_early_inc_range(
+             llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
+      notifyRemoval(&op);
+      op.erase();
     }
-    results.push_back(constInst->getResult(0));
+    return failure();
+  }
+
+  // Process any newly generated operations.
+  if (processGeneratedConstants) {
+    for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
+      processGeneratedConstants(&*i);
   }
 
   return success();
 }
 
-void OperationFolder::notifyRemoval(Operation *op) {
-  assert(op->getFunction() == function &&
-         "cannot remove constant from another function");
-
-  Attribute constValue;
-  if (!matchPattern(op, m_Constant(&constValue)))
-    return;
-
-  // This constant is dead. keep uniquedConstants up to date.
-  auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()});
-  if (it != uniquedConstants.end() && it->second == op)
-    uniquedConstants.erase(it);
-}
+/// Try to get or create a new constant entry. On success this returns the
+/// constant operation value, nullptr otherwise.
+Operation *OperationFolder::tryGetOrCreateConstant(Dialect *dialect,
+                                                   OpBuilder &builder,
+                                                   Attribute value, Type type,
+                                                   Location loc) {
+  // Check if an existing mapping already exists.
+  auto constKey = std::make_tuple(dialect, value, type);
+  auto *&constInst = uniquedConstants[constKey];
+  if (constInst)
+    return constInst;
+
+  // If one doesn't exist, try to materialize one.
+  if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
+    return nullptr;
+
+  // Check to see if the generated constant is in the expected dialect.
+  auto *newDialect = constInst->getDialect();
+  if (newDialect == dialect) {
+    referencedDialects[constInst].push_back(dialect);
+    return constInst;
+  }
 
-LogicalResult OperationFolder::tryToUnify(Operation *op) {
-  Attribute constValue;
-  matchPattern(op, m_Constant(&constValue));
-  assert(constValue);
+  // If it isn't, then we also need to make sure that the mapping for the new
+  // dialect is valid.
+  auto newKey = std::make_tuple(newDialect, value, type);
 
-  // Check to see if we already have a constant with this type and value:
-  auto &constInst =
-      uniquedConstants[std::make_pair(constValue, op->getResult(0)->getType())];
-  if (constInst) {
-    // If this constant is already our uniqued one, then leave it alone.
-    if (constInst == op)
-      return failure();
-
-    // Otherwise replace this redundant constant with the uniqued one.  We know
-    // this is safe because we move constants to the top of the function when
-    // they are uniqued, so we know they dominate all uses.
-    op->getResult(0)->replaceAllUsesWith(constInst->getResult(0));
-    op->erase();
-    return success();
+  // If an existing operation in the new dialect already exists, delete the
+  // materialized operation in favor of the existing one.
+  if (auto *existingOp = uniquedConstants.lookup(newKey)) {
+    constInst->erase();
+    referencedDialects[existingOp].push_back(dialect);
+    return constInst = existingOp;
   }
 
-  // If we have no entry, then we should unique this constant as the
-  // canonical version.  To ensure safe dominance, move the operation to the
-  // entry block of the function.
-  constInst = op;
-  moveConstantToEntryBlock(op);
-  return failure();
-}
-
-void OperationFolder::moveConstantToEntryBlock(Operation *op) {
-  // Insert at the very top of the entry block.
-  auto &entryBB = function->front();
-  op->moveBefore(&entryBB, entryBB.begin());
+  // Otherwise, update the new dialect to the materialized operation.
+  referencedDialects[constInst].assign({dialect, newDialect});
+  auto newIt = uniquedConstants.insert({newKey, constInst});
+  return newIt.first->second;
 }
index 0cd3225..30a256d 100644 (file)
@@ -147,11 +147,14 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
   // TODO(riverriddle) OperationFolder should take a region to insert into.
   OperationFolder helper(region->getContainingFunction());
 
+  // Add the given operation to the worklist.
+  auto collectOps = [this](Operation *op) { addToWorklist(op); };
+
   bool changed = false;
   int i = 0;
   do {
     // Add all operations to the worklist.
-    region->walk([&](Operation *op) { addToWorklist(op); });
+    region->walk(collectOps);
 
     // These are scratch vectors used in the folding loop below.
     SmallVector<Value *, 8> originalOperands, resultValues;
@@ -190,7 +193,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
       };
 
       // Try to fold this op.
-      if (succeeded(helper.tryToFold(op, collectOperandsAndUses))) {
+      if (succeeded(helper.tryToFold(op, collectOps, collectOperandsAndUses))) {
         changed |= true;
         continue;
       }