From 41d919aa29468ac072755b8449b8a38ff26f6979 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 6 Jan 2021 14:54:39 -0800 Subject: [PATCH] [mlir][TypeDefGen] Remove the need to define parser/printer for singleton types This allows for singleton types without an explicit parser/printer to simply use the mnemonic as the assembly format, removing the need for these types to provide the parser/printer fields. Differential Revision: https://reviews.llvm.org/D94194 --- mlir/test/lib/Dialect/Test/TestTypeDefs.td | 3 -- mlir/tools/mlir-tblgen/TypeDefGen.cpp | 46 +++++++++++++++++++++--------- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 75fffa1..80927df 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -21,9 +21,6 @@ class Test_Type : TypeDef { } def SimpleTypeA : Test_Type<"SimpleA"> { let mnemonic = "smpla"; - - let printer = [{ $_printer << "smpla"; }]; - let parser = [{ return get($_ctxt); }]; } // A more complex parameterized type. diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp index 8fdb5f4..2016816 100644 --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -537,12 +537,21 @@ static void emitParsePrintDispatch(ArrayRef types, raw_ostream &os) { os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* " "ctxt, " "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n"; - for (const TypeDef &type : types) - if (type.getMnemonic()) + for (const TypeDef &type : types) { + if (type.getMnemonic()) { os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return " - "{0}::{1}::parse(ctxt, parser);\n", + "{0}::{1}::", type.getDialect().getCppNamespace(), type.getCppClassName()); + + // If the type has no parameters and no parser code, just invoke a normal + // `get`. + if (type.getNumParameters() == 0 && !type.getParserCode()) + os << "get(ctxt);\n"; + else + os << "parse(ctxt, parser);\n"; + } + } os << " return ::mlir::Type();\n"; os << "}\n\n"; @@ -551,17 +560,26 @@ static void emitParsePrintDispatch(ArrayRef types, raw_ostream &os) { os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type " "type, " "::mlir::DialectAsmPrinter& printer) {\n" - << " ::mlir::LogicalResult found = ::mlir::success();\n" - << " ::llvm::TypeSwitch<::mlir::Type>(type)\n"; - for (const TypeDef &type : types) - if (type.getMnemonic()) - os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ " - "t.dyn_cast<{0}::{1}>().print(printer); })\n", - type.getDialect().getCppNamespace(), - type.getCppClassName()); - os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); " - "});\n" - << " return found;\n" + << " return ::llvm::TypeSwitch<::mlir::Type, " + "::mlir::LogicalResult>(type)\n"; + for (const TypeDef &type : types) { + if (Optional mnemonic = type.getMnemonic()) { + StringRef cppNamespace = type.getDialect().getCppNamespace(); + StringRef cppClassName = type.getCppClassName(); + os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ", + cppNamespace, cppClassName); + + // If the type has no parameters and no printer code, just print the + // mnemonic. + if (type.getNumParameters() == 0 && !type.getPrinterCode()) + os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace, + cppClassName); + else + os << "t.print(printer);"; + os << "\n return ::mlir::success();\n })\n"; + } + } + os << " .Default([](::mlir::Type) { return ::mlir::failure(); });\n" << "}\n\n"; } -- 2.7.4