#include "mlir/IR/OperationSupport.h"
namespace mlir {
+class OpBuilder;
class Type;
using DialectConstantDecodeHook =
///
class Dialect {
public:
+ virtual ~Dialect();
+
+ /// Utility function that returns if the given string is a valid dialect
+ /// namespace.
+ static bool isValidNamespace(StringRef str);
+
MLIRContext *getContext() const { return context; }
StringRef getNamespace() const { return name; }
/// addOperation.
bool allowsUnknownOperations() const { return allowUnknownOps; }
+ //===--------------------------------------------------------------------===//
+ // Constant Hooks
+ //===--------------------------------------------------------------------===//
+
/// Registered fallback constant fold hook for the dialect. Like the constant
/// fold hook of each operation, it attempts to constant fold the operation
/// with the specified constant operand values - the elements in "operands"
return Attribute();
};
+ /// Registered hook to materialize a single constant operation from a given
+ /// attribute value with the desired resultant type. This method should use
+ /// the provided builder to create the operation without changing the
+ /// insertion position. The generated operation is expected to be constant
+ /// like, i.e. single result, zero operands, non side-effecting, etc. On
+ /// success, this hook should return the value generated to represent the
+ /// constant value. Otherwise, it should return null on failure.
+ virtual Operation *materializeConstant(OpBuilder &builder, Attribute value,
+ Type type, Location loc) {
+ return nullptr;
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Parsing Hooks
+ //===--------------------------------------------------------------------===//
+
/// Parse an attribute registered to this dialect.
virtual Attribute parseAttribute(StringRef attrData, Location loc) const;
virtual void
getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) {}
+ //===--------------------------------------------------------------------===//
+ // Verification Hooks
+ //===--------------------------------------------------------------------===//
+
/// Verify an attribute from this dialect on the given function. Returns
/// failure if the verification failed, success otherwise.
virtual LogicalResult verifyFunctionAttribute(Function *, NamedAttribute) {
return success();
}
- virtual ~Dialect();
-
- /// Utility function that returns if the given string is a valid dialect
- /// namespace.
- static bool isValidNamespace(StringRef str);
-
protected:
/// The constructor takes a unique namespace for this dialect as well as the
/// context to bind to.
#define MLIR_TRANSFORMS_FOLDUTILS_H
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
namespace mlir {
class Function;
OperationFolder(Function *f) : function(f) {}
/// Tries to perform folding on the given `op`, including unifying
- /// deduplicated constants. If successful, calls `preReplaceAction` (if
- /// provided) by passing in `op`, then replaces `op`'s uses with folded
- /// results, and returns success. If the op was completely folded it is
+ /// deduplicated constants. If successful, replaces `op`'s uses with
+ /// folded results, and returns success. `preReplaceAction` is invoked on `op`
+ /// before it is replaced. 'processGeneratedConstants' is invoked for any new
+ /// operations generated when folding. If the op was completely folded it is
/// erased.
- LogicalResult
- tryToFold(Operation *op,
- std::function<void(Operation *)> preReplaceAction = {});
+ LogicalResult tryToFold(
+ Operation *op,
+ llvm::function_ref<void(Operation *)> processGeneratedConstants = nullptr,
+ llvm::function_ref<void(Operation *)> preReplaceAction = nullptr);
/// Notifies that the given constant `op` should be remove from this
/// OperationFolder's internal bookkeeping.
private:
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
- LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results);
+ LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value *> &results,
+ llvm::function_ref<void(Operation *)>
+ processGeneratedConstants = nullptr);
- /// Tries to deduplicate the given constant and returns success if that can be
- /// done. This moves the given constant to the top of the entry block if it
- /// is first seen. If there is already an existing constant that is the same,
- /// this does *not* erases the given constant.
- LogicalResult tryToUnify(Operation *op);
-
- /// Moves the given constant `op` to entry block to guarantee dominance.
- void moveConstantToEntryBlock(Operation *op);
+ /// Try to get or create a new constant entry. On success this returns the
+ /// constant operation, nullptr otherwise.
+ Operation *tryGetOrCreateConstant(Dialect *dialect, OpBuilder &builder,
+ Attribute value, Type type, Location loc);
/// The function where we are managing constant.
Function *function;
- /// This map keeps track of uniqued constants.
- DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
+ /// This map keeps track of uniqued constants by dialect, attribute, and type.
+ /// A constant operation materializes an attribute with a type. Dialects may
+ /// generate different constants with the same input attribute and type, so we
+ /// also need to track per-dialect.
+ DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>
+ uniquedConstants;
+
+ /// This map tracks all of the dialects that an operation is referenced by;
+ /// given that many dialects may generate the same constant.
+ DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects;
};
} // end namespace mlir
// OperationFolder
//===----------------------------------------------------------------------===//
-LogicalResult
-OperationFolder::tryToFold(Operation *op,
- std::function<void(Operation *)> preReplaceAction) {
+LogicalResult OperationFolder::tryToFold(
+ Operation *op,
+ llvm::function_ref<void(Operation *)> processGeneratedConstants,
+ llvm::function_ref<void(Operation *)> preReplaceAction) {
assert(op->getFunction() == function &&
"cannot constant fold op from another function");
- // The constant op also implements the constant fold hook; it can be folded
- // into the value it contains. We need to consider constants before the
- // constant folding logic to avoid re-creating the same constant later.
- // TODO: Extend to support dialect-specific constant ops.
- if (auto constant = dyn_cast<ConstantOp>(op)) {
- // If this constant is dead, update bookkeeping and signal the caller.
- if (constant.use_empty()) {
- notifyRemoval(op);
- op->erase();
- return success();
- }
- // Otherwise, try to see if we can de-duplicate it.
- return tryToUnify(op);
- }
+ // If this is a unique'd constant, return failure as we know that it has
+ // already been folded.
+ if (referencedDialects.count(op))
+ return failure();
// Try to fold the operation.
SmallVector<Value *, 8> results;
- if (failed(tryToFold(op, results)))
+ if (failed(tryToFold(op, results, processGeneratedConstants)))
return failure();
// Constant folding succeeded. We will start replacing this op's uses and
return success();
}
+/// Notifies that the given constant `op` should be remove from this
+/// OperationFolder's internal bookkeeping.
+void OperationFolder::notifyRemoval(Operation *op) {
+ assert(op->getFunction() == function &&
+ "cannot remove constant from another function");
+
+ // Check to see if this operation is uniqued within the folder.
+ auto it = referencedDialects.find(op);
+ if (it == referencedDialects.end())
+ return;
+
+ // Get the constant value for this operation, this is the value that was used
+ // to unique the operation internally.
+ Attribute constValue;
+ matchPattern(op, m_Constant(&constValue));
+ assert(constValue);
+
+ // Erase all of the references to this operation.
+ auto type = op->getResult(0)->getType();
+ for (auto *dialect : it->second)
+ uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
+ referencedDialects.erase(it);
+}
+
+/// A utility function used to materialize a constant for a given attribute and
+/// type. On success, a valid constant value is returned. Otherwise, null is
+/// returned
+static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ auto insertPt = builder.getInsertionPoint();
+ (void)insertPt;
+
+ // Ask the dialect to materialize a constant operation for this value.
+ if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
+ assert(insertPt == builder.getInsertionPoint());
+ assert(matchPattern(constOp, m_Constant(&value)));
+ return constOp;
+ }
+
+ // If the dialect is unable to materialize a constant, check to see if the
+ // standard constant can be used.
+ if (ConstantOp::isBuildableWith(value, type))
+ return builder.create<ConstantOp>(loc, type, value);
+ return nullptr;
+}
+
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
-LogicalResult OperationFolder::tryToFold(Operation *op,
- SmallVectorImpl<Value *> &results) {
+LogicalResult OperationFolder::tryToFold(
+ Operation *op, SmallVectorImpl<Value *> &results,
+ llvm::function_ref<void(Operation *)> processGeneratedConstants) {
assert(op->getFunction() == function &&
"cannot constant fold op from another function");
return success();
assert(foldResults.size() == op->getNumResults());
+ // Create a builder to insert new operations into the entry block.
+ auto &entry = function->getBody().front();
+ OpBuilder builder(&entry, entry.empty() ? entry.end() : entry.begin());
+
// Create the result constants and replace the results.
- OpBuilder builder(op);
+ auto *dialect = op->getDialect();
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
continue;
}
- // If we already have a canonicalized version of this constant, just reuse
- // it. Otherwise create a new one.
- Attribute attrRepl = foldResults[i].get<Attribute>();
+ // Check to see if there is a canonicalized version of this constant.
auto *res = op->getResult(i);
- auto &constInst =
- uniquedConstants[std::make_pair(attrRepl, res->getType())];
- if (!constInst) {
- // TODO: Extend to support dialect-specific constant ops.
- auto newOp =
- builder.create<ConstantOp>(op->getLoc(), res->getType(), attrRepl);
- // Register to the constant map and also move up to entry block to
- // guarantee dominance.
- constInst = newOp.getOperation();
- moveConstantToEntryBlock(constInst);
+ Attribute attrRepl = foldResults[i].get<Attribute>();
+ if (auto *constOp = tryGetOrCreateConstant(dialect, builder, attrRepl,
+ res->getType(), op->getLoc())) {
+ results.push_back(constOp->getResult(0));
+ continue;
+ }
+ // If materialization fails, cleanup any operations generated for the
+ // previous results and return failure.
+ for (Operation &op : llvm::make_early_inc_range(
+ llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
+ notifyRemoval(&op);
+ op.erase();
}
- results.push_back(constInst->getResult(0));
+ return failure();
+ }
+
+ // Process any newly generated operations.
+ if (processGeneratedConstants) {
+ for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
+ processGeneratedConstants(&*i);
}
return success();
}
-void OperationFolder::notifyRemoval(Operation *op) {
- assert(op->getFunction() == function &&
- "cannot remove constant from another function");
-
- Attribute constValue;
- if (!matchPattern(op, m_Constant(&constValue)))
- return;
-
- // This constant is dead. keep uniquedConstants up to date.
- auto it = uniquedConstants.find({constValue, op->getResult(0)->getType()});
- if (it != uniquedConstants.end() && it->second == op)
- uniquedConstants.erase(it);
-}
+/// Try to get or create a new constant entry. On success this returns the
+/// constant operation value, nullptr otherwise.
+Operation *OperationFolder::tryGetOrCreateConstant(Dialect *dialect,
+ OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ // Check if an existing mapping already exists.
+ auto constKey = std::make_tuple(dialect, value, type);
+ auto *&constInst = uniquedConstants[constKey];
+ if (constInst)
+ return constInst;
+
+ // If one doesn't exist, try to materialize one.
+ if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
+ return nullptr;
+
+ // Check to see if the generated constant is in the expected dialect.
+ auto *newDialect = constInst->getDialect();
+ if (newDialect == dialect) {
+ referencedDialects[constInst].push_back(dialect);
+ return constInst;
+ }
-LogicalResult OperationFolder::tryToUnify(Operation *op) {
- Attribute constValue;
- matchPattern(op, m_Constant(&constValue));
- assert(constValue);
+ // If it isn't, then we also need to make sure that the mapping for the new
+ // dialect is valid.
+ auto newKey = std::make_tuple(newDialect, value, type);
- // Check to see if we already have a constant with this type and value:
- auto &constInst =
- uniquedConstants[std::make_pair(constValue, op->getResult(0)->getType())];
- if (constInst) {
- // If this constant is already our uniqued one, then leave it alone.
- if (constInst == op)
- return failure();
-
- // Otherwise replace this redundant constant with the uniqued one. We know
- // this is safe because we move constants to the top of the function when
- // they are uniqued, so we know they dominate all uses.
- op->getResult(0)->replaceAllUsesWith(constInst->getResult(0));
- op->erase();
- return success();
+ // If an existing operation in the new dialect already exists, delete the
+ // materialized operation in favor of the existing one.
+ if (auto *existingOp = uniquedConstants.lookup(newKey)) {
+ constInst->erase();
+ referencedDialects[existingOp].push_back(dialect);
+ return constInst = existingOp;
}
- // If we have no entry, then we should unique this constant as the
- // canonical version. To ensure safe dominance, move the operation to the
- // entry block of the function.
- constInst = op;
- moveConstantToEntryBlock(op);
- return failure();
-}
-
-void OperationFolder::moveConstantToEntryBlock(Operation *op) {
- // Insert at the very top of the entry block.
- auto &entryBB = function->front();
- op->moveBefore(&entryBB, entryBB.begin());
+ // Otherwise, update the new dialect to the materialized operation.
+ referencedDialects[constInst].assign({dialect, newDialect});
+ auto newIt = uniquedConstants.insert({newKey, constInst});
+ return newIt.first->second;
}