namespace mlir {
+using DialectConstantFoldHook = std::function<bool(
+ const Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
+
/// Dialects are groups of MLIR operations and behavior associated with the
/// entire group. For example, hooks into other systems for constant folding,
/// default named types for asm printing, etc.
StringRef getOperationPrefix() const { return opPrefix; }
- /// Dialect implementations can implement this hook. It should attempt to
- /// constant fold this operation with the specified constant operand values -
- /// the elements in "operands" will correspond directly to the operands of the
- /// operation, but may be null if non-constant. If constant folding is
- /// successful, this returns false and fills in the `results` vector. If not,
- /// this returns true and `results` is unspecified.
- ///
- /// If not overridden, this fallback implementation always fails to fold.
- ///
- virtual bool constantFold(const Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) const {
- return true;
- }
+ /// Registered fallback constant fold hook for the dialect. Like the constant
+ /// fold hook of each operation, it attempts to constant fold the operation
+ /// with the specified constant operand values - the elements in "operands"
+ /// will correspond directly to the operands of the operation, but may be null
+ /// if non-constant. If constant folding is successful, this returns false
+ /// and fills in the `results` vector. If not, this returns true and
+ /// `results` is unspecified.
+ DialectConstantFoldHook constantFoldHook =
+ [](const Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<Attribute> &results) { return true; };
// TODO: Hook to return the list of named types that are known.
};
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
+using ConstantFoldHookAllocator = std::function<void(MLIRContext *)>;
-/// Register a specific dialect creation function with the system, typically
+/// Registers a specific dialect creation function with the system, typically
/// used through the DialectRegistration template.
void registerDialectAllocator(const DialectAllocatorFunction &function);
+/// Registers a constant fold hook for one or multiple dialects. The
+/// ConstantFoldHookAllocator defines how the hook gets mapped to the targeted
+/// dialect(s) in the context.
+/// Exmaple:
+/// registerConstantFoldHook([&](MLIRContext *ctx) {
+/// auto dialects = ctx->getRegisteredDialects();
+/// // then iterate and select the target dialect from dialects, or
+/// // get one dialect directly by the prefix:
+/// auto dialect = ctx->getRegisteredDialect("TARGET_PREFIX")
+///
+/// dialect->constantFoldHook = MyConstantFoldHook;
+/// });
+void registerConstantFoldHook(const ConstantFoldHookAllocator &function);
+
/// Registers all dialects with the specified MLIRContext.
void registerAllDialects(MLIRContext *context);
/// Return information about all registered IR dialects.
std::vector<Dialect *> getRegisteredDialects() const;
+ /// Get registered IR dialect which has the longest matching with the given
+ /// prefix. If none is found, returns nullptr.
+ Dialect *getRegisteredDialect(StringRef prefix) const;
+
/// Return information about all registered operations. This isn't very
/// efficient: typically you should ask the operations about their properties
/// directly.
};
} // end namespace mlir
-#endif // MLIR_IR_MLIRCONTEXT_H
+#endif // MLIR_IR_MLIRCONTEXT_H
static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
dialectRegistry;
-/// Register a specific dialect creation function with the system, typically
+// Registry for dialect's constant fold hooks.
+static llvm::ManagedStatic<SmallVector<ConstantFoldHookAllocator, 8>>
+ constantFoldHookRegistry;
+
+/// Registers a specific dialect creation function with the system, typically
/// used through the DialectRegistration template.
void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
- assert(function && "Attempting to register an empty op initialize function");
+ assert(function &&
+ "Attempting to register an empty dialect initialize function");
dialectRegistry->push_back(function);
}
-/// Registers all dialects with the specified MLIRContext.
+/// Registers a constant fold hook for a specific dialect with the system.
+void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) {
+ assert(
+ function &&
+ "Attempting to register an empty constant fold hook initialize function");
+ constantFoldHookRegistry->push_back(function);
+}
+
+/// Registers all dialects and their const folding hooks with the specified
+/// MLIRContext.
void mlir::registerAllDialects(MLIRContext *context) {
for (const auto &fn : *dialectRegistry)
fn(context);
+ for (const auto &fn : *constantFoldHookRegistry)
+ fn(context);
}
Dialect::Dialect(StringRef opPrefix, MLIRContext *context)
return result;
}
+/// Get registered IR dialect which has the longest matching with the given
+/// prefix. If none is found, returns nullptr.
+Dialect *MLIRContext::getRegisteredDialect(StringRef prefix) const {
+ Dialect *result = nullptr;
+ for (auto &dialect : getImpl().dialects) {
+ if (prefix.startswith(dialect->getOperationPrefix()))
+ if (!result || result->getOperationPrefix().size() <
+ dialect->getOperationPrefix().size())
+ result = dialect.get();
+ }
+ return result;
+}
+
/// Register this dialect object with the specified context. The context
/// takes ownership of the heap allocated dialect.
void Dialect::registerDialect(MLIRContext *context) {
return false;
// Otherwise, fall back on the dialect hook to handle it.
- Dialect &dialect = abstractOp->dialect;
- return dialect.constantFold(this, operands, results);
+ return abstractOp->dialect.constantFoldHook(this, operands, results);
}
+
+ // If this operation hasn't been registered or doesn't have abstract
+ // operation, fall back to a dialect which matches the prefix.
+ auto opName = getName().getStringRef();
+ if (auto *dialect = getContext()->getRegisteredDialect(opName)) {
+ return dialect->constantFoldHook(this, operands, results);
+ }
+
return true;
}