Revert "Refactor OperationName to use virtual tables for dispatch (NFC)"
authorMehdi Amini <joker.eph@gmail.com>
Mon, 16 Jan 2023 23:11:12 +0000 (23:11 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 16 Jan 2023 23:11:38 +0000 (23:11 +0000)
This reverts commit e055aad5ffb348472c65dfcbede85f39efe8f906.

This crashes on Windows at the moment for some reasons.

12 files changed:
mlir/include/mlir/IR/ExtensibleDialect.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/AsmParser/Parser.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/ExtensibleDialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp

index 9820aa6..662a735 100644 (file)
@@ -336,15 +336,12 @@ public:
 
 /// The definition of a dynamic op. A dynamic op is an op that is defined at
 /// runtime, and that can be registered at runtime by an extensible dialect (a
-/// dialect inheriting ExtensibleDialect). This class implements the method
-/// exposed by the OperationName class, and in addition defines the TypeID of
-/// the op that will be defined. Each dynamic operation definition refers to one
-/// instance of this class.
-class DynamicOpDefinition : public OperationName::Impl {
+/// dialect inheriting ExtensibleDialect). This class stores the functions that
+/// are in the OperationName class, and in addition defines the TypeID of the op
+/// that will be defined.
+/// Each dynamic operation definition refers to one instance of this class.
+class DynamicOpDefinition {
 public:
-  using GetCanonicalizationPatternsFn =
-      llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
-
   /// Create a new op at runtime. The op is registered only after passing it to
   /// the dialect using registerDynamicOp.
   static std::unique_ptr<DynamicOpDefinition>
@@ -364,7 +361,8 @@ public:
       OperationName::ParseAssemblyFn &&parseFn,
       OperationName::PrintAssemblyFn &&printFn,
       OperationName::FoldHookFn &&foldHookFn,
-      GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
+      OperationName::GetCanonicalizationPatternsFn
+          &&getCanonicalizationPatternsFn,
       OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
 
   /// Returns the op typeID.
@@ -402,8 +400,9 @@ public:
 
   /// Set the hook returning any canonicalization pattern rewrites that the op
   /// supports, for use by the canonicalization pass.
-  void setGetCanonicalizationPatternsFn(
-      GetCanonicalizationPatternsFn &&getCanonicalizationPatterns) {
+  void
+  setGetCanonicalizationPatternsFn(OperationName::GetCanonicalizationPatternsFn
+                                       &&getCanonicalizationPatterns) {
     getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
   }
 
@@ -413,29 +412,6 @@ public:
     populateDefaultAttrsFn = std::move(populateDefaultAttrs);
   }
 
-  LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
-                         SmallVectorImpl<OpFoldResult> &results) final {
-    return foldHookFn(op, attrs, results);
-  }
-  void getCanonicalizationPatterns(RewritePatternSet &set,
-                                   MLIRContext *context) final {
-    getCanonicalizationPatternsFn(set, context);
-  }
-  bool hasTrait(TypeID id) final { return false; }
-  OperationName::ParseAssemblyFn getParseAssemblyFn() final { return parseFn; }
-  void populateDefaultAttrs(const OperationName &name,
-                            NamedAttrList &attrs) final {
-    populateDefaultAttrsFn(name, attrs);
-  }
-  void printAssembly(Operation *op, OpAsmPrinter &printer,
-                     StringRef name) final {
-    printFn(op, printer, name);
-  }
-  LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); }
-  LogicalResult verifyRegionInvariants(Operation *op) final {
-    return verifyRegionFn(op);
-  }
-
 private:
   DynamicOpDefinition(
       StringRef name, ExtensibleDialect *dialect,
@@ -444,18 +420,26 @@ private:
       OperationName::ParseAssemblyFn &&parseFn,
       OperationName::PrintAssemblyFn &&printFn,
       OperationName::FoldHookFn &&foldHookFn,
-      GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
+      OperationName::GetCanonicalizationPatternsFn
+          &&getCanonicalizationPatternsFn,
       OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
 
+  /// Unique identifier for this operation.
+  TypeID typeID;
+
+  /// Name of the operation.
+  /// The name is prefixed with the dialect name.
+  std::string name;
+
   /// Dialect defining this operation.
-  ExtensibleDialect *getdialect();
+  ExtensibleDialect *dialect;
 
   OperationName::VerifyInvariantsFn verifyFn;
   OperationName::VerifyRegionInvariantsFn verifyRegionFn;
   OperationName::ParseAssemblyFn parseFn;
   OperationName::PrintAssemblyFn printFn;
   OperationName::FoldHookFn foldHookFn;
-  GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+  OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
   OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
 
   friend ExtensibleDialect;
index 385597c..944f933 100644 (file)
@@ -184,7 +184,8 @@ public:
                                           MLIRContext *context) {}
 
   /// This hook populates any unset default attrs.
-  static void populateDefaultAttrs(const OperationName &, NamedAttrList &) {}
+  static void populateDefaultAttrs(const RegisteredOperationName &,
+                                   NamedAttrList &) {}
 
 protected:
   /// If the concrete type didn't implement a custom verifier hook, just fall
@@ -1832,11 +1833,20 @@ private:
     return result;
   }
 
+  /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
+  static OperationName::GetCanonicalizationPatternsFn
+  getGetCanonicalizationPatternsFn() {
+    return &ConcreteType::getCanonicalizationPatterns;
+  }
   /// Implementation of `GetHasTraitFn`
   static OperationName::HasTraitFn getHasTraitFn() {
     return
         [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
   }
+  /// Implementation of `ParseAssemblyFn` OperationName hook.
+  static OperationName::ParseAssemblyFn getParseAssemblyFn() {
+    return &ConcreteType::parse;
+  }
   /// Implementation of `PrintAssemblyFn` OperationName hook.
   static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
     if constexpr (detect_has_print<ConcreteType>::value)
index 59d450e..2dd5c5d 100644 (file)
@@ -506,9 +506,11 @@ public:
 
   /// Sets default attributes on unset attributes.
   void populateDefaultAttrs() {
+    if (auto registered = getRegisteredInfo()) {
       NamedAttrList attrs(getAttrDictionary());
-      name.populateDefaultAttrs(attrs);
+      registered->populateDefaultAttrs(attrs);
       setAttrs(attrs.getDictionary(getContext()));
+    }
   }
 
   //===--------------------------------------------------------------------===//
index 8ec11c1..a6d8a35 100644 (file)
@@ -23,7 +23,6 @@
 #include "mlir/Support/InterfaceSupport.h"
 #include "llvm/ADT/BitmaskEnum.h"
 #include "llvm/ADT/PointerUnion.h"
-#include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 #include "llvm/Support/TrailingObjects.h"
 #include <memory>
@@ -65,15 +64,17 @@ class ValueTypeRange;
 
 class OperationName {
 public:
+  using GetCanonicalizationPatternsFn =
+      llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
   using FoldHookFn = llvm::unique_function<LogicalResult(
       Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) const>;
   using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
   using ParseAssemblyFn =
-      llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>;
+      llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
   // Note: RegisteredOperationName is passed as reference here as the derived
   // class is defined below.
-  using PopulateDefaultAttrsFn =
-      llvm::unique_function<void(const OperationName &, NamedAttrList &) const>;
+  using PopulateDefaultAttrsFn = llvm::unique_function<void(
+      const RegisteredOperationName &, NamedAttrList &) const>;
   using PrintAssemblyFn =
       llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
   using VerifyInvariantsFn =
@@ -81,132 +82,63 @@ public:
   using VerifyRegionInvariantsFn =
       llvm::unique_function<LogicalResult(Operation *) const>;
 
+protected:
   /// This class represents a type erased version of an operation. It contains
   /// all of the components necessary for opaquely interacting with an
   /// operation. If the operation is not registered, some of these components
   /// may not be populated.
-  struct InterfaceConcept {
-    virtual ~InterfaceConcept() = default;
-    virtual LogicalResult foldHook(Operation *, ArrayRef<Attribute>,
-                                   SmallVectorImpl<OpFoldResult> &) = 0;
-    virtual void getCanonicalizationPatterns(RewritePatternSet &,
-                                             MLIRContext *) = 0;
-    virtual bool hasTrait(TypeID) = 0;
-    virtual OperationName::ParseAssemblyFn getParseAssemblyFn() = 0;
-    virtual void populateDefaultAttrs(const OperationName &,
-                                      NamedAttrList &) = 0;
-    virtual void printAssembly(Operation *, OpAsmPrinter &, StringRef) = 0;
-    virtual LogicalResult verifyInvariants(Operation *) = 0;
-    virtual LogicalResult verifyRegionInvariants(Operation *) = 0;
-  };
-
-public:
-  class Impl : public InterfaceConcept {
-  public:
-    Impl(StringRef, Dialect *dialect, TypeID typeID,
-         detail::InterfaceMap interfaceMap);
-    Impl(StringAttr name, Dialect *dialect, TypeID typeID,
-         detail::InterfaceMap interfaceMap)
-        : name(name), typeID(typeID), dialect(dialect),
-          interfaceMap(std::move(interfaceMap)) {}
-
-    /// Returns true if this is a registered operation.
-    bool isRegistered() const { return typeID != TypeID::get<void>(); }
-    detail::InterfaceMap &getInterfaceMap() { return interfaceMap; }
-    Dialect *getDialect() const { return dialect; }
-    StringAttr getName() const { return name; }
-    TypeID getTypeID() const { return typeID; }
-    ArrayRef<StringAttr> getAttributeNames() const { return attributeNames; }
-
-  protected:
-    //===------------------------------------------------------------------===//
-    // Registered Operation Info
+  struct Impl {
+    Impl(StringAttr name)
+        : name(name), dialect(nullptr), interfaceMap(std::nullopt) {}
 
     /// The name of the operation.
     StringAttr name;
 
-    /// The unique identifier of the derived Op class.
-    TypeID typeID;
+    //===------------------------------------------------------------------===//
+    // Registered Operation Info
 
     /// The following fields are only populated when the operation is
     /// registered.
 
+    /// Returns true if the operation has been registered, i.e. if the
+    /// registration info has been populated.
+    bool isRegistered() const { return dialect; }
+
     /// This is the dialect that this operation belongs to.
     Dialect *dialect;
 
+    /// The unique identifier of the derived Op class.
+    TypeID typeID;
+
     /// A map of interfaces that were registered to this operation.
     detail::InterfaceMap interfaceMap;
 
+    /// Internal callback hooks provided by the op implementation.
+    FoldHookFn foldHookFn;
+    GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+    HasTraitFn hasTraitFn;
+    ParseAssemblyFn parseAssemblyFn;
+    PopulateDefaultAttrsFn populateDefaultAttrsFn;
+    PrintAssemblyFn printAssemblyFn;
+    VerifyInvariantsFn verifyInvariantsFn;
+    VerifyRegionInvariantsFn verifyRegionInvariantsFn;
+
     /// A list of attribute names registered to this operation in StringAttr
     /// form. This allows for operation classes to use StringAttr for attribute
     /// lookup/creation/etc., as opposed to raw strings.
     ArrayRef<StringAttr> attributeNames;
-
-    friend class RegisteredOperationName;
-  };
-
-protected:
-  /// Default implementation for unregistered operations.
-  struct UnregisteredOpModel : public Impl {
-    using Impl::Impl;
-    LogicalResult foldHook(Operation *, ArrayRef<Attribute>,
-                           SmallVectorImpl<OpFoldResult> &) final;
-    void getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) final;
-    bool hasTrait(TypeID) final;
-    virtual OperationName::ParseAssemblyFn getParseAssemblyFn() final;
-    void populateDefaultAttrs(const OperationName &, NamedAttrList &) final;
-    void printAssembly(Operation *, OpAsmPrinter &, StringRef) final;
-    LogicalResult verifyInvariants(Operation *) final;
-    LogicalResult verifyRegionInvariants(Operation *) final;
   };
 
 public:
   OperationName(StringRef name, MLIRContext *context);
 
   /// Return if this operation is registered.
-  bool isRegistered() const { return getImpl()->isRegistered(); }
-
-  /// Return the unique identifier of the derived Op class, or null if not
-  /// registered.
-  TypeID getTypeID() const { return getImpl()->getTypeID(); }
+  bool isRegistered() const { return impl->isRegistered(); }
 
   /// If this operation is registered, returns the registered information,
   /// std::nullopt otherwise.
   std::optional<RegisteredOperationName> getRegisteredInfo() const;
 
-  /// This hook implements a generalized folder for this operation. Operations
-  /// can implement this to provide simplifications rules that are applied by
-  /// the Builder::createOrFold API and the canonicalization pass.
-  ///
-  /// This is an intentionally limited interface - implementations of this
-  /// hook can only perform the following changes to the operation:
-  ///
-  ///  1. They can leave the operation alone and without changing the IR, and
-  ///     return failure.
-  ///  2. They can mutate the operation in place, without changing anything
-  ///  else
-  ///     in the IR.  In this case, return success.
-  ///  3. They can return a list of existing values that can be used instead
-  ///  of
-  ///     the operation.  In this case, fill in the results list and return
-  ///     success.  The caller will remove the operation and use those results
-  ///     instead.
-  ///
-  /// This allows expression of some simple in-place canonicalizations (e.g.
-  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
-  /// generalized constant folding.
-  LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
-                         SmallVectorImpl<OpFoldResult> &results) const {
-    return getImpl()->foldHook(op, operands, results);
-  }
-
-  /// This hook returns any canonicalization pattern rewrites that the
-  /// operation supports, for use by the canonicalization pass.
-  void getCanonicalizationPatterns(RewritePatternSet &results,
-                                   MLIRContext *context) const {
-    return getImpl()->getCanonicalizationPatterns(results, context);
-  }
-
   /// Returns true if the operation was registered with a particular trait, e.g.
   /// hasTrait<OperandsAreSignlessIntegerLike>(). Returns false if the operation
   /// is unregistered.
@@ -214,7 +146,9 @@ public:
   bool hasTrait() const {
     return hasTrait(TypeID::get<Trait>());
   }
-  bool hasTrait(TypeID traitID) const { return getImpl()->hasTrait(traitID); }
+  bool hasTrait(TypeID traitID) const {
+    return isRegistered() && impl->hasTraitFn(traitID);
+  }
 
   /// Returns true if the operation *might* have the provided trait. This
   /// means that either the operation is unregistered, or it was registered with
@@ -224,54 +158,7 @@ public:
     return mightHaveTrait(TypeID::get<Trait>());
   }
   bool mightHaveTrait(TypeID traitID) const {
-    return !isRegistered() || getImpl()->hasTrait(traitID);
-  }
-
-  /// Return the static hook for parsing this operation assembly.
-  ParseAssemblyFn getParseAssemblyFn() const {
-    return getImpl()->getParseAssemblyFn();
-  }
-
-  /// This hook implements the method to populate defaults attributes that are
-  /// unset.
-  void populateDefaultAttrs(NamedAttrList &attrs) const {
-    getImpl()->populateDefaultAttrs(*this, attrs);
-  }
-
-  /// This hook implements the AsmPrinter for this operation.
-  void printAssembly(Operation *op, OpAsmPrinter &p,
-                     StringRef defaultDialect) const {
-    return getImpl()->printAssembly(op, p, defaultDialect);
-  }
-
-  /// These hooks implement the verifiers for this operation.  It should emits
-  /// an error message and returns failure if a problem is detected, or
-  /// returns success if everything is ok.
-  LogicalResult verifyInvariants(Operation *op) const {
-    return getImpl()->verifyInvariants(op);
-  }
-  LogicalResult verifyRegionInvariants(Operation *op) const {
-    return getImpl()->verifyRegionInvariants(op);
-  }
-
-  /// Return the list of cached attribute names registered to this operation.
-  /// The order of attributes cached here is unique to each type of operation,
-  /// and the interpretation of this attribute list should generally be driven
-  /// by the respective operation. In many cases, this caching removes the
-  /// need to use the raw string name of a known attribute.
-  ///
-  /// For example the ODS generator, with an op defining the following
-  /// attributes:
-  ///
-  ///   let arguments = (ins I32Attr:$attr1, I32Attr:$attr2);
-  ///
-  /// ... may produce an order here of ["attr1", "attr2"]. This allows for the
-  /// ODS generator to directly access the cached name for a known attribute,
-  /// greatly simplifying the cost and complexity of attribute usage produced
-  /// by the generator.
-  ///
-  ArrayRef<StringAttr> getAttributeNames() const {
-    return getImpl()->getAttributeNames();
+    return !isRegistered() || impl->hasTraitFn(traitID);
   }
 
   /// Returns an instance of the concept object for the given interface if it
@@ -279,13 +166,7 @@ public:
   /// directly.
   template <typename T>
   typename T::Concept *getInterface() const {
-    return getImpl()->getInterfaceMap().lookup<T>();
-  }
-
-  /// Attach the given models as implementations of the corresponding
-  /// interfaces for the concrete operation.
-  template <typename... Models> void attachInterface() {
-    getImpl()->getInterfaceMap().insert<Models...>();
+    return impl->interfaceMap.lookup<T>();
   }
 
   /// Returns true if this operation has the given interface registered to it.
@@ -294,7 +175,7 @@ public:
     return hasInterface(TypeID::get<T>());
   }
   bool hasInterface(TypeID interfaceID) const {
-    return getImpl()->getInterfaceMap().contains(interfaceID);
+    return impl->interfaceMap.contains(interfaceID);
   }
 
   /// Returns true if the operation *might* have the provided interface. This
@@ -311,8 +192,7 @@ public:
   /// Return the dialect this operation is registered to if the dialect is
   /// loaded in the context, or nullptr if the dialect isn't loaded.
   Dialect *getDialect() const {
-    return isRegistered() ? getImpl()->getDialect()
-                          : getImpl()->getName().getReferencedDialect();
+    return isRegistered() ? impl->dialect : impl->name.getReferencedDialect();
   }
 
   /// Return the name of the dialect this operation is registered to.
@@ -325,7 +205,7 @@ public:
   StringRef getStringRef() const { return getIdentifier(); }
 
   /// Return the name of this operation as a StringAttr.
-  StringAttr getIdentifier() const { return getImpl()->getName(); }
+  StringAttr getIdentifier() const { return impl->name; }
 
   void print(raw_ostream &os) const;
   void dump() const;
@@ -343,17 +223,12 @@ public:
 
 protected:
   OperationName(Impl *impl) : impl(impl) {}
-  Impl *getImpl() const { return impl; }
-  void setImpl(Impl *rhs) { impl = rhs; }
 
-private:
   /// The internal implementation of the operation name.
-  Impl *impl = nullptr;
+  Impl *impl;
 
   /// Allow access to the Impl struct.
   friend MLIRContextImpl;
-  friend DenseMapInfo<mlir::OperationName>;
-  friend DenseMapInfo<mlir::RegisteredOperationName>;
 };
 
 inline raw_ostream &operator<<(raw_ostream &os, OperationName info) {
@@ -376,62 +251,137 @@ inline llvm::hash_code hash_value(OperationName arg) {
 /// the concrete operation types.
 class RegisteredOperationName : public OperationName {
 public:
-  /// Implementation of the InterfaceConcept for operation APIs that forwarded
-  /// to a concrete op implementation.
-  template <typename ConcreteOp> struct Model : public Impl {
-    Model(Dialect *dialect)
-        : Impl(ConcreteOp::getOperationName(), dialect,
-               TypeID::get<ConcreteOp>(), ConcreteOp::getInterfaceMap()) {}
-    LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
-                           SmallVectorImpl<OpFoldResult> &results) final {
-      return ConcreteOp::getFoldHookFn()(op, attrs, results);
-    }
-    void getCanonicalizationPatterns(RewritePatternSet &set,
-                                     MLIRContext *context) final {
-      ConcreteOp::getCanonicalizationPatterns(set, context);
-    }
-    bool hasTrait(TypeID id) final { return ConcreteOp::getHasTraitFn()(id); }
-    OperationName::ParseAssemblyFn getParseAssemblyFn() final {
-      return ConcreteOp::parse;
-    }
-    void populateDefaultAttrs(const OperationName &name,
-                              NamedAttrList &attrs) final {
-      ConcreteOp::populateDefaultAttrs(name, attrs);
-    }
-    void printAssembly(Operation *op, OpAsmPrinter &printer,
-                       StringRef name) final {
-      ConcreteOp::getPrintAssemblyFn()(op, printer, name);
-    }
-    LogicalResult verifyInvariants(Operation *op) final {
-      return ConcreteOp::getVerifyInvariantsFn()(op);
-    }
-    LogicalResult verifyRegionInvariants(Operation *op) final {
-      return ConcreteOp::getVerifyRegionInvariantsFn()(op);
-    }
-  };
-
   /// Lookup the registered operation information for the given operation.
   /// Returns std::nullopt if the operation isn't registered.
   static std::optional<RegisteredOperationName> lookup(StringRef name,
                                                        MLIRContext *ctx);
 
   /// Register a new operation in a Dialect object.
-  /// This constructor is used by Dialect objects when they register the list
-  /// of operations they contain.
-  template <typename T> static void insert(Dialect &dialect) {
-    insert(std::make_unique<Model<T>>(&dialect), T::getAttributeNames());
+  /// This constructor is used by Dialect objects when they register the list of
+  /// operations they contain.
+  template <typename T>
+  static void insert(Dialect &dialect) {
+    insert(T::getOperationName(), dialect, TypeID::get<T>(),
+           T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
+           T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
+           T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
+           T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(),
+           T::getPopulateDefaultAttrsFn());
   }
   /// The use of this method is in general discouraged in favor of
   /// 'insert<CustomOp>(dialect)'.
-  static void insert(std::unique_ptr<OperationName::Impl> ownedImpl,
-                     ArrayRef<StringRef> attrNames);
+  static void
+  insert(StringRef name, Dialect &dialect, TypeID typeID,
+         ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+         VerifyInvariantsFn &&verifyInvariants,
+         VerifyRegionInvariantsFn &&verifyRegionInvariants,
+         FoldHookFn &&foldHook,
+         GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+         detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
+         ArrayRef<StringRef> attrNames,
+         PopulateDefaultAttrsFn &&populateDefaultAttrs);
 
   /// Return the dialect this operation is registered to.
-  Dialect &getDialect() const { return *getImpl()->getDialect(); }
+  Dialect &getDialect() const { return *impl->dialect; }
+
+  /// Return the unique identifier of the derived Op class.
+  TypeID getTypeID() const { return impl->typeID; }
 
   /// Use the specified object to parse this ops custom assembly format.
   ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
 
+  /// Return the static hook for parsing this operation assembly.
+  const ParseAssemblyFn &getParseAssemblyFn() const {
+    return impl->parseAssemblyFn;
+  }
+
+  /// This hook implements the AsmPrinter for this operation.
+  void printAssembly(Operation *op, OpAsmPrinter &p,
+                     StringRef defaultDialect) const {
+    return impl->printAssemblyFn(op, p, defaultDialect);
+  }
+
+  /// These hooks implement the verifiers for this operation.  It should emits
+  /// an error message and returns failure if a problem is detected, or returns
+  /// success if everything is ok.
+  LogicalResult verifyInvariants(Operation *op) const {
+    return impl->verifyInvariantsFn(op);
+  }
+  LogicalResult verifyRegionInvariants(Operation *op) const {
+    return impl->verifyRegionInvariantsFn(op);
+  }
+
+  /// This hook implements a generalized folder for this operation.  Operations
+  /// can implement this to provide simplifications rules that are applied by
+  /// the Builder::createOrFold API and the canonicalization pass.
+  ///
+  /// This is an intentionally limited interface - implementations of this hook
+  /// can only perform the following changes to the operation:
+  ///
+  ///  1. They can leave the operation alone and without changing the IR, and
+  ///     return failure.
+  ///  2. They can mutate the operation in place, without changing anything else
+  ///     in the IR.  In this case, return success.
+  ///  3. They can return a list of existing values that can be used instead of
+  ///     the operation.  In this case, fill in the results list and return
+  ///     success.  The caller will remove the operation and use those results
+  ///     instead.
+  ///
+  /// This allows expression of some simple in-place canonicalizations (e.g.
+  /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
+  /// generalized constant folding.
+  LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
+                         SmallVectorImpl<OpFoldResult> &results) const {
+    return impl->foldHookFn(op, operands, results);
+  }
+
+  /// This hook returns any canonicalization pattern rewrites that the operation
+  /// supports, for use by the canonicalization pass.
+  void getCanonicalizationPatterns(RewritePatternSet &results,
+                                   MLIRContext *context) const {
+    return impl->getCanonicalizationPatternsFn(results, context);
+  }
+
+  /// Attach the given models as implementations of the corresponding interfaces
+  /// for the concrete operation.
+  template <typename... Models>
+  void attachInterface() {
+    impl->interfaceMap.insert<Models...>();
+  }
+
+  /// Returns true if the operation has a particular trait.
+  template <template <typename T> class Trait>
+  bool hasTrait() const {
+    return hasTrait(TypeID::get<Trait>());
+  }
+
+  /// Returns true if the operation has a particular trait.
+  bool hasTrait(TypeID traitID) const { return impl->hasTraitFn(traitID); }
+
+  /// Return the list of cached attribute names registered to this operation.
+  /// The order of attributes cached here is unique to each type of operation,
+  /// and the interpretation of this attribute list should generally be driven
+  /// by the respective operation. In many cases, this caching removes the need
+  /// to use the raw string name of a known attribute.
+  ///
+  /// For example the ODS generator, with an op defining the following
+  /// attributes:
+  ///
+  ///   let arguments = (ins I32Attr:$attr1, I32Attr:$attr2);
+  ///
+  /// ... may produce an order here of ["attr1", "attr2"]. This allows for the
+  /// ODS generator to directly access the cached name for a known attribute,
+  /// greatly simplifying the cost and complexity of attribute usage produced by
+  /// the generator.
+  ///
+  ArrayRef<StringAttr> getAttributeNames() const {
+    return impl->attributeNames;
+  }
+
+  /// This hook implements the method to populate defaults attributes that are
+  /// unset.
+  void populateDefaultAttrs(NamedAttrList &attrs) const;
+
   /// Represent the operation name as an opaque pointer. (Used to support
   /// PointerLikeTypeTraits).
   static RegisteredOperationName getFromOpaquePointer(const void *pointer) {
index c56befe..a1af481 100644 (file)
@@ -1368,18 +1368,14 @@ Operation *OperationParser::parseGenericOperation() {
   if (!result.name.isRegistered()) {
     StringRef dialectName = StringRef(name).split('.').first;
     if (!getContext()->getLoadedDialect(dialectName) &&
-        !getContext()->getOrLoadDialect(dialectName)) {
-      if (!getContext()->allowsUnregisteredDialects()) {
-        // Emit an error if the dialect couldn't be loaded (i.e., it was not
-        // registered) and unregistered dialects aren't allowed.
-        emitError("operation being parsed with an unregistered dialect. If "
-                  "this is intended, please use -allow-unregistered-dialect "
-                  "with the MLIR tool used");
-        return nullptr;
-      }
-    } else {
-      // Reload the OperationName now that the dialect is loaded.
-      result.name = OperationName(name, getContext());
+        !getContext()->getOrLoadDialect(dialectName) &&
+        !getContext()->allowsUnregisteredDialects()) {
+      // Emit an error if the dialect couldn't be loaded (i.e., it was not
+      // registered) and unregistered dialects aren't allowed.
+      emitError("operation being parsed with an unregistered dialect. If "
+                "this is intended, please use -allow-unregistered-dialect "
+                "with the MLIR tool used");
+      return nullptr;
     }
   }
 
index 0981b3b..76807a5 100644 (file)
@@ -603,8 +603,12 @@ public:
 
     // If requested, always print the generic form.
     if (!printerFlags.shouldPrintGenericOpForm()) {
-      op->getName().printAssembly(op, *this, /*defaultDialect=*/"");
-      return;
+      // Check to see if this is a known operation.  If so, use the registered
+      // custom printer hook.
+      if (auto opInfo = op->getRegisteredInfo()) {
+        opInfo->printAssembly(op, *this, /*defaultDialect=*/"");
+        return;
+      }
     }
 
     // Otherwise print with the generic assembly form.
index 5190b5f..fd169a8 100644 (file)
@@ -294,19 +294,16 @@ DynamicOpDefinition::DynamicOpDefinition(
     OperationName::ParseAssemblyFn &&parseFn,
     OperationName::PrintAssemblyFn &&printFn,
     OperationName::FoldHookFn &&foldHookFn,
-    GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
+    OperationName::GetCanonicalizationPatternsFn
+        &&getCanonicalizationPatternsFn,
     OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn)
-    : Impl(StringAttr::get(dialect->getContext(),
-                           (dialect->getNamespace() + "." + name).str()),
-           dialect, dialect->allocateTypeID(),
-           /*interfaceMap=*/detail::InterfaceMap(std::nullopt)),
+    : typeID(dialect->allocateTypeID()),
+      name((dialect->getNamespace() + "." + name).str()), dialect(dialect),
       verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
       parseFn(std::move(parseFn)), printFn(std::move(printFn)),
       foldHookFn(std::move(foldHookFn)),
       getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)),
-      populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {
-  typeID = dialect->allocateTypeID();
-}
+      populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {}
 
 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
     StringRef name, ExtensibleDialect *dialect,
@@ -341,7 +338,8 @@ std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
   auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
   };
 
-  auto populateDefaultAttrsFn = [](const OperationName &, NamedAttrList &) {};
+  auto populateDefaultAttrsFn = [](const RegisteredOperationName &,
+                                   NamedAttrList &) {};
 
   return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
                                   std::move(verifyRegionFn), std::move(parseFn),
@@ -357,7 +355,8 @@ std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
     OperationName::ParseAssemblyFn &&parseFn,
     OperationName::PrintAssemblyFn &&printFn,
     OperationName::FoldHookFn &&foldHookFn,
-    GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
+    OperationName::GetCanonicalizationPatternsFn
+        &&getCanonicalizationPatternsFn,
     OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) {
   return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
       name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
@@ -449,7 +448,15 @@ void ExtensibleDialect::registerDynamicOp(
     std::unique_ptr<DynamicOpDefinition> &&op) {
   assert(op->dialect == this &&
          "trying to register a dynamic op in the wrong dialect");
-  RegisteredOperationName::insert(std::move(op), /*attrNames=*/{});
+  auto hasTraitFn = [](TypeID traitId) { return false; };
+
+  RegisteredOperationName::insert(
+      op->name, *op->dialect, op->typeID, std::move(op->parseFn),
+      std::move(op->printFn), std::move(op->verifyFn),
+      std::move(op->verifyRegionFn), std::move(op->foldHookFn),
+      std::move(op->getCanonicalizationPatternsFn),
+      detail::InterfaceMap::get<>(), std::move(hasTraitFn), {},
+      std::move(op->populateDefaultAttrsFn));
 }
 
 bool ExtensibleDialect::classof(const Dialect *dialect) {
index b0fe94f..8e3edc8 100644 (file)
@@ -180,7 +180,7 @@ public:
   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
 
   /// This is a mapping from operation name to the operation info describing it.
-  llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
+  llvm::StringMap<OperationName::Impl> operations;
 
   /// A vector of operation info specifically for registered operations.
   llvm::StringMap<RegisteredOperationName> registeredOperations;
@@ -706,11 +706,6 @@ AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
 // OperationName
 //===----------------------------------------------------------------------===//
 
-OperationName::Impl::Impl(StringRef name, Dialect *dialect, TypeID typeID,
-                          detail::InterfaceMap interfaceMap)
-    : Impl(StringAttr::get(dialect->getContext(), name), dialect, typeID,
-           std::move(interfaceMap)) {}
-
 OperationName::OperationName(StringRef name, MLIRContext *context) {
   MLIRContextImpl &ctxImpl = context->getImpl();
 
@@ -729,7 +724,7 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
     llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
     auto it = ctxImpl.operations.find(name);
     if (it != ctxImpl.operations.end()) {
-      impl = it->second.get();
+      impl = &it->second;
       return;
     }
   }
@@ -737,14 +732,10 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
   // Acquire a writer-lock so that we can safely create the new instance.
   ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
 
-  auto it = ctxImpl.operations.insert({name, nullptr});
-  if (it.second) {
-    auto nameAttr = StringAttr::get(context, name);
-    it.first->second = std::make_unique<UnregisteredOpModel>(
-        nameAttr, nameAttr.getReferencedDialect(), TypeID::get<void>(),
-        detail::InterfaceMap(std::nullopt));
-  }
-  impl = it.first->second.get();
+  auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
+  if (it.second)
+    it.first->second.name = StringAttr::get(context, name);
+  impl = &it.first->second;
 }
 
 StringRef OperationName::getDialectNamespace() const {
@@ -753,34 +744,6 @@ StringRef OperationName::getDialectNamespace() const {
   return getStringRef().split('.').first;
 }
 
-LogicalResult
-OperationName::UnregisteredOpModel::foldHook(Operation *, ArrayRef<Attribute>,
-                                             SmallVectorImpl<OpFoldResult> &) {
-  return failure();
-}
-void OperationName::UnregisteredOpModel::getCanonicalizationPatterns(
-    RewritePatternSet &, MLIRContext *) {}
-bool OperationName::UnregisteredOpModel::hasTrait(TypeID) { return false; }
-
-OperationName::ParseAssemblyFn
-OperationName::UnregisteredOpModel::getParseAssemblyFn() {
-  llvm::report_fatal_error("getParseAssemblyFn hook called on unregistered op");
-}
-void OperationName::UnregisteredOpModel::populateDefaultAttrs(
-    const OperationName &, NamedAttrList &) {}
-void OperationName::UnregisteredOpModel::printAssembly(
-    Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
-  p.printGenericOp(op);
-}
-LogicalResult
-OperationName::UnregisteredOpModel::verifyInvariants(Operation *) {
-  return success();
-}
-LogicalResult
-OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation *) {
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // RegisteredOperationName
 //===----------------------------------------------------------------------===//
@@ -794,11 +757,26 @@ RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
   return std::nullopt;
 }
 
+ParseResult
+RegisteredOperationName::parseAssembly(OpAsmParser &parser,
+                                       OperationState &result) const {
+  return impl->parseAssemblyFn(parser, result);
+}
+
+void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
+  impl->populateDefaultAttrsFn(*this, attrs);
+}
+
 void RegisteredOperationName::insert(
-    std::unique_ptr<RegisteredOperationName::Impl> ownedImpl,
-    ArrayRef<StringRef> attrNames) {
-  RegisteredOperationName::Impl *impl = ownedImpl.get();
-  MLIRContext *ctx = impl->getDialect()->getContext();
+    StringRef name, Dialect &dialect, TypeID typeID,
+    ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
+    VerifyInvariantsFn &&verifyInvariants,
+    VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
+    GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
+    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
+    ArrayRef<StringRef> attrNames,
+    PopulateDefaultAttrsFn &&populateDefaultAttrs) {
+  MLIRContext *ctx = dialect.getContext();
   auto &ctxImpl = ctx->getImpl();
   assert(ctxImpl.multiThreadedExecutionContext == 0 &&
          "registering a new operation kind while in a multi-threaded execution "
@@ -813,16 +791,21 @@ void RegisteredOperationName::insert(
         attrNames.size());
     for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
       new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
-    impl->attributeNames = cachedAttrNames;
   }
-  StringRef name = impl->getName().strref();
-  // Insert the operation info if it doesn't exist yet.
-  auto it = ctxImpl.operations.insert({name, nullptr});
-  it.first->second = std::move(ownedImpl);
 
-  // Update the registered info for this operation.
+  // Insert the operation info if it doesn't exist yet.
+  auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
+  if (it.second)
+    it.first->second.name = StringAttr::get(ctx, name);
+  OperationName::Impl &impl = it.first->second;
+
+  if (impl.isRegistered()) {
+    llvm::errs() << "error: operation named '" << name
+                 << "' is already registered.\n";
+    abort();
+  }
   auto emplaced = ctxImpl.registeredOperations.try_emplace(
-      name, RegisteredOperationName(impl));
+      name, RegisteredOperationName(&impl));
   assert(emplaced.second && "operation name registration must be successful");
 
   // Add emplaced operation name to the sorted operations container.
@@ -834,6 +817,20 @@ void RegisteredOperationName::insert(
                               rhs.getIdentifier());
                         }),
       value);
+
+  // Update the registered info for this operation.
+  impl.dialect = &dialect;
+  impl.typeID = typeID;
+  impl.interfaceMap = std::move(interfaceMap);
+  impl.foldHookFn = std::move(foldHook);
+  impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
+  impl.hasTraitFn = std::move(hasTrait);
+  impl.parseAssemblyFn = std::move(parseAssembly);
+  impl.printAssemblyFn = std::move(printAssembly);
+  impl.verifyInvariantsFn = std::move(verifyInvariants);
+  impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
+  impl.attributeNames = cachedAttrNames;
+  impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
 }
 
 //===----------------------------------------------------------------------===//
index 50815b3..0c5869b 100644 (file)
@@ -78,7 +78,8 @@ Operation *Operation::create(Location location, OperationName name,
   void *rawMem = mallocMem + prefixByteSize;
 
   // Populate default attributes.
-  name.populateDefaultAttrs(attributes);
+  if (Optional<RegisteredOperationName> info = name.getRegisteredInfo())
+    info->populateDefaultAttrs(attributes);
 
   // Create the new Operation.
   Operation *op = ::new (rawMem) Operation(
@@ -490,7 +491,8 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
                               SmallVectorImpl<OpFoldResult> &results) {
   // If we have a registered operation definition matching this one, use it to
   // try to constant fold the operation.
-  if (succeeded(name.foldHook(this, operands, results)))
+  Optional<RegisteredOperationName> info = getRegisteredInfo();
+  if (info && succeeded(info->foldHook(this, operands, results)))
     return success();
 
   // Otherwise, fall back on the dialect hook to handle it.
index 1ea4bef..91441ce 100644 (file)
@@ -1606,7 +1606,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
   unsigned numResults = read();
   if (numResults == kInferTypesMarker) {
     InferTypeOpInterface::Concept *inferInterface =
-        state.name.getInterface<InferTypeOpInterface>();
+        state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
     assert(inferInterface &&
            "expected operation to provide InferTypeOpInterface");
 
index 83937f4..c9d056b 100644 (file)
@@ -926,7 +926,7 @@ void OpEmitter::genAttrNameGetters() {
     const char *const getAttrName = R"(
   assert(index < {0} && "invalid attribute index");
   assert(name.getStringRef() == getOperationName() && "invalid operation name");
-  return name.getAttributeNames()[index];
+  return name.getRegisteredInfo()->getAttributeNames()[index];
 )";
     method->body() << formatv(getAttrName, attributes.size());
   }
@@ -1739,7 +1739,7 @@ void OpEmitter::genPopulateDefaultAttributes() {
     return;
 
   SmallVector<MethodParameter> paramList;
-  paramList.emplace_back("const ::mlir::OperationName &", "opName");
+  paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName");
   paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
   auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
   ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
index 2fc8a43..27b0978 100644 (file)
@@ -36,7 +36,6 @@ protected:
     registry.insert<func::FuncDialect, arith::ArithDialect>();
     ctx.appendDialectRegistry(registry);
     module = parseSourceString<ModuleOp>(ir, &ctx);
-    assert(module);
     mapFn = cast<func::FuncOp>(module->front());
   }