};
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
-using ConstantFoldHookAllocator = std::function<void(MLIRContext *)>;
/// Registers a specific dialect creation function with the system, typically
/// used through the DialectRegistration template.
void registerDialectAllocator(const DialectAllocatorFunction &function);
-/// Registers a constant fold hook for one or multiple dialects. The
-/// ConstantFoldHookAllocator defines how the hook gets mapped to the targeted
-/// dialect(s) in the context.
-/// Exmaple:
-/// registerConstantFoldHook([&](MLIRContext *ctx) {
-/// auto dialects = ctx->getRegisteredDialects();
-/// // then iterate and select the target dialect from dialects, or
-/// // get one dialect directly by the prefix:
-/// auto dialect = ctx->getRegisteredDialect("TARGET_PREFIX")
-///
-/// dialect->constantFoldHook = MyConstantFoldHook;
-/// });
-void registerConstantFoldHook(const ConstantFoldHookAllocator &function);
-
/// Registers all dialects with the specified MLIRContext.
void registerAllDialects(MLIRContext *context);
/// The subclass should override DialectHook methods for supported hooks.
class DialectHooks {
public:
+ // Returns hook to constant fold an operation.
+ DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
// Returns hook to decode opaque constant tensor.
DialectConstantDecodeHook getDecodeHook() { return nullptr; }
// Returns hook to extract an element of an opaque constant tensor.
}
// Set hooks.
ConcreteHooks hooks;
+ if (auto h = hooks.getConstantFoldHook())
+ dialect->constantFoldHook = h;
if (auto h = hooks.getDecodeHook())
dialect->decodeHook = h;
if (auto h = hooks.getExtractElementHook())
static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
dialectRegistry;
-// Registry for dialect's constant fold hooks.
-static llvm::ManagedStatic<SmallVector<ConstantFoldHookAllocator, 8>>
- constantFoldHookRegistry;
-
// Registry for functions that set dialect hooks.
static llvm::ManagedStatic<SmallVector<DialectHooksSetter, 8>>
dialectHooksRegistry;
dialectRegistry->push_back(function);
}
-/// Registers a constant fold hook for a specific dialect with the system.
-void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) {
- assert(
- function &&
- "Attempting to register an empty constant fold hook initialize function");
- constantFoldHookRegistry->push_back(function);
-}
-
/// Registers a function to set specific hooks for a specific dialect, typically
/// used through the DialectHooksRegistreation template.
void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
void mlir::registerAllDialects(MLIRContext *context) {
for (const auto &fn : *dialectRegistry)
fn(context);
- for (const auto &fn : *constantFoldHookRegistry)
- fn(context);
for (const auto &fn : *dialectHooksRegistry) {
fn(context);
}