set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
+ mlir_tablegen(${dialect}Types.h.inc -gen-typedef-decls)
+ mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs)
mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
add_public_tablegen_target(MLIR${dialect}IncGen)
add_dependencies(mlir-headers MLIR${dialect}IncGen)
// so to replace the matched DAG with an existing SSA value.
def replaceWithValue;
+
+//===----------------------------------------------------------------------===//
+// Data type generation
+//===----------------------------------------------------------------------===//
+
+// Define a new type belonging to a dialect and called 'name'.
+class TypeDef<Dialect owningdialect, string name> {
+ Dialect dialect = owningdialect;
+ string cppClassName = name # "Type";
+
+ // Short summary of the type.
+ string summary = ?;
+ // The longer description of this type.
+ string description = ?;
+
+ // Name of storage class to generate or use.
+ string storageClass = name # "TypeStorage";
+ // Namespace (withing dialect c++ namespace) in which the storage class
+ // resides.
+ string storageNamespace = "detail";
+ // Specify if the storage class is to be generated.
+ bit genStorageClass = 1;
+ // Specify that the generated storage class has a constructor which is written
+ // in C++.
+ bit hasStorageCustomConstructor = 0;
+
+ // The list of parameters for this type. Parameters will become both
+ // parameters to the get() method and storage class member variables.
+ //
+ // The format of this dag is:
+ // (ins
+ // "<c++ type>":$param1Name,
+ // "<c++ type>":$param2Name,
+ // TypeParameter<"c++ type", "param description">:$param3Name)
+ // TypeParameters (or more likely one of their subclasses) are required to add
+ // more information about the parameter, specifically:
+ // - Documentation
+ // - Code to allocate the parameter (if allocation is needed in the storage
+ // class constructor)
+ //
+ // For example:
+ // (ins
+ // "int":$width,
+ // ArrayRefParameter<"bool", "list of bools">:$yesNoArray)
+ //
+ // (ArrayRefParameter is a subclass of TypeParameter which has allocation code
+ // for re-allocating ArrayRefs. It is defined below.)
+ dag parameters = (ins);
+
+ // Use the lowercased name as the keyword for parsing/printing. Specify only
+ // if you want tblgen to generate declarations and/or definitions of
+ // printer/parser for this type.
+ string mnemonic = ?;
+ // If 'mnemonic' specified,
+ // If null, generate just the declarations.
+ // If a non-empty code block, just use that code as the definition code.
+ // Error if an empty code block.
+ code printer = ?;
+ code parser = ?;
+
+ // If set, generate accessors for each Type parameter.
+ bit genAccessors = 1;
+ // Generate the verifyConstructionInvariants declaration and getChecked
+ // method.
+ bit genVerifyInvariantsDecl = 0;
+ // Extra code to include in the class declaration.
+ code extraClassDeclaration = [{}];
+}
+
+// 'Parameters' should be subclasses of this or simple strings (which is a
+// shorthand for TypeParameter<"C++Type">).
+class TypeParameter<string type, string desc> {
+ // Custom memory allocation code for storage constructor.
+ code allocator = ?;
+ // The C++ type of this parameter.
+ string cppType = type;
+ // A description of this parameter.
+ string description = desc;
+ // The format string for the asm syntax (documentation only).
+ string syntax = ?;
+}
+
+// For StringRefs, which require allocation.
+class StringRefParameter<string desc> :
+ TypeParameter<"::llvm::StringRef", desc> {
+ let allocator = [{$_dst = $_allocator.copyInto($_self);}];
+}
+
+// For standard ArrayRefs, which require allocation.
+class ArrayRefParameter<string arrayOf, string desc> :
+ TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
+ let allocator = [{$_dst = $_allocator.copyInto($_self);}];
+}
+
+// For classes which require allocation and have their own allocateInto method.
+class SelfAllocationParameter<string type, string desc> :
+ TypeParameter<type, desc> {
+ let allocator = [{$_dst = $_self.allocateInto($_allocator);}];
+}
+
+// For ArrayRefs which contain things which allocate themselves.
+class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
+ TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
+ let allocator = [{
+ llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields;
+ for (size_t i = 0, e = $_self.size(); i < e; ++i)
+ tmpFields.push_back($_self[i].allocateInto($_allocator));
+ $_dst = $_allocator.copyInto(ArrayRef<}] # arrayOf # [{>(tmpFields));
+ }];
+}
+
+
#endif // OP_BASE
--- /dev/null
+//===-- TypeDef.h - Record wrapper for type definitions ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TypeDef wrapper to simplify using TableGen Record defining a MLIR type.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_TYPEDEF_H
+#define MLIR_TABLEGEN_TYPEDEF_H
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Dialect.h"
+
+namespace llvm {
+class Record;
+class DagInit;
+class SMLoc;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+class TypeParameter;
+
+/// Wrapper class that contains a TableGen TypeDef's record and provides helper
+/// methods for accessing them.
+class TypeDef {
+public:
+ explicit TypeDef(const llvm::Record *def) : def(def) {}
+
+ // Get the dialect for which this type belongs.
+ Dialect getDialect() const;
+
+ // Returns the name of this TypeDef record.
+ StringRef getName() const;
+
+ // Query functions for the documentation of the operator.
+ bool hasDescription() const;
+ StringRef getDescription() const;
+ bool hasSummary() const;
+ StringRef getSummary() const;
+
+ // Returns the name of the C++ class to generate.
+ StringRef getCppClassName() const;
+
+ // Returns the name of the storage class for this type.
+ StringRef getStorageClassName() const;
+
+ // Returns the C++ namespace for this types storage class.
+ StringRef getStorageNamespace() const;
+
+ // Returns true if we should generate the storage class.
+ bool genStorageClass() const;
+
+ // Indicates whether or not to generate the storage class constructor.
+ bool hasStorageCustomConstructor() const;
+
+ // Fill a list with this types parameters. See TypeDef in OpBase.td for
+ // documentation of parameter usage.
+ void getParameters(SmallVectorImpl<TypeParameter> &) const;
+ // Return the number of type parameters
+ unsigned getNumParameters() const;
+
+ // Return the keyword/mnemonic to use in the printer/parser methods if we are
+ // supposed to auto-generate them.
+ Optional<StringRef> getMnemonic() const;
+
+ // Returns the code to use as the types printer method. If not specified,
+ // return a non-value. Otherwise, return the contents of that code block.
+ Optional<StringRef> getPrinterCode() const;
+
+ // Returns the code to use as the types parser method. If not specified,
+ // return a non-value. Otherwise, return the contents of that code block.
+ Optional<StringRef> getParserCode() const;
+
+ // Returns true if the accessors based on the types parameters should be
+ // generated.
+ bool genAccessors() const;
+
+ // Return true if we need to generate the verifyConstructionInvariants
+ // declaration and getChecked method.
+ bool genVerifyInvariantsDecl() const;
+
+ // Returns the dialects extra class declaration code.
+ Optional<StringRef> getExtraDecls() const;
+
+ // Get the code location (for error printing).
+ ArrayRef<llvm::SMLoc> getLoc() const;
+
+ // Returns whether two TypeDefs are equal by checking the equality of the
+ // underlying record.
+ bool operator==(const TypeDef &other) const;
+
+ // Compares two TypeDefs by comparing the names of the dialects.
+ bool operator<(const TypeDef &other) const;
+
+ // Returns whether the TypeDef is defined.
+ operator bool() const { return def != nullptr; }
+
+private:
+ const llvm::Record *def;
+};
+
+// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs
+// to parameterize them.
+class TypeParameter {
+public:
+ explicit TypeParameter(const llvm::DagInit *def, unsigned num)
+ : def(def), num(num) {}
+
+ // Get the parameter name.
+ StringRef getName() const;
+ // If specified, get the custom allocator code for this parameter.
+ llvm::Optional<StringRef> getAllocator() const;
+ // Get the C++ type of this parameter.
+ StringRef getCppType() const;
+ // Get a description of this parameter for documentation purposes.
+ llvm::Optional<StringRef> getDescription() const;
+ // Get the assembly syntax documentation.
+ StringRef getSyntax() const;
+
+private:
+ const llvm::DagInit *def;
+ const unsigned num;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_TYPEDEF_H
SideEffects.cpp
Successor.cpp
Type.cpp
+ TypeDef.cpp
DISABLE_LLVM_LINK_LLVM_DYLIB
--- /dev/null
+//===- TypeDef.cpp - TypeDef wrapper class --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TypeDef wrapper to simplify using TableGen Record defining a MLIR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/TypeDef.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+Dialect TypeDef::getDialect() const {
+ auto *dialectDef =
+ dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
+ if (dialectDef == nullptr)
+ return Dialect(nullptr);
+ return Dialect(dialectDef->getDef());
+}
+
+StringRef TypeDef::getName() const { return def->getName(); }
+StringRef TypeDef::getCppClassName() const {
+ return def->getValueAsString("cppClassName");
+}
+
+bool TypeDef::hasDescription() const {
+ const llvm::RecordVal *s = def->getValue("description");
+ return s != nullptr && isa<llvm::StringInit>(s->getValue());
+}
+
+StringRef TypeDef::getDescription() const {
+ return def->getValueAsString("description");
+}
+
+bool TypeDef::hasSummary() const {
+ const llvm::RecordVal *s = def->getValue("summary");
+ return s != nullptr && isa<llvm::StringInit>(s->getValue());
+}
+
+StringRef TypeDef::getSummary() const {
+ return def->getValueAsString("summary");
+}
+
+StringRef TypeDef::getStorageClassName() const {
+ return def->getValueAsString("storageClass");
+}
+StringRef TypeDef::getStorageNamespace() const {
+ return def->getValueAsString("storageNamespace");
+}
+
+bool TypeDef::genStorageClass() const {
+ return def->getValueAsBit("genStorageClass");
+}
+bool TypeDef::hasStorageCustomConstructor() const {
+ return def->getValueAsBit("hasStorageCustomConstructor");
+}
+void TypeDef::getParameters(SmallVectorImpl<TypeParameter> ¶meters) const {
+ auto *parametersDag = def->getValueAsDag("parameters");
+ if (parametersDag != nullptr) {
+ size_t numParams = parametersDag->getNumArgs();
+ for (unsigned i = 0; i < numParams; i++)
+ parameters.push_back(TypeParameter(parametersDag, i));
+ }
+}
+unsigned TypeDef::getNumParameters() const {
+ auto *parametersDag = def->getValueAsDag("parameters");
+ return parametersDag ? parametersDag->getNumArgs() : 0;
+}
+llvm::Optional<StringRef> TypeDef::getMnemonic() const {
+ return def->getValueAsOptionalString("mnemonic");
+}
+llvm::Optional<StringRef> TypeDef::getPrinterCode() const {
+ return def->getValueAsOptionalCode("printer");
+}
+llvm::Optional<StringRef> TypeDef::getParserCode() const {
+ return def->getValueAsOptionalCode("parser");
+}
+bool TypeDef::genAccessors() const {
+ return def->getValueAsBit("genAccessors");
+}
+bool TypeDef::genVerifyInvariantsDecl() const {
+ return def->getValueAsBit("genVerifyInvariantsDecl");
+}
+llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
+ auto value = def->getValueAsString("extraClassDeclaration");
+ return value.empty() ? llvm::Optional<StringRef>() : value;
+}
+llvm::ArrayRef<llvm::SMLoc> TypeDef::getLoc() const { return def->getLoc(); }
+bool TypeDef::operator==(const TypeDef &other) const {
+ return def == other.def;
+}
+
+bool TypeDef::operator<(const TypeDef &other) const {
+ return getName() < other.getName();
+}
+
+StringRef TypeParameter::getName() const {
+ return def->getArgName(num)->getValue();
+}
+llvm::Optional<StringRef> TypeParameter::getAllocator() const {
+ llvm::Init *parameterType = def->getArg(num);
+ if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+ return llvm::Optional<StringRef>();
+
+ if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
+ llvm::RecordVal *code = typeParameter->getDef()->getValue("allocator");
+ if (llvm::CodeInit *ci = dyn_cast<llvm::CodeInit>(code->getValue()))
+ return ci->getValue();
+ if (isa<llvm::UnsetInit>(code->getValue()))
+ return llvm::Optional<StringRef>();
+
+ llvm::PrintFatalError(
+ typeParameter->getDef()->getLoc(),
+ "Record `" + def->getArgName(num)->getValue() +
+ "', field `printer' does not have a code initializer!");
+ }
+
+ llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
+ "defs which inherit from TypeParameter\n");
+}
+StringRef TypeParameter::getCppType() const {
+ auto *parameterType = def->getArg(num);
+ if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+ return stringType->getValue();
+ if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType))
+ return typeParameter->getDef()->getValueAsString("cppType");
+ llvm::PrintFatalError(
+ "Parameters DAG arguments must be either strings or defs "
+ "which inherit from TypeParameter\n");
+}
+llvm::Optional<StringRef> TypeParameter::getDescription() const {
+ auto *parameterType = def->getArg(num);
+ if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
+ const auto *desc = typeParameter->getDef()->getValue("description");
+ if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
+ return ci->getValue();
+ }
+ return llvm::Optional<StringRef>();
+}
+StringRef TypeParameter::getSyntax() const {
+ auto *parameterType = def->getArg(num);
+ if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+ return stringType->getValue();
+ if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
+ const auto *syntax = typeParameter->getDef()->getValue("syntax");
+ if (syntax && isa<llvm::StringInit>(syntax->getValue()))
+ return dyn_cast<llvm::StringInit>(syntax->getValue())->getValue();
+ return getCppType();
+ }
+ llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
+ "defs which inherit from TypeParameter");
+}
mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRTestInterfaceIncGen)
+set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
+mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls)
+mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs)
+add_public_tablegen_target(MLIRTestDefIncGen)
+
+
set(LLVM_TARGET_DEFINITIONS TestOps.td)
mlir_tablegen(TestOps.h.inc -gen-op-decls)
mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
TestDialect.cpp
TestPatterns.cpp
TestTraits.cpp
+ TestTypes.cpp
EXCLUDE_FROM_LIBMLIR
DEPENDS
MLIRTestInterfaceIncGen
+ MLIRTestDefIncGen
MLIRTestOpsIncGen
LINK_LIBS PUBLIC
>();
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface>();
- addTypes<TestType, TestRecursiveType>();
+ addTypes<TestType, TestRecursiveType,
+#define GET_TYPEDEF_LIST
+#include "TestTypeDefs.cpp.inc"
+ >();
allowUnknownOperations();
}
-static Type parseTestType(DialectAsmParser &parser,
+static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
llvm::SetVector<Type> &stack) {
StringRef typeTag;
if (failed(parser.parseKeyword(&typeTag)))
return Type();
+ auto genType = generatedTypeParser(ctxt, parser, typeTag);
+ if (genType != Type())
+ return genType;
+
if (typeTag == "test_type")
return TestType::get(parser.getBuilder().getContext());
if (failed(parser.parseComma()))
return Type();
stack.insert(rec);
- Type subtype = parseTestType(parser, stack);
+ Type subtype = parseTestType(ctxt, parser, stack);
stack.pop_back();
if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
return Type();
Type TestDialect::parseType(DialectAsmParser &parser) const {
llvm::SetVector<Type> stack;
- return parseTestType(parser, stack);
+ return parseTestType(getContext(), parser, stack);
}
static void printTestType(Type type, DialectAsmPrinter &printer,
llvm::SetVector<Type> &stack) {
+ if (succeeded(generatedTypePrinter(type, printer)))
+ return;
if (type.isa<TestType>()) {
printer << "test_type";
return;
--- /dev/null
+//===-- TestTypeDefs.td - Test dialect type definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TableGen data type definitions for Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_TYPEDEFS
+#define TEST_TYPEDEFS
+
+// To get the test dialect def.
+include "TestOps.td"
+
+// All of the types will extend this class.
+class Test_Type<string name> : TypeDef<Test_Dialect, name> { }
+
+def SimpleTypeA : Test_Type<"SimpleA"> {
+ let mnemonic = "smpla";
+
+ let printer = [{ $_printer << "smpla"; }];
+ let parser = [{ return get($_ctxt); }];
+}
+
+// A more complex parameterized type.
+def CompoundTypeA : Test_Type<"CompoundA"> {
+ let mnemonic = "cmpnd_a";
+
+ // List of type parameters.
+ let parameters = (
+ ins
+ "int":$widthOfSomething,
+ "::mlir::Type":$oneType,
+ // This is special syntax since ArrayRefs require allocation in the
+ // constructor.
+ ArrayRefParameter<
+ "int", // The parameter C++ type.
+ "An example of an array of ints" // Parameter description.
+ >: $arrayOfInts
+ );
+
+ let extraClassDeclaration = [{
+ struct SomeCppStruct {};
+ }];
+}
+
+// An example of how one could implement a standard integer.
+def IntegerType : Test_Type<"TestInteger"> {
+ let mnemonic = "int";
+ let genVerifyInvariantsDecl = 1;
+ let parameters = (
+ ins
+ // SignednessSemantics is defined below.
+ "::mlir::TestIntegerType::SignednessSemantics":$signedness,
+ "unsigned":$width
+ );
+
+ // We define the printer inline.
+ let printer = [{
+ $_printer << "int<";
+ printSignedness($_printer, getImpl()->signedness);
+ $_printer << ", " << getImpl()->width << ">";
+ }];
+
+ // The parser is defined here also.
+ let parser = [{
+ if (parser.parseLess()) return Type();
+ SignednessSemantics signedness;
+ if (parseSignedness($_parser, signedness)) return mlir::Type();
+ if ($_parser.parseComma()) return Type();
+ int width;
+ if ($_parser.parseInteger(width)) return Type();
+ if ($_parser.parseGreater()) return Type();
+ return get(ctxt, signedness, width);
+ }];
+
+ // Any extra code one wants in the type's class declaration.
+ let extraClassDeclaration = [{
+ /// Signedness semantics.
+ enum SignednessSemantics {
+ Signless, /// No signedness semantics
+ Signed, /// Signed integer
+ Unsigned, /// Unsigned integer
+ };
+
+ /// This extra function is necessary since it doesn't include signedness
+ static IntegerType getChecked(unsigned width, Location location);
+
+ /// Return true if this is a signless integer type.
+ bool isSignless() const { return getSignedness() == Signless; }
+ /// Return true if this is a signed integer type.
+ bool isSigned() const { return getSignedness() == Signed; }
+ /// Return true if this is an unsigned integer type.
+ bool isUnsigned() const { return getSignedness() == Unsigned; }
+ }];
+}
+
+// A parent type for any type which is just a list of fields (e.g. structs,
+// unions).
+class FieldInfo_Type<string name> : Test_Type<name> {
+ let parameters = (
+ ins
+ // An ArrayRef of something which requires allocation in the storage
+ // constructor.
+ ArrayRefOfSelfAllocationParameter<
+ "::mlir::FieldInfo", // FieldInfo is defined/declared in TestTypes.h.
+ "Models struct fields">: $fields
+ );
+
+ // Prints the type in this format:
+ // struct<[{field1Name, field1Type}, {field2Name, field2Type}]
+ let printer = [{
+ $_printer << "struct" << "<";
+ for (size_t i=0, e = getImpl()->fields.size(); i < e; i++) {
+ const auto& field = getImpl()->fields[i];
+ $_printer << "{" << field.name << "," << field.type << "}";
+ if (i < getImpl()->fields.size() - 1)
+ $_printer << ",";
+ }
+ $_printer << ">";
+ }];
+
+ // Parses the above format
+ let parser = [{
+ llvm::SmallVector<FieldInfo, 4> parameters;
+ if ($_parser.parseLess()) return Type();
+ while (mlir::succeeded($_parser.parseOptionalLBrace())) {
+ StringRef name;
+ if ($_parser.parseKeyword(&name)) return Type();
+ if ($_parser.parseComma()) return Type();
+ Type type;
+ if ($_parser.parseType(type)) return Type();
+ if ($_parser.parseRBrace()) return Type();
+ parameters.push_back(FieldInfo {name, type});
+ if ($_parser.parseOptionalComma()) break;
+ }
+ if ($_parser.parseGreater()) return Type();
+ return get($_ctxt, parameters);
+ }];
+}
+
+def StructType : FieldInfo_Type<"Struct"> {
+ let mnemonic = "struct";
+}
+
+#endif // TEST_TYPEDEFS
--- /dev/null
+//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains types defined by the TestDialect for testing various
+// features of MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+// Custom parser for SignednessSemantics.
+static ParseResult
+parseSignedness(DialectAsmParser &parser,
+ TestIntegerType::SignednessSemantics &result) {
+ StringRef signStr;
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseKeyword(&signStr))
+ return failure();
+ if (signStr.compare_lower("u") || signStr.compare_lower("unsigned"))
+ result = TestIntegerType::SignednessSemantics::Unsigned;
+ else if (signStr.compare_lower("s") || signStr.compare_lower("signed"))
+ result = TestIntegerType::SignednessSemantics::Signed;
+ else if (signStr.compare_lower("n") || signStr.compare_lower("none"))
+ result = TestIntegerType::SignednessSemantics::Signless;
+ else
+ return parser.emitError(loc, "expected signed, unsigned, or none");
+ return success();
+}
+
+// Custom printer for SignednessSemantics.
+static void printSignedness(DialectAsmPrinter &printer,
+ const TestIntegerType::SignednessSemantics &ss) {
+ switch (ss) {
+ case TestIntegerType::SignednessSemantics::Unsigned:
+ printer << "unsigned";
+ break;
+ case TestIntegerType::SignednessSemantics::Signed:
+ printer << "signed";
+ break;
+ case TestIntegerType::SignednessSemantics::Signless:
+ printer << "none";
+ break;
+ }
+}
+
+Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) {
+ int widthOfSomething;
+ Type oneType;
+ SmallVector<int, 4> arrayOfInts;
+ if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
+ parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
+ parser.parseLSquare())
+ return Type();
+
+ int i;
+ while (!*parser.parseOptionalInteger(i)) {
+ arrayOfInts.push_back(i);
+ if (parser.parseOptionalComma())
+ break;
+ }
+
+ if (parser.parseRSquare() || parser.parseGreater())
+ return Type();
+
+ return get(ctxt, widthOfSomething, oneType, arrayOfInts);
+}
+void CompoundAType::print(DialectAsmPrinter &printer) const {
+ printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
+ << ", [";
+ auto intArray = getArrayOfInts();
+ llvm::interleaveComma(intArray, printer);
+ printer << "]>";
+}
+
+// The functions don't need to be in the header file, but need to be in the mlir
+// namespace. Declare them here, then define them immediately below. Separating
+// the declaration and definition adheres to the LLVM coding standards.
+namespace mlir {
+// FieldInfo is used as part of a parameter, so equality comparison is
+// compulsory.
+static bool operator==(const FieldInfo &a, const FieldInfo &b);
+// FieldInfo is used as part of a parameter, so a hash will be computed.
+static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
+} // namespace mlir
+
+// FieldInfo is used as part of a parameter, so equality comparison is
+// compulsory.
+static bool mlir::operator==(const FieldInfo &a, const FieldInfo &b) {
+ return a.name == b.name && a.type == b.type;
+}
+
+// FieldInfo is used as part of a parameter, so a hash will be computed.
+static llvm::hash_code mlir::hash_value(const FieldInfo &fi) { // NOLINT
+ return llvm::hash_combine(fi.name, fi.type);
+}
+
+// Example type validity checker.
+LogicalResult TestIntegerType::verifyConstructionInvariants(
+ Location loc, TestIntegerType::SignednessSemantics ss, unsigned int width) {
+ if (width > 8)
+ return failure();
+ return success();
+}
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.cpp.inc"
#ifndef MLIR_TESTTYPES_H
#define MLIR_TESTTYPES_H
+#include <tuple>
+
#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
namespace mlir {
+/// FieldInfo represents a field in the StructType data type. It is used as a
+/// parameter in TestTypeDefs.td.
+struct FieldInfo {
+ StringRef name;
+ Type type;
+
+ // Custom allocation called from generated constructor code
+ FieldInfo allocateInto(TypeStorageAllocator &alloc) const {
+ return FieldInfo{alloc.copyInto(name), type};
+ }
+};
+
+} // namespace mlir
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.h.inc"
+
+namespace mlir {
+
#include "TestTypeInterfaces.h.inc"
/// This class is a simple test type that uses a generated interface.
--- /dev/null
+// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s
+
+//////////////
+// Tests the types in the 'Test' dialect, not the ones in 'typedefs.mlir'
+
+// CHECK: @simpleA(%arg0: !test.smpla)
+func @simpleA(%A : !test.smpla) -> () {
+ return
+}
+
+// CHECK: @compoundA(%arg0: !test.cmpnd_a<1, !test.smpla, [5, 6]>)
+func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () {
+ return
+}
+
+// CHECK: @testInt(%arg0: !test.int<unsigned, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<unsigned, 1>)
+func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
+ return
+}
+
+// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<unsigned, 3>}>)
+func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int<none, 3>} > ) {
+ return
+}
--- /dev/null
+// RUN: mlir-tblgen -gen-typedef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+
+include "mlir/IR/OpBase.td"
+
+// DECL: #ifdef GET_TYPEDEF_CLASSES
+// DECL: #undef GET_TYPEDEF_CLASSES
+
+// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);
+// DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer);
+
+// DEF: #ifdef GET_TYPEDEF_LIST
+// DEF: #undef GET_TYPEDEF_LIST
+// DEF: ::mlir::test::SimpleAType,
+// DEF: ::mlir::test::CompoundAType,
+// DEF: ::mlir::test::IndexType,
+// DEF: ::mlir::test::SingleParameterType,
+// DEF: ::mlir::test::IntegerType
+
+// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic)
+// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser);
+// DEF return ::mlir::Type();
+
+def Test_Dialect: Dialect {
+// DECL-NOT: TestDialect
+// DEF-NOT: TestDialect
+ let name = "TestDialect";
+ let cppNamespace = "::mlir::test";
+}
+
+class TestType<string name> : TypeDef<Test_Dialect, name> { }
+
+def A_SimpleTypeA : TestType<"SimpleA"> {
+// DECL: class SimpleAType: public ::mlir::Type
+}
+
+// A more complex parameterized type
+def B_CompoundTypeA : TestType<"CompoundA"> {
+ let summary = "A more complex parameterized type";
+ let description = "This type is to test a reasonably complex type";
+ let mnemonic = "cmpnd_a";
+ let parameters = (
+ ins
+ "int":$widthOfSomething,
+ "::mlir::test::SimpleTypeA": $exampleTdType,
+ "SomeCppStruct": $exampleCppType,
+ ArrayRefParameter<"int", "Matrix dimensions">:$dims
+ );
+
+ let genVerifyInvariantsDecl = 1;
+
+// DECL-LABEL: class CompoundAType: public ::mlir::Type
+// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
+// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
+// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
+// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
+// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
+// DECL: int getWidthOfSomething() const;
+// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
+// DECL: SomeCppStruct getExampleCppType() const;
+}
+
+def C_IndexType : TestType<"Index"> {
+ let mnemonic = "index";
+
+ let parameters = (
+ ins
+ StringRefParameter<"Label for index">:$label
+ );
+
+// DECL-LABEL: class IndexType: public ::mlir::Type
+// DECL: static ::llvm::StringRef getMnemonic() { return "index"; }
+// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
+// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
+}
+
+def D_SingleParameterType : TestType<"SingleParameter"> {
+ let parameters = (
+ ins
+ "int": $num
+ );
+// DECL-LABEL: struct SingleParameterTypeStorage;
+// DECL-LABEL: class SingleParameterType
+// DECL-NEXT: detail::SingleParameterTypeStorage
+}
+
+def E_IntegerType : TestType<"Integer"> {
+ let mnemonic = "int";
+ let genVerifyInvariantsDecl = 1;
+ let parameters = (
+ ins
+ "SignednessSemantics":$signedness,
+ TypeParameter<"unsigned", "Bitwdith of integer">:$width
+ );
+
+// DECL-LABEL: IntegerType: public ::mlir::Type
+
+ let extraClassDeclaration = [{
+ /// Signedness semantics.
+ enum SignednessSemantics {
+ Signless, /// No signedness semantics
+ Signed, /// Signed integer
+ Unsigned, /// Unsigned integer
+ };
+
+ /// This extra function is necessary since it doesn't include signedness
+ static IntegerType getChecked(unsigned width, Location location);
+
+ /// Return true if this is a signless integer type.
+ bool isSignless() const { return getSignedness() == Signless; }
+ /// Return true if this is a signed integer type.
+ bool isSigned() const { return getSignedness() == Signed; }
+ /// Return true if this is an unsigned integer type.
+ bool isUnsigned() const { return getSignedness() == Unsigned; }
+ }];
+
+// DECL: /// Signedness semantics.
+// DECL-NEXT: enum SignednessSemantics {
+// DECL-NEXT: Signless, /// No signedness semantics
+// DECL-NEXT: Signed, /// Signed integer
+// DECL-NEXT: Unsigned, /// Unsigned integer
+// DECL-NEXT: };
+// DECL: /// This extra function is necessary since it doesn't include signedness
+// DECL-NEXT: static IntegerType getChecked(unsigned width, Location location);
+
+// DECL: /// Return true if this is a signless integer type.
+// DECL-NEXT: bool isSignless() const { return getSignedness() == Signless; }
+// DECL-NEXT: /// Return true if this is a signed integer type.
+// DECL-NEXT: bool isSigned() const { return getSignedness() == Signed; }
+// DECL-NEXT: /// Return true if this is an unsigned integer type.
+// DECL-NEXT: bool isUnsigned() const { return getSignedness() == Unsigned; }
+}
RewriterGen.cpp
SPIRVUtilsGen.cpp
StructsGen.cpp
+ TypeDefGen.cpp
)
set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/TypeDef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
+#include <set>
+
using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
}
//===----------------------------------------------------------------------===//
+// TypeDef Documentation
+//===----------------------------------------------------------------------===//
+
+/// Emit the assembly format of a type.
+static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) {
+ SmallVector<TypeParameter, 4> parameters;
+ td.getParameters(parameters);
+ if (parameters.size() == 0) {
+ os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic()
+ << "`\n";
+ return;
+ }
+
+ os << "\nSyntax:\n\n```\n!" << td.getDialect().getName() << "."
+ << td.getMnemonic() << "<\n";
+ for (auto *it = parameters.begin(), *e = parameters.end(); it < e; ++it) {
+ os << " " << it->getSyntax();
+ if (it < parameters.end() - 1)
+ os << ",";
+ os << " # " << it->getName() << "\n";
+ }
+ os << ">\n```\n";
+}
+
+static void emitTypeDefDoc(TypeDef td, raw_ostream &os) {
+ os << llvm::formatv("### `{0}` ({1})\n", td.getName(), td.getCppClassName());
+
+ // Emit the summary, syntax, and description if present.
+ if (td.hasSummary())
+ os << "\n" << td.getSummary() << "\n";
+ if (td.getMnemonic() && td.getPrinterCode() && *td.getPrinterCode() == "" &&
+ td.getParserCode() && *td.getParserCode() == "")
+ emitTypeAssemblyFormat(td, os);
+ if (td.hasDescription())
+ mlir::tblgen::emitDescription(td.getDescription(), os);
+
+ // Emit attribute documentation.
+ SmallVector<TypeParameter, 4> parameters;
+ td.getParameters(parameters);
+ if (parameters.size() != 0) {
+ os << "\n#### Type parameters:\n\n";
+ os << "| Parameter | C++ type | Description |\n"
+ << "| :-------: | :-------: | ----------- |\n";
+ for (const auto &it : parameters) {
+ auto desc = it.getDescription();
+ os << "| " << it.getName() << " | `" << td.getCppClassName() << "` | "
+ << (desc ? *desc : "") << " |\n";
+ }
+ }
+
+ os << "\n";
+}
+
+//===----------------------------------------------------------------------===//
// Dialect Documentation
//===----------------------------------------------------------------------===//
static void emitDialectDoc(const Dialect &dialect, ArrayRef<Operator> ops,
- ArrayRef<Type> types, raw_ostream &os) {
+ ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
+ raw_ostream &os) {
os << "# '" << dialect.getName() << "' Dialect\n\n";
emitIfNotEmpty(dialect.getSummary(), os);
emitIfNotEmpty(dialect.getDescription(), os);
// TODO: Add link between use and def for types
if (!types.empty()) {
- os << "## Type definition\n\n";
+ os << "## Type constraint definition\n\n";
for (const Type &type : types)
emitTypeDoc(type, os);
}
for (const Operator &op : ops)
emitOpDoc(op, os);
}
+
+ if (!typeDefs.empty()) {
+ os << "## Type definition\n\n";
+ for (const TypeDef &td : typeDefs)
+ emitTypeDefDoc(td, os);
+ }
}
static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
const auto &opDefs = recordKeeper.getAllDerivedDefinitions("Op");
const auto &typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType");
+ const auto &typeDefDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
+ std::set<Dialect> dialectsWithDocs;
std::map<Dialect, std::vector<Operator>> dialectOps;
std::map<Dialect, std::vector<Type>> dialectTypes;
+ std::map<Dialect, std::vector<TypeDef>> dialectTypeDefs;
for (auto *opDef : opDefs) {
Operator op(opDef);
dialectOps[op.getDialect()].push_back(op);
+ dialectsWithDocs.insert(op.getDialect());
}
for (auto *typeDef : typeDefs) {
Type type(typeDef);
if (auto dialect = type.getDialect())
dialectTypes[dialect].push_back(type);
}
+ for (auto *typeDef : typeDefDefs) {
+ TypeDef type(typeDef);
+ dialectTypeDefs[type.getDialect()].push_back(type);
+ dialectsWithDocs.insert(type.getDialect());
+ }
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
- for (const auto &dialectWithOps : dialectOps)
- emitDialectDoc(dialectWithOps.first, dialectWithOps.second,
- dialectTypes[dialectWithOps.first], os);
+ for (auto dialect : dialectsWithDocs)
+ emitDialectDoc(dialect, dialectOps[dialect], dialectTypes[dialect],
+ dialectTypeDefs[dialect], os);
}
//===----------------------------------------------------------------------===//
--- /dev/null
+//===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TypeDefGen uses the description of typeDefs to generate C++ definitions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/TypeDef.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+#define DEBUG_TYPE "mlir-tblgen-typedefgen"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
+static llvm::cl::opt<std::string>
+ selectedDialect("typedefs-dialect",
+ llvm::cl::desc("Gen types for this dialect"),
+ llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
+
+/// Find all the TypeDefs for the specified dialect. If no dialect specified and
+/// can only find one dialect's types, use that.
+static void findAllTypeDefs(const llvm::RecordKeeper &recordKeeper,
+ SmallVectorImpl<TypeDef> &typeDefs) {
+ auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
+ auto defs = llvm::map_range(
+ recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); });
+ if (defs.empty())
+ return;
+
+ StringRef dialectName;
+ if (selectedDialect.getNumOccurrences() == 0) {
+ if (defs.empty())
+ return;
+
+ llvm::SmallSet<Dialect, 4> dialects;
+ for (const TypeDef &typeDef : defs)
+ dialects.insert(typeDef.getDialect());
+ if (dialects.size() != 1)
+ llvm::PrintFatalError("TypeDefs belonging to more than one dialect. Must "
+ "select one via '--typedefs-dialect'");
+
+ dialectName = (*dialects.begin()).getName();
+ } else if (selectedDialect.getNumOccurrences() == 1) {
+ dialectName = selectedDialect.getValue();
+ } else {
+ llvm::PrintFatalError("Cannot select multiple dialects for which to "
+ "generate types via '--typedefs-dialect'.");
+ }
+
+ for (const TypeDef &typeDef : defs)
+ if (typeDef.getDialect().getName().equals(dialectName))
+ typeDefs.push_back(typeDef);
+}
+
+namespace {
+
+/// Pass an instance of this class to llvm::formatv() to emit a comma separated
+/// list of parameters in the format by 'EmitFormat'.
+class TypeParamCommaFormatter : public llvm::detail::format_adapter {
+public:
+ /// Choose the output format
+ enum EmitFormat {
+ /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
+ /// [...]".
+ TypeNamePairs,
+
+ /// Emit ", parameter1Type parameter1Name, parameter2Type parameter2Name,
+ /// [...]".
+ TypeNamePairsPrependComma,
+
+ /// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
+ TypeNameInitializer
+ };
+
+ TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params)
+ : emitFormat(emitFormat), params(params) {}
+
+ /// llvm::formatv will call this function when using an instance as a
+ /// replacement value.
+ void format(raw_ostream &os, StringRef options) {
+ if (params.size() && emitFormat == EmitFormat::TypeNamePairsPrependComma)
+ os << ", ";
+ switch (emitFormat) {
+ case EmitFormat::TypeNamePairs:
+ case EmitFormat::TypeNamePairsPrependComma:
+ interleaveComma(params, os,
+ [&](const TypeParameter &p) { emitTypeNamePair(p, os); });
+ break;
+ case EmitFormat::TypeNameInitializer:
+ interleaveComma(params, os, [&](const TypeParameter &p) {
+ emitTypeNameInitializer(p, os);
+ });
+ break;
+ }
+ }
+
+private:
+ // Emit "paramType paramName".
+ static void emitTypeNamePair(const TypeParameter ¶m, raw_ostream &os) {
+ os << param.getCppType() << " " << param.getName();
+ }
+ // Emit "paramName(paramName)"
+ void emitTypeNameInitializer(const TypeParameter ¶m, raw_ostream &os) {
+ os << param.getName() << "(" << param.getName() << ")";
+ }
+
+ EmitFormat emitFormat;
+ ArrayRef<TypeParameter> params;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// GEN: TypeDef declarations
+//===----------------------------------------------------------------------===//
+
+/// The code block for the start of a typeDef class declaration -- singleton
+/// case.
+///
+/// {0}: The name of the typeDef class.
+static const char *const typeDefDeclSingletonBeginStr = R"(
+ class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{
+ public:
+ /// Inherit some necessary constructors from 'TypeBase'.
+ using Base::Base;
+
+)";
+
+/// The code block for the start of a typeDef class declaration -- parametric
+/// case.
+///
+/// {0}: The name of the typeDef class.
+/// {1}: The typeDef storage class namespace.
+/// {2}: The storage class name.
+/// {3}: The list of parameters with types.
+static const char *const typeDefDeclParametricBeginStr = R"(
+ namespace {1} {
+ struct {2};
+ }
+ class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type,
+ {1}::{2}> {{
+ public:
+ /// Inherit some necessary constructors from 'TypeBase'.
+ using Base::Base;
+
+)";
+
+/// The snippet for print/parse.
+static const char *const typeDefParsePrint = R"(
+ static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
+ void print(::mlir::DialectAsmPrinter& printer) const;
+)";
+
+/// The code block for the verifyConstructionInvariants and getChecked.
+///
+/// {0}: List of parameters, parameters style.
+/// {1}: C++ type class name.
+static const char *const typeDefDeclVerifyStr = R"(
+ static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0});
+ static {1} getChecked(Location loc{0});
+)";
+
+/// Generate the declaration for the given typeDef class.
+static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
+ SmallVector<TypeParameter, 4> params;
+ typeDef.getParameters(params);
+
+ // Emit the beginning string template: either the singleton or parametric
+ // template.
+ if (typeDef.getNumParameters() == 0)
+ os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
+ typeDef.getStorageNamespace(), typeDef.getStorageClassName());
+ else
+ os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
+ typeDef.getStorageNamespace(), typeDef.getStorageClassName());
+
+ // Emit the extra declarations first in case there's a type definition in
+ // there.
+ if (Optional<StringRef> extraDecl = typeDef.getExtraDecls())
+ os << *extraDecl << "\n";
+
+ TypeParamCommaFormatter emitTypeNamePairsAfterComma(
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, params);
+ os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
+ typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
+
+ // Emit the verify invariants declaration.
+ if (typeDef.genVerifyInvariantsDecl())
+ os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma,
+ typeDef.getCppClassName());
+
+ // Emit the mnenomic, if specified.
+ if (auto mnenomic = typeDef.getMnemonic()) {
+ os << " static ::llvm::StringRef getMnemonic() { return \"" << mnenomic
+ << "\"; }\n";
+
+ // If mnemonic specified, emit print/parse declarations.
+ os << typeDefParsePrint;
+ }
+
+ if (typeDef.genAccessors()) {
+ SmallVector<TypeParameter, 4> parameters;
+ typeDef.getParameters(parameters);
+
+ for (TypeParameter ¶meter : parameters) {
+ SmallString<16> name = parameter.getName();
+ name[0] = llvm::toUpper(name[0]);
+ os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name);
+ }
+ }
+
+ // End the typeDef decl.
+ os << " };\n";
+}
+
+/// Main entry point for decls.
+static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ emitSourceFileHeader("TypeDef Declarations", os);
+
+ SmallVector<TypeDef, 16> typeDefs;
+ findAllTypeDefs(recordKeeper, typeDefs);
+
+ IfDefScope scope("GET_TYPEDEF_CLASSES", os);
+ if (typeDefs.size() > 0) {
+ NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
+
+ // Well known print/parse dispatch function declarations. These are called
+ // from Dialect::parseType() and Dialect::printType() methods.
+ os << " ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
+ "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n";
+ os << " ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
+ "::mlir::DialectAsmPrinter& printer);\n";
+ os << "\n";
+
+ // Declare all the type classes first (in case they reference each other).
+ for (const TypeDef &typeDef : typeDefs)
+ os << " class " << typeDef.getCppClassName() << ";\n";
+
+ // Declare all the typedefs.
+ for (const TypeDef &typeDef : typeDefs)
+ emitTypeDefDecl(typeDef, os);
+ }
+
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: TypeDef list
+//===----------------------------------------------------------------------===//
+
+static void emitTypeDefList(SmallVectorImpl<TypeDef> &typeDefs,
+ raw_ostream &os) {
+ IfDefScope scope("GET_TYPEDEF_LIST", os);
+ for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) {
+ os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName();
+ if (i < typeDefs.end() - 1)
+ os << ",\n";
+ else
+ os << "\n";
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: TypeDef definitions
+//===----------------------------------------------------------------------===//
+
+/// Beginning of storage class.
+/// {0}: Storage class namespace.
+/// {1}: Storage class c++ name.
+/// {2}: Parameters parameters.
+/// {3}: Parameter initialzer string.
+/// {4}: Parameter name list.
+/// {5}: Parameter types.
+static const char *const typeDefStorageClassBegin = R"(
+namespace {0} {{
+ struct {1} : public ::mlir::TypeStorage {{
+ {1} ({2})
+ : {3} {{ }
+
+ /// The hash key for this storage is a pair of the integer and type params.
+ using KeyTy = std::tuple<{5}>;
+
+ /// Define the comparison function for the key type.
+ bool operator==(const KeyTy &key) const {{
+ return key == KeyTy({4});
+ }
+)";
+
+/// The storage class' constructor template.
+/// {0}: storage class name.
+static const char *const typeDefStorageClassConstructorBegin = R"(
+ /// Define a construction method for creating a new instance of this storage.
+ static {0} *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {{
+)";
+
+/// The storage class' constructor return template.
+/// {0}: storage class name.
+/// {1}: list of parameters.
+static const char *const typeDefStorageClassConstructorReturn = R"(
+ return new (allocator.allocate<{0}>())
+ {0}({1});
+ }
+)";
+
+/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
+static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
+ SmallVector<TypeParameter, 4> parameters;
+ typeDef.getParameters(parameters);
+ auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
+ for (TypeParameter ¶meter : parameters) {
+ auto allocCode = parameter.getAllocator();
+ if (allocCode) {
+ fmtCtxt.withSelf(parameter.getName());
+ fmtCtxt.addSubst("_dst", parameter.getName());
+ os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n";
+ }
+ }
+}
+
+/// Emit the storage class code for type 'typeDef'.
+/// This includes (in-order):
+/// 1) typeDefStorageClassBegin, which includes:
+/// - The class constructor.
+/// - The KeyTy definition.
+/// - The equality (==) operator.
+/// 2) The hashKey method.
+/// 3) The construct method.
+/// 4) The list of parameters as the storage class member variables.
+static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
+ SmallVector<TypeParameter, 4> parameters;
+ typeDef.getParameters(parameters);
+
+ // Initialize a bunch of variables to be used later on.
+ auto parameterNames = map_range(
+ parameters, [](TypeParameter parameter) { return parameter.getName(); });
+ auto parameterTypes = map_range(parameters, [](TypeParameter parameter) {
+ return parameter.getCppType();
+ });
+ auto parameterList = join(parameterNames, ", ");
+ auto parameterTypeList = join(parameterTypes, ", ");
+
+ // 1) Emit most of the storage class up until the hashKey body.
+ os << formatv(
+ typeDefStorageClassBegin, typeDef.getStorageNamespace(),
+ typeDef.getStorageClassName(),
+ TypeParamCommaFormatter(
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
+ TypeParamCommaFormatter(
+ TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, parameters),
+ parameterList, parameterTypeList);
+
+ // 2) Emit the haskKey method.
+ os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
+ // Extract each parameter from the key.
+ for (size_t i = 0, e = parameters.size(); i < e; ++i)
+ os << formatv(" const auto &{0} = std::get<{1}>(key);\n",
+ parameters[i].getName(), i);
+ // Then combine them all. This requires all the parameters types to have a
+ // hash_value defined.
+ os << " return ::llvm::hash_combine(";
+ interleaveComma(parameterNames, os);
+ os << ");\n";
+ os << " }\n";
+
+ // 3) Emit the construct method.
+ if (typeDef.hasStorageCustomConstructor())
+ // If user wants to build the storage constructor themselves, declare it
+ // here and then they can write the definition elsewhere.
+ os << " static " << typeDef.getStorageClassName()
+ << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
+ "&key);\n";
+ else {
+ // If not, autogenerate one.
+
+ // First, unbox the parameters.
+ os << formatv(typeDefStorageClassConstructorBegin,
+ typeDef.getStorageClassName());
+ for (size_t i = 0; i < parameters.size(); ++i) {
+ os << formatv(" auto {0} = std::get<{1}>(key);\n",
+ parameters[i].getName(), i);
+ }
+ // Second, reassign the parameter variables with allocation code, if it's
+ // specified.
+ emitParameterAllocationCode(typeDef, os);
+
+ // Last, return an allocated copy.
+ os << formatv(typeDefStorageClassConstructorReturn,
+ typeDef.getStorageClassName(), parameterList);
+ }
+
+ // 4) Emit the parameters as storage class members.
+ for (auto parameter : parameters) {
+ os << " " << parameter.getCppType() << " " << parameter.getName()
+ << ";\n";
+ }
+ os << " };\n";
+
+ os << "} // namespace " << typeDef.getStorageNamespace() << "\n";
+}
+
+/// Emit the parser and printer for a particular type, if they're specified.
+void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
+ // Emit the printer code, if specified.
+ if (auto printerCode = typeDef.getPrinterCode()) {
+ // Both the mnenomic and printerCode must be defined (for parity with
+ // parserCode).
+ os << "void " << typeDef.getCppClassName()
+ << "::print(::mlir::DialectAsmPrinter& printer) const {\n";
+ if (*printerCode == "") {
+ // If no code specified, emit error.
+ PrintFatalError(typeDef.getLoc(),
+ typeDef.getName() +
+ ": printer (if specified) must have non-empty code");
+ }
+ auto fmtCtxt = FmtContext().addSubst("_printer", "printer");
+ os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
+ }
+
+ // emit a parser, if specified.
+ if (auto parserCode = typeDef.getParserCode()) {
+ // The mnenomic must be defined so the dispatcher knows how to dispatch.
+ os << "::mlir::Type " << typeDef.getCppClassName()
+ << "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& "
+ "parser) "
+ "{\n";
+ if (*parserCode == "") {
+ // if no code specified, emit error.
+ PrintFatalError(typeDef.getLoc(),
+ typeDef.getName() +
+ ": parser (if specified) must have non-empty code");
+ }
+ auto fmtCtxt =
+ FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt");
+ os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
+ }
+}
+
+/// Print all the typedef-specific definition code.
+static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
+ NamespaceEmitter ns(os, typeDef.getDialect());
+ SmallVector<TypeParameter, 4> parameters;
+ typeDef.getParameters(parameters);
+
+ // Emit the storage class, if requested and necessary.
+ if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
+ emitStorageClass(typeDef, os);
+
+ os << llvm::formatv(
+ "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
+ " return Base::get(ctxt",
+ typeDef.getCppClassName(),
+ TypeParamCommaFormatter(
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma,
+ parameters));
+ for (TypeParameter ¶m : parameters)
+ os << ", " << param.getName();
+ os << ");\n}\n";
+
+ // Emit the parameter accessors.
+ if (typeDef.genAccessors())
+ for (const TypeParameter ¶meter : parameters) {
+ SmallString<16> name = parameter.getName();
+ name[0] = llvm::toUpper(name[0]);
+ os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
+ parameter.getCppType(), name, parameter.getName(),
+ typeDef.getCppClassName());
+ }
+
+ // If mnemonic is specified maybe print definitions for the parser and printer
+ // code, if they're specified.
+ if (typeDef.getMnemonic())
+ emitParserPrinter(typeDef, os);
+}
+
+/// Emit the dialect printer/parser dispatcher. User's code should call these
+/// functions from their dialect's print/parse methods.
+static void emitParsePrintDispatch(SmallVectorImpl<TypeDef> &typeDefs,
+ raw_ostream &os) {
+ if (typeDefs.size() == 0)
+ return;
+ const Dialect &dialect = typeDefs.begin()->getDialect();
+ NamespaceEmitter ns(os, dialect);
+
+ // The parser dispatch is just a list of if-elses, matching on the mnemonic
+ // and calling the class's parse function.
+ os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
+ "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
+ for (const TypeDef &typeDef : typeDefs)
+ if (typeDef.getMnemonic())
+ os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
+ "{0}::{1}::parse(ctxt, parser);\n",
+ typeDef.getDialect().getCppNamespace(),
+ typeDef.getCppClassName());
+ os << " return ::mlir::Type();\n";
+ os << "}\n\n";
+
+ // The printer dispatch uses llvm::TypeSwitch to find and call the correct
+ // printer.
+ os << "::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
+ "::mlir::DialectAsmPrinter& printer) {\n"
+ << " ::mlir::LogicalResult found = ::mlir::success();\n"
+ << " ::llvm::TypeSwitch<::mlir::Type>(type)\n";
+ for (auto typeDef : typeDefs)
+ if (typeDef.getMnemonic())
+ os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ "
+ "t.dyn_cast<{0}::{1}>().print(printer); })\n",
+ typeDef.getDialect().getCppNamespace(),
+ typeDef.getCppClassName());
+ os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); "
+ "});\n"
+ << " return found;\n"
+ << "}\n\n";
+}
+
+/// Entry point for typedef definitions.
+static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ emitSourceFileHeader("TypeDef Definitions", os);
+
+ SmallVector<TypeDef, 16> typeDefs;
+ findAllTypeDefs(recordKeeper, typeDefs);
+ emitTypeDefList(typeDefs, os);
+
+ IfDefScope scope("GET_TYPEDEF_CLASSES", os);
+ emitParsePrintDispatch(typeDefs, os);
+ for (auto typeDef : typeDefs)
+ emitTypeDefDef(typeDef, os);
+
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: TypeDef registration hooks
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+ genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ return emitTypeDefDefs(records, os);
+ });
+
+static mlir::GenRegistration
+ genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ return emitTypeDefDecls(records, os);
+ });