[mlir] Add extensible dialects
authorMathieu Fehr <mathieu.fehr@gmail.com>
Wed, 27 Apr 2022 02:48:13 +0000 (19:48 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Wed, 27 Apr 2022 02:48:22 +0000 (19:48 -0700)
Depends on D104534
Add support for extensible dialects, which are dialects that can be
extended at runtime with new operations and types.

These operations and types cannot at the moment implement traits
or interfaces.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D104554

18 files changed:
mlir/docs/ExtensibleDialects.md [new file with mode: 0644]
mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/DialectBase.td
mlir/include/mlir/IR/ExtensibleDialect.h [new file with mode: 0644]
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/TableGen/Dialect.h
mlir/lib/IR/CMakeLists.txt
mlir/lib/IR/ExtensibleDialect.cpp [new file with mode: 0644]
mlir/lib/TableGen/Dialect.cpp
mlir/test/IR/dynamic.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Test/TestAttributes.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.h
mlir/test/lib/Dialect/Test/TestDialect.td
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
mlir/tools/mlir-tblgen/DialectGen.cpp

diff --git a/mlir/docs/ExtensibleDialects.md b/mlir/docs/ExtensibleDialects.md
new file mode 100644 (file)
index 0000000..e46e406
--- /dev/null
@@ -0,0 +1,369 @@
+# Extensible dialects
+
+This file documents the design and API of the extensible dialects. Extensible
+dialects are dialects that can be extended with new operations and types defined
+at runtime. This allows for users to define dialects via with meta-programming,
+or from another language, without having to recompile C++ code.
+
+[TOC]
+
+## Usage
+
+### Defining an extensible dialect
+
+Dialects defined in C++ can be extended with new operations, types, etc., at
+runtime by inheriting from `mlir::ExtensibleDialect` instead of `mlir::Dialect`
+(note that `ExtensibleDialect` inherits from `Dialect`). The `ExtensibleDialect`
+class contains the necessary fields and methods to extend the dialect at
+runtime.
+
+```c++
+class MyDialect : public mlir::ExtensibleDialect {
+    ...
+}
+```
+
+For dialects defined in TableGen, this is done by setting the `isExtensible`
+flag to `1`.
+
+```tablegen
+def Test_Dialect : Dialect {
+  let isExtensible = 1;
+  ...
+}
+```
+
+An extensible `Dialect` can be casted back to `ExtensibleDialect` using
+`llvm::dyn_cast`, or `llvm::cast`:
+
+```c++
+if (auto extensibleDialect = llvm::dyn_cast<ExtensibleDialect>(dialect)) {
+    ...
+}
+```
+
+### Defining an operation at runtime
+
+The `DynamicOpDefinition` class represents the definition of an operation
+defined at runtime. It is created using the `DynamicOpDefinition::get`
+functions. An operation defined at runtime must provide a name, a dialect in
+which the operation will be registered in, an operation verifier. It may also
+optionally define a custom parser and a printer, fold hook, and more.
+
+```c++
+// The operation name, without the dialect name prefix.
+StringRef name = "my_operation_name";
+
+// The dialect defining the operation.
+Dialect* dialect = ctx->getOrLoadDialect<MyDialect>();
+
+// Operation verifier definition.
+AbstractOperation::VerifyInvariantsFn verifyFn = [](Operation* op) {
+    // Logic for the operation verification.
+    ...
+}
+
+// Parser function definition.
+AbstractOperation::ParseAssemblyFn parseFn =
+    [](OpAsmParser &parser, OperationState &state) {
+        // Parse the operation, given that the name is already parsed.
+        ...    
+};
+
+// Printer function
+auto printFn = [](Operation *op, OpAsmPrinter &printer) {
+        printer << op->getName();
+        // Print the operation, given that the name is already printed.
+        ...
+};
+
+// General folder implementation, see AbstractOperation::foldHook for more
+// information.
+auto foldHookFn = [](Operation * op, ArrayRef<Attribute> operands, 
+                                   SmallVectorImpl<OpFoldResult> &result) {
+    ...
+};
+
+// Returns any canonicalization pattern rewrites that the operation
+// supports, for use by the canonicalization pass.
+auto getCanonicalizationPatterns = 
+        [](RewritePatternSet &results, MLIRContext *context) {
+    ...
+}
+
+// Definition of the operation.
+std::unique_ptr<DynamicOpDefinition> opDef =
+    DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
+        std::move(parseFn), std::move(printFn), std::move(foldHookFn),
+        std::move(getCanonicalizationPatterns));
+```
+
+Once the operation is defined, it can be registered by an `ExtensibleDialect`:
+
+```c++
+extensibleDialect->registerDynamicOperation(std::move(opDef));
+```
+
+Note that the `Dialect` given to the operation should be the one registering
+the operation.
+
+### Using an operation defined at runtime
+
+It is possible to match on an operation defined at runtime using their names:
+
+```c++
+if (op->getName().getStringRef() == "my_dialect.my_dynamic_op") {
+    ...
+}
+```
+
+An operation defined at runtime can be created by instantiating an
+`OperationState` with the operation name, and using it with a rewriter
+(for instance a `PatternRewriter`) to create the operation.
+
+```c++
+OperationState state(location, "my_dialect.my_dynamic_op",
+                     operands, resultTypes, attributes);
+
+rewriter.createOperation(state);
+```
+
+
+### Defining a type at runtime
+
+Contrary to types defined in C++ or in TableGen, types defined at runtime can
+only have as argument a list of `Attribute`.
+
+Similarily to operations, a type is defined at runtime using the class
+`DynamicTypeDefinition`, which is created using the `DynamicTypeDefinition::get`
+functions. A type definition requires a name, the dialect that will register the
+type, and a parameter verifier. It can also define optionally a custom parser
+and printer for the arguments (the type name is assumed to be already
+parsed/printed).
+
+```c++
+// The type name, without the dialect name prefix.
+StringRef name = "my_type_name";
+
+// The dialect defining the type.
+Dialect* dialect = ctx->getOrLoadDialect<MyDialect>();
+
+// The type verifier.
+// A type defined at runtime has a list of attributes as parameters.
+auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<Attribute> args) {
+    ...
+};
+
+// The type parameters parser.
+auto parser = [](DialectAsmParser &parser,
+                 llvm::SmallVectorImpl<Attribute> &parsedParams) {
+    ...
+};
+
+// The type parameters printer.
+auto printer =[](DialectAsmPrinter &printer, ArrayRef<Attribute> params) {
+    ...
+};
+
+std::unique_ptr<DynamicTypeDefinition> typeDef =
+    DynamicTypeDefinition::get(std::move(name), std::move(dialect),
+                               std::move(verifier), std::move(printer),
+                               std::move(parser));
+```
+
+If the printer and the parser are ommited, a default parser and printer is
+generated with the format `!dialect.typename<arg1, arg2, ..., argN>`.
+
+The type can then be registered by the `ExtensibleDialect`:
+
+```c++
+dialect->registerDynamicType(std::move(typeDef));
+```
+
+### Parsing types defined at runtime in an extensible dialect
+
+`parseType` methods generated by TableGen can parse types defined at runtime,
+though overriden `parseType` methods need to add the necessary support for them.
+
+```c++
+Type MyDialect::parseType(DialectAsmParser &parser) const {
+    ...
+    
+    // The type name.
+    StringRef typeTag;
+    if (failed(parser.parseKeyword(&typeTag)))
+        return Type();
+
+    // Try to parse a dynamic type with 'typeTag' name.
+    Type dynType;
+    auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
+    if (parseResult.hasValue()) {
+        if (succeeded(parseResult.getValue()))
+            return dynType;
+         return Type();
+    }
+    
+    ...
+}
+```
+
+### Using a type defined at runtime
+
+Dynamic types are instances of `DynamicType`. It is possible to get a dynamic
+type with `DynamicType::get` and `ExtensibleDialect::lookupTypeDefinition`.
+
+```c++
+auto typeDef = extensibleDialect->lookupTypeDefinition("my_dynamic_type");
+ArrayRef<Attribute> params = ...;
+auto type = DynamicType::get(typeDef, params);
+```
+
+It is also possible to cast a `Type` known to be defined at runtime to a
+`DynamicType`.
+
+```c++
+auto dynType = type.cast<DynamicType>();
+auto typeDef = dynType.getTypeDef();
+auto args = dynType.getParams();
+```
+
+### Defining an attribute at runtime
+
+Similar to types defined at runtime, attributes defined at runtime can only have
+as argument a list of `Attribute`.
+
+Similarily to types, an attribute is defined at runtime using the class
+`DynamicAttrDefinition`, which is created using the `DynamicAttrDefinition::get`
+functions. An attribute definition requires a name, the dialect that will
+register the attribute, and a parameter verifier. It can also define optionally
+a custom parser and printer for the arguments (the attribute name is assumed to
+be already parsed/printed).
+
+```c++
+// The attribute name, without the dialect name prefix.
+StringRef name = "my_attribute_name";
+
+// The dialect defining the attribute.
+Dialect* dialect = ctx->getOrLoadDialect<MyDialect>();
+
+// The attribute verifier.
+// An attribute defined at runtime has a list of attributes as parameters.
+auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<Attribute> args) {
+    ...
+};
+
+// The attribute parameters parser.
+auto parser = [](DialectAsmParser &parser,
+                 llvm::SmallVectorImpl<Attribute> &parsedParams) {
+    ...
+};
+
+// The attribute parameters printer.
+auto printer =[](DialectAsmPrinter &printer, ArrayRef<Attribute> params) {
+    ...
+};
+
+std::unique_ptr<DynamicAttrDefinition> attrDef =
+    DynamicAttrDefinition::get(std::move(name), std::move(dialect),
+                               std::move(verifier), std::move(printer),
+                               std::move(parser));
+```
+
+If the printer and the parser are ommited, a default parser and printer is
+generated with the format `!dialect.attrname<arg1, arg2, ..., argN>`.
+
+The attribute can then be registered by the `ExtensibleDialect`:
+
+```c++
+dialect->registerDynamicAttr(std::move(typeDef));
+```
+
+### Parsing attributes defined at runtime in an extensible dialect
+
+`parseAttribute` methods generated by TableGen can parse attributes defined at
+runtime, though overriden `parseAttribute` methods need to add the necessary
+support for them.
+
+```c++
+Attribute MyDialect::parseAttribute(DialectAsmParser &parser,
+                                    Type type) const override {
+    ...
+    // The attribute name.
+    StringRef attrTag;
+    if (failed(parser.parseKeyword(&attrTag)))
+        return Attribute();
+
+    // Try to parse a dynamic attribute with 'attrTag' name.
+    Attribute dynAttr;
+    auto parseResult = parseOptionalDynamicAttr(attrTag, parser, dynAttr);
+    if (parseResult.hasValue()) {
+        if (succeeded(parseResult.getValue()))
+            return dynAttr;
+         return Attribute();
+    }
+```
+
+### Using an attribute defined at runtime
+
+Similar to types, attributes defined at runtime are instances of `DynamicAttr`.
+It is possible to get a dynamic attribute with `DynamicAttr::get` and
+`ExtensibleDialect::lookupAttrDefinition`.
+
+```c++
+auto attrDef = extensibleDialect->lookupAttrDefinition("my_dynamic_attr");
+ArrayRef<Attribute> params = ...;
+auto attr = DynamicAttr::get(attrDef, params);
+```
+
+It is also possible to cast an `Attribute` known to be defined at runtime to a
+`DynamicAttr`.
+
+```c++
+auto dynAttr = attr.cast<DynamicAttr>();
+auto attrDef = dynAttr.getAttrDef();
+auto args = dynAttr.getParams();
+```
+
+## Implementation details
+
+### Extensible dialect
+
+The role of extensible dialects is to own the necessary data for defined
+operations and types. They also contain the necessary accessors to easily
+access them.
+
+In order to cast a `Dialect` back to an `ExtensibleDialect`, we implement the
+`IsExtensibleDialect` interface to all `ExtensibleDialect`. The casting is done
+by checking if the `Dialect` implements `IsExtensibleDialect` or not.
+
+### Operation representation and registration
+
+Operations are represented in mlir using the `AbstractOperation` class. They are
+registered in dialects the same way operations defined in C++ are registered,
+which is by calling `AbstractOperation::insert`.
+
+The only difference is that a new `TypeID` needs to be created for each
+operation, since operations are not represented by a C++ class. This is done
+using a `TypeIDAllocator`, which can allocate a new unique `TypeID` at runtime.
+
+### Type representation and registration
+
+Unlike operations, types need to define a C++ storage class that takes care of
+type parameters. They also need to define another C++ class to access that
+storage. `DynamicTypeStorage` defines the storage of types defined at runtime,
+and `DynamicType` gives access to the storage, as well as defining useful
+functions. A `DynamicTypeStorage` contains a list of `Attribute` type
+parameters, as well as a pointer to the type definition.
+
+Types are registered using the `Dialect::addType` method, which expect a
+`TypeID` that is generated using a `TypeIDAllocator`. The type uniquer also
+register the type with the given `TypeID`. This mean that we can reuse our
+single `DynamicType` with different `TypeID` to represent the different types
+defined at runtime.
+
+Since the different types defined at runtime have different `TypeID`, it is not
+possible to use `TypeID` to cast a `Type` into a `DynamicType`. Thus, similar to
+`Dialect`, all `DynamicType` define a `IsDynamicTypeTrait`, so casting a `Type`
+to a `DynamicType` boils down to querying the `IsDynamicTypeTrait` trait.
index 9745207..948eaf8 100644 (file)
@@ -45,6 +45,17 @@ public:
                              T::getTypeID());
   }
 
+  /// This method is used by Dialect objects to register attributes with
+  /// custom TypeIDs.
+  /// The use of this method is in general discouraged in favor of
+  /// 'get<CustomAttribute>(dialect)'.
+  static AbstractAttribute get(Dialect &dialect,
+                               detail::InterfaceMap &&interfaceMap,
+                               HasTraitFn &&hasTrait, TypeID typeID) {
+    return AbstractAttribute(dialect, std::move(interfaceMap),
+                             std::move(hasTrait), typeID);
+  }
+
   /// Return the dialect this attribute was registered to.
   Dialect &getDialect() const { return const_cast<Dialect &>(dialect); }
 
@@ -175,14 +186,22 @@ namespace detail {
 // MLIRContext. This class manages all creation and uniquing of attributes.
 class AttributeUniquer {
 public:
+  /// Get an uniqued instance of an attribute T.
+  template <typename T, typename... Args>
+  static T get(MLIRContext *ctx, Args &&...args) {
+    return getWithTypeID<T, Args...>(ctx, T::getTypeID(),
+                                     std::forward<Args>(args)...);
+  }
+
   /// Get an uniqued instance of a parametric attribute T.
+  /// The use of this method is in general discouraged in favor of
+  /// 'get<T, Args>(ctx, args)'.
   template <typename T, typename... Args>
   static typename std::enable_if_t<
       !std::is_same<typename T::ImplType, AttributeStorage>::value, T>
-  get(MLIRContext *ctx, Args &&...args) {
+  getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) {
 #ifndef NDEBUG
-    if (!ctx->getAttributeUniquer().isParametricStorageInitialized(
-            T::getTypeID()))
+    if (!ctx->getAttributeUniquer().isParametricStorageInitialized(typeID))
       llvm::report_fatal_error(
           llvm::Twine("can't create Attribute '") + llvm::getTypeName<T>() +
           "' because storage uniquer isn't initialized: the dialect was likely "
@@ -190,30 +209,31 @@ public:
           "in the Dialect::initialize() method.");
 #endif
     return ctx->getAttributeUniquer().get<typename T::ImplType>(
-        [ctx](AttributeStorage *storage) {
-          initializeAttributeStorage(storage, ctx, T::getTypeID());
+        [typeID, ctx](AttributeStorage *storage) {
+          initializeAttributeStorage(storage, ctx, typeID);
 
           // Execute any additional attribute storage initialization with the
           // context.
           static_cast<typename T::ImplType *>(storage)->initialize(ctx);
         },
-        T::getTypeID(), std::forward<Args>(args)...);
+        typeID, std::forward<Args>(args)...);
   }
   /// Get an uniqued instance of a singleton attribute T.
+  /// The use of this method is in general discouraged in favor of
+  /// 'get<T, Args>(ctx, args)'.
   template <typename T>
   static typename std::enable_if_t<
       std::is_same<typename T::ImplType, AttributeStorage>::value, T>
-  get(MLIRContext *ctx) {
+  getWithTypeID(MLIRContext *ctx, TypeID typeID) {
 #ifndef NDEBUG
-    if (!ctx->getAttributeUniquer().isSingletonStorageInitialized(
-            T::getTypeID()))
+    if (!ctx->getAttributeUniquer().isSingletonStorageInitialized(typeID))
       llvm::report_fatal_error(
           llvm::Twine("can't create Attribute '") + llvm::getTypeName<T>() +
           "' because storage uniquer isn't initialized: the dialect was likely "
           "not loaded, or the attribute wasn't added with addAttributes<...>() "
           "in the Dialect::initialize() method.");
 #endif
-    return ctx->getAttributeUniquer().get<typename T::ImplType>(T::getTypeID());
+    return ctx->getAttributeUniquer().get<typename T::ImplType>(typeID);
   }
 
   template <typename T, typename... Args>
@@ -224,23 +244,33 @@ public:
                                              std::forward<Args>(args)...);
   }
 
+  /// Register an attribute instance T with the uniquer.
+  template <typename T>
+  static void registerAttribute(MLIRContext *ctx) {
+    registerAttribute<T>(ctx, T::getTypeID());
+  }
+
   /// Register a parametric attribute instance T with the uniquer.
+  /// The use of this method is in general discouraged in favor of
+  /// 'registerAttribute<T>(ctx)'.
   template <typename T>
   static typename std::enable_if_t<
       !std::is_same<typename T::ImplType, AttributeStorage>::value>
-  registerAttribute(MLIRContext *ctx) {
+  registerAttribute(MLIRContext *ctx, TypeID typeID) {
     ctx->getAttributeUniquer()
-        .registerParametricStorageType<typename T::ImplType>(T::getTypeID());
+        .registerParametricStorageType<typename T::ImplType>(typeID);
   }
   /// Register a singleton attribute instance T with the uniquer.
+  /// The use of this method is in general discouraged in favor of
+  /// 'registerAttribute<T>(ctx)'.
   template <typename T>
   static typename std::enable_if_t<
       std::is_same<typename T::ImplType, AttributeStorage>::value>
-  registerAttribute(MLIRContext *ctx) {
+  registerAttribute(MLIRContext *ctx, TypeID typeID) {
     ctx->getAttributeUniquer()
         .registerSingletonStorageType<typename T::ImplType>(
-            T::getTypeID(), [ctx](AttributeStorage *storage) {
-              initializeAttributeStorage(storage, ctx, T::getTypeID());
+            typeID, [ctx, typeID](AttributeStorage *storage) {
+              initializeAttributeStorage(storage, ctx, typeID);
             });
   }
 
index c7a70bf..ee14ec0 100644 (file)
@@ -221,6 +221,11 @@ protected:
     (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
   }
 
+  /// Register an attribute instance with this dialect.
+  /// The use of this method is in general discouraged in favor of
+  /// 'addAttributes<CustomAttr>()'.
+  void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
+
   /// Enable support for unregistered operations.
   void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
 
@@ -237,7 +242,6 @@ private:
     addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
     detail::AttributeUniquer::registerAttribute<T>(context);
   }
-  void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
 
   /// Register a type instance with this dialect.
   template <typename T> void addType() {
index ce7d014..8ce5d95 100644 (file)
@@ -94,6 +94,9 @@ class Dialect {
   // UpperCamel) and prefixed with `get` or `set` depending on if it is a getter
   // or setter.
   int emitAccessorPrefix = kEmitAccessorPrefix_Raw;
+
+  // If this dialect can be extended at runtime with new operations or types.
+  bit isExtensible = 0;
 }
 
 #endif // DIALECTBASE_TD
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
new file mode 100644 (file)
index 0000000..ee83ef5
--- /dev/null
@@ -0,0 +1,556 @@
+//===- ExtensibleDialect.h - Extensible dialect -----------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the DynamicOpDefinition class, the DynamicTypeDefinition
+// class, and the DynamicAttrDefinition class, which represent respectively
+// operations, types, and attributes that can be defined at runtime. They can
+// be registered at runtime to an extensible dialect, using the
+// ExtensibleDialect class defined in this file.
+//
+// For a more complete documentation, see
+// https://mlir.llvm.org/docs/ExtensibleDialects/ .
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_EXTENSIBLEDIALECT_H
+#define MLIR_IR_EXTENSIBLEDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/TypeID.h"
+#include "llvm/ADT/StringMap.h"
+
+namespace mlir {
+class AsmParser;
+class AsmPrinter;
+class DynamicAttr;
+class DynamicType;
+class ExtensibleDialect;
+class MLIRContext;
+class OptionalParseResult;
+class ParseResult;
+
+namespace detail {
+struct DynamicAttrStorage;
+struct DynamicTypeStorage;
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// Dynamic attribute
+//===----------------------------------------------------------------------===//
+
+/// The definition of a dynamic attribute. A dynamic attribute is an attribute
+/// that is defined at runtime, and that can be registered at runtime by an
+/// extensible dialect (a dialect inheriting ExtensibleDialect). This class
+/// stores the parser, the printer, and the verifier of the attribute. Each
+/// dynamic attribute definition refers to one instance of this class.
+class DynamicAttrDefinition : SelfOwningTypeID {
+public:
+  using VerifierFn = llvm::unique_function<LogicalResult(
+      function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) const>;
+  using ParserFn = llvm::unique_function<ParseResult(
+      AsmParser &parser, llvm::SmallVectorImpl<Attribute> &parsedAttributes)
+                                             const>;
+  using PrinterFn = llvm::unique_function<void(
+      AsmPrinter &printer, ArrayRef<Attribute> params) const>;
+
+  /// Create a new attribute definition at runtime. The attribute is registered
+  /// only after passing it to the dialect using registerDynamicAttr.
+  static std::unique_ptr<DynamicAttrDefinition>
+  get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier);
+  static std::unique_ptr<DynamicAttrDefinition>
+  get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier,
+      ParserFn &&parser, PrinterFn &&printer);
+
+  /// Check that the attribute parameters are valid.
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       ArrayRef<Attribute> params) const {
+    return verifier(emitError, params);
+  }
+
+  /// Return the MLIRContext in which the dynamic attributes are uniqued.
+  MLIRContext &getContext() const { return *ctx; }
+
+  /// Return the name of the attribute, in the format 'attrname' and
+  /// not 'dialectname.attrname'.
+  StringRef getName() const { return name; }
+
+  /// Return the dialect defining the attribute.
+  ExtensibleDialect *getDialect() const { return dialect; }
+
+private:
+  DynamicAttrDefinition(StringRef name, ExtensibleDialect *dialect,
+                        VerifierFn &&verifier, ParserFn &&parser,
+                        PrinterFn &&printer);
+
+  /// This constructor should only be used when we need a pointer to
+  /// the DynamicAttrDefinition in the verifier, the parser, or the printer.
+  /// The verifier, parser, and printer need thus to be initialized after the
+  /// constructor.
+  DynamicAttrDefinition(ExtensibleDialect *dialect, StringRef name);
+
+  /// Register the concrete attribute in the attribute Uniquer.
+  void registerInAttrUniquer();
+
+  /// The name should be prefixed with the dialect name followed by '.'.
+  std::string name;
+
+  /// Dialect in which this attribute is defined.
+  ExtensibleDialect *dialect;
+
+  /// The attribute verifier. It checks that the attribute parameters satisfy
+  /// the invariants.
+  VerifierFn verifier;
+
+  /// The attribute parameters parser. It parses only the parameters, and
+  /// expects the attribute name to have already been parsed.
+  ParserFn parser;
+
+  /// The attribute parameters printer. It prints only the parameters, and
+  /// expects the attribute name to have already been printed.
+  PrinterFn printer;
+
+  /// Context in which the concrete attributes are uniqued.
+  MLIRContext *ctx;
+
+  friend ExtensibleDialect;
+  friend DynamicAttr;
+};
+
+/// This trait is used to determine if an attribute is a dynamic attribute or
+/// not; it should only be implemented by dynamic attributes.
+/// Note: This is only required because dynamic attributes do not have a
+/// static/single TypeID.
+namespace AttributeTrait {
+template <typename ConcreteType>
+class IsDynamicAttr : public TraitBase<ConcreteType, IsDynamicAttr> {};
+} // namespace AttributeTrait
+
+/// A dynamic attribute instance. This is an attribute whose definition is
+/// defined at runtime.
+/// It is possible to check if an attribute is a dynamic attribute using
+/// `my_attr.isa<DynamicAttr>()`, and getting the attribute definition of a
+/// dynamic attribute using the `DynamicAttr::getAttrDef` method.
+/// All dynamic attributes have the same storage, which is an array of
+/// attributes.
+
+class DynamicAttr : public Attribute::AttrBase<DynamicAttr, Attribute,
+                                               detail::DynamicAttrStorage,
+                                               AttributeTrait::IsDynamicAttr> {
+public:
+  // Inherit Base constructors.
+  using Base::Base;
+
+  /// Return an instance of a dynamic attribute given a dynamic attribute
+  /// definition and attribute parameters.
+  /// This asserts that the attribute verifier succeeded.
+  static DynamicAttr get(DynamicAttrDefinition *attrDef,
+                         ArrayRef<Attribute> params = {});
+
+  /// Return an instance of a dynamic attribute given a dynamic attribute
+  /// definition and attribute parameters. If the parameters provided are
+  /// invalid, errors are emitted using the provided location and a null object
+  /// is returned.
+  static DynamicAttr getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                DynamicAttrDefinition *attrDef,
+                                ArrayRef<Attribute> params = {});
+
+  /// Return the attribute definition of the concrete attribute.
+  DynamicAttrDefinition *getAttrDef();
+
+  /// Return the attribute parameters.
+  ArrayRef<Attribute> getParams();
+
+  /// Check if an attribute is a specific dynamic attribute.
+  static bool isa(Attribute attr, DynamicAttrDefinition *attrDef) {
+    return attr.getTypeID() == attrDef->getTypeID();
+  }
+
+  /// Check if an attribute is a dynamic attribute.
+  static bool classof(Attribute attr);
+
+  /// Parse the dynamic attribute parameters and construct the attribute.
+  /// The parameters are either empty, and nothing is parsed,
+  /// or they are in the format '<>' or '<attr (,attr)*>'.
+  static ParseResult parse(AsmParser &parser, DynamicAttrDefinition *attrDef,
+                           DynamicAttr &parsedAttr);
+
+  /// Print the dynamic attribute with the format 'attrname' if there is no
+  /// parameters, or 'attrname<attr (,attr)*>'.
+  void print(AsmPrinter &printer);
+};
+
+//===----------------------------------------------------------------------===//
+// Dynamic type
+//===----------------------------------------------------------------------===//
+
+/// The definition of a dynamic type. A dynamic type is a type that is
+/// defined at runtime, and that can be registered at runtime by an
+/// extensible dialect (a dialect inheriting ExtensibleDialect). This class
+/// stores the parser, the printer, and the verifier of the type. Each dynamic
+/// type definition refers to one instance of this class.
+class DynamicTypeDefinition : SelfOwningTypeID {
+public:
+  using VerifierFn = llvm::unique_function<LogicalResult(
+      function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) const>;
+  using ParserFn = llvm::unique_function<ParseResult(
+      AsmParser &parser, llvm::SmallVectorImpl<Attribute> &parsedAttributes)
+                                             const>;
+  using PrinterFn = llvm::unique_function<void(
+      AsmPrinter &printer, ArrayRef<Attribute> params) const>;
+
+  /// Create a new dynamic type definition. The type is registered only after
+  /// passing it to the dialect using registerDynamicType.
+  static std::unique_ptr<DynamicTypeDefinition>
+  get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier);
+  static std::unique_ptr<DynamicTypeDefinition>
+  get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier,
+      ParserFn &&parser, PrinterFn &&printer);
+
+  /// Check that the type parameters are valid.
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       ArrayRef<Attribute> params) const {
+    return verifier(emitError, params);
+  }
+
+  /// Return the MLIRContext in which the dynamic types is uniqued.
+  MLIRContext &getContext() const { return *ctx; }
+
+  /// Return the name of the type, in the format 'typename' and
+  /// not 'dialectname.typename'.
+  StringRef getName() const { return name; }
+
+  /// Return the dialect defining the type.
+  ExtensibleDialect *getDialect() const { return dialect; }
+
+private:
+  DynamicTypeDefinition(StringRef name, ExtensibleDialect *dialect,
+                        VerifierFn &&verifier, ParserFn &&parser,
+                        PrinterFn &&printer);
+
+  /// This constructor should only be used when we need a pointer to
+  /// the DynamicTypeDefinition in the verifier, the parser, or the printer.
+  /// The verifier, parser, and printer need thus to be initialized after the
+  /// constructor.
+  DynamicTypeDefinition(ExtensibleDialect *dialect, StringRef name);
+
+  /// Register the concrete type in the type Uniquer.
+  void registerInTypeUniquer();
+
+  /// The name should be prefixed with the dialect name followed by '.'.
+  std::string name;
+
+  /// Dialect in which this type is defined.
+  ExtensibleDialect *dialect;
+
+  /// The type verifier. It checks that the type parameters satisfy the
+  /// invariants.
+  VerifierFn verifier;
+
+  /// The type parameters parser. It parses only the parameters, and expects the
+  /// type name to have already been parsed.
+  ParserFn parser;
+
+  /// The type parameters printer. It prints only the parameters, and expects
+  /// the type name to have already been printed.
+  PrinterFn printer;
+
+  /// Context in which the concrete types are uniqued.
+  MLIRContext *ctx;
+
+  friend ExtensibleDialect;
+  friend DynamicType;
+};
+
+/// This trait is used to determine if a type is a dynamic type or not;
+/// it should only be implemented by dynamic types.
+/// Note: This is only required because dynamic type do not have a
+/// static/single TypeID.
+namespace TypeTrait {
+template <typename ConcreteType>
+class IsDynamicType : public TypeTrait::TraitBase<ConcreteType, IsDynamicType> {
+};
+} // namespace TypeTrait
+
+/// A dynamic type instance. This is a type whose definition is defined at
+/// runtime.
+/// It is possible to check if a type is a dynamic type using
+/// `my_type.isa<DynamicType>()`, and getting the type definition of a dynamic
+/// type using the `DynamicType::getTypeDef` method.
+/// All dynamic types have the same storage, which is an array of attributes.
+class DynamicType
+    : public Type::TypeBase<DynamicType, Type, detail::DynamicTypeStorage,
+                            TypeTrait::IsDynamicType> {
+public:
+  // Inherit Base constructors.
+  using Base::Base;
+
+  /// Return an instance of a dynamic type given a dynamic type definition and
+  /// type parameters.
+  /// This asserts that the type verifier succeeded.
+  static DynamicType get(DynamicTypeDefinition *typeDef,
+                         ArrayRef<Attribute> params = {});
+
+  /// Return an instance of a dynamic type given a dynamic type definition and
+  /// type parameters. If the parameters provided are invalid, errors are
+  /// emitted using the provided location and a null object is returned.
+  static DynamicType getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                DynamicTypeDefinition *typeDef,
+                                ArrayRef<Attribute> params = {});
+
+  /// Return the type definition of the concrete type.
+  DynamicTypeDefinition *getTypeDef();
+
+  /// Return the type parameters.
+  ArrayRef<Attribute> getParams();
+
+  /// Check if a type is a specific dynamic type.
+  static bool isa(Type type, DynamicTypeDefinition *typeDef) {
+    return type.getTypeID() == typeDef->getTypeID();
+  }
+
+  /// Check if a type is a dynamic type.
+  static bool classof(Type type);
+
+  /// Parse the dynamic type parameters and construct the type.
+  /// The parameters are either empty, and nothing is parsed,
+  /// or they are in the format '<>' or '<attr (,attr)*>'.
+  static ParseResult parse(AsmParser &parser, DynamicTypeDefinition *typeDef,
+                           DynamicType &parsedType);
+
+  /// Print the dynamic type with the format
+  /// 'type' or 'type<>' if there is no parameters, or 'type<attr (,attr)*>'.
+  void print(AsmPrinter &printer);
+};
+
+//===----------------------------------------------------------------------===//
+// Dynamic operation
+//===----------------------------------------------------------------------===//
+
+/// The definition of a dynamic op. A dynamic op is an op that is defined at
+/// runtime, and that can be registered at runtime by an extensible dialect (a
+/// dialect inheriting ExtensibleDialect). This class stores the functions that
+/// are in the OperationName class, and in addition defines the TypeID of the op
+/// that will be defined.
+/// Each dynamic operation definition refers to one instance of this class.
+class DynamicOpDefinition {
+public:
+  /// Create a new op at runtime. The op is registered only after passing it to
+  /// the dialect using registerDynamicOp.
+  static std::unique_ptr<DynamicOpDefinition>
+  get(StringRef name, ExtensibleDialect *dialect,
+      OperationName::VerifyInvariantsFn &&verifyFn,
+      OperationName::VerifyRegionInvariantsFn &&verifyRegionFn);
+  static std::unique_ptr<DynamicOpDefinition>
+  get(StringRef name, ExtensibleDialect *dialect,
+      OperationName::VerifyInvariantsFn &&verifyFn,
+      OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
+      OperationName::ParseAssemblyFn &&parseFn,
+      OperationName::PrintAssemblyFn &&printFn);
+  static std::unique_ptr<DynamicOpDefinition>
+  get(StringRef name, ExtensibleDialect *dialect,
+      OperationName::VerifyInvariantsFn &&verifyFn,
+      OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
+      OperationName::ParseAssemblyFn &&parseFn,
+      OperationName::PrintAssemblyFn &&printFn,
+      OperationName::FoldHookFn &&foldHookFn,
+      OperationName::GetCanonicalizationPatternsFn
+          &&getCanonicalizationPatternsFn);
+
+  /// Returns the op typeID.
+  TypeID getTypeID() { return typeID; }
+
+  /// Sets the verifier function for this operation. It should emits an error
+  /// message and returns failure if a problem is detected, or returns success
+  /// if everything is ok.
+  void setVerifyFn(OperationName::VerifyInvariantsFn &&verify) {
+    verifyFn = std::move(verify);
+  }
+
+  /// Sets the region verifier function for this operation. It should emits an
+  /// error message and returns failure if a problem is detected, or returns
+  /// success if everything is ok.
+  void setVerifyRegionFn(OperationName::VerifyRegionInvariantsFn &&verify) {
+    verifyRegionFn = std::move(verify);
+  }
+
+  /// Sets the static hook for parsing this op assembly.
+  void setParseFn(OperationName::ParseAssemblyFn &&parse) {
+    parseFn = std::move(parse);
+  }
+
+  /// Sets the static hook for printing this op assembly.
+  void setPrintFn(OperationName::PrintAssemblyFn &&print) {
+    printFn = std::move(print);
+  }
+
+  /// Sets the hook implementing a generalized folder for the op. See
+  /// `RegisteredOperationName::foldHook` for more details
+  void setFoldHookFn(OperationName::FoldHookFn &&foldHook) {
+    foldHookFn = std::move(foldHook);
+  }
+
+  /// Set the hook returning any canonicalization pattern rewrites that the op
+  /// supports, for use by the canonicalization pass.
+  void
+  setGetCanonicalizationPatternsFn(OperationName::GetCanonicalizationPatternsFn
+                                       &&getCanonicalizationPatterns) {
+    getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
+  }
+
+private:
+  DynamicOpDefinition(StringRef name, ExtensibleDialect *dialect,
+                      OperationName::VerifyInvariantsFn &&verifyFn,
+                      OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
+                      OperationName::ParseAssemblyFn &&parseFn,
+                      OperationName::PrintAssemblyFn &&printFn,
+                      OperationName::FoldHookFn &&foldHookFn,
+                      OperationName::GetCanonicalizationPatternsFn
+                          &&getCanonicalizationPatternsFn);
+
+  /// Unique identifier for this operation.
+  TypeID typeID;
+
+  /// Name of the operation.
+  /// The name is prefixed with the dialect name.
+  std::string name;
+
+  /// Dialect defining this operation.
+  ExtensibleDialect *dialect;
+
+  OperationName::VerifyInvariantsFn verifyFn;
+  OperationName::VerifyRegionInvariantsFn verifyRegionFn;
+  OperationName::ParseAssemblyFn parseFn;
+  OperationName::PrintAssemblyFn printFn;
+  OperationName::FoldHookFn foldHookFn;
+  OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+
+  friend ExtensibleDialect;
+};
+
+//===----------------------------------------------------------------------===//
+// Extensible dialect
+//===----------------------------------------------------------------------===//
+
+/// A dialect that can be extended with new operations/types/attributes at
+/// runtime.
+class ExtensibleDialect : public mlir::Dialect {
+public:
+  ExtensibleDialect(StringRef name, MLIRContext *ctx, TypeID typeID);
+
+  /// Add a new type defined at runtime to the dialect.
+  void registerDynamicType(std::unique_ptr<DynamicTypeDefinition> &&type);
+
+  /// Add a new attribute defined at runtime to the dialect.
+  void registerDynamicAttr(std::unique_ptr<DynamicAttrDefinition> &&attr);
+
+  /// Add a new operation defined at runtime to the dialect.
+  void registerDynamicOp(std::unique_ptr<DynamicOpDefinition> &&type);
+
+  /// Check if the dialect is an extensible dialect.
+  static bool classof(const Dialect *dialect);
+
+  /// Returns nullptr if the definition was not found.
+  DynamicTypeDefinition *lookupTypeDefinition(StringRef name) const {
+    auto it = nameToDynTypes.find(name);
+    if (it == nameToDynTypes.end())
+      return nullptr;
+    return it->second;
+  }
+
+  /// Returns nullptr if the definition was not found.
+  DynamicTypeDefinition *lookupTypeDefinition(TypeID id) const {
+    auto it = dynTypes.find(id);
+    if (it == dynTypes.end())
+      return nullptr;
+    return it->second.get();
+  }
+
+  /// Returns nullptr if the definition was not found.
+  DynamicAttrDefinition *lookupAttrDefinition(StringRef name) const {
+    auto it = nameToDynAttrs.find(name);
+    if (it == nameToDynAttrs.end())
+      return nullptr;
+    return it->second;
+  }
+
+  /// Returns nullptr if the definition was not found.
+  DynamicAttrDefinition *lookupAttrDefinition(TypeID id) const {
+    auto it = dynAttrs.find(id);
+    if (it == dynAttrs.end())
+      return nullptr;
+    return it->second.get();
+  }
+
+protected:
+  /// Parse the dynamic type 'typeName' in the dialect 'dialect'.
+  /// typename should not be prefixed with the dialect name.
+  /// If the dynamic type does not exist, return no value.
+  /// Otherwise, parse it, and return the parse result.
+  /// If the parsing succeed, put the resulting type in 'resultType'.
+  OptionalParseResult parseOptionalDynamicType(StringRef typeName,
+                                               AsmParser &parser,
+                                               Type &resultType) const;
+
+  /// If 'type' is a dynamic type, print it.
+  /// Returns success if the type was printed, and failure if the type was not a
+  /// dynamic type.
+  static LogicalResult printIfDynamicType(Type type, AsmPrinter &printer);
+
+  /// Parse the dynamic attribute 'attrName' in the dialect 'dialect'.
+  /// attrname should not be prefixed with the dialect name.
+  /// If the dynamic attribute does not exist, return no value.
+  /// Otherwise, parse it, and return the parse result.
+  /// If the parsing succeed, put the resulting attribute in 'resultAttr'.
+  OptionalParseResult parseOptionalDynamicAttr(StringRef attrName,
+                                               AsmParser &parser,
+                                               Attribute &resultAttr) const;
+
+  /// If 'attr' is a dynamic attribute, print it.
+  /// Returns success if the attribute was printed, and failure if the
+  /// attribute was not a dynamic attribute.
+  static LogicalResult printIfDynamicAttr(Attribute attr, AsmPrinter &printer);
+
+private:
+  /// The set of all dynamic types registered.
+  DenseMap<TypeID, std::unique_ptr<DynamicTypeDefinition>> dynTypes;
+
+  /// This structure allows to get in O(1) a dynamic type given its name.
+  llvm::StringMap<DynamicTypeDefinition *> nameToDynTypes;
+
+  /// The set of all dynamic attributes registered.
+  DenseMap<TypeID, std::unique_ptr<DynamicAttrDefinition>> dynAttrs;
+
+  /// This structure allows to get in O(1) a dynamic attribute given its name.
+  llvm::StringMap<DynamicAttrDefinition *> nameToDynAttrs;
+
+  /// Give DynamicOpDefinition access to allocateTypeID.
+  friend DynamicOpDefinition;
+
+  /// Allocates a type ID to uniquify operations.
+  TypeID allocateTypeID() { return typeIDAllocator.allocate(); }
+
+  /// Owns the TypeID generated at runtime for operations.
+  TypeIDAllocator typeIDAllocator;
+};
+} // namespace mlir
+
+namespace llvm {
+/// Provide isa functionality for ExtensibleDialect.
+/// This is to override the isa functionality for Dialect.
+template <>
+struct isa_impl<mlir::ExtensibleDialect, mlir::Dialect> {
+  static inline bool doit(const ::mlir::Dialect &dialect) {
+    return mlir::ExtensibleDialect::classof(&dialect);
+  }
+};
+} // namespace llvm
+
+#endif // MLIR_IR_EXTENSIBLEDIALECT_H
index 6f93270..358ab40 100644 (file)
@@ -163,13 +163,22 @@ namespace detail {
 /// A utility class to get, or create, unique instances of types within an
 /// MLIRContext. This class manages all creation and uniquing of types.
 struct TypeUniquer {
+  /// Get an uniqued instance of a type T.
+  template <typename T, typename... Args>
+  static T get(MLIRContext *ctx, Args &&...args) {
+    return getWithTypeID<T, Args...>(ctx, T::getTypeID(),
+                                     std::forward<Args>(args)...);
+  }
+
   /// Get an uniqued instance of a parametric type T.
+  /// The use of this method is in general discouraged in favor of
+  /// 'get<T, Args>(ctx, args)'.
   template <typename T, typename... Args>
   static typename std::enable_if_t<
       !std::is_same<typename T::ImplType, TypeStorage>::value, T>
-  get(MLIRContext *ctx, Args &&...args) {
+  getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) {
 #ifndef NDEBUG
-    if (!ctx->getTypeUniquer().isParametricStorageInitialized(T::getTypeID()))
+    if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID))
       llvm::report_fatal_error(
           llvm::Twine("can't create type '") + llvm::getTypeName<T>() +
           "' because storage uniquer isn't initialized: the dialect was likely "
@@ -177,25 +186,27 @@ struct TypeUniquer {
           "in the Dialect::initialize() method.");
 #endif
     return ctx->getTypeUniquer().get<typename T::ImplType>(
-        [&](TypeStorage *storage) {
-          storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
+        [&, typeID](TypeStorage *storage) {
+          storage->initialize(AbstractType::lookup(typeID, ctx));
         },
-        T::getTypeID(), std::forward<Args>(args)...);
+        typeID, std::forward<Args>(args)...);
   }
   /// Get an uniqued instance of a singleton type T.
+  /// The use of this method is in general discouraged in favor of
+  /// 'get<T, Args>(ctx, args)'.
   template <typename T>
   static typename std::enable_if_t<
       std::is_same<typename T::ImplType, TypeStorage>::value, T>
-  get(MLIRContext *ctx) {
+  getWithTypeID(MLIRContext *ctx, TypeID typeID) {
 #ifndef NDEBUG
-    if (!ctx->getTypeUniquer().isSingletonStorageInitialized(T::getTypeID()))
+    if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID))
       llvm::report_fatal_error(
           llvm::Twine("can't create type '") + llvm::getTypeName<T>() +
           "' because storage uniquer isn't initialized: the dialect was likely "
           "not loaded, or the type wasn't added with addTypes<...>() "
           "in the Dialect::initialize() method.");
 #endif
-    return ctx->getTypeUniquer().get<typename T::ImplType>(T::getTypeID());
+    return ctx->getTypeUniquer().get<typename T::ImplType>(typeID);
   }
 
   /// Change the mutable component of the given type instance in the provided
@@ -208,22 +219,32 @@ struct TypeUniquer {
                                         std::forward<Args>(args)...);
   }
 
+  /// Register a type instance T with the uniquer.
+  template <typename T>
+  static void registerType(MLIRContext *ctx) {
+    registerType<T>(ctx, T::getTypeID());
+  }
+
   /// Register a parametric type instance T with the uniquer.
+  /// The use of this method is in general discouraged in favor of
+  /// 'registerType<T>(ctx)'.
   template <typename T>
   static typename std::enable_if_t<
       !std::is_same<typename T::ImplType, TypeStorage>::value>
-  registerType(MLIRContext *ctx) {
+  registerType(MLIRContext *ctx, TypeID typeID) {
     ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>(
-        T::getTypeID());
+        typeID);
   }
   /// Register a singleton type instance T with the uniquer.
+  /// The use of this method is in general discouraged in favor of
+  /// 'registerType<T>(ctx)'.
   template <typename T>
   static typename std::enable_if_t<
       std::is_same<typename T::ImplType, TypeStorage>::value>
-  registerType(MLIRContext *ctx) {
+  registerType(MLIRContext *ctx, TypeID typeID) {
     ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>(
-        T::getTypeID(), [&](TypeStorage *storage) {
-          storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
+        typeID, [&ctx, typeID](TypeStorage *storage) {
+          storage->initialize(AbstractType::lookup(typeID, ctx));
         });
   }
 };
index 3a37889..297d2df 100644 (file)
@@ -82,6 +82,10 @@ public:
   /// type printing/parsing.
   bool useDefaultTypePrinterParser() const;
 
+  /// Returns true if this dialect can be extended at runtime with new
+  /// operations or types.
+  bool isExtensible() const;
+
   // Returns whether two dialects are equal by checking the equality of the
   // underlying record.
   bool operator==(const Dialect &other) const;
index 899194f..699ef7e 100644 (file)
@@ -13,6 +13,7 @@ add_mlir_library(MLIRIR
   Diagnostics.cpp
   Dialect.cpp
   Dominance.cpp
+  ExtensibleDialect.cpp
   FunctionImplementation.cpp
   FunctionInterfaces.cpp
   IntegerSet.cpp
diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp
new file mode 100644 (file)
index 0000000..3e96b83
--- /dev/null
@@ -0,0 +1,500 @@
+//===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/AttributeSupport.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/StorageUniquerSupport.h"
+#include "mlir/Support/LogicalResult.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Dynamic types and attributes shared functions
+//===----------------------------------------------------------------------===//
+
+/// Default parser for dynamic attribute or type parameters.
+/// Parse in the format '(<>)?' or '<attr (,attr)*>'.
+static LogicalResult
+typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) {
+  // No parameters
+  if (parser.parseOptionalLess() || !parser.parseOptionalGreater())
+    return success();
+
+  Attribute attr;
+  if (parser.parseAttribute(attr))
+    return failure();
+  parsedParams.push_back(attr);
+
+  while (parser.parseOptionalGreater()) {
+    Attribute attr;
+    if (parser.parseComma() || parser.parseAttribute(attr))
+      return failure();
+    parsedParams.push_back(attr);
+  }
+
+  return success();
+}
+
+/// Default printer for dynamic attribute or type parameters.
+/// Print in the format '(<>)?' or '<attr (,attr)*>'.
+static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) {
+  if (params.empty())
+    return;
+
+  printer << "<";
+  interleaveComma(params, printer.getStream());
+  printer << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// Dynamic type
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<DynamicTypeDefinition>
+DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
+                           VerifierFn &&verifier) {
+  return DynamicTypeDefinition::get(name, dialect, std::move(verifier),
+                                    typeOrAttrParser, typeOrAttrPrinter);
+}
+
+std::unique_ptr<DynamicTypeDefinition>
+DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
+                           VerifierFn &&verifier, ParserFn &&parser,
+                           PrinterFn &&printer) {
+  return std::unique_ptr<DynamicTypeDefinition>(
+      new DynamicTypeDefinition(name, dialect, std::move(verifier),
+                                std::move(parser), std::move(printer)));
+}
+
+DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef,
+                                             ExtensibleDialect *dialect,
+                                             VerifierFn &&verifier,
+                                             ParserFn &&parser,
+                                             PrinterFn &&printer)
+    : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
+      parser(std::move(parser)), printer(std::move(printer)),
+      ctx(dialect->getContext()) {}
+
+DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect,
+                                             StringRef nameRef)
+    : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
+
+void DynamicTypeDefinition::registerInTypeUniquer() {
+  detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID());
+}
+
+namespace mlir {
+namespace detail {
+/// Storage of DynamicType.
+/// Contains a pointer to the type definition and type parameters.
+struct DynamicTypeStorage : public TypeStorage {
+
+  using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>;
+
+  explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef,
+                              ArrayRef<Attribute> params)
+      : typeDef(typeDef), params(params) {}
+
+  bool operator==(const KeyTy &key) const {
+    return typeDef == key.first && params == key.second;
+  }
+
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_value(key);
+  }
+
+  static DynamicTypeStorage *construct(TypeStorageAllocator &alloc,
+                                       const KeyTy &key) {
+    return new (alloc.allocate<DynamicTypeStorage>())
+        DynamicTypeStorage(key.first, alloc.copyInto(key.second));
+  }
+
+  /// Definition of the type.
+  DynamicTypeDefinition *typeDef;
+
+  /// The type parameters.
+  ArrayRef<Attribute> params;
+};
+} // namespace detail
+} // namespace mlir
+
+DynamicType DynamicType::get(DynamicTypeDefinition *typeDef,
+                             ArrayRef<Attribute> params) {
+  auto &ctx = typeDef->getContext();
+  auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx);
+  assert(succeeded(typeDef->verify(emitError, params)));
+  return detail::TypeUniquer::getWithTypeID<DynamicType>(
+      &ctx, typeDef->getTypeID(), typeDef, params);
+}
+
+DynamicType
+DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                        DynamicTypeDefinition *typeDef,
+                        ArrayRef<Attribute> params) {
+  if (failed(typeDef->verify(emitError, params)))
+    return {};
+  auto &ctx = typeDef->getContext();
+  return detail::TypeUniquer::getWithTypeID<DynamicType>(
+      &ctx, typeDef->getTypeID(), typeDef, params);
+}
+
+DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; }
+
+ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; }
+
+bool DynamicType::classof(Type type) {
+  return type.hasTrait<TypeTrait::IsDynamicType>();
+}
+
+ParseResult DynamicType::parse(AsmParser &parser,
+                               DynamicTypeDefinition *typeDef,
+                               DynamicType &parsedType) {
+  SmallVector<Attribute> params;
+  if (failed(typeDef->parser(parser, params)))
+    return failure();
+  parsedType = parser.getChecked<DynamicType>(typeDef, params);
+  if (!parsedType)
+    return failure();
+  return success();
+}
+
+void DynamicType::print(AsmPrinter &printer) {
+  printer << getTypeDef()->getName();
+  getTypeDef()->printer(printer, getParams());
+}
+
+//===----------------------------------------------------------------------===//
+// Dynamic attribute
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<DynamicAttrDefinition>
+DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
+                           VerifierFn &&verifier) {
+  return DynamicAttrDefinition::get(name, dialect, std::move(verifier),
+                                    typeOrAttrParser, typeOrAttrPrinter);
+}
+
+std::unique_ptr<DynamicAttrDefinition>
+DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
+                           VerifierFn &&verifier, ParserFn &&parser,
+                           PrinterFn &&printer) {
+  return std::unique_ptr<DynamicAttrDefinition>(
+      new DynamicAttrDefinition(name, dialect, std::move(verifier),
+                                std::move(parser), std::move(printer)));
+}
+
+DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef,
+                                             ExtensibleDialect *dialect,
+                                             VerifierFn &&verifier,
+                                             ParserFn &&parser,
+                                             PrinterFn &&printer)
+    : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
+      parser(std::move(parser)), printer(std::move(printer)),
+      ctx(dialect->getContext()) {}
+
+DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect,
+                                             StringRef nameRef)
+    : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
+
+void DynamicAttrDefinition::registerInAttrUniquer() {
+  detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(),
+                                                           getTypeID());
+}
+
+namespace mlir {
+namespace detail {
+/// Storage of DynamicAttr.
+/// Contains a pointer to the attribute definition and attribute parameters.
+struct DynamicAttrStorage : public AttributeStorage {
+  using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>;
+
+  explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef,
+                              ArrayRef<Attribute> params)
+      : attrDef(attrDef), params(params) {}
+
+  bool operator==(const KeyTy &key) const {
+    return attrDef == key.first && params == key.second;
+  }
+
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_value(key);
+  }
+
+  static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc,
+                                       const KeyTy &key) {
+    return new (alloc.allocate<DynamicAttrStorage>())
+        DynamicAttrStorage(key.first, alloc.copyInto(key.second));
+  }
+
+  /// Definition of the type.
+  DynamicAttrDefinition *attrDef;
+
+  /// The type parameters.
+  ArrayRef<Attribute> params;
+};
+} // namespace detail
+} // namespace mlir
+
+DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef,
+                             ArrayRef<Attribute> params) {
+  auto &ctx = attrDef->getContext();
+  return detail::AttributeUniquer::getWithTypeID<DynamicAttr>(
+      &ctx, attrDef->getTypeID(), attrDef, params);
+}
+
+DynamicAttr
+DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                        DynamicAttrDefinition *attrDef,
+                        ArrayRef<Attribute> params) {
+  if (failed(attrDef->verify(emitError, params)))
+    return {};
+  return get(attrDef, params);
+}
+
+DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; }
+
+ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; }
+
+bool DynamicAttr::classof(Attribute attr) {
+  return attr.hasTrait<AttributeTrait::IsDynamicAttr>();
+}
+
+ParseResult DynamicAttr::parse(AsmParser &parser,
+                               DynamicAttrDefinition *attrDef,
+                               DynamicAttr &parsedAttr) {
+  SmallVector<Attribute> params;
+  if (failed(attrDef->parser(parser, params)))
+    return failure();
+  parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params);
+  if (!parsedAttr)
+    return failure();
+  return success();
+}
+
+void DynamicAttr::print(AsmPrinter &printer) {
+  printer << getAttrDef()->getName();
+  getAttrDef()->printer(printer, getParams());
+}
+
+//===----------------------------------------------------------------------===//
+// Dynamic operation
+//===----------------------------------------------------------------------===//
+
+DynamicOpDefinition::DynamicOpDefinition(
+    StringRef name, ExtensibleDialect *dialect,
+    OperationName::VerifyInvariantsFn &&verifyFn,
+    OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
+    OperationName::ParseAssemblyFn &&parseFn,
+    OperationName::PrintAssemblyFn &&printFn,
+    OperationName::FoldHookFn &&foldHookFn,
+    OperationName::GetCanonicalizationPatternsFn
+        &&getCanonicalizationPatternsFn)
+    : typeID(dialect->allocateTypeID()),
+      name((dialect->getNamespace() + "." + name).str()), dialect(dialect),
+      verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
+      parseFn(std::move(parseFn)), printFn(std::move(printFn)),
+      foldHookFn(std::move(foldHookFn)),
+      getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) {}
+
+std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
+    StringRef name, ExtensibleDialect *dialect,
+    OperationName::VerifyInvariantsFn &&verifyFn,
+    OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) {
+  auto parseFn = [](OpAsmParser &parser, OperationState &result) {
+    return parser.emitError(
+        parser.getCurrentLocation(),
+        "dynamic operation do not define any parser function");
+  };
+
+  auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) {
+    printer.printGenericOp(op);
+  };
+
+  return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
+                                  std::move(verifyRegionFn), std::move(parseFn),
+                                  std::move(printFn));
+}
+
+std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
+    StringRef name, ExtensibleDialect *dialect,
+    OperationName::VerifyInvariantsFn &&verifyFn,
+    OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
+    OperationName::ParseAssemblyFn &&parseFn,
+    OperationName::PrintAssemblyFn &&printFn) {
+  auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands,
+                       SmallVectorImpl<OpFoldResult> &results) {
+    return failure();
+  };
+
+  auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
+  };
+
+  return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
+                                  std::move(verifyRegionFn), std::move(parseFn),
+                                  std::move(printFn), std::move(foldHookFn),
+                                  std::move(getCanonicalizationPatternsFn));
+}
+
+std::unique_ptr<DynamicOpDefinition>
+DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect,
+                         OperationName::VerifyInvariantsFn &&verifyFn,
+                         OperationName::VerifyInvariantsFn &&verifyRegionFn,
+                         OperationName::ParseAssemblyFn &&parseFn,
+                         OperationName::PrintAssemblyFn &&printFn,
+                         OperationName::FoldHookFn &&foldHookFn,
+                         OperationName::GetCanonicalizationPatternsFn
+                             &&getCanonicalizationPatternsFn) {
+  return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
+      name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
+      std::move(parseFn), std::move(printFn), std::move(foldHookFn),
+      std::move(getCanonicalizationPatternsFn)));
+}
+
+//===----------------------------------------------------------------------===//
+// Extensible dialect
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Interface that can only be implemented by extensible dialects.
+/// The interface is used to check if a dialect is extensible or not.
+class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> {
+public:
+  IsExtensibleDialect(Dialect *dialect) : Base(dialect) {}
+
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect)
+};
+} // namespace
+
+ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx,
+                                     TypeID typeID)
+    : Dialect(name, ctx, typeID) {
+  addInterfaces<IsExtensibleDialect>();
+}
+
+void ExtensibleDialect::registerDynamicType(
+    std::unique_ptr<DynamicTypeDefinition> &&type) {
+  DynamicTypeDefinition *typePtr = type.get();
+  TypeID typeID = type->getTypeID();
+  StringRef name = type->getName();
+  ExtensibleDialect *dialect = type->getDialect();
+
+  assert(dialect == this &&
+         "trying to register a dynamic type in the wrong dialect");
+
+  // If a type with the same name is already defined, fail.
+  auto registered = dynTypes.try_emplace(typeID, std::move(type)).second;
+  (void)registered;
+  assert(registered && "type TypeID was not unique");
+
+  registered = nameToDynTypes.insert({name, typePtr}).second;
+  (void)registered;
+  assert(registered &&
+         "Trying to create a new dynamic type with an existing name");
+
+  auto abstractType =
+      AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(),
+                        DynamicType::getHasTraitFn(), typeID);
+
+  /// Add the type to the dialect and the type uniquer.
+  addType(typeID, std::move(abstractType));
+  typePtr->registerInTypeUniquer();
+}
+
+void ExtensibleDialect::registerDynamicAttr(
+    std::unique_ptr<DynamicAttrDefinition> &&attr) {
+  auto *attrPtr = attr.get();
+  auto typeID = attr->getTypeID();
+  auto name = attr->getName();
+  auto *dialect = attr->getDialect();
+
+  assert(dialect == this &&
+         "trying to register a dynamic attribute in the wrong dialect");
+
+  // If an attribute with the same name is already defined, fail.
+  auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second;
+  (void)registered;
+  assert(registered && "attribute TypeID was not unique");
+
+  registered = nameToDynAttrs.insert({name, attrPtr}).second;
+  (void)registered;
+  assert(registered &&
+         "Trying to create a new dynamic attribute with an existing name");
+
+  auto abstractAttr =
+      AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(),
+                             DynamicAttr::getHasTraitFn(), typeID);
+
+  /// Add the type to the dialect and the type uniquer.
+  addAttribute(typeID, std::move(abstractAttr));
+  attrPtr->registerInAttrUniquer();
+}
+
+void ExtensibleDialect::registerDynamicOp(
+    std::unique_ptr<DynamicOpDefinition> &&op) {
+  assert(op->dialect == this &&
+         "trying to register a dynamic op in the wrong dialect");
+  auto hasTraitFn = [](TypeID traitId) { return false; };
+
+  RegisteredOperationName::insert(
+      op->name, *op->dialect, op->typeID, std::move(op->parseFn),
+      std::move(op->printFn), std::move(op->verifyFn),
+      std::move(op->verifyRegionFn), std::move(op->foldHookFn),
+      std::move(op->getCanonicalizationPatternsFn),
+      detail::InterfaceMap::get<>(), std::move(hasTraitFn), {});
+}
+
+bool ExtensibleDialect::classof(const Dialect *dialect) {
+  return const_cast<Dialect *>(dialect)
+      ->getRegisteredInterface<IsExtensibleDialect>();
+}
+
+OptionalParseResult ExtensibleDialect::parseOptionalDynamicType(
+    StringRef typeName, AsmParser &parser, Type &resultType) const {
+  DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName);
+  if (!typeDef)
+    return llvm::None;
+
+  DynamicType dynType;
+  if (DynamicType::parse(parser, typeDef, dynType))
+    return failure();
+  resultType = dynType;
+  return success();
+}
+
+LogicalResult ExtensibleDialect::printIfDynamicType(Type type,
+                                                    AsmPrinter &printer) {
+  if (auto dynType = type.dyn_cast<DynamicType>()) {
+    dynType.print(printer);
+    return success();
+  }
+  return failure();
+}
+
+OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr(
+    StringRef attrName, AsmParser &parser, Attribute &resultAttr) const {
+  DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName);
+  if (!attrDef)
+    return llvm::None;
+
+  DynamicAttr dynAttr;
+  if (DynamicAttr::parse(parser, attrDef, dynAttr))
+    return failure();
+  resultAttr = dynAttr;
+  return success();
+}
+
+LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
+                                                    AsmPrinter &printer) {
+  if (auto dynAttr = attribute.dyn_cast<DynamicAttr>()) {
+    dynAttr.print(printer);
+    return success();
+  }
+  return failure();
+}
index 57586dd..11f972e 100644 (file)
@@ -102,9 +102,14 @@ Dialect::EmitPrefix Dialect::getEmitAccessorPrefix() const {
   int prefix = def->getValueAsInt("emitAccessorPrefix");
   if (prefix < 0 || prefix > static_cast<int>(EmitPrefix::Both))
     PrintFatalError(def->getLoc(), "Invalid accessor prefix value");
+
   return static_cast<EmitPrefix>(prefix);
 }
 
+bool Dialect::isExtensible() const {
+  return def->getValueAsBit("isExtensible");
+}
+
 bool Dialect::operator==(const Dialect &other) const {
   return def == other.def;
 }
diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir
new file mode 100644 (file)
index 0000000..7c512a1
--- /dev/null
@@ -0,0 +1,126 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics | FileCheck %s
+// Verify that extensible dialects can register dynamic operations and types.
+
+//===----------------------------------------------------------------------===//
+// Dynamic type
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @succeededDynamicTypeVerifier
+func @succeededDynamicTypeVerifier() {
+  // CHECK: %{{.*}} = "unregistered_op"() : () -> !test.dynamic_singleton
+  "unregistered_op"() : () -> !test.dynamic_singleton
+  // CHECK-NEXT: "unregistered_op"() : () -> !test.dynamic_pair<i32, f64>
+  "unregistered_op"() : () -> !test.dynamic_pair<i32, f64>
+  // CHECK_NEXT: %{{.*}} = "unregistered_op"() : () -> !test.dynamic_pair<!test.dynamic_pair<i32, f64>, !test.dynamic_singleton>
+  "unregistered_op"() : () -> !test.dynamic_pair<!test.dynamic_pair<i32, f64>, !test.dynamic_singleton>
+  return
+}
+
+// -----
+
+func @failedDynamicTypeVerifier() {
+  // expected-error@+1 {{expected 0 type arguments, but had 1}}
+  "unregistered_op"() : () -> !test.dynamic_singleton<f64>
+  return
+}
+
+// -----
+
+func @failedDynamicTypeVerifier2() {
+  // expected-error@+1 {{expected 2 type arguments, but had 1}}
+  "unregistered_op"() : () -> !test.dynamic_pair<f64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @customTypeParserPrinter
+func @customTypeParserPrinter() {
+  // CHECK: "unregistered_op"() : () -> !test.dynamic_custom_assembly_format<f32:f64>
+  "unregistered_op"() : () -> !test.dynamic_custom_assembly_format<f32 : f64>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Dynamic attribute
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @succeededDynamicAttributeVerifier
+func @succeededDynamicAttributeVerifier() {
+  // CHECK: "unregistered_op"() {test_attr = #test.dynamic_singleton} : () -> ()
+  "unregistered_op"() {test_attr = #test.dynamic_singleton} : () -> ()
+  // CHECK-NEXT: "unregistered_op"() {test_attr = #test.dynamic_pair<3 : i32, 5 : i32>} : () -> ()
+  "unregistered_op"() {test_attr = #test.dynamic_pair<3 : i32, 5 : i32>} : () -> ()
+  // CHECK_NEXT: "unregistered_op"() {test_attr = #test.dynamic_pair<3 : i32, 5 : i32>} : () -> ()
+  "unregistered_op"() {test_attr = #test.dynamic_pair<#test.dynamic_pair<3 : i32, 5 : i32>, f64>} : () -> ()
+  return
+}
+
+// -----
+
+func @failedDynamicAttributeVerifier() {
+  // expected-error@+1 {{expected 0 attribute arguments, but had 1}}
+  "unregistered_op"() {test_attr = #test.dynamic_singleton<f64>} : () -> ()
+  return
+}
+
+// -----
+
+func @failedDynamicAttributeVerifier2() {
+  // expected-error@+1 {{expected 2 attribute arguments, but had 1}}
+  "unregistered_op"() {test_attr = #test.dynamic_pair<f64>} : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @customAttributeParserPrinter
+func @customAttributeParserPrinter() {
+  // CHECK: "unregistered_op"() {test_attr = #test.dynamic_custom_assembly_format<f32:f64>} : () -> ()
+  "unregistered_op"() {test_attr = #test.dynamic_custom_assembly_format<f32:f64>} : () -> ()
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// Dynamic op
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: func @succeededDynamicOpVerifier
+func @succeededDynamicOpVerifier(%a: f32) {
+  // CHECK: "test.dynamic_generic"() : () -> ()
+  // CHECK-NEXT: %{{.*}} = "test.dynamic_generic"(%{{.*}}) : (f32) -> f64
+  // CHECK-NEXT: %{{.*}}:2 = "test.dynamic_one_operand_two_results"(%{{.*}}) : (f32) -> (f64, f64)
+  "test.dynamic_generic"() : () -> ()
+  "test.dynamic_generic"(%a) : (f32) -> f64
+  "test.dynamic_one_operand_two_results"(%a) : (f32) -> (f64, f64)
+  return
+}
+
+// -----
+
+func @failedDynamicOpVerifier() {
+  // expected-error@+1 {{expected 1 operand, but had 0}}
+  "test.dynamic_one_operand_two_results"() : () -> (f64, f64)
+  return
+}
+
+// -----
+
+func @failedDynamicOpVerifier2(%a: f32) {
+  // expected-error@+1 {{expected 2 results, but had 0}}
+  "test.dynamic_one_operand_two_results"(%a) : (f32) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @customOpParserPrinter
+func @customOpParserPrinter() {
+  // CHECK: test.dynamic_custom_parser_printer custom_keyword
+  test.dynamic_custom_parser_printer custom_keyword
+  return
+}
index 3a86099..45f6599 100644 (file)
 #include "TestDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/ADT/bit.h"
+#include "llvm/Support/ErrorHandling.h"
 
 using namespace mlir;
 using namespace test;
@@ -217,6 +219,74 @@ SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
 #include "TestAttrDefs.cpp.inc"
 
 //===----------------------------------------------------------------------===//
+// Dynamic Attributes
+//===----------------------------------------------------------------------===//
+
+/// Define a singleton dynamic attribute.
+static std::unique_ptr<DynamicAttrDefinition>
+getDynamicSingletonAttr(TestDialect *testDialect) {
+  return DynamicAttrDefinition::get(
+      "dynamic_singleton", testDialect,
+      [](function_ref<InFlightDiagnostic()> emitError,
+         ArrayRef<Attribute> args) {
+        if (!args.empty()) {
+          emitError() << "expected 0 attribute arguments, but had "
+                      << args.size();
+          return failure();
+        }
+        return success();
+      });
+}
+
+/// Define a dynamic attribute representing a pair or attributes.
+static std::unique_ptr<DynamicAttrDefinition>
+getDynamicPairAttr(TestDialect *testDialect) {
+  return DynamicAttrDefinition::get(
+      "dynamic_pair", testDialect,
+      [](function_ref<InFlightDiagnostic()> emitError,
+         ArrayRef<Attribute> args) {
+        if (args.size() != 2) {
+          emitError() << "expected 2 attribute arguments, but had "
+                      << args.size();
+          return failure();
+        }
+        return success();
+      });
+}
+
+static std::unique_ptr<DynamicAttrDefinition>
+getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) {
+  auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
+                     ArrayRef<Attribute> args) {
+    if (args.size() != 2) {
+      emitError() << "expected 2 attribute arguments, but had " << args.size();
+      return failure();
+    }
+    return success();
+  };
+
+  auto parser = [](AsmParser &parser,
+                   llvm::SmallVectorImpl<Attribute> &parsedParams) {
+    Attribute leftAttr, rightAttr;
+    if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
+        parser.parseColon() || parser.parseAttribute(rightAttr) ||
+        parser.parseGreater())
+      return failure();
+    parsedParams.push_back(leftAttr);
+    parsedParams.push_back(rightAttr);
+    return success();
+  };
+
+  auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
+    printer << "<" << params[0] << ":" << params[1] << ">";
+  };
+
+  return DynamicAttrDefinition::get("dynamic_custom_assembly_format",
+                                    testDialect, std::move(verifier),
+                                    std::move(parser), std::move(printer));
+}
+
+//===----------------------------------------------------------------------===//
 // TestDialect
 //===----------------------------------------------------------------------===//
 
@@ -225,4 +295,7 @@ void TestDialect::registerAttributes() {
 #define GET_ATTRDEF_LIST
 #include "TestAttrDefs.cpp.inc"
       >();
+  registerDynamicAttr(getDynamicSingletonAttr(this));
+  registerDynamicAttr(getDynamicPairAttr(this));
+  registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this));
 }
index eaa5649..0f8cfd8 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Verifier.h"
@@ -209,6 +210,58 @@ public:
 } // namespace
 
 //===----------------------------------------------------------------------===//
+// Dynamic operations
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
+  return DynamicOpDefinition::get(
+      "dynamic_generic", dialect, [](Operation *op) { return success(); },
+      [](Operation *op) { return success(); });
+}
+
+std::unique_ptr<DynamicOpDefinition>
+getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
+  return DynamicOpDefinition::get(
+      "dynamic_one_operand_two_results", dialect,
+      [](Operation *op) {
+        if (op->getNumOperands() != 1) {
+          op->emitOpError()
+              << "expected 1 operand, but had " << op->getNumOperands();
+          return failure();
+        }
+        if (op->getNumResults() != 2) {
+          op->emitOpError()
+              << "expected 2 results, but had " << op->getNumResults();
+          return failure();
+        }
+        return success();
+      },
+      [](Operation *op) { return success(); });
+}
+
+std::unique_ptr<DynamicOpDefinition>
+getDynamicCustomParserPrinterOp(TestDialect *dialect) {
+  auto verifier = [](Operation *op) {
+    if (op->getNumOperands() == 0 && op->getNumResults() == 0)
+      return success();
+    op->emitError() << "operation should have no operands and no results";
+    return failure();
+  };
+  auto regionVerifier = [](Operation *op) { return success(); };
+
+  auto parser = [](OpAsmParser &parser, OperationState &state) {
+    return parser.parseKeyword("custom_keyword");
+  };
+
+  auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
+    printer << op->getName() << " custom_keyword";
+  };
+
+  return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
+                                  verifier, regionVerifier, parser, printer);
+}
+
+//===----------------------------------------------------------------------===//
 // TestDialect
 //===----------------------------------------------------------------------===//
 
@@ -242,6 +295,10 @@ void TestDialect::initialize() {
 #define GET_OP_LIST
 #include "TestOps.cpp.inc"
       >();
+  registerDynamicOp(getDynamicGenericOp(this));
+  registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
+  registerDynamicOp(getDynamicCustomParserPrinterOp(this));
+
   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
                 TestInlinerInterface, TestReductionPatternInterface>();
   allowUnknownOperations();
index da2eead..4805850 100644 (file)
@@ -24,6 +24,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/RegionKindInterface.h"
index 2ac446f..d51ddba 100644 (file)
@@ -23,6 +23,8 @@ def Test_Dialect : Dialect {
   let hasOperationInterfaceFallback = 1;
   let hasNonDefaultDestructor = 1;
   let useDefaultTypePrinterParser = 0;
+  let useDefaultAttributePrinterParser = 1;
+  let isExtensible = 1;
   let dependentDialects = ["::mlir::DLTIDialect"];
 
   let extraClassDeclaration = [{
@@ -43,6 +45,10 @@ def Test_Dialect : Dialect {
     // Storage for a custom fallback interface.
     void *fallbackEffectOpInterfaces;
 
+    ::mlir::Type parseTestType(::mlir::AsmParser &parser,
+                               ::llvm::SetVector<::mlir::Type> &stack) const;
+    void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer,
+                       ::llvm::SetVector<::mlir::Type> &stack) const;
   }];
 }
 
index f72b837..5418025 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/SetVector.h"
@@ -316,6 +317,72 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
 }
 
 //===----------------------------------------------------------------------===//
+// Dynamic Types
+//===----------------------------------------------------------------------===//
+
+/// Define a singleton dynamic type.
+static std::unique_ptr<DynamicTypeDefinition>
+getSingletonDynamicType(TestDialect *testDialect) {
+  return DynamicTypeDefinition::get(
+      "dynamic_singleton", testDialect,
+      [](function_ref<InFlightDiagnostic()> emitError,
+         ArrayRef<Attribute> args) {
+        if (!args.empty()) {
+          emitError() << "expected 0 type arguments, but had " << args.size();
+          return failure();
+        }
+        return success();
+      });
+}
+
+/// Define a dynamic type representing a pair.
+static std::unique_ptr<DynamicTypeDefinition>
+getPairDynamicType(TestDialect *testDialect) {
+  return DynamicTypeDefinition::get(
+      "dynamic_pair", testDialect,
+      [](function_ref<InFlightDiagnostic()> emitError,
+         ArrayRef<Attribute> args) {
+        if (args.size() != 2) {
+          emitError() << "expected 2 type arguments, but had " << args.size();
+          return failure();
+        }
+        return success();
+      });
+}
+
+static std::unique_ptr<DynamicTypeDefinition>
+getCustomAssemblyFormatDynamicType(TestDialect *testDialect) {
+  auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
+                     ArrayRef<Attribute> args) {
+    if (args.size() != 2) {
+      emitError() << "expected 2 type arguments, but had " << args.size();
+      return failure();
+    }
+    return success();
+  };
+
+  auto parser = [](AsmParser &parser,
+                   llvm::SmallVectorImpl<Attribute> &parsedParams) {
+    Attribute leftAttr, rightAttr;
+    if (parser.parseLess() || parser.parseAttribute(leftAttr) ||
+        parser.parseColon() || parser.parseAttribute(rightAttr) ||
+        parser.parseGreater())
+      return failure();
+    parsedParams.push_back(leftAttr);
+    parsedParams.push_back(rightAttr);
+    return success();
+  };
+
+  auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
+    printer << "<" << params[0] << ":" << params[1] << ">";
+  };
+
+  return DynamicTypeDefinition::get("dynamic_custom_assembly_format",
+                                    testDialect, std::move(verifier),
+                                    std::move(parser), std::move(printer));
+}
+
+//===----------------------------------------------------------------------===//
 // TestDialect
 //===----------------------------------------------------------------------===//
 
@@ -332,9 +399,14 @@ void TestDialect::registerTypes() {
 #include "TestTypeDefs.cpp.inc"
            >();
   SimpleAType::attachInterface<PtrElementModel>(*getContext());
+
+  registerDynamicType(getSingletonDynamicType(this));
+  registerDynamicType(getPairDynamicType(this));
+  registerDynamicType(getCustomAssemblyFormatDynamicType(this));
 }
 
-static Type parseTestType(AsmParser &parser, SetVector<Type> &stack) {
+Type TestDialect::parseTestType(AsmParser &parser,
+                                SetVector<Type> &stack) const {
   StringRef typeTag;
   if (failed(parser.parseKeyword(&typeTag)))
     return Type();
@@ -346,6 +418,16 @@ static Type parseTestType(AsmParser &parser, SetVector<Type> &stack) {
       return genType;
   }
 
+  {
+    Type dynType;
+    auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
+    if (parseResult.hasValue()) {
+      if (succeeded(parseResult.getValue()))
+        return dynType;
+      return Type();
+    }
+  }
+
   if (typeTag != "test_rec") {
     parser.emitError(parser.getNameLoc()) << "unknown type!";
     return Type();
@@ -381,11 +463,14 @@ Type TestDialect::parseType(DialectAsmParser &parser) const {
   return parseTestType(parser, stack);
 }
 
-static void printTestType(Type type, AsmPrinter &printer,
-                          SetVector<Type> &stack) {
+void TestDialect::printTestType(Type type, AsmPrinter &printer,
+                                SetVector<Type> &stack) const {
   if (succeeded(generatedTypePrinter(type, printer)))
     return;
 
+  if (succeeded(printIfDynamicType(type, printer)))
+    return;
+
   auto rec = type.cast<TestRecursiveType>();
   printer << "test_rec<" << rec.getName();
   if (!stack.contains(rec)) {
index cd37655..64bccc5 100644 (file)
@@ -689,6 +689,8 @@ void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
 
 /// The code block for default attribute parser/printer dispatch boilerplate.
 /// {0}: the dialect fully qualified class name.
+/// {1}: the optional code for the dynamic attribute parser dispatch.
+/// {2}: the optional code for the dynamic attribute printer dispatch.
 static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
 /// Parse an attribute registered to this dialect.
 ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
@@ -703,6 +705,7 @@ static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
     if (parseResult.hasValue())
       return attr;
   }
+  {1}
   parser.emitError(typeLoc) << "unknown attribute `"
       << attrTag << "` in dialect `" << getNamespace() << "`";
   return {{};
@@ -712,11 +715,33 @@ void {0}::printAttribute(::mlir::Attribute attr,
                          ::mlir::DialectAsmPrinter &printer) const {{
   if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
     return;
+  {2}
 }
 )";
 
+/// The code block for dynamic attribute parser dispatch boilerplate.
+static const char *const dialectDynamicAttrParserDispatch = R"(
+  {
+    ::mlir::Attribute genAttr;
+    auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
+    if (parseResult.hasValue()) {
+      if (::mlir::succeeded(parseResult.getValue()))
+        return genAttr;
+      return Attribute();
+    }
+  }
+)";
+
+/// The code block for dynamic type printer dispatch boilerplate.
+static const char *const dialectDynamicAttrPrinterDispatch = R"(
+  if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
+    return;
+)";
+
 /// The code block for default type parser/printer dispatch boilerplate.
 /// {0}: the dialect fully qualified class name.
+/// {1}: the optional code for the dynamic type parser dispatch.
+/// {2}: the optional code for the dynamic type printer dispatch.
 static const char *const dialectDefaultTypePrinterParserDispatch = R"(
 /// Parse a type registered to this dialect.
 ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
@@ -728,6 +753,7 @@ static const char *const dialectDefaultTypePrinterParserDispatch = R"(
   auto parseResult = generatedTypeParser(parser, mnemonic, genType);
   if (parseResult.hasValue())
     return genType;
+  {1}
   parser.emitError(typeLoc) << "unknown  type `"
       << mnemonic << "` in dialect `" << getNamespace() << "`";
   return {{};
@@ -737,9 +763,28 @@ void {0}::printType(::mlir::Type type,
                     ::mlir::DialectAsmPrinter &printer) const {{
   if (::mlir::succeeded(generatedTypePrinter(type, printer)))
     return;
+  {2}
 }
 )";
 
+/// The code block for dynamic type parser dispatch boilerplate.
+static const char *const dialectDynamicTypeParserDispatch = R"(
+  {
+    auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
+    if (parseResult.hasValue()) {
+      if (::mlir::succeeded(parseResult.getValue()))
+        return genType;
+      return Type();
+    }
+  }
+)";
+
+/// The code block for dynamic type printer dispatch boilerplate.
+static const char *const dialectDynamicTypePrinterDispatch = R"(
+  if (::mlir::succeeded(printIfDynamicType(type, printer)))
+    return;
+)";
+
 /// Emit the dialect printer/parser dispatcher. User's code should call these
 /// functions from their dialect's print/parse methods.
 void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
@@ -839,16 +884,30 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
   if (valueType == "Attribute" && needsDialectParserPrinter &&
       firstDialect.useDefaultAttributePrinterParser()) {
     NamespaceEmitter nsEmitter(os, firstDialect);
-    os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
-                        firstDialect.getCppClassName());
+    if (firstDialect.isExtensible()) {
+      os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
+                          firstDialect.getCppClassName(),
+                          dialectDynamicAttrParserDispatch,
+                          dialectDynamicAttrPrinterDispatch);
+    } else {
+      os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
+                          firstDialect.getCppClassName(), "", "");
+    }
   }
 
   // Emit the default parser/printer for Types if the dialect asked for it.
   if (valueType == "Type" && needsDialectParserPrinter &&
       firstDialect.useDefaultTypePrinterParser()) {
     NamespaceEmitter nsEmitter(os, firstDialect);
-    os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
-                        firstDialect.getCppClassName());
+    if (firstDialect.isExtensible()) {
+      os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
+                          firstDialect.getCppClassName(),
+                          dialectDynamicTypeParserDispatch,
+                          dialectDynamicTypePrinterDispatch);
+    } else {
+      os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
+                          firstDialect.getCppClassName(), "", "");
+    }
   }
 
   return false;
index b6497eb..347f08f 100644 (file)
@@ -87,8 +87,9 @@ findSelectedDialect(ArrayRef<const llvm::Record *> dialectDefs) {
 ///
 /// {0}: The name of the dialect class.
 /// {1}: The dialect namespace.
+/// {2}: The dialect parent class.
 static const char *const dialectDeclBeginStr = R"(
-class {0} : public ::mlir::Dialect {
+class {0} : public ::mlir::{2} {
   explicit {0}(::mlir::MLIRContext *context);
 
   void initialize();
@@ -189,7 +190,10 @@ emitDialectDecl(Dialect &dialect,
 
     // Emit the start of the decl.
     std::string cppName = dialect.getCppClassName();
-    os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
+    StringRef superClassName =
+        dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
+    os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
+                        superClassName);
 
     // Check for any attributes/types registered to this dialect.  If there are,
     // add the hooks for parsing/printing.
@@ -250,9 +254,10 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
 /// {0}: The name of the dialect class.
 /// {1}: initialization code that is emitted in the ctor body before calling
 ///      initialize().
+/// {2}: The dialect parent class.
 static const char *const dialectConstructorStr = R"(
 {0}::{0}(::mlir::MLIRContext *context) 
-    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
+    : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
   {1}
   initialize();
 }
@@ -287,8 +292,10 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
   }
 
   // Emit the constructor and destructor.
+  StringRef superClassName =
+      dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
   os << llvm::formatv(dialectConstructorStr, cppClassName,
-                      dependentDialectRegistrations);
+                      dependentDialectRegistrations, superClassName);
   if (!dialect.hasNonDefaultDestructor())
     os << llvm::formatv(dialectDestructorStr, cppClassName);
 }