[MLIR] Add support for defining Types in tblgen
authorJohn Demme <john.demme@microsoft.com>
Tue, 13 Oct 2020 22:07:27 +0000 (22:07 +0000)
committerJohn Demme <john.demme@microsoft.com>
Wed, 14 Oct 2020 00:32:18 +0000 (00:32 +0000)
Adds a TypeDef class to OpBase and backing generation code. Allows one
to define the Type, its parameters, and printer/parser methods in ODS.
Can generate the Type C++ class, accessors, storage class, per-parameter
custom allocators (for the storage constructor), and documentation.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D86904

15 files changed:
mlir/cmake/modules/AddMLIR.cmake
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/TypeDef.h [new file with mode: 0644]
mlir/lib/TableGen/CMakeLists.txt
mlir/lib/TableGen/TypeDef.cpp [new file with mode: 0644]
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestTypeDefs.td [new file with mode: 0644]
mlir/test/lib/Dialect/Test/TestTypes.cpp [new file with mode: 0644]
mlir/test/lib/Dialect/Test/TestTypes.h
mlir/test/mlir-tblgen/testdialect-typedefs.mlir [new file with mode: 0644]
mlir/test/mlir-tblgen/typedefs.td [new file with mode: 0644]
mlir/tools/mlir-tblgen/CMakeLists.txt
mlir/tools/mlir-tblgen/OpDocGen.cpp
mlir/tools/mlir-tblgen/TypeDefGen.cpp [new file with mode: 0644]

index 8394c05..0d99c29 100644 (file)
@@ -9,6 +9,8 @@ function(add_mlir_dialect dialect dialect_namespace)
   set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
   mlir_tablegen(${dialect}.h.inc -gen-op-decls)
   mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
+  mlir_tablegen(${dialect}Types.h.inc -gen-typedef-decls)
+  mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs)
   mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
   add_public_tablegen_target(MLIR${dialect}IncGen)
   add_dependencies(mlir-headers MLIR${dialect}IncGen)
index b38189f..bf27724 100644 (file)
@@ -2364,4 +2364,116 @@ def location;
 // so to replace the matched DAG with an existing SSA value.
 def replaceWithValue;
 
+
+//===----------------------------------------------------------------------===//
+// Data type generation
+//===----------------------------------------------------------------------===//
+
+// Define a new type belonging to a dialect and called 'name'.
+class TypeDef<Dialect owningdialect, string name> {
+  Dialect dialect = owningdialect;
+  string cppClassName = name # "Type";
+
+  // Short summary of the type.
+  string summary = ?;
+  // The longer description of this type.
+  string description = ?;
+
+  // Name of storage class to generate or use.
+  string storageClass = name # "TypeStorage";
+  // Namespace (withing dialect c++ namespace) in which the storage class
+  // resides.
+  string storageNamespace = "detail";
+  // Specify if the storage class is to be generated.
+  bit genStorageClass = 1;
+  // Specify that the generated storage class has a constructor which is written
+  // in C++.
+  bit hasStorageCustomConstructor = 0;
+
+  // The list of parameters for this type. Parameters will become both
+  // parameters to the get() method and storage class member variables.
+  //
+  // The format of this dag is:
+  //    (ins
+  //        "<c++ type>":$param1Name,
+  //        "<c++ type>":$param2Name,
+  //        TypeParameter<"c++ type", "param description">:$param3Name)
+  // TypeParameters (or more likely one of their subclasses) are required to add
+  // more information about the parameter, specifically:
+  //  - Documentation
+  //  - Code to allocate the parameter (if allocation is needed in the storage
+  //    class constructor)
+  //
+  // For example:
+  //    (ins
+  //        "int":$width,
+  //        ArrayRefParameter<"bool", "list of bools">:$yesNoArray)
+  //
+  // (ArrayRefParameter is a subclass of TypeParameter which has allocation code
+  // for re-allocating ArrayRefs. It is defined below.)
+  dag parameters = (ins);
+
+  // Use the lowercased name as the keyword for parsing/printing. Specify only
+  // if you want tblgen to generate declarations and/or definitions of
+  // printer/parser for this type.
+  string mnemonic = ?;
+  // If 'mnemonic' specified,
+  //   If null, generate just the declarations.
+  //   If a non-empty code block, just use that code as the definition code.
+  //   Error if an empty code block.
+  code printer = ?;
+  code parser = ?;
+
+  // If set, generate accessors for each Type parameter.
+  bit genAccessors = 1;
+  // Generate the verifyConstructionInvariants declaration and getChecked
+  // method.
+  bit genVerifyInvariantsDecl = 0;
+  // Extra code to include in the class declaration.
+  code extraClassDeclaration = [{}];
+}
+
+// 'Parameters' should be subclasses of this or simple strings (which is a
+// shorthand for TypeParameter<"C++Type">).
+class TypeParameter<string type, string desc> {
+  // Custom memory allocation code for storage constructor.
+  code allocator = ?;
+  // The C++ type of this parameter.
+  string cppType = type;
+  // A description of this parameter.
+  string description = desc;
+  // The format string for the asm syntax (documentation only).
+  string syntax = ?;
+}
+
+// For StringRefs, which require allocation.
+class StringRefParameter<string desc> :
+    TypeParameter<"::llvm::StringRef", desc> {
+  let allocator = [{$_dst = $_allocator.copyInto($_self);}];
+}
+
+// For standard ArrayRefs, which require allocation.
+class ArrayRefParameter<string arrayOf, string desc> :
+    TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
+  let allocator = [{$_dst = $_allocator.copyInto($_self);}];
+}
+
+// For classes which require allocation and have their own allocateInto method.
+class SelfAllocationParameter<string type, string desc> :
+    TypeParameter<type, desc> {
+  let allocator = [{$_dst = $_self.allocateInto($_allocator);}];
+}
+
+// For ArrayRefs which contain things which allocate themselves.
+class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
+    TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
+  let allocator = [{
+    llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields;
+    for (size_t i = 0, e = $_self.size(); i < e; ++i)
+      tmpFields.push_back($_self[i].allocateInto($_allocator));
+    $_dst = $_allocator.copyInto(ArrayRef<}] # arrayOf # [{>(tmpFields));
+  }];
+}
+
+
 #endif // OP_BASE
diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h
new file mode 100644 (file)
index 0000000..462fed3
--- /dev/null
@@ -0,0 +1,135 @@
+//===-- TypeDef.h - Record wrapper for type definitions ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TypeDef wrapper to simplify using TableGen Record defining a MLIR type.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_TYPEDEF_H
+#define MLIR_TABLEGEN_TYPEDEF_H
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Dialect.h"
+
+namespace llvm {
+class Record;
+class DagInit;
+class SMLoc;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+class TypeParameter;
+
+/// Wrapper class that contains a TableGen TypeDef's record and provides helper
+/// methods for accessing them.
+class TypeDef {
+public:
+  explicit TypeDef(const llvm::Record *def) : def(def) {}
+
+  // Get the dialect for which this type belongs.
+  Dialect getDialect() const;
+
+  // Returns the name of this TypeDef record.
+  StringRef getName() const;
+
+  // Query functions for the documentation of the operator.
+  bool hasDescription() const;
+  StringRef getDescription() const;
+  bool hasSummary() const;
+  StringRef getSummary() const;
+
+  // Returns the name of the C++ class to generate.
+  StringRef getCppClassName() const;
+
+  // Returns the name of the storage class for this type.
+  StringRef getStorageClassName() const;
+
+  // Returns the C++ namespace for this types storage class.
+  StringRef getStorageNamespace() const;
+
+  // Returns true if we should generate the storage class.
+  bool genStorageClass() const;
+
+  // Indicates whether or not to generate the storage class constructor.
+  bool hasStorageCustomConstructor() const;
+
+  // Fill a list with this types parameters. See TypeDef in OpBase.td for
+  // documentation of parameter usage.
+  void getParameters(SmallVectorImpl<TypeParameter> &) const;
+  // Return the number of type parameters
+  unsigned getNumParameters() const;
+
+  // Return the keyword/mnemonic to use in the printer/parser methods if we are
+  // supposed to auto-generate them.
+  Optional<StringRef> getMnemonic() const;
+
+  // Returns the code to use as the types printer method. If not specified,
+  // return a non-value. Otherwise, return the contents of that code block.
+  Optional<StringRef> getPrinterCode() const;
+
+  // Returns the code to use as the types parser method. If not specified,
+  // return a non-value. Otherwise, return the contents of that code block.
+  Optional<StringRef> getParserCode() const;
+
+  // Returns true if the accessors based on the types parameters should be
+  // generated.
+  bool genAccessors() const;
+
+  // Return true if we need to generate the verifyConstructionInvariants
+  // declaration and getChecked method.
+  bool genVerifyInvariantsDecl() const;
+
+  // Returns the dialects extra class declaration code.
+  Optional<StringRef> getExtraDecls() const;
+
+  // Get the code location (for error printing).
+  ArrayRef<llvm::SMLoc> getLoc() const;
+
+  // Returns whether two TypeDefs are equal by checking the equality of the
+  // underlying record.
+  bool operator==(const TypeDef &other) const;
+
+  // Compares two TypeDefs by comparing the names of the dialects.
+  bool operator<(const TypeDef &other) const;
+
+  // Returns whether the TypeDef is defined.
+  operator bool() const { return def != nullptr; }
+
+private:
+  const llvm::Record *def;
+};
+
+// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs
+// to parameterize them.
+class TypeParameter {
+public:
+  explicit TypeParameter(const llvm::DagInit *def, unsigned num)
+      : def(def), num(num) {}
+
+  // Get the parameter name.
+  StringRef getName() const;
+  // If specified, get the custom allocator code for this parameter.
+  llvm::Optional<StringRef> getAllocator() const;
+  // Get the C++ type of this parameter.
+  StringRef getCppType() const;
+  // Get a description of this parameter for documentation purposes.
+  llvm::Optional<StringRef> getDescription() const;
+  // Get the assembly syntax documentation.
+  StringRef getSyntax() const;
+
+private:
+  const llvm::DagInit *def;
+  const unsigned num;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_TYPEDEF_H
index ef74764..af3900f 100644 (file)
@@ -25,6 +25,7 @@ llvm_add_library(MLIRTableGen STATIC
   SideEffects.cpp
   Successor.cpp
   Type.cpp
+  TypeDef.cpp
 
   DISABLE_LLVM_LINK_LLVM_DYLIB
 
diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp
new file mode 100644 (file)
index 0000000..e5327fc
--- /dev/null
@@ -0,0 +1,160 @@
+//===- TypeDef.cpp - TypeDef wrapper class --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TypeDef wrapper to simplify using TableGen Record defining a MLIR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/TypeDef.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+Dialect TypeDef::getDialect() const {
+  auto *dialectDef =
+      dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
+  if (dialectDef == nullptr)
+    return Dialect(nullptr);
+  return Dialect(dialectDef->getDef());
+}
+
+StringRef TypeDef::getName() const { return def->getName(); }
+StringRef TypeDef::getCppClassName() const {
+  return def->getValueAsString("cppClassName");
+}
+
+bool TypeDef::hasDescription() const {
+  const llvm::RecordVal *s = def->getValue("description");
+  return s != nullptr && isa<llvm::StringInit>(s->getValue());
+}
+
+StringRef TypeDef::getDescription() const {
+  return def->getValueAsString("description");
+}
+
+bool TypeDef::hasSummary() const {
+  const llvm::RecordVal *s = def->getValue("summary");
+  return s != nullptr && isa<llvm::StringInit>(s->getValue());
+}
+
+StringRef TypeDef::getSummary() const {
+  return def->getValueAsString("summary");
+}
+
+StringRef TypeDef::getStorageClassName() const {
+  return def->getValueAsString("storageClass");
+}
+StringRef TypeDef::getStorageNamespace() const {
+  return def->getValueAsString("storageNamespace");
+}
+
+bool TypeDef::genStorageClass() const {
+  return def->getValueAsBit("genStorageClass");
+}
+bool TypeDef::hasStorageCustomConstructor() const {
+  return def->getValueAsBit("hasStorageCustomConstructor");
+}
+void TypeDef::getParameters(SmallVectorImpl<TypeParameter> &parameters) const {
+  auto *parametersDag = def->getValueAsDag("parameters");
+  if (parametersDag != nullptr) {
+    size_t numParams = parametersDag->getNumArgs();
+    for (unsigned i = 0; i < numParams; i++)
+      parameters.push_back(TypeParameter(parametersDag, i));
+  }
+}
+unsigned TypeDef::getNumParameters() const {
+  auto *parametersDag = def->getValueAsDag("parameters");
+  return parametersDag ? parametersDag->getNumArgs() : 0;
+}
+llvm::Optional<StringRef> TypeDef::getMnemonic() const {
+  return def->getValueAsOptionalString("mnemonic");
+}
+llvm::Optional<StringRef> TypeDef::getPrinterCode() const {
+  return def->getValueAsOptionalCode("printer");
+}
+llvm::Optional<StringRef> TypeDef::getParserCode() const {
+  return def->getValueAsOptionalCode("parser");
+}
+bool TypeDef::genAccessors() const {
+  return def->getValueAsBit("genAccessors");
+}
+bool TypeDef::genVerifyInvariantsDecl() const {
+  return def->getValueAsBit("genVerifyInvariantsDecl");
+}
+llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
+  auto value = def->getValueAsString("extraClassDeclaration");
+  return value.empty() ? llvm::Optional<StringRef>() : value;
+}
+llvm::ArrayRef<llvm::SMLoc> TypeDef::getLoc() const { return def->getLoc(); }
+bool TypeDef::operator==(const TypeDef &other) const {
+  return def == other.def;
+}
+
+bool TypeDef::operator<(const TypeDef &other) const {
+  return getName() < other.getName();
+}
+
+StringRef TypeParameter::getName() const {
+  return def->getArgName(num)->getValue();
+}
+llvm::Optional<StringRef> TypeParameter::getAllocator() const {
+  llvm::Init *parameterType = def->getArg(num);
+  if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+    return llvm::Optional<StringRef>();
+
+  if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
+    llvm::RecordVal *code = typeParameter->getDef()->getValue("allocator");
+    if (llvm::CodeInit *ci = dyn_cast<llvm::CodeInit>(code->getValue()))
+      return ci->getValue();
+    if (isa<llvm::UnsetInit>(code->getValue()))
+      return llvm::Optional<StringRef>();
+
+    llvm::PrintFatalError(
+        typeParameter->getDef()->getLoc(),
+        "Record `" + def->getArgName(num)->getValue() +
+            "', field `printer' does not have a code initializer!");
+  }
+
+  llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
+                        "defs which inherit from TypeParameter\n");
+}
+StringRef TypeParameter::getCppType() const {
+  auto *parameterType = def->getArg(num);
+  if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+    return stringType->getValue();
+  if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType))
+    return typeParameter->getDef()->getValueAsString("cppType");
+  llvm::PrintFatalError(
+      "Parameters DAG arguments must be either strings or defs "
+      "which inherit from TypeParameter\n");
+}
+llvm::Optional<StringRef> TypeParameter::getDescription() const {
+  auto *parameterType = def->getArg(num);
+  if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
+    const auto *desc = typeParameter->getDef()->getValue("description");
+    if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
+      return ci->getValue();
+  }
+  return llvm::Optional<StringRef>();
+}
+StringRef TypeParameter::getSyntax() const {
+  auto *parameterType = def->getArg(num);
+  if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+    return stringType->getValue();
+  if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
+    const auto *syntax = typeParameter->getDef()->getValue("syntax");
+    if (syntax && isa<llvm::StringInit>(syntax->getValue()))
+      return dyn_cast<llvm::StringInit>(syntax->getValue())->getValue();
+    return getCppType();
+  }
+  llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
+                        "defs which inherit from TypeParameter");
+}
index 31c8ccc..d1d84f9 100644 (file)
@@ -9,6 +9,12 @@ mlir_tablegen(TestTypeInterfaces.h.inc -gen-type-interface-decls)
 mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs)
 add_public_tablegen_target(MLIRTestInterfaceIncGen)
 
+set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
+mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls)
+mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs)
+add_public_tablegen_target(MLIRTestDefIncGen)
+
+
 set(LLVM_TARGET_DEFINITIONS TestOps.td)
 mlir_tablegen(TestOps.h.inc -gen-op-decls)
 mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
@@ -25,11 +31,13 @@ add_mlir_library(MLIRTestDialect
   TestDialect.cpp
   TestPatterns.cpp
   TestTraits.cpp
+  TestTypes.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
   DEPENDS
   MLIRTestInterfaceIncGen
+  MLIRTestDefIncGen
   MLIRTestOpsIncGen
 
   LINK_LIBS PUBLIC
index 4ca89bc..3bfb824 100644 (file)
@@ -141,16 +141,23 @@ void TestDialect::initialize() {
       >();
   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
                 TestInlinerInterface>();
-  addTypes<TestType, TestRecursiveType>();
+  addTypes<TestType, TestRecursiveType,
+#define GET_TYPEDEF_LIST
+#include "TestTypeDefs.cpp.inc"
+           >();
   allowUnknownOperations();
 }
 
-static Type parseTestType(DialectAsmParser &parser,
+static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
                           llvm::SetVector<Type> &stack) {
   StringRef typeTag;
   if (failed(parser.parseKeyword(&typeTag)))
     return Type();
 
+  auto genType = generatedTypeParser(ctxt, parser, typeTag);
+  if (genType != Type())
+    return genType;
+
   if (typeTag == "test_type")
     return TestType::get(parser.getBuilder().getContext());
 
@@ -174,7 +181,7 @@ static Type parseTestType(DialectAsmParser &parser,
   if (failed(parser.parseComma()))
     return Type();
   stack.insert(rec);
-  Type subtype = parseTestType(parser, stack);
+  Type subtype = parseTestType(ctxt, parser, stack);
   stack.pop_back();
   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
     return Type();
@@ -184,11 +191,13 @@ static Type parseTestType(DialectAsmParser &parser,
 
 Type TestDialect::parseType(DialectAsmParser &parser) const {
   llvm::SetVector<Type> stack;
-  return parseTestType(parser, stack);
+  return parseTestType(getContext(), parser, stack);
 }
 
 static void printTestType(Type type, DialectAsmPrinter &printer,
                           llvm::SetVector<Type> &stack) {
+  if (succeeded(generatedTypePrinter(type, printer)))
+    return;
   if (type.isa<TestType>()) {
     printer << "test_type";
     return;
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
new file mode 100644 (file)
index 0000000..cfab698
--- /dev/null
@@ -0,0 +1,150 @@
+//===-- TestTypeDefs.td - Test dialect type definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TableGen data type definitions for Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_TYPEDEFS
+#define TEST_TYPEDEFS
+
+// To get the test dialect def.
+include "TestOps.td"
+
+// All of the types will extend this class.
+class Test_Type<string name> : TypeDef<Test_Dialect, name> { }
+
+def SimpleTypeA : Test_Type<"SimpleA"> {
+  let mnemonic = "smpla";
+
+  let printer = [{ $_printer << "smpla"; }];
+  let parser = [{ return get($_ctxt); }];
+}
+
+// A more complex parameterized type.
+def CompoundTypeA : Test_Type<"CompoundA"> {
+  let mnemonic = "cmpnd_a";
+
+  // List of type parameters.
+  let parameters = (
+    ins
+    "int":$widthOfSomething,
+    "::mlir::Type":$oneType,
+    // This is special syntax since ArrayRefs require allocation in the
+    // constructor.
+    ArrayRefParameter<
+      "int", // The parameter C++ type.
+      "An example of an array of ints" // Parameter description.
+      >: $arrayOfInts
+  );
+
+  let extraClassDeclaration = [{
+    struct SomeCppStruct {};
+  }];
+}
+
+// An example of how one could implement a standard integer.
+def IntegerType : Test_Type<"TestInteger"> {
+  let mnemonic = "int";
+  let genVerifyInvariantsDecl = 1;
+  let parameters = (
+    ins
+    // SignednessSemantics is defined below.
+    "::mlir::TestIntegerType::SignednessSemantics":$signedness,
+    "unsigned":$width
+  );
+
+  // We define the printer inline.
+  let printer = [{
+    $_printer << "int<";
+    printSignedness($_printer, getImpl()->signedness);
+    $_printer << ", " << getImpl()->width << ">";
+  }];
+
+  // The parser is defined here also.
+  let parser = [{
+    if (parser.parseLess()) return Type();
+    SignednessSemantics signedness;
+    if (parseSignedness($_parser, signedness)) return mlir::Type();
+    if ($_parser.parseComma()) return Type();
+    int width;
+    if ($_parser.parseInteger(width)) return Type();
+    if ($_parser.parseGreater()) return Type();
+    return get(ctxt, signedness, width);
+  }];
+
+  // Any extra code one wants in the type's class declaration.
+  let extraClassDeclaration = [{
+    /// Signedness semantics.
+    enum SignednessSemantics {
+      Signless, /// No signedness semantics
+      Signed,   /// Signed integer
+      Unsigned, /// Unsigned integer
+    };
+
+    /// This extra function is necessary since it doesn't include signedness
+    static IntegerType getChecked(unsigned width, Location location);
+
+    /// Return true if this is a signless integer type.
+    bool isSignless() const { return getSignedness() == Signless; }
+    /// Return true if this is a signed integer type.
+    bool isSigned() const { return getSignedness() == Signed; }
+    /// Return true if this is an unsigned integer type.
+    bool isUnsigned() const { return getSignedness() == Unsigned; }
+  }];
+}
+
+// A parent type for any type which is just a list of fields (e.g. structs,
+// unions).
+class FieldInfo_Type<string name> : Test_Type<name> {
+  let parameters = (
+    ins
+    // An ArrayRef of something which requires allocation in the storage
+    // constructor.
+    ArrayRefOfSelfAllocationParameter<
+      "::mlir::FieldInfo", // FieldInfo is defined/declared in TestTypes.h.
+      "Models struct fields">: $fields
+  );
+
+  // Prints the type in this format:
+  //   struct<[{field1Name, field1Type}, {field2Name, field2Type}]
+  let printer = [{
+    $_printer << "struct" << "<";
+    for (size_t i=0, e = getImpl()->fields.size(); i < e; i++) {
+      const auto& field = getImpl()->fields[i];
+      $_printer << "{" << field.name << "," << field.type << "}";
+      if (i < getImpl()->fields.size() - 1)
+          $_printer << ",";
+    }
+    $_printer << ">";
+  }];
+
+  // Parses the above format
+  let parser = [{
+    llvm::SmallVector<FieldInfo, 4> parameters;
+    if ($_parser.parseLess()) return Type();
+    while (mlir::succeeded($_parser.parseOptionalLBrace())) {
+      StringRef name;
+      if ($_parser.parseKeyword(&name)) return Type();
+      if ($_parser.parseComma()) return Type();
+      Type type;
+      if ($_parser.parseType(type)) return Type();
+      if ($_parser.parseRBrace()) return Type();
+      parameters.push_back(FieldInfo {name, type});
+      if ($_parser.parseOptionalComma()) break;
+    }
+    if ($_parser.parseGreater()) return Type();
+    return get($_ctxt, parameters);
+  }];
+}
+
+def StructType : FieldInfo_Type<"Struct"> {
+    let mnemonic = "struct";
+}
+
+#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
new file mode 100644 (file)
index 0000000..1eb5347
--- /dev/null
@@ -0,0 +1,117 @@
+//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+// Custom parser for SignednessSemantics.
+static ParseResult
+parseSignedness(DialectAsmParser &parser,
+                TestIntegerType::SignednessSemantics &result) {
+  StringRef signStr;
+  auto loc = parser.getCurrentLocation();
+  if (parser.parseKeyword(&signStr))
+    return failure();
+  if (signStr.compare_lower("u") || signStr.compare_lower("unsigned"))
+    result = TestIntegerType::SignednessSemantics::Unsigned;
+  else if (signStr.compare_lower("s") || signStr.compare_lower("signed"))
+    result = TestIntegerType::SignednessSemantics::Signed;
+  else if (signStr.compare_lower("n") || signStr.compare_lower("none"))
+    result = TestIntegerType::SignednessSemantics::Signless;
+  else
+    return parser.emitError(loc, "expected signed, unsigned, or none");
+  return success();
+}
+
+// Custom printer for SignednessSemantics.
+static void printSignedness(DialectAsmPrinter &printer,
+                            const TestIntegerType::SignednessSemantics &ss) {
+  switch (ss) {
+  case TestIntegerType::SignednessSemantics::Unsigned:
+    printer << "unsigned";
+    break;
+  case TestIntegerType::SignednessSemantics::Signed:
+    printer << "signed";
+    break;
+  case TestIntegerType::SignednessSemantics::Signless:
+    printer << "none";
+    break;
+  }
+}
+
+Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) {
+  int widthOfSomething;
+  Type oneType;
+  SmallVector<int, 4> arrayOfInts;
+  if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
+      parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
+      parser.parseLSquare())
+    return Type();
+
+  int i;
+  while (!*parser.parseOptionalInteger(i)) {
+    arrayOfInts.push_back(i);
+    if (parser.parseOptionalComma())
+      break;
+  }
+
+  if (parser.parseRSquare() || parser.parseGreater())
+    return Type();
+
+  return get(ctxt, widthOfSomething, oneType, arrayOfInts);
+}
+void CompoundAType::print(DialectAsmPrinter &printer) const {
+  printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
+          << ", [";
+  auto intArray = getArrayOfInts();
+  llvm::interleaveComma(intArray, printer);
+  printer << "]>";
+}
+
+// The functions don't need to be in the header file, but need to be in the mlir
+// namespace. Declare them here, then define them immediately below. Separating
+// the declaration and definition adheres to the LLVM coding standards.
+namespace mlir {
+// FieldInfo is used as part of a parameter, so equality comparison is
+// compulsory.
+static bool operator==(const FieldInfo &a, const FieldInfo &b);
+// FieldInfo is used as part of a parameter, so a hash will be computed.
+static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
+} // namespace mlir
+
+// FieldInfo is used as part of a parameter, so equality comparison is
+// compulsory.
+static bool mlir::operator==(const FieldInfo &a, const FieldInfo &b) {
+  return a.name == b.name && a.type == b.type;
+}
+
+// FieldInfo is used as part of a parameter, so a hash will be computed.
+static llvm::hash_code mlir::hash_value(const FieldInfo &fi) { // NOLINT
+  return llvm::hash_combine(fi.name, fi.type);
+}
+
+// Example type validity checker.
+LogicalResult TestIntegerType::verifyConstructionInvariants(
+    Location loc, TestIntegerType::SignednessSemantics ss, unsigned int width) {
+  if (width > 8)
+    return failure();
+  return success();
+}
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.cpp.inc"
index c7fd80e..371171f 100644 (file)
 #ifndef MLIR_TESTTYPES_H
 #define MLIR_TESTTYPES_H
 
+#include <tuple>
+
 #include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/Types.h"
 
 namespace mlir {
 
+/// FieldInfo represents a field in the StructType data type. It is used as a
+/// parameter in TestTypeDefs.td.
+struct FieldInfo {
+  StringRef name;
+  Type type;
+
+  // Custom allocation called from generated constructor code
+  FieldInfo allocateInto(TypeStorageAllocator &alloc) const {
+    return FieldInfo{alloc.copyInto(name), type};
+  }
+};
+
+} // namespace mlir
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.h.inc"
+
+namespace mlir {
+
 #include "TestTypeInterfaces.h.inc"
 
 /// This class is a simple test type that uses a generated interface.
diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
new file mode 100644 (file)
index 0000000..c8500e4
--- /dev/null
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s
+
+//////////////
+// Tests the types in the 'Test' dialect, not the ones in 'typedefs.mlir'
+
+// CHECK: @simpleA(%arg0: !test.smpla)
+func @simpleA(%A : !test.smpla) -> () {
+  return
+}
+
+// CHECK: @compoundA(%arg0: !test.cmpnd_a<1, !test.smpla, [5, 6]>)
+func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () {
+  return
+}
+
+// CHECK: @testInt(%arg0: !test.int<unsigned, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<unsigned, 1>)
+func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
+  return
+}
+
+// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<unsigned, 3>}>)
+func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int<none, 3>} > ) {
+  return
+}
diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
new file mode 100644 (file)
index 0000000..5ba3fbc
--- /dev/null
@@ -0,0 +1,132 @@
+// RUN: mlir-tblgen -gen-typedef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+
+include "mlir/IR/OpBase.td"
+
+// DECL: #ifdef GET_TYPEDEF_CLASSES
+// DECL: #undef GET_TYPEDEF_CLASSES
+
+// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);
+// DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer);
+
+// DEF: #ifdef GET_TYPEDEF_LIST
+// DEF: #undef GET_TYPEDEF_LIST
+// DEF: ::mlir::test::SimpleAType,
+// DEF: ::mlir::test::CompoundAType,
+// DEF: ::mlir::test::IndexType,
+// DEF: ::mlir::test::SingleParameterType,
+// DEF: ::mlir::test::IntegerType
+
+// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic)
+// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser);
+// DEF return ::mlir::Type();
+
+def Test_Dialect: Dialect {
+// DECL-NOT: TestDialect
+// DEF-NOT: TestDialect
+    let name = "TestDialect";
+    let cppNamespace = "::mlir::test";
+}
+
+class TestType<string name> : TypeDef<Test_Dialect, name> { }
+
+def A_SimpleTypeA : TestType<"SimpleA"> {
+// DECL: class SimpleAType: public ::mlir::Type
+}
+
+// A more complex parameterized type
+def B_CompoundTypeA : TestType<"CompoundA"> {
+  let summary = "A more complex parameterized type";
+  let description = "This type is to test a reasonably complex type";
+  let mnemonic = "cmpnd_a";
+  let parameters = (
+      ins
+      "int":$widthOfSomething,
+      "::mlir::test::SimpleTypeA": $exampleTdType,
+      "SomeCppStruct": $exampleCppType,
+      ArrayRefParameter<"int", "Matrix dimensions">:$dims
+  );
+
+  let genVerifyInvariantsDecl = 1;
+
+// DECL-LABEL: class CompoundAType: public ::mlir::Type
+// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
+// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
+// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
+// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
+// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
+// DECL: int getWidthOfSomething() const;
+// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
+// DECL: SomeCppStruct getExampleCppType() const;
+}
+
+def C_IndexType : TestType<"Index"> {
+    let mnemonic = "index";
+
+    let parameters = (
+      ins
+      StringRefParameter<"Label for index">:$label
+    );
+
+// DECL-LABEL: class IndexType: public ::mlir::Type
+// DECL: static ::llvm::StringRef getMnemonic() { return "index"; }
+// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
+// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
+}
+
+def D_SingleParameterType : TestType<"SingleParameter"> {
+  let parameters = (
+    ins
+    "int": $num
+  );
+// DECL-LABEL: struct SingleParameterTypeStorage;
+// DECL-LABEL: class SingleParameterType
+// DECL-NEXT:                   detail::SingleParameterTypeStorage
+}
+
+def E_IntegerType : TestType<"Integer"> {
+    let mnemonic = "int";
+    let genVerifyInvariantsDecl = 1;
+    let parameters = (
+        ins
+        "SignednessSemantics":$signedness,
+        TypeParameter<"unsigned", "Bitwdith of integer">:$width
+    );
+
+// DECL-LABEL: IntegerType: public ::mlir::Type
+
+    let extraClassDeclaration = [{
+  /// Signedness semantics.
+  enum SignednessSemantics {
+    Signless, /// No signedness semantics
+    Signed,   /// Signed integer
+    Unsigned, /// Unsigned integer
+  };
+
+  /// This extra function is necessary since it doesn't include signedness
+  static IntegerType getChecked(unsigned width, Location location);
+
+  /// Return true if this is a signless integer type.
+  bool isSignless() const { return getSignedness() == Signless; }
+  /// Return true if this is a signed integer type.
+  bool isSigned() const { return getSignedness() == Signed; }
+  /// Return true if this is an unsigned integer type.
+  bool isUnsigned() const { return getSignedness() == Unsigned; }
+    }];
+
+// DECL: /// Signedness semantics.
+// DECL-NEXT: enum SignednessSemantics {
+// DECL-NEXT:   Signless, /// No signedness semantics
+// DECL-NEXT:   Signed,   /// Signed integer
+// DECL-NEXT:   Unsigned, /// Unsigned integer
+// DECL-NEXT: };
+// DECL: /// This extra function is necessary since it doesn't include signedness
+// DECL-NEXT: static IntegerType getChecked(unsigned width, Location location);
+
+// DECL: /// Return true if this is a signless integer type.
+// DECL-NEXT: bool isSignless() const { return getSignedness() == Signless; }
+// DECL-NEXT: /// Return true if this is a signed integer type.
+// DECL-NEXT: bool isSigned() const { return getSignedness() == Signed; }
+// DECL-NEXT: /// Return true if this is an unsigned integer type.
+// DECL-NEXT: bool isUnsigned() const { return getSignedness() == Unsigned; }
+}
index 3904752..280fcc4 100644 (file)
@@ -20,6 +20,7 @@ add_tablegen(mlir-tblgen MLIR
   RewriterGen.cpp
   SPIRVUtilsGen.cpp
   StructsGen.cpp
+  TypeDefGen.cpp
   )
 
 set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")
index ff6a290..0dd59bc 100644 (file)
@@ -15,6 +15,7 @@
 #include "mlir/Support/IndentedOstream.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/TypeDef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -23,6 +24,8 @@
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 
+#include <set>
+
 using namespace llvm;
 using namespace mlir;
 using namespace mlir::tblgen;
@@ -156,11 +159,66 @@ static void emitTypeDoc(const Type &type, raw_ostream &os) {
 }
 
 //===----------------------------------------------------------------------===//
+// TypeDef Documentation
+//===----------------------------------------------------------------------===//
+
+/// Emit the assembly format of a type.
+static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) {
+  SmallVector<TypeParameter, 4> parameters;
+  td.getParameters(parameters);
+  if (parameters.size() == 0) {
+    os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic()
+       << "`\n";
+    return;
+  }
+
+  os << "\nSyntax:\n\n```\n!" << td.getDialect().getName() << "."
+     << td.getMnemonic() << "<\n";
+  for (auto *it = parameters.begin(), *e = parameters.end(); it < e; ++it) {
+    os << "  " << it->getSyntax();
+    if (it < parameters.end() - 1)
+      os << ",";
+    os << "   # " << it->getName() << "\n";
+  }
+  os << ">\n```\n";
+}
+
+static void emitTypeDefDoc(TypeDef td, raw_ostream &os) {
+  os << llvm::formatv("### `{0}` ({1})\n", td.getName(), td.getCppClassName());
+
+  // Emit the summary, syntax, and description if present.
+  if (td.hasSummary())
+    os << "\n" << td.getSummary() << "\n";
+  if (td.getMnemonic() && td.getPrinterCode() && *td.getPrinterCode() == "" &&
+      td.getParserCode() && *td.getParserCode() == "")
+    emitTypeAssemblyFormat(td, os);
+  if (td.hasDescription())
+    mlir::tblgen::emitDescription(td.getDescription(), os);
+
+  // Emit attribute documentation.
+  SmallVector<TypeParameter, 4> parameters;
+  td.getParameters(parameters);
+  if (parameters.size() != 0) {
+    os << "\n#### Type parameters:\n\n";
+    os << "| Parameter | C++ type | Description |\n"
+       << "| :-------: | :-------: | ----------- |\n";
+    for (const auto &it : parameters) {
+      auto desc = it.getDescription();
+      os << "| " << it.getName() << " | `" << td.getCppClassName() << "` | "
+         << (desc ? *desc : "") << " |\n";
+    }
+  }
+
+  os << "\n";
+}
+
+//===----------------------------------------------------------------------===//
 // Dialect Documentation
 //===----------------------------------------------------------------------===//
 
 static void emitDialectDoc(const Dialect &dialect, ArrayRef<Operator> ops,
-                           ArrayRef<Type> types, raw_ostream &os) {
+                           ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
+                           raw_ostream &os) {
   os << "# '" << dialect.getName() << "' Dialect\n\n";
   emitIfNotEmpty(dialect.getSummary(), os);
   emitIfNotEmpty(dialect.getDescription(), os);
@@ -169,7 +227,7 @@ static void emitDialectDoc(const Dialect &dialect, ArrayRef<Operator> ops,
 
   // TODO: Add link between use and def for types
   if (!types.empty()) {
-    os << "## Type definition\n\n";
+    os << "## Type constraint definition\n\n";
     for (const Type &type : types)
       emitTypeDoc(type, os);
   }
@@ -179,28 +237,43 @@ static void emitDialectDoc(const Dialect &dialect, ArrayRef<Operator> ops,
     for (const Operator &op : ops)
       emitOpDoc(op, os);
   }
+
+  if (!typeDefs.empty()) {
+    os << "## Type definition\n\n";
+    for (const TypeDef &td : typeDefs)
+      emitTypeDefDoc(td, os);
+  }
 }
 
 static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
   const auto &opDefs = recordKeeper.getAllDerivedDefinitions("Op");
   const auto &typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType");
+  const auto &typeDefDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
 
+  std::set<Dialect> dialectsWithDocs;
   std::map<Dialect, std::vector<Operator>> dialectOps;
   std::map<Dialect, std::vector<Type>> dialectTypes;
+  std::map<Dialect, std::vector<TypeDef>> dialectTypeDefs;
   for (auto *opDef : opDefs) {
     Operator op(opDef);
     dialectOps[op.getDialect()].push_back(op);
+    dialectsWithDocs.insert(op.getDialect());
   }
   for (auto *typeDef : typeDefs) {
     Type type(typeDef);
     if (auto dialect = type.getDialect())
       dialectTypes[dialect].push_back(type);
   }
+  for (auto *typeDef : typeDefDefs) {
+    TypeDef type(typeDef);
+    dialectTypeDefs[type.getDialect()].push_back(type);
+    dialectsWithDocs.insert(type.getDialect());
+  }
 
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
-  for (const auto &dialectWithOps : dialectOps)
-    emitDialectDoc(dialectWithOps.first, dialectWithOps.second,
-                   dialectTypes[dialectWithOps.first], os);
+  for (auto dialect : dialectsWithDocs)
+    emitDialectDoc(dialect, dialectOps[dialect], dialectTypes[dialect],
+                   dialectTypeDefs[dialect], os);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp
new file mode 100644 (file)
index 0000000..4e2ef48
--- /dev/null
@@ -0,0 +1,561 @@
+//===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TypeDefGen uses the description of typeDefs to generate C++ definitions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/TypeDef.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+#define DEBUG_TYPE "mlir-tblgen-typedefgen"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
+static llvm::cl::opt<std::string>
+    selectedDialect("typedefs-dialect",
+                    llvm::cl::desc("Gen types for this dialect"),
+                    llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
+
+/// Find all the TypeDefs for the specified dialect. If no dialect specified and
+/// can only find one dialect's types, use that.
+static void findAllTypeDefs(const llvm::RecordKeeper &recordKeeper,
+                            SmallVectorImpl<TypeDef> &typeDefs) {
+  auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
+  auto defs = llvm::map_range(
+      recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); });
+  if (defs.empty())
+    return;
+
+  StringRef dialectName;
+  if (selectedDialect.getNumOccurrences() == 0) {
+    if (defs.empty())
+      return;
+
+    llvm::SmallSet<Dialect, 4> dialects;
+    for (const TypeDef &typeDef : defs)
+      dialects.insert(typeDef.getDialect());
+    if (dialects.size() != 1)
+      llvm::PrintFatalError("TypeDefs belonging to more than one dialect. Must "
+                            "select one via '--typedefs-dialect'");
+
+    dialectName = (*dialects.begin()).getName();
+  } else if (selectedDialect.getNumOccurrences() == 1) {
+    dialectName = selectedDialect.getValue();
+  } else {
+    llvm::PrintFatalError("Cannot select multiple dialects for which to "
+                          "generate types via '--typedefs-dialect'.");
+  }
+
+  for (const TypeDef &typeDef : defs)
+    if (typeDef.getDialect().getName().equals(dialectName))
+      typeDefs.push_back(typeDef);
+}
+
+namespace {
+
+/// Pass an instance of this class to llvm::formatv() to emit a comma separated
+/// list of parameters in the format by 'EmitFormat'.
+class TypeParamCommaFormatter : public llvm::detail::format_adapter {
+public:
+  /// Choose the output format
+  enum EmitFormat {
+    /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
+    /// [...]".
+    TypeNamePairs,
+
+    /// Emit ", parameter1Type parameter1Name, parameter2Type parameter2Name,
+    /// [...]".
+    TypeNamePairsPrependComma,
+
+    /// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
+    TypeNameInitializer
+  };
+
+  TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params)
+      : emitFormat(emitFormat), params(params) {}
+
+  /// llvm::formatv will call this function when using an instance as a
+  /// replacement value.
+  void format(raw_ostream &os, StringRef options) {
+    if (params.size() && emitFormat == EmitFormat::TypeNamePairsPrependComma)
+      os << ", ";
+    switch (emitFormat) {
+    case EmitFormat::TypeNamePairs:
+    case EmitFormat::TypeNamePairsPrependComma:
+      interleaveComma(params, os,
+                      [&](const TypeParameter &p) { emitTypeNamePair(p, os); });
+      break;
+    case EmitFormat::TypeNameInitializer:
+      interleaveComma(params, os, [&](const TypeParameter &p) {
+        emitTypeNameInitializer(p, os);
+      });
+      break;
+    }
+  }
+
+private:
+  // Emit "paramType paramName".
+  static void emitTypeNamePair(const TypeParameter &param, raw_ostream &os) {
+    os << param.getCppType() << " " << param.getName();
+  }
+  // Emit "paramName(paramName)"
+  void emitTypeNameInitializer(const TypeParameter &param, raw_ostream &os) {
+    os << param.getName() << "(" << param.getName() << ")";
+  }
+
+  EmitFormat emitFormat;
+  ArrayRef<TypeParameter> params;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// GEN: TypeDef declarations
+//===----------------------------------------------------------------------===//
+
+/// The code block for the start of a typeDef class declaration -- singleton
+/// case.
+///
+/// {0}: The name of the typeDef class.
+static const char *const typeDefDeclSingletonBeginStr = R"(
+  class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{
+  public:
+    /// Inherit some necessary constructors from 'TypeBase'.
+    using Base::Base;
+
+)";
+
+/// The code block for the start of a typeDef class declaration -- parametric
+/// case.
+///
+/// {0}: The name of the typeDef class.
+/// {1}: The typeDef storage class namespace.
+/// {2}: The storage class name.
+/// {3}: The list of parameters with types.
+static const char *const typeDefDeclParametricBeginStr = R"(
+  namespace {1} {
+    struct {2};
+  }
+  class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type,
+                                        {1}::{2}> {{
+  public:
+    /// Inherit some necessary constructors from 'TypeBase'.
+    using Base::Base;
+
+)";
+
+/// The snippet for print/parse.
+static const char *const typeDefParsePrint = R"(
+    static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
+    void print(::mlir::DialectAsmPrinter& printer) const;
+)";
+
+/// The code block for the verifyConstructionInvariants and getChecked.
+///
+/// {0}: List of parameters, parameters style.
+/// {1}: C++ type class name.
+static const char *const typeDefDeclVerifyStr = R"(
+    static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0});
+    static {1} getChecked(Location loc{0});
+)";
+
+/// Generate the declaration for the given typeDef class.
+static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
+  SmallVector<TypeParameter, 4> params;
+  typeDef.getParameters(params);
+
+  // Emit the beginning string template: either the singleton or parametric
+  // template.
+  if (typeDef.getNumParameters() == 0)
+    os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
+                  typeDef.getStorageNamespace(), typeDef.getStorageClassName());
+  else
+    os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
+                  typeDef.getStorageNamespace(), typeDef.getStorageClassName());
+
+  // Emit the extra declarations first in case there's a type definition in
+  // there.
+  if (Optional<StringRef> extraDecl = typeDef.getExtraDecls())
+    os << *extraDecl << "\n";
+
+  TypeParamCommaFormatter emitTypeNamePairsAfterComma(
+      TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, params);
+  os << llvm::formatv("    static {0} get(::mlir::MLIRContext* ctxt{1});\n",
+                      typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
+
+  // Emit the verify invariants declaration.
+  if (typeDef.genVerifyInvariantsDecl())
+    os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma,
+                        typeDef.getCppClassName());
+
+  // Emit the mnenomic, if specified.
+  if (auto mnenomic = typeDef.getMnemonic()) {
+    os << "    static ::llvm::StringRef getMnemonic() { return \"" << mnenomic
+       << "\"; }\n";
+
+    // If mnemonic specified, emit print/parse declarations.
+    os << typeDefParsePrint;
+  }
+
+  if (typeDef.genAccessors()) {
+    SmallVector<TypeParameter, 4> parameters;
+    typeDef.getParameters(parameters);
+
+    for (TypeParameter &parameter : parameters) {
+      SmallString<16> name = parameter.getName();
+      name[0] = llvm::toUpper(name[0]);
+      os << formatv("    {0} get{1}() const;\n", parameter.getCppType(), name);
+    }
+  }
+
+  // End the typeDef decl.
+  os << "  };\n";
+}
+
+/// Main entry point for decls.
+static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
+                             raw_ostream &os) {
+  emitSourceFileHeader("TypeDef Declarations", os);
+
+  SmallVector<TypeDef, 16> typeDefs;
+  findAllTypeDefs(recordKeeper, typeDefs);
+
+  IfDefScope scope("GET_TYPEDEF_CLASSES", os);
+  if (typeDefs.size() > 0) {
+    NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
+
+    // Well known print/parse dispatch function declarations. These are called
+    // from Dialect::parseType() and Dialect::printType() methods.
+    os << "  ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
+          "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n";
+    os << "  ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
+          "::mlir::DialectAsmPrinter& printer);\n";
+    os << "\n";
+
+    // Declare all the type classes first (in case they reference each other).
+    for (const TypeDef &typeDef : typeDefs)
+      os << "  class " << typeDef.getCppClassName() << ";\n";
+
+    // Declare all the typedefs.
+    for (const TypeDef &typeDef : typeDefs)
+      emitTypeDefDecl(typeDef, os);
+  }
+
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: TypeDef list
+//===----------------------------------------------------------------------===//
+
+static void emitTypeDefList(SmallVectorImpl<TypeDef> &typeDefs,
+                            raw_ostream &os) {
+  IfDefScope scope("GET_TYPEDEF_LIST", os);
+  for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) {
+    os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName();
+    if (i < typeDefs.end() - 1)
+      os << ",\n";
+    else
+      os << "\n";
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: TypeDef definitions
+//===----------------------------------------------------------------------===//
+
+/// Beginning of storage class.
+/// {0}: Storage class namespace.
+/// {1}: Storage class c++ name.
+/// {2}: Parameters parameters.
+/// {3}: Parameter initialzer string.
+/// {4}: Parameter name list.
+/// {5}: Parameter types.
+static const char *const typeDefStorageClassBegin = R"(
+namespace {0} {{
+  struct {1} : public ::mlir::TypeStorage {{
+    {1} ({2})
+      : {3} {{ }
+
+    /// The hash key for this storage is a pair of the integer and type params.
+    using KeyTy = std::tuple<{5}>;
+
+    /// Define the comparison function for the key type.
+    bool operator==(const KeyTy &key) const {{
+      return key == KeyTy({4});
+    }
+)";
+
+/// The storage class' constructor template.
+/// {0}: storage class name.
+static const char *const typeDefStorageClassConstructorBegin = R"(
+    /// Define a construction method for creating a new instance of this storage.
+    static {0} *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {{
+)";
+
+/// The storage class' constructor return template.
+/// {0}: storage class name.
+/// {1}: list of parameters.
+static const char *const typeDefStorageClassConstructorReturn = R"(
+      return new (allocator.allocate<{0}>())
+          {0}({1});
+    }
+)";
+
+/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
+static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
+  SmallVector<TypeParameter, 4> parameters;
+  typeDef.getParameters(parameters);
+  auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
+  for (TypeParameter &parameter : parameters) {
+    auto allocCode = parameter.getAllocator();
+    if (allocCode) {
+      fmtCtxt.withSelf(parameter.getName());
+      fmtCtxt.addSubst("_dst", parameter.getName());
+      os << "      " << tgfmt(*allocCode, &fmtCtxt) << "\n";
+    }
+  }
+}
+
+/// Emit the storage class code for type 'typeDef'.
+/// This includes (in-order):
+///  1) typeDefStorageClassBegin, which includes:
+///      - The class constructor.
+///      - The KeyTy definition.
+///      - The equality (==) operator.
+///  2) The hashKey method.
+///  3) The construct method.
+///  4) The list of parameters as the storage class member variables.
+static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
+  SmallVector<TypeParameter, 4> parameters;
+  typeDef.getParameters(parameters);
+
+  // Initialize a bunch of variables to be used later on.
+  auto parameterNames = map_range(
+      parameters, [](TypeParameter parameter) { return parameter.getName(); });
+  auto parameterTypes = map_range(parameters, [](TypeParameter parameter) {
+    return parameter.getCppType();
+  });
+  auto parameterList = join(parameterNames, ", ");
+  auto parameterTypeList = join(parameterTypes, ", ");
+
+  // 1) Emit most of the storage class up until the hashKey body.
+  os << formatv(
+      typeDefStorageClassBegin, typeDef.getStorageNamespace(),
+      typeDef.getStorageClassName(),
+      TypeParamCommaFormatter(
+          TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
+      TypeParamCommaFormatter(
+          TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, parameters),
+      parameterList, parameterTypeList);
+
+  // 2) Emit the haskKey method.
+  os << "  static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
+  // Extract each parameter from the key.
+  for (size_t i = 0, e = parameters.size(); i < e; ++i)
+    os << formatv("      const auto &{0} = std::get<{1}>(key);\n",
+                  parameters[i].getName(), i);
+  // Then combine them all. This requires all the parameters types to have a
+  // hash_value defined.
+  os << "      return ::llvm::hash_combine(";
+  interleaveComma(parameterNames, os);
+  os << ");\n";
+  os << "    }\n";
+
+  // 3) Emit the construct method.
+  if (typeDef.hasStorageCustomConstructor())
+    // If user wants to build the storage constructor themselves, declare it
+    // here and then they can write the definition elsewhere.
+    os << "    static " << typeDef.getStorageClassName()
+       << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
+          "&key);\n";
+  else {
+    // If not, autogenerate one.
+
+    // First, unbox the parameters.
+    os << formatv(typeDefStorageClassConstructorBegin,
+                  typeDef.getStorageClassName());
+    for (size_t i = 0; i < parameters.size(); ++i) {
+      os << formatv("      auto {0} = std::get<{1}>(key);\n",
+                    parameters[i].getName(), i);
+    }
+    // Second, reassign the parameter variables with allocation code, if it's
+    // specified.
+    emitParameterAllocationCode(typeDef, os);
+
+    // Last, return an allocated copy.
+    os << formatv(typeDefStorageClassConstructorReturn,
+                  typeDef.getStorageClassName(), parameterList);
+  }
+
+  // 4) Emit the parameters as storage class members.
+  for (auto parameter : parameters) {
+    os << "      " << parameter.getCppType() << " " << parameter.getName()
+       << ";\n";
+  }
+  os << "  };\n";
+
+  os << "} // namespace " << typeDef.getStorageNamespace() << "\n";
+}
+
+/// Emit the parser and printer for a particular type, if they're specified.
+void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
+  // Emit the printer code, if specified.
+  if (auto printerCode = typeDef.getPrinterCode()) {
+    // Both the mnenomic and printerCode must be defined (for parity with
+    // parserCode).
+    os << "void " << typeDef.getCppClassName()
+       << "::print(::mlir::DialectAsmPrinter& printer) const {\n";
+    if (*printerCode == "") {
+      // If no code specified, emit error.
+      PrintFatalError(typeDef.getLoc(),
+                      typeDef.getName() +
+                          ": printer (if specified) must have non-empty code");
+    }
+    auto fmtCtxt = FmtContext().addSubst("_printer", "printer");
+    os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
+  }
+
+  // emit a parser, if specified.
+  if (auto parserCode = typeDef.getParserCode()) {
+    // The mnenomic must be defined so the dispatcher knows how to dispatch.
+    os << "::mlir::Type " << typeDef.getCppClassName()
+       << "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& "
+          "parser) "
+          "{\n";
+    if (*parserCode == "") {
+      // if no code specified, emit error.
+      PrintFatalError(typeDef.getLoc(),
+                      typeDef.getName() +
+                          ": parser (if specified) must have non-empty code");
+    }
+    auto fmtCtxt =
+        FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt");
+    os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
+  }
+}
+
+/// Print all the typedef-specific definition code.
+static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
+  NamespaceEmitter ns(os, typeDef.getDialect());
+  SmallVector<TypeParameter, 4> parameters;
+  typeDef.getParameters(parameters);
+
+  // Emit the storage class, if requested and necessary.
+  if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
+    emitStorageClass(typeDef, os);
+
+  os << llvm::formatv(
+      "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
+      "  return Base::get(ctxt",
+      typeDef.getCppClassName(),
+      TypeParamCommaFormatter(
+          TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma,
+          parameters));
+  for (TypeParameter &param : parameters)
+    os << ", " << param.getName();
+  os << ");\n}\n";
+
+  // Emit the parameter accessors.
+  if (typeDef.genAccessors())
+    for (const TypeParameter &parameter : parameters) {
+      SmallString<16> name = parameter.getName();
+      name[0] = llvm::toUpper(name[0]);
+      os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
+                    parameter.getCppType(), name, parameter.getName(),
+                    typeDef.getCppClassName());
+    }
+
+  // If mnemonic is specified maybe print definitions for the parser and printer
+  // code, if they're specified.
+  if (typeDef.getMnemonic())
+    emitParserPrinter(typeDef, os);
+}
+
+/// Emit the dialect printer/parser dispatcher. User's code should call these
+/// functions from their dialect's print/parse methods.
+static void emitParsePrintDispatch(SmallVectorImpl<TypeDef> &typeDefs,
+                                   raw_ostream &os) {
+  if (typeDefs.size() == 0)
+    return;
+  const Dialect &dialect = typeDefs.begin()->getDialect();
+  NamespaceEmitter ns(os, dialect);
+
+  // The parser dispatch is just a list of if-elses, matching on the mnemonic
+  // and calling the class's parse function.
+  os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
+        "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
+  for (const TypeDef &typeDef : typeDefs)
+    if (typeDef.getMnemonic())
+      os << formatv("  if (mnemonic == {0}::{1}::getMnemonic()) return "
+                    "{0}::{1}::parse(ctxt, parser);\n",
+                    typeDef.getDialect().getCppNamespace(),
+                    typeDef.getCppClassName());
+  os << "  return ::mlir::Type();\n";
+  os << "}\n\n";
+
+  // The printer dispatch uses llvm::TypeSwitch to find and call the correct
+  // printer.
+  os << "::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
+        "::mlir::DialectAsmPrinter& printer) {\n"
+     << "  ::mlir::LogicalResult found = ::mlir::success();\n"
+     << "  ::llvm::TypeSwitch<::mlir::Type>(type)\n";
+  for (auto typeDef : typeDefs)
+    if (typeDef.getMnemonic())
+      os << formatv("    .Case<{0}::{1}>([&](::mlir::Type t) {{ "
+                    "t.dyn_cast<{0}::{1}>().print(printer); })\n",
+                    typeDef.getDialect().getCppNamespace(),
+                    typeDef.getCppClassName());
+  os << "    .Default([&found](::mlir::Type) { found = ::mlir::failure(); "
+        "});\n"
+     << "  return found;\n"
+     << "}\n\n";
+}
+
+/// Entry point for typedef definitions.
+static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
+                            raw_ostream &os) {
+  emitSourceFileHeader("TypeDef Definitions", os);
+
+  SmallVector<TypeDef, 16> typeDefs;
+  findAllTypeDefs(recordKeeper, typeDefs);
+  emitTypeDefList(typeDefs, os);
+
+  IfDefScope scope("GET_TYPEDEF_CLASSES", os);
+  emitParsePrintDispatch(typeDefs, os);
+  for (auto typeDef : typeDefs)
+    emitTypeDefDef(typeDef, os);
+
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: TypeDef registration hooks
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+    genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions",
+                   [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                     return emitTypeDefDefs(records, os);
+                   });
+
+static mlir::GenRegistration
+    genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations",
+                    [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                      return emitTypeDefDecls(records, os);
+                    });