These hooks were introduced before the Interfaces mechanism was available.
DialectExtractElementHook is unused and entirely removed. The
DialectConstantFoldHook is used a fallback in the
operation fold() method, and is replaced by a DialectInterface.
The DialectConstantDecodeHook is used for interpreting OpaqueAttribute
and should be revamped, but is replaced with an interface in 1:1 fashion
for now.
Differential Revision: https://reviews.llvm.org/D85595
class OpBuilder;
class Type;
-using DialectConstantDecodeHook =
- std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
-using DialectConstantFoldHook = std::function<LogicalResult(
- Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
-using DialectExtractElementHook =
- std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
/// Dialects are groups of MLIR operations and behavior associated with the
/// These are represented with OpaqueType.
bool allowsUnknownTypes() const { return unknownTypesAllowed; }
- //===--------------------------------------------------------------------===//
- // Constant Hooks
- //===--------------------------------------------------------------------===//
-
- /// Registered fallback constant fold hook for the dialect. Like the constant
- /// fold hook of each operation, it attempts to constant fold the operation
- /// with the specified constant operand values - the elements in "operands"
- /// will correspond directly to the operands of the operation, but may be null
- /// if non-constant. If constant folding is successful, this fills in the
- /// `results` vector. If not, this returns failure and `results` is
- /// unspecified.
- DialectConstantFoldHook constantFoldHook =
- [](Operation *op, ArrayRef<Attribute> operands,
- SmallVectorImpl<Attribute> &results) { return failure(); };
-
- /// Registered hook to decode opaque constants associated with this
- /// dialect. The hook function attempts to decode an opaque constant tensor
- /// into a tensor with non-opaque content. If decoding is successful, this
- /// method returns false and sets 'output' attribute. If not, it returns true
- /// and leaves 'output' unspecified. The default hook fails to decode.
- DialectConstantDecodeHook decodeHook =
- [](const OpaqueElementsAttr input, ElementsAttr &output) { return true; };
-
- /// Registered hook to extract an element from an opaque constant associated
- /// with this dialect. If element has been successfully extracted, this
- /// method returns that element. If not, it returns an empty attribute.
- /// The default hook fails to extract an element.
- DialectExtractElementHook extractElementHook =
- [](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) {
- return Attribute();
- };
-
/// Registered hook to materialize a single constant operation from a given
/// attribute value with the desired resultant type. This method should use
/// the provided builder to create the operation without changing the
+++ /dev/null
-//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines abstraction and registration mechanism for dialect hooks.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_DIALECT_HOOKS_H
-#define MLIR_IR_DIALECT_HOOKS_H
-
-#include "mlir/IR/Dialect.h"
-#include "llvm/Support/raw_ostream.h"
-
-namespace mlir {
-using DialectHooksSetter = std::function<void(MLIRContext *)>;
-
-/// Dialect hooks allow external components to register their functions to
-/// be called for specific tasks specialized per dialect, such as decoding
-/// of opaque constants. To register concrete dialect hooks, one should
-/// define a DialectHooks subclass and use it as a template
-/// argument to DialectHooksRegistration. For example,
-/// class MyHooks : public DialectHooks {...};
-/// static DialectHooksRegistration<MyHooks, MyDialect> hooksReg;
-/// The subclass should override DialectHook methods for supported hooks.
-class DialectHooks {
-public:
- // Returns hook to constant fold an operation.
- DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
- // Returns hook to decode opaque constant tensor.
- DialectConstantDecodeHook getDecodeHook() { return nullptr; }
- // Returns hook to extract an element of an opaque constant tensor.
- DialectExtractElementHook getExtractElementHook() { return nullptr; }
-
-private:
- /// Registers a function that will set hooks in the registered dialects.
- /// Registrations are deduplicated by dialect TypeID and only the first
- /// registration will be used.
- static void registerDialectHooksSetter(TypeID typeID,
- const DialectHooksSetter &function);
- template <typename ConcreteHooks>
- friend void registerDialectHooks(StringRef dialectName);
-};
-
-void registerDialectHooksSetter(TypeID typeID,
- const DialectHooksSetter &function);
-
-/// Utility to register dialect hooks. Client can register their dialect hooks
-/// with the global registry by calling
-/// registerDialectHooks<MyHooks>("dialect_namespace");
-template <typename ConcreteHooks>
-void registerDialectHooks(StringRef dialectName) {
- DialectHooks::registerDialectHooksSetter(
- TypeID::get<ConcreteHooks>(), [dialectName](MLIRContext *ctx) {
- Dialect *dialect = ctx->getRegisteredDialect(dialectName);
- if (!dialect) {
- llvm::errs() << "error: cannot register hooks for unknown dialect '"
- << dialectName << "'\n";
- abort();
- }
- // Set hooks.
- ConcreteHooks hooks;
- if (auto h = hooks.getConstantFoldHook())
- dialect->constantFoldHook = h;
- if (auto h = hooks.getDecodeHook())
- dialect->decodeHook = h;
- if (auto h = hooks.getExtractElementHook())
- dialect->extractElementHook = h;
- });
-}
-
-/// DialectHooksRegistration provides a global initializer that registers
-/// a dialect hooks setter routine.
-/// Usage:
-///
-/// // At namespace scope.
-/// static DialectHooksRegistration<MyHooks> Unused("dialect_namespace");
-template <typename ConcreteHooks> struct DialectHooksRegistration {
- DialectHooksRegistration(StringRef dialectName) {
- registerDialectHooks<ConcreteHooks>(dialectName);
- }
-};
-
-} // namespace mlir
-
-#endif
--- /dev/null
+//===- DecodeAttributesInterfaces.h - DecodeAttributes Interfaces -*- C++ -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
+#define MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+
+/// Define an interface to decode opaque constant tensor.
+class DialectDecodeAttributesInterface
+ : public DialectInterface::Base<DialectDecodeAttributesInterface> {
+public:
+ DialectDecodeAttributesInterface(Dialect *dialect) : Base(dialect) {}
+
+ /// Registered hook to decode opaque constants associated with this
+ /// dialect. The hook function attempts to decode an opaque constant tensor
+ /// into a tensor with non-opaque content. If decoding is successful, this
+ /// method returns success() and sets 'output' attribute. If not, it returns
+ /// failure() and leaves 'output' unspecified. The default hook fails to
+ /// decode.
+ virtual LogicalResult decode(OpaqueElementsAttr input,
+ ElementsAttr &output) const {
+ return failure();
+ }
+};
+
+} // end namespace mlir
+
+#endif // MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
--- /dev/null
+//===- FoldInterfaces.h - Folding Interfaces --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_INTERFACES_FOLDINTERFACES_H_
+#define MLIR_INTERFACES_FOLDINTERFACES_H_
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+class Attribute;
+class OpFoldResult;
+
+/// Define a fold interface to allow for dialects to opt-in specific
+/// folding for operations they define.
+class DialectFoldInterface
+ : public DialectInterface::Base<DialectFoldInterface> {
+public:
+ DialectFoldInterface(Dialect *dialect) : Base(dialect) {}
+
+ /// Registered fallback fold for the dialect. Like the fold hook of each
+ /// operation, it attempts to fold the operation with the specified constant
+ /// operand values - the elements in "operands" will correspond directly to
+ /// the operands of the operation, but may be null if non-constant. If
+ /// folding is successful, this fills in the `results` vector. If not, this
+ /// returns failure and `results` is unspecified.
+ virtual LogicalResult Fold(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) const {
+ return failure();
+ }
+};
+
+} // end namespace mlir
+
+#endif // MLIR_INTERFACES_FOLDINTERFACES_H_
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DecodeAttributesInterfaces.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Endian.h"
/// element, then a null attribute is returned.
Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
assert(isValidIndex(index) && "expected valid multi-dimensional index");
- if (Dialect *dialect = getDialect())
- return dialect->extractElementHook(*this, index);
return Attribute();
}
Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
bool OpaqueElementsAttr::decode(ElementsAttr &result) {
- if (auto *d = getDialect())
- return d->decodeHook(*this, result);
- return true;
+ auto *d = getDialect();
+ if (!d)
+ return true;
+ auto *interface =
+ d->getRegisteredInterface<DialectDecodeAttributesInterface>();
+ if (!interface)
+ return true;
+ return failed(interface->decode(*this, result));
}
//===----------------------------------------------------------------------===//
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/DialectHooks.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/MLIRContext.h"
static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
dialectRegistry;
-/// Registry for functions that set dialect hooks.
-static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectHooksSetter>>
- dialectHooksRegistry;
-
void Dialect::registerDialectAllocator(
TypeID typeID, const DialectAllocatorFunction &function) {
assert(function &&
dialectRegistry->insert({typeID, function});
}
-/// Registers a function to set specific hooks for a specific dialect, typically
-/// used through the DialectHooksRegistration template.
-void DialectHooks::registerDialectHooksSetter(
- TypeID typeID, const DialectHooksSetter &function) {
- assert(
- function &&
- "Attempting to register an empty dialect hooks initialization function");
-
- dialectHooksRegistry->insert({typeID, function});
-}
-
/// Registers all dialects and hooks from the global registries with the
/// specified MLIRContext.
void mlir::registerAllDialects(MLIRContext *context) {
for (const auto &it : *dialectRegistry)
it.second(context);
- for (const auto &it : *dialectHooksRegistry)
- it.second(context);
}
//===----------------------------------------------------------------------===//
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/FoldInterfaces.h"
#include <numeric>
using namespace mlir;
if (!dialect)
return failure();
- SmallVector<Attribute, 8> constants;
- if (failed(dialect->constantFoldHook(this, operands, constants)))
+ auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
+ if (!interface)
return failure();
- results.assign(constants.begin(), constants.end());
- return success();
+
+ return interface->Fold(this, operands, results);
}
/// Emit an error with the op name prefixed, like "'dim' op " which is