Remove DialectHooks and introduce a Dialect Interfaces instead
authorMehdi Amini <joker.eph@gmail.com>
Wed, 12 Aug 2020 09:36:54 +0000 (09:36 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 13 Aug 2020 00:38:55 +0000 (00:38 +0000)
These hooks were introduced before the Interfaces mechanism was available.

DialectExtractElementHook is unused and entirely removed. The
DialectConstantFoldHook is used a fallback in the
operation fold() method, and is replaced by a DialectInterface.
The DialectConstantDecodeHook is used for interpreting OpaqueAttribute
and should be revamped, but is replaced with an interface in 1:1 fashion
for now.

Differential Revision: https://reviews.llvm.org/D85595

mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/DialectHooks.h [deleted file]
mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h [new file with mode: 0644]
mlir/include/mlir/Interfaces/FoldInterfaces.h [new file with mode: 0644]
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/Operation.cpp

index bd9f3c1..4f9e4cb 100644 (file)
@@ -23,12 +23,6 @@ class DialectInterface;
 class OpBuilder;
 class Type;
 
-using DialectConstantDecodeHook =
-    std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
-using DialectConstantFoldHook = std::function<LogicalResult(
-    Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
-using DialectExtractElementHook =
-    std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
 using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
 
 /// Dialects are groups of MLIR operations and behavior associated with the
@@ -63,38 +57,6 @@ public:
   /// These are represented with OpaqueType.
   bool allowsUnknownTypes() const { return unknownTypesAllowed; }
 
-  //===--------------------------------------------------------------------===//
-  // Constant Hooks
-  //===--------------------------------------------------------------------===//
-
-  /// Registered fallback constant fold hook for the dialect. Like the constant
-  /// fold hook of each operation, it attempts to constant fold the operation
-  /// with the specified constant operand values - the elements in "operands"
-  /// will correspond directly to the operands of the operation, but may be null
-  /// if non-constant.  If constant folding is successful, this fills in the
-  /// `results` vector.  If not, this returns failure and `results` is
-  /// unspecified.
-  DialectConstantFoldHook constantFoldHook =
-      [](Operation *op, ArrayRef<Attribute> operands,
-         SmallVectorImpl<Attribute> &results) { return failure(); };
-
-  /// Registered hook to decode opaque constants associated with this
-  /// dialect. The hook function attempts to decode an opaque constant tensor
-  /// into a tensor with non-opaque content. If decoding is successful, this
-  /// method returns false and sets 'output' attribute. If not, it returns true
-  /// and leaves 'output' unspecified. The default hook fails to decode.
-  DialectConstantDecodeHook decodeHook =
-      [](const OpaqueElementsAttr input, ElementsAttr &output) { return true; };
-
-  /// Registered hook to extract an element from an opaque constant associated
-  /// with this dialect. If element has been successfully extracted, this
-  /// method returns that element. If not, it returns an empty attribute.
-  /// The default hook fails to extract an element.
-  DialectExtractElementHook extractElementHook =
-      [](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) {
-        return Attribute();
-      };
-
   /// Registered hook to materialize a single constant operation from a given
   /// attribute value with the desired resultant type. This method should use
   /// the provided builder to create the operation without changing the
diff --git a/mlir/include/mlir/IR/DialectHooks.h b/mlir/include/mlir/IR/DialectHooks.h
deleted file mode 100644 (file)
index 3986266..0000000
+++ /dev/null
@@ -1,90 +0,0 @@
-//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines abstraction and registration mechanism for dialect hooks.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_DIALECT_HOOKS_H
-#define MLIR_IR_DIALECT_HOOKS_H
-
-#include "mlir/IR/Dialect.h"
-#include "llvm/Support/raw_ostream.h"
-
-namespace mlir {
-using DialectHooksSetter = std::function<void(MLIRContext *)>;
-
-/// Dialect hooks allow external components to register their functions to
-/// be called for specific tasks specialized per dialect, such as decoding
-/// of opaque constants. To register concrete dialect hooks, one should
-/// define a DialectHooks subclass and use it as a template
-/// argument to DialectHooksRegistration. For example,
-///     class MyHooks : public DialectHooks {...};
-///     static DialectHooksRegistration<MyHooks, MyDialect> hooksReg;
-/// The subclass should override DialectHook methods for supported hooks.
-class DialectHooks {
-public:
-  // Returns hook to constant fold an operation.
-  DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
-  // Returns hook to decode opaque constant tensor.
-  DialectConstantDecodeHook getDecodeHook() { return nullptr; }
-  // Returns hook to extract an element of an opaque constant tensor.
-  DialectExtractElementHook getExtractElementHook() { return nullptr; }
-
-private:
-  /// Registers a function that will set hooks in the registered dialects.
-  /// Registrations are deduplicated by dialect TypeID and only the first
-  /// registration will be used.
-  static void registerDialectHooksSetter(TypeID typeID,
-                                         const DialectHooksSetter &function);
-  template <typename ConcreteHooks>
-  friend void registerDialectHooks(StringRef dialectName);
-};
-
-void registerDialectHooksSetter(TypeID typeID,
-                                const DialectHooksSetter &function);
-
-/// Utility to register dialect hooks. Client can register their dialect hooks
-/// with the global registry by calling
-/// registerDialectHooks<MyHooks>("dialect_namespace");
-template <typename ConcreteHooks>
-void registerDialectHooks(StringRef dialectName) {
-  DialectHooks::registerDialectHooksSetter(
-      TypeID::get<ConcreteHooks>(), [dialectName](MLIRContext *ctx) {
-        Dialect *dialect = ctx->getRegisteredDialect(dialectName);
-        if (!dialect) {
-          llvm::errs() << "error: cannot register hooks for unknown dialect '"
-                       << dialectName << "'\n";
-          abort();
-        }
-        // Set hooks.
-        ConcreteHooks hooks;
-        if (auto h = hooks.getConstantFoldHook())
-          dialect->constantFoldHook = h;
-        if (auto h = hooks.getDecodeHook())
-          dialect->decodeHook = h;
-        if (auto h = hooks.getExtractElementHook())
-          dialect->extractElementHook = h;
-      });
-}
-
-/// DialectHooksRegistration provides a global initializer that registers
-/// a dialect hooks setter routine.
-/// Usage:
-///
-///   // At namespace scope.
-///   static DialectHooksRegistration<MyHooks> Unused("dialect_namespace");
-template <typename ConcreteHooks> struct DialectHooksRegistration {
-  DialectHooksRegistration(StringRef dialectName) {
-    registerDialectHooks<ConcreteHooks>(dialectName);
-  }
-};
-
-} // namespace mlir
-
-#endif
diff --git a/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h b/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h
new file mode 100644 (file)
index 0000000..2b7607e
--- /dev/null
@@ -0,0 +1,37 @@
+//===- DecodeAttributesInterfaces.h - DecodeAttributes Interfaces -*- C++ -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
+#define MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+
+/// Define an interface to decode opaque constant tensor.
+class DialectDecodeAttributesInterface
+    : public DialectInterface::Base<DialectDecodeAttributesInterface> {
+public:
+  DialectDecodeAttributesInterface(Dialect *dialect) : Base(dialect) {}
+
+  /// Registered hook to decode opaque constants associated with this
+  /// dialect. The hook function attempts to decode an opaque constant tensor
+  /// into a tensor with non-opaque content. If decoding is successful, this
+  /// method returns success() and sets 'output' attribute. If not, it returns
+  /// failure() and leaves 'output' unspecified. The default hook fails to
+  /// decode.
+  virtual LogicalResult decode(OpaqueElementsAttr input,
+                               ElementsAttr &output) const {
+    return failure();
+  }
+};
+
+} // end namespace mlir
+
+#endif // MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
diff --git a/mlir/include/mlir/Interfaces/FoldInterfaces.h b/mlir/include/mlir/Interfaces/FoldInterfaces.h
new file mode 100644 (file)
index 0000000..e1f1787
--- /dev/null
@@ -0,0 +1,40 @@
+//===- FoldInterfaces.h - Folding Interfaces --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_INTERFACES_FOLDINTERFACES_H_
+#define MLIR_INTERFACES_FOLDINTERFACES_H_
+
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+class Attribute;
+class OpFoldResult;
+
+/// Define a fold interface to allow for dialects to opt-in specific
+/// folding for operations they define.
+class DialectFoldInterface
+    : public DialectInterface::Base<DialectFoldInterface> {
+public:
+  DialectFoldInterface(Dialect *dialect) : Base(dialect) {}
+
+  /// Registered fallback fold for the dialect. Like the fold hook of each
+  /// operation, it attempts to fold the operation with the specified constant
+  /// operand values - the elements in "operands" will correspond directly to
+  /// the operands of the operation, but may be null if non-constant.  If
+  /// folding is successful, this fills in the `results` vector.  If not, this
+  /// returns failure and `results` is unspecified.
+  virtual LogicalResult Fold(Operation *op, ArrayRef<Attribute> operands,
+                             SmallVectorImpl<OpFoldResult> &results) const {
+    return failure();
+  }
+};
+
+} // end namespace mlir
+
+#endif // MLIR_INTERFACES_FOLDINTERFACES_H_
index e353b0b..dceb072 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/IR/Function.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DecodeAttributesInterfaces.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/Endian.h"
@@ -1227,17 +1228,20 @@ StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
 /// element, then a null attribute is returned.
 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   assert(isValidIndex(index) && "expected valid multi-dimensional index");
-  if (Dialect *dialect = getDialect())
-    return dialect->extractElementHook(*this, index);
   return Attribute();
 }
 
 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
 
 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
-  if (auto *d = getDialect())
-    return d->decodeHook(*this, result);
-  return true;
+  auto *d = getDialect();
+  if (!d)
+    return true;
+  auto *interface =
+      d->getRegisteredInterface<DialectDecodeAttributesInterface>();
+  if (!interface)
+    return true;
+  return failed(interface->decode(*this, result));
 }
 
 //===----------------------------------------------------------------------===//
index 02448b3..555bb2b 100644 (file)
@@ -8,7 +8,6 @@
 
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/DialectHooks.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/MLIRContext.h"
@@ -31,10 +30,6 @@ DialectAsmParser::~DialectAsmParser() {}
 static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
     dialectRegistry;
 
-/// Registry for functions that set dialect hooks.
-static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectHooksSetter>>
-    dialectHooksRegistry;
-
 void Dialect::registerDialectAllocator(
     TypeID typeID, const DialectAllocatorFunction &function) {
   assert(function &&
@@ -42,24 +37,11 @@ void Dialect::registerDialectAllocator(
   dialectRegistry->insert({typeID, function});
 }
 
-/// Registers a function to set specific hooks for a specific dialect, typically
-/// used through the DialectHooksRegistration template.
-void DialectHooks::registerDialectHooksSetter(
-    TypeID typeID, const DialectHooksSetter &function) {
-  assert(
-      function &&
-      "Attempting to register an empty dialect hooks initialization function");
-
-  dialectHooksRegistry->insert({typeID, function});
-}
-
 /// Registers all dialects and hooks from the global registries with the
 /// specified MLIRContext.
 void mlir::registerAllDialects(MLIRContext *context) {
   for (const auto &it : *dialectRegistry)
     it.second(context);
-  for (const auto &it : *dialectHooksRegistry)
-    it.second(context);
 }
 
 //===----------------------------------------------------------------------===//
index 8feab8e..152ed01 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/FoldInterfaces.h"
 #include <numeric>
 
 using namespace mlir;
@@ -570,11 +571,11 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
   if (!dialect)
     return failure();
 
-  SmallVector<Attribute, 8> constants;
-  if (failed(dialect->constantFoldHook(this, operands, constants)))
+  auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
+  if (!interface)
     return failure();
-  results.assign(constants.begin(), constants.end());
-  return success();
+
+  return interface->Fold(this, operands, results);
 }
 
 /// Emit an error with the op name prefixed, like "'dim' op " which is