From 948be58258dd81d56b1057657193f7dcf6dfa9bd Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 11 Jan 2021 11:55:00 -0800 Subject: [PATCH] [mlir][TypeDefGen] Add support for adding builders when generating a TypeDef This allows for specifying additional get/getChecked methods that should be generated on the type, and acts similarly to how OpBuilders work. TypeBuilders have two additional components though: * InferredContextParam - Bit indicating that the context parameter of a get method is inferred from one of the builder parameters * checkedBody - A code block representing the body of the equivalent getChecked method. Differential Revision: https://reviews.llvm.org/D94274 --- mlir/docs/OpDefinitions.md | 165 +++++++++++++++++++ mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td | 2 +- mlir/include/mlir/IR/OpBase.td | 82 ++++++++++ mlir/include/mlir/TableGen/TypeDef.h | 43 ++++- mlir/lib/TableGen/TypeDef.cpp | 53 +++++++ mlir/test/lib/Dialect/Test/TestTypeDefs.td | 20 ++- mlir/test/lib/Dialect/Test/TestTypes.cpp | 2 +- mlir/test/mlir-tblgen/typedefs.td | 24 +-- mlir/tools/mlir-tblgen/TypeDefGen.cpp | 245 ++++++++++++++++++++--------- 9 files changed, 541 insertions(+), 95 deletions(-) diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md index bfd3d43c..dd52290 100644 --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -1536,6 +1536,171 @@ responsible for parsing/printing the types in `Dialect::printType` and - The `extraClassDeclaration` field is used to include extra code in the class declaration. +### Type builder methods + +For each type, there are a few builders(`get`/`getChecked`) automatically +generated based on the parameters of the type. For example, given the following +type definition: + +```tablegen +def MyType : ... { + let parameters = (ins "int":$intParam); +} +``` + +The following builders are generated: + +```c++ +// Type builders are named `get`, and return a new instance of a type for a +// given set of parameters. +static MyType get(MLIRContext *context, int intParam); + +// If `genVerifyInvariantsDecl` is set to 1, the following method is also +// generated. +static MyType getChecked(Location loc, int intParam); +``` + +If these autogenerated methods are not desired, such as when they conflict with +a custom builder method, a type can set `skipDefaultBuilders` to 1 to signal +that they should not be generated. + +#### Custom type builder methods + +The default build methods may cover a majority of the simple cases related to +type construction, but when they cannot satisfy a type's needs, you can define +additional convenience get methods in the `builders` field as follows: + +```tablegen +def MyType : ... { + let parameters = (ins "int":$intParam); + + let builders = [ + TypeBuilder<(ins "int":$intParam)>, + TypeBuilder<(ins CArg<"int", "0">:$intParam)>, + TypeBuilder<(ins CArg<"int", "0">:$intParam), [{ + // Write the body of the `get` builder inline here. + return Base::get($_ctxt, intParam); + }]>, + TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{ + // This builder states that it can infer an MLIRContext instance from + // its arguments. + return Base::get(typeParam.getContext(), ...); + }]>, + ]; +} +``` + +The `builders` field is a list of custom builders that are added to the type +class. In this example, we provide a several different convenience builders that +are useful in different scenarios. The `ins` prefix is common to many function +declarations in ODS, which use a TableGen [`dag`](#tablegen-syntax). What +follows is a comma-separated list of types (quoted string or CArg) and names +prefixed with the `$` sign. The use of `CArg` allows for providing a default +value to that argument. Let's take a look at each of these builders individually + +The first builder will generate the declaration of a builder method that looks +like: + +```tablegen + let builders = [ + TypeBuilder<(ins "int":$intParam)>, + ]; +``` + +```c++ +class MyType : /*...*/ { + /*...*/ + static MyType get(::mlir::MLIRContext *context, int intParam); +}; +``` + +This builder is identical to the one that will be automatically generated for +`MyType`. The `context` parameter is implicitly added by the generator, and is +used when building the file Type instance (with `Base::get`). The distinction +here is that we can provide the implementation of this `get` method. With this +style of builder definition only the declaration is generated, the implementor +of MyType will need to provide a definition of `MyType::get`. + +The second builder will generate the declaration of a builder method that looks +like: + +```tablegen + let builders = [ + TypeBuilder<(ins CArg<"int", "0">:$intParam)>, + ]; +``` + +```c++ +class MyType : /*...*/ { + /*...*/ + static MyType get(::mlir::MLIRContext *context, int intParam = 0); +}; +``` + +The constraints here are identical to the first builder example except for the +fact that `intParam` now has a default value attached. + +The third builder will generate the declaration of a builder method that looks +like: + +```tablegen + let builders = [ + TypeBuilder<(ins CArg<"int", "0">:$intParam), [{ + // Write the body of the `get` builder inline here. + return Base::get($_ctxt, intParam); + }]>, + ]; +``` + +```c++ +class MyType : /*...*/ { + /*...*/ + static MyType get(::mlir::MLIRContext *context, int intParam = 0); +}; + +MyType MyType::get(::mlir::MLIRContext *context, int intParam) { + // Write the body of the `get` builder inline here. + return Base::get(context, intParam); +} +``` + +This is identical to the second builder example. The difference is that now, a +definition for the builder method will be generated automatically using the +provided code block as the body. When specifying the body inline, `$_ctxt` may +be used to access the `MLIRContext *` parameter. + +The fourth builder will generate the declaration of a builder method that looks +like: + +```tablegen + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{ + // This builder states that it can infer an MLIRContext instance from + // its arguments. + return Base::get(typeParam.getContext(), ...); + }]>, + ]; +``` + +```c++ +class MyType : /*...*/ { + /*...*/ + static MyType get(Type typeParam); +}; + +MyType MyType::get(Type typeParam) { + // This builder states that it can infer an MLIRContext instance from its + // arguments. + return Base::get(typeParam.getContext(), ...); +} +``` + +In this builder example, the main difference from the third builder example +three is that the `MLIRContext` parameter is no longer added. This is because +the builder type used `TypeBuilderWithInferredContext` implies that the context +parameter is not necessary as it can be inferred from the arguments to the +builder. + ## Debugging Tips ### Run `mlir-tblgen` to see the generated content diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td index b1a42e9..3bfbccf 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -74,7 +74,7 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> { VectorType vector; if ($_parser.parseType(vector)) return Type(); - return get(ctxt, vector.getShape(), vector.getElementType()); + return get($_ctxt, vector.getShape(), vector.getElementType()); }]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index dc3e8a6..73ddbc1 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2430,6 +2430,73 @@ def replaceWithValue; // Data type generation //===----------------------------------------------------------------------===// +// Class for defining a custom type getter. +// +// TableGen generates several generic getter methods for each type by default, +// corresponding to the specified dag parameters. If the default generated ones +// cannot cover some use case, custom getters can be defined using instances of +// this class. +// +// The signature of the `get` is always either: +// +// ```c++ +// static get(MLIRContext *context, ...) { +// ... +// } +// ``` +// +// or: +// +// ```c++ +// static get(MLIRContext *context, ...); +// ``` +// +// To define a custom getter, the parameter list and body should be passed +// in as separate template arguments to this class. The parameter list is a +// TableGen DAG with `ins` operation with named arguments, which has either: +// - string initializers ("Type":$name) to represent a typed parameter, or +// - CArg-typed initializers (CArg<"Type", "default">:$name) to represent a +// typed parameter that may have a default value. +// The type string is used verbatim to produce code and, therefore, must be a +// valid C++ type. It is used inside the C++ namespace of the parent Type's +// dialect; explicit namespace qualification like `::mlir` may be necessary if +// Types are not placed inside the `mlir` namespace. The default value string is +// used verbatim to produce code and must be a valid C++ initializer the given +// type. For example, the following signature specification +// +// ``` +// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)> +// ``` +// +// has an integer parameter and a float parameter with a default value. +// +// If an empty string is passed in for `body`, then *only* the builder +// declaration will be generated; this provides a way to define complicated +// builders entirely in C++. +// +// `checkedBody` is similar to `body`, but is the code block used when +// generating a `getChecked` method. +class TypeBuilder { + dag dagParams = parameters; + code body = bodyCode; + code checkedBody = checkedBodyCode; + + // The context parameter can be inferred from one of the other parameters and + // is not implicitly added to the parameter list. + bit hasInferredContextParam = 0; +} + +// A class of TypeBuilder that is able to infer the MLIRContext parameter from +// one of the other builder parameters. Instances of this builder do not have +// `MLIRContext *` implicitly added to the parameter list. +class TypeBuilderWithInferredContext + : TypeBuilder { + code checkedBody = checkedBodyCode; + let hasInferredContextParam = 1; +} + // Define a new type, named `name`, belonging to `dialect` that inherits from // the given C++ base class. class TypeDef get(MLIRContext *, ); + // ``` + // + // Note that builders should only be provided when a type has parameters. + list builders = ?; + // Use the lowercased name as the keyword for parsing/printing. Specify only // if you want tblgen to generate declarations and/or definitions of // printer/parser for this type. @@ -2488,6 +2567,9 @@ class TypeDef getCheckedBody() const; + + /// Returns true if this builder is able to infer the MLIRContext parameter. + bool hasInferredContextParameter() const; +}; + +//===----------------------------------------------------------------------===// +// TypeDef +//===----------------------------------------------------------------------===// + /// Wrapper class that contains a TableGen TypeDef's record and provides helper /// methods for accessing them. class TypeDef { public: - explicit TypeDef(const llvm::Record *def) : def(def) {} + explicit TypeDef(const llvm::Record *def); // Get the dialect for which this type belongs. Dialect getDialect() const; @@ -95,6 +116,13 @@ public: // Get the code location (for error printing). ArrayRef getLoc() const; + // Returns true if the default get/getChecked methods should be skipped during + // generation. + bool skipDefaultBuilders() const; + + // Returns the builders of this type. + ArrayRef getBuilders() const { return builders; } + // Returns whether two TypeDefs are equal by checking the equality of the // underlying record. bool operator==(const TypeDef &other) const; @@ -107,8 +135,15 @@ public: private: const llvm::Record *def; + + // The builders of this type definition. + SmallVector builders; }; +//===----------------------------------------------------------------------===// +// TypeParameter +//===----------------------------------------------------------------------===// + // A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs // to parameterize them. class TypeParameter { diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp index 5c4f534..d6adbc2 100644 --- a/mlir/lib/TableGen/TypeDef.cpp +++ b/mlir/lib/TableGen/TypeDef.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/TypeDef.h" +#include "mlir/TableGen/Dialect.h" #include "llvm/ADT/StringExtras.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -18,6 +19,26 @@ using namespace mlir; using namespace mlir::tblgen; +//===----------------------------------------------------------------------===// +// TypeBuilder +//===----------------------------------------------------------------------===// + +/// Return an optional code body used for the `getChecked` variant of this +/// builder. +Optional TypeBuilder::getCheckedBody() const { + Optional body = def->getValueAsOptionalString("checkedBody"); + return body && !body->empty() ? body : llvm::None; +} + +/// Returns true if this builder is able to infer the MLIRContext parameter. +bool TypeBuilder::hasInferredContextParameter() const { + return def->getValueAsBit("hasInferredContextParam"); +} + +//===----------------------------------------------------------------------===// +// TypeDef +//===----------------------------------------------------------------------===// + Dialect TypeDef::getDialect() const { auto *dialectDef = dyn_cast(def->getValue("dialect")->getValue()); @@ -98,6 +119,11 @@ llvm::Optional TypeDef::getExtraDecls() const { return value.empty() ? llvm::Optional() : value; } llvm::ArrayRef TypeDef::getLoc() const { return def->getLoc(); } + +bool TypeDef::skipDefaultBuilders() const { + return def->getValueAsBit("skipDefaultBuilders"); +} + bool TypeDef::operator==(const TypeDef &other) const { return def == other.def; } @@ -106,6 +132,33 @@ bool TypeDef::operator<(const TypeDef &other) const { return getName() < other.getName(); } +//===----------------------------------------------------------------------===// +// TypeParameter +//===----------------------------------------------------------------------===// + +TypeDef::TypeDef(const llvm::Record *def) : def(def) { + // Populate the builders. + auto *builderList = + dyn_cast_or_null(def->getValueInit("builders")); + if (builderList && !builderList->empty()) { + for (llvm::Init *init : builderList->getValues()) { + TypeBuilder builder(cast(init)->getDef(), def->getLoc()); + + // Ensure that all parameters have names. + for (const TypeBuilder::Parameter ¶m : builder.getParameters()) { + if (!param.getName()) + PrintFatalError(def->getLoc(), + "type builder parameters must have a name"); + } + builders.emplace_back(builder); + } + } else if (skipDefaultBuilders()) { + PrintFatalError( + def->getLoc(), + "default builders are skipped and no custom builders provided"); + } +} + StringRef TypeParameter::getName() const { return def->getArgName(num)->getValue(); } diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 80927df..0e2c11a 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -51,9 +51,9 @@ def IntegerType : Test_Type<"TestInteger"> { let genVerifyInvariantsDecl = 1; let parameters = ( ins + "unsigned":$width, // SignednessSemantics is defined below. - "::mlir::test::TestIntegerType::SignednessSemantics":$signedness, - "unsigned":$width + "::mlir::test::TestIntegerType::SignednessSemantics":$signedness ); // We define the printer inline. @@ -63,6 +63,17 @@ def IntegerType : Test_Type<"TestInteger"> { $_printer << ", " << getImpl()->width << ">"; }]; + // Define custom builder methods. + let builders = [ + TypeBuilder<(ins "unsigned":$width, + CArg<"SignednessSemantics", "Signless">:$signedness), [{ + return Base::get($_ctxt, width, signedness); + }], [{ + return Base::getChecked($_loc, width, signedness); + }]> + ]; + let skipDefaultBuilders = 1; + // The parser is defined here also. let parser = [{ if (parser.parseLess()) return Type(); @@ -73,7 +84,7 @@ def IntegerType : Test_Type<"TestInteger"> { if ($_parser.parseInteger(width)) return Type(); if ($_parser.parseGreater()) return Type(); Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc()); - return getChecked(loc, signedness, width); + return getChecked(loc, width, signedness); }]; // Any extra code one wants in the type's class declaration. @@ -85,9 +96,6 @@ def IntegerType : Test_Type<"TestInteger"> { Unsigned, /// Unsigned integer }; - /// This extra function is necessary since it doesn't include signedness - static IntegerType getChecked(unsigned width, Location location); - /// Return true if this is a signless integer type. bool isSignless() const { return getSignedness() == Signless; } /// Return true if this is a signed integer type. diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 14a3e86..094e5c9 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -113,7 +113,7 @@ static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT // Example type validity checker. LogicalResult TestIntegerType::verifyConstructionInvariants( - Location loc, TestIntegerType::SignednessSemantics ss, unsigned int width) { + Location loc, unsigned width, TestIntegerType::SignednessSemantics ss) { if (width > 8) return failure(); return success(); diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td index 6e6e1c0..5471519 100644 --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -19,8 +19,8 @@ include "mlir/IR/OpBase.td" // DEF: ::mlir::test::SingleParameterType, // DEF: ::mlir::test::IntegerType -// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) -// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser); +// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &parser, ::llvm::StringRef mnemonic) +// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(context, parser); // DEF return ::mlir::Type(); def Test_Dialect: Dialect { @@ -33,7 +33,7 @@ def Test_Dialect: Dialect { class TestType : TypeDef { } def A_SimpleTypeA : TestType<"SimpleA"> { -// DECL: class SimpleAType: public ::mlir::Type +// DECL: class SimpleAType : public ::mlir::Type } def RTLValueType : Type, "Type"> { @@ -56,12 +56,13 @@ def B_CompoundTypeA : TestType<"CompoundA"> { let genVerifyInvariantsDecl = 1; -// DECL-LABEL: class CompoundAType: public ::mlir::Type +// DECL-LABEL: class CompoundAType : public ::mlir::Type +// DECL: static CompoundAType getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); // DECL: static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); -// DECL: static ::mlir::Type getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); // DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; } -// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); -// DECL: void print(::mlir::DialectAsmPrinter& printer) const; +// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context, +// DECL-NEXT: ::mlir::DialectAsmParser &parser); +// DECL: void print(::mlir::DialectAsmPrinter &printer) const; // DECL: int getWidthOfSomething() const; // DECL: ::mlir::test::SimpleTypeA getExampleTdType() const; // DECL: SomeCppStruct getExampleCppType() const; @@ -75,10 +76,11 @@ def C_IndexType : TestType<"Index"> { StringRefParameter<"Label for index">:$label ); -// DECL-LABEL: class IndexType: public ::mlir::Type +// DECL-LABEL: class IndexType : public ::mlir::Type // DECL: static ::llvm::StringRef getMnemonic() { return "index"; } -// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); -// DECL: void print(::mlir::DialectAsmPrinter& printer) const; +// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context, +// DECL-NEXT: ::mlir::DialectAsmParser &parser); +// DECL: void print(::mlir::DialectAsmPrinter &printer) const; } def D_SingleParameterType : TestType<"SingleParameter"> { @@ -100,7 +102,7 @@ def E_IntegerType : TestType<"Integer"> { TypeParameter<"unsigned", "Bitwidth of integer">:$width ); -// DECL-LABEL: IntegerType: public ::mlir::Type +// DECL-LABEL: IntegerType : public ::mlir::Type let extraClassDeclaration = [{ /// Signedness semantics. diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp index 2016816..9cbd532 100644 --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -148,7 +148,7 @@ class DialectAsmPrinter; /// {0}: The name of the typeDef class. /// {1}: The name of the type base class. static const char *const typeDefDeclSingletonBeginStr = R"( - class {0}: public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{ + class {0} : public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{ public: /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; @@ -166,9 +166,9 @@ static const char *const typeDefDeclSingletonBeginStr = R"( static const char *const typeDefDeclParametricBeginStr = R"( namespace {2} { struct {3}; - } - class {0}: public ::mlir::Type::TypeBase<{0}, {1}, - {2}::{3}> {{ + } // end namespace {2} + class {0} : public ::mlir::Type::TypeBase<{0}, {1}, + {2}::{3}> {{ public: /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; @@ -177,18 +177,68 @@ static const char *const typeDefDeclParametricBeginStr = R"( /// The snippet for print/parse. static const char *const typeDefParsePrint = R"( - static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); - void print(::mlir::DialectAsmPrinter& printer) const; + static ::mlir::Type parse(::mlir::MLIRContext *context, + ::mlir::DialectAsmParser &parser); + void print(::mlir::DialectAsmPrinter &printer) const; )"; /// The code block for the verifyConstructionInvariants and getChecked. /// -/// {0}: List of parameters, parameters style. +/// {0}: The name of the typeDef class. +/// {1}: List of parameters, parameters style. static const char *const typeDefDeclVerifyStr = R"( - static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{0}); - static ::mlir::Type getChecked(::mlir::Location loc{0}); + static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{1}); )"; +/// Emit the builders for the given type. +static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os, + TypeParamCommaFormatter ¶mTypes) { + StringRef typeClass = typeDef.getCppClassName(); + bool genCheckedMethods = typeDef.genVerifyInvariantsDecl(); + if (!typeDef.skipDefaultBuilders()) { + os << llvm::formatv( + " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass, + paramTypes); + if (genCheckedMethods) { + os << llvm::formatv( + " static {0} getChecked(::mlir::Location loc{1});\n", typeClass, + paramTypes); + } + } + + // Generate the builders specified by the user. + for (const TypeBuilder &builder : typeDef.getBuilders()) { + std::string paramStr; + llvm::raw_string_ostream paramOS(paramStr); + llvm::interleaveComma( + builder.getParameters(), paramOS, + [&](const TypeBuilder::Parameter ¶m) { + // Note: TypeBuilder parameters are guaranteed to have names. + paramOS << param.getCppType() << " " << *param.getName(); + if (Optional defaultParamValue = param.getDefaultValue()) + paramOS << " = " << *defaultParamValue; + }); + paramOS.flush(); + + // Generate the `get` variant of the builder. + os << " static " << typeClass << " get("; + if (!builder.hasInferredContextParameter()) { + os << "::mlir::MLIRContext *context"; + if (!paramStr.empty()) + os << ", "; + } + os << paramStr << ");\n"; + + // Generate the `getChecked` variant of the builder. + if (genCheckedMethods) { + os << " static " << typeClass << " getChecked(::mlir::Location loc"; + if (!paramStr.empty()) + os << ", " << paramStr; + os << ");\n"; + } + } +} + /// Generate the declaration for the given typeDef class. static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) { SmallVector params; @@ -212,13 +262,13 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) { TypeParamCommaFormatter emitTypeNamePairsAfterComma( TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params); if (!params.empty()) { - os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n", - typeDef.getCppClassName(), emitTypeNamePairsAfterComma); - } + emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma); - // Emit the verify invariants declaration. - if (typeDef.genVerifyInvariantsDecl()) - os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma); + // Emit the verify invariants declaration. + if (typeDef.genVerifyInvariantsDecl()) + os << llvm::formatv(typeDefDeclVerifyStr, typeDef.getCppClassName(), + emitTypeNamePairsAfterComma); + } // Emit the mnenomic, if specified. if (auto mnenomic = typeDef.getMnemonic()) { @@ -226,7 +276,8 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) { << "\"; }\n"; // If mnemonic specified, emit print/parse declarations. - os << typeDefParsePrint; + if (typeDef.getParserCode() || typeDef.getPrinterCode() || !params.empty()) + os << typeDefParsePrint; } if (typeDef.genAccessors()) { @@ -330,17 +381,6 @@ static const char *const typeDefStorageClassConstructorReturn = R"( } )"; -/// The code block for the getChecked definition. -/// -/// {0}: List of parameters, parameters style. -/// {1}: C++ type class name. -/// {2}: Comma separated list of parameter names. -static const char *const typeDefDefGetCheckeStr = R"( - ::mlir::Type {1}::getChecked(Location loc{0}) {{ - return Base::getChecked(loc{2}); - } -)"; - /// Use tgfmt to emit custom allocation code for each parameter, if necessary. static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) { SmallVector parameters; @@ -403,13 +443,13 @@ static void emitStorageClass(TypeDef typeDef, raw_ostream &os) { parameters, /* prependComma */ false)); // 3) Emit the construct method. - if (typeDef.hasStorageCustomConstructor()) + if (typeDef.hasStorageCustomConstructor()) { // If user wants to build the storage constructor themselves, declare it // here and then they can write the definition elsewhere. os << " static " << typeDef.getStorageClassName() << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy " "&key);\n"; - else { + } else { // If not, autogenerate one. // First, unbox the parameters. @@ -445,7 +485,7 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) { // Both the mnenomic and printerCode must be defined (for parity with // parserCode). os << "void " << typeDef.getCppClassName() - << "::print(::mlir::DialectAsmPrinter& printer) const {\n"; + << "::print(::mlir::DialectAsmPrinter &printer) const {\n"; if (*printerCode == "") { // If no code specified, emit error. PrintFatalError(typeDef.getLoc(), @@ -460,7 +500,7 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) { if (auto parserCode = typeDef.getParserCode()) { // The mnenomic must be defined so the dispatcher knows how to dispatch. os << "::mlir::Type " << typeDef.getCppClassName() - << "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& " + << "::parse(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &" "parser) " "{\n"; if (*parserCode == "") { @@ -470,51 +510,112 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) { ": parser (if specified) must have non-empty code"); } auto fmtCtxt = - FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt"); + FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "context"); os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n"; } } -/// Print all the typedef-specific definition code. -static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) { - NamespaceEmitter ns(os, typeDef.getDialect()); - SmallVector parameters; - typeDef.getParameters(parameters); - - // Emit the storage class, if requested and necessary. - if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0) - emitStorageClass(typeDef, os); - - if (!parameters.empty()) { +/// Emit the builders for the given type. +static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os, + ArrayRef typeDefParams) { + bool genCheckedMethods = typeDef.genVerifyInvariantsDecl(); + StringRef typeClass = typeDef.getCppClassName(); + if (!typeDef.skipDefaultBuilders()) { os << llvm::formatv( - "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n" - " return Base::get(ctxt{2});\n}\n", - typeDef.getCppClassName(), + "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n" + " return Base::get(context{2});\n}\n", + typeClass, TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters), + TypeParamCommaFormatter::EmitFormat::TypeNamePairs, typeDefParams), TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, - parameters)); + typeDefParams)); + if (genCheckedMethods) { + os << llvm::formatv( + "{0} {0}::getChecked(::mlir::Location loc{1}) {{\n" + " return Base::getChecked(loc{2});\n}\n", + typeClass, + TypeParamCommaFormatter( + TypeParamCommaFormatter::EmitFormat::TypeNamePairs, + typeDefParams), + TypeParamCommaFormatter( + TypeParamCommaFormatter::EmitFormat::JustParams, typeDefParams)); + } } - // Emit the parameter accessors. - if (typeDef.genAccessors()) - for (const TypeParameter ¶meter : parameters) { - SmallString<16> name = parameter.getName(); - name[0] = llvm::toUpper(name[0]); - os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n", - parameter.getCppType(), name, parameter.getName(), - typeDef.getCppClassName()); + // Generate the builders specified by the user. + auto builderFmtCtx = FmtContext().addSubst("_ctxt", "context"); + auto checkedBuilderFmtCtx = FmtContext() + .addSubst("_loc", "loc") + .addSubst("_ctxt", "loc.getContext()"); + for (const TypeBuilder &builder : typeDef.getBuilders()) { + Optional body = builder.getBody(); + Optional checkedBody = + genCheckedMethods ? builder.getCheckedBody() : llvm::None; + if (!body && !checkedBody) + continue; + std::string paramStr; + llvm::raw_string_ostream paramOS(paramStr); + llvm::interleaveComma(builder.getParameters(), paramOS, + [&](const TypeBuilder::Parameter ¶m) { + // Note: TypeBuilder parameters are guaranteed to + // have names. + paramOS << param.getCppType() << " " + << *param.getName(); + }); + paramOS.flush(); + + // Emit the `get` variant of the builder. + if (body) { + os << llvm::formatv("{0} {0}::get(", typeClass); + if (!builder.hasInferredContextParameter()) { + os << "::mlir::MLIRContext *context"; + if (!paramStr.empty()) + os << ", "; + os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, + tgfmt(*body, &builderFmtCtx).str()); + } else { + os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, *body); + } } - // Generate getChecked() method. - if (typeDef.genVerifyInvariantsDecl()) { - os << llvm::formatv( - typeDefDefGetCheckeStr, - TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters), - typeDef.getCppClassName(), - TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, - parameters)); + // Emit the `getChecked` variant of the builder. + if (checkedBody) { + os << llvm::formatv("{0} {0}::getChecked(::mlir::Location loc", + typeClass); + if (!paramStr.empty()) + os << ", " << paramStr; + os << llvm::formatv(") {{\n {0};\n}\n", + tgfmt(*checkedBody, &checkedBuilderFmtCtx)); + } + } +} + +/// Print all the typedef-specific definition code. +static void emitTypeDefDef(const TypeDef &typeDef, raw_ostream &os) { + NamespaceEmitter ns(os, typeDef.getDialect()); + + SmallVector parameters; + typeDef.getParameters(parameters); + if (!parameters.empty()) { + // Emit the storage class, if requested and necessary. + if (typeDef.genStorageClass()) + emitStorageClass(typeDef, os); + + // Emit the builders for this type. + emitTypeBuilderDefs(typeDef, os, parameters); + + // Generate accessor definitions only if we also generate the storage class. + // Otherwise, let the user define the exact accessor definition. + if (typeDef.genAccessors() && typeDef.genStorageClass()) { + // Emit the parameter accessors. + for (const TypeParameter ¶meter : parameters) { + SmallString<16> name = parameter.getName(); + name[0] = llvm::toUpper(name[0]); + os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n", + parameter.getCppType(), name, parameter.getName(), + typeDef.getCppClassName()); + } + } } // If mnemonic is specified maybe print definitions for the parser and printer @@ -534,9 +635,9 @@ static void emitParsePrintDispatch(ArrayRef types, raw_ostream &os) { // The parser dispatch is just a list of if-elses, matching on the // mnemonic and calling the class's parse function. - os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* " - "ctxt, " - "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n"; + os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext *" + "context, ::mlir::DialectAsmParser &parser, " + "::llvm::StringRef mnemonic) {\n"; for (const TypeDef &type : types) { if (type.getMnemonic()) { os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return " @@ -547,9 +648,9 @@ static void emitParsePrintDispatch(ArrayRef types, raw_ostream &os) { // If the type has no parameters and no parser code, just invoke a normal // `get`. if (type.getNumParameters() == 0 && !type.getParserCode()) - os << "get(ctxt);\n"; + os << "get(context);\n"; else - os << "parse(ctxt, parser);\n"; + os << "parse(context, parser);\n"; } } os << " return ::mlir::Type();\n"; @@ -559,7 +660,7 @@ static void emitParsePrintDispatch(ArrayRef types, raw_ostream &os) { // printer. os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type " "type, " - "::mlir::DialectAsmPrinter& printer) {\n" + "::mlir::DialectAsmPrinter &printer) {\n" << " return ::llvm::TypeSwitch<::mlir::Type, " "::mlir::LogicalResult>(type)\n"; for (const TypeDef &type : types) { @@ -594,7 +695,7 @@ static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper, IfDefScope scope("GET_TYPEDEF_CLASSES", os); emitParsePrintDispatch(typeDefs, os); - for (auto typeDef : typeDefs) + for (const TypeDef &typeDef : typeDefs) emitTypeDefDef(typeDef, os); return false; -- 2.7.4