From 1be9fc66115de245f469e3b09114a06603258ce0 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 8 Jun 2019 08:39:07 -0700 Subject: [PATCH] [TableGen] Generating enum definitions and utility functions Enum attributes can be defined using `EnumAttr`, which requires all its cases to be defined with `EnumAttrCase`. To facilitate the interaction between `EnumAttr`s and their C++ consumers, add a new EnumsGen TableGen backend to generate a few common utilities, including an enum class, `llvm::DenseMapInfo` for the enum class, conversion functions from/to strings. This is controlled via the `-gen-enum-decls` and `-gen-enum-defs` command-line options of `mlir-tblgen`. PiperOrigin-RevId: 252209623 --- mlir/g3doc/OpDefinitions.md | 90 +++++++++++++- mlir/include/mlir/IR/OpBase.td | 39 +++++- mlir/include/mlir/TableGen/Attribute.h | 18 +++ mlir/lib/TableGen/Attribute.cpp | 22 ++++ mlir/tools/mlir-tblgen/CMakeLists.txt | 1 + mlir/tools/mlir-tblgen/EnumsGen.cpp | 198 +++++++++++++++++++++++++++++++ mlir/unittests/TableGen/CMakeLists.txt | 9 ++ mlir/unittests/TableGen/EnumsGenTest.cpp | 66 +++++++++++ mlir/unittests/TableGen/enums.td | 31 +++++ 9 files changed, 472 insertions(+), 2 deletions(-) create mode 100644 mlir/tools/mlir-tblgen/EnumsGen.cpp create mode 100644 mlir/unittests/TableGen/EnumsGenTest.cpp create mode 100644 mlir/unittests/TableGen/enums.td diff --git a/mlir/g3doc/OpDefinitions.md b/mlir/g3doc/OpDefinitions.md index 810dd34..8146fce 100644 --- a/mlir/g3doc/OpDefinitions.md +++ b/mlir/g3doc/OpDefinitions.md @@ -618,7 +618,94 @@ duplication, which is being worked on right now. ## Attribute Definition -TODO: This section is outdated. Update it. +### Enum attributes + +Enum attributes can be defined using `EnumAttr`, which requires all its cases to +be defined with `EnumAttrCase`. To facilitate the interaction between +`EnumAttr`s and their C++ consumers, the [`EnumsGen`][EnumsGen] TableGen backend +can generate a few common utilities, including an enum class, +`llvm::DenseMapInfo` for the enum class, conversion functions from/to strings. +This is controlled via the `-gen-enum-decls` and `-gen-enum-defs` command-line +options of `mlir-tblgen`. + +For example, given the following `EnumAttr`: + +```tablegen +def CaseA: EnumAttrCase<"caseA", 0>; +def CaseB: EnumAttrCase<"caseB", 10>; + +def MyEnum: EnumAttr<"MyEnum", "An example enum", [CaseA, CaseB]> { + let cppNamespace = "Outer::Inner"; + let underlyingType = "uint64_t"; + let stringToSymbolFnName = "ConvertToEnum"; + let symbolToStringFnName = "ConvertToString"; +} +``` + +The following will be generated via `mlir-tblgen -gen-enum-decls`: + +```c++ +namespace Outer { +namespace Inner { +// An example enum +enum class MyEnum : uint64_t { + caseA = 0, + caseB = 10, +}; + +llvm::StringRef ConvertToString(MyEnum); +llvm::Optional ConvertToEnum(llvm::StringRef); +} // namespace Inner +} // namespace Outer + +namespace llvm { +template<> struct DenseMapInfo { + using StorageInfo = llvm::DenseMapInfo; + + static inline Outer::Inner::MyEnum getEmptyKey() { + return static_cast(StorageInfo::getEmptyKey()); + } + + static inline Outer::Inner::MyEnum getTombstoneKey() { + return static_cast(StorageInfo::getTombstoneKey()); + } + + static unsigned getHashValue(const Outer::Inner::MyEnum &val) { + return StorageInfo::getHashValue(static_cast(val)); + } + + static bool isEqual(const Outer::Inner::MyEnum &lhs, + const Outer::Inner::MyEnum &rhs) { + return lhs == rhs; + } +}; +} +``` + +The following will be generated via `mlir-tblgen -gen-enum-defs`: + +```c++ +namespace Outer { +namespace Inner { +llvm::StringRef ConvertToString(MyEnum val) { + switch (val) { + case MyEnum::caseA: return "caseA"; + case MyEnum::caseB: return "caseB"; + default: return ""; + } +} + +llvm::Optional ConvertToEnum(llvm::StringRef str) { + return llvm::StringSwitch>(str) + .Case("caseA", MyEnum::caseA) + .Case("caseB", MyEnum::caseB) + .Default(llvm::None); +} +} // namespace Inner +} // namespace Outer +``` + +TODO(b/132506080): This following is outdated. Update it. An attribute is a compile time known constant of an operation. Attributes are required to be known to construct an operation (e.g., the padding behavior is @@ -829,3 +916,4 @@ TODO: Describe the generation of benefit metric given pattern. [TableGenBackend]: https://llvm.org/docs/TableGen/BackEnds.html#introduction [OpBase]: https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/OpBase.td [OpDefinitionsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/OpDefinitionsGen.cpp +[EnumsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/EnumsGen.cpp diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index a170069..ac8c652 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -634,11 +634,14 @@ class TypeAttrBase : def TypeAttr : TypeAttrBase<"Type", "any type attribute">; // An enum attribute case. -class EnumAttrCase : StringBasedAttr< +class EnumAttrCase : StringBasedAttr< CPred<"$_self.cast().getValue() == \"" # sym # "\"">, "case " # sym> { // The C++ enumerant symbol string symbol = sym; + // The C++ enumerant value + // A non-negative value must be provided if to use EnumsGen backend. + int value = val; } // An enum attribute. Its value can only be one from the given list of `cases`. @@ -651,8 +654,42 @@ class EnumAttr cases> : description> { // The C++ enum class name string className = name; + // List of all accepted cases list enumerants = cases; + + // The following fields are only used by the EnumsGen backend to generate + // an enum class definition and conversion utility functions. + + // The underlying type for the C++ enum class. An empty string mean the + // underlying type is not explicitly specified. + string underlyingType = ""; + + // The C++ namespaces that the enum class definition and utility functions + // should be placed into. + // + // Normally you want to place the full namespace path here. If it is nested, + // use "::" as the delimiter, e.g., given "A::B", generated code will be + // placed in `namespace A { namespace B { ... } }`. To avoid placing in any + // namespace, use "". + // TODO(b/134741431): use dialect to provide the namespace. + string cppNamespace = ""; + + // The name of the utility function that converts a string to the + // corresponding symbol. It will have the following signature: + // + // ```c++ + // llvm::Optional<> (llvm::StringRef); + // ``` + string stringToSymbolFnName = "symbolize" # name; + + // The name of the utility function that converts a symbol to the + // corresponding string. It will have the following signature: + // + // ```c++ + // llvm::StringRef (); + // ``` + string symbolToStringFnName = "stringify" # name; } class ElementsAttrBase : diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index f921605..f69961a 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -129,6 +129,9 @@ public: // Returns the symbol of this enum attribute case. StringRef getSymbol() const; + + // Returns the value of this enum attribute case. + int64_t getValue() const; }; // Wrapper class providing helper methods for accessing enum attributes defined @@ -137,11 +140,26 @@ public: class EnumAttr : public Attribute { public: explicit EnumAttr(const llvm::Record *record); + explicit EnumAttr(const llvm::Record &record); explicit EnumAttr(const llvm::DefInit *init); // Returns the enum class name. StringRef getEnumClassName() const; + // Returns the C++ namespaces this enum class should be placed in. + StringRef getCppNamespace() const; + + // Returns the underlying type. + StringRef getUnderlyingType() const; + + // Returns the name of the utility function that converts a string to the + // corresponding symbol. + StringRef getStringToSymbolFnName() const; + + // Returns the name of the utility function that converts a symbol to the + // corresponding string. + StringRef getSymbolToStringFnName() const; + // Returns all allowed cases for this enum attribute. std::vector getAllCases() const; }; diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 9de95ef..29259be 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -144,11 +144,17 @@ StringRef tblgen::EnumAttrCase::getSymbol() const { return def->getValueAsString("symbol"); } +int64_t tblgen::EnumAttrCase::getValue() const { + return def->getValueAsInt("value"); +} + tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { assert(def->isSubClassOf("EnumAttr") && "must be subclass of TableGen 'EnumAttr' class"); } +tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} + tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} @@ -156,6 +162,22 @@ StringRef tblgen::EnumAttr::getEnumClassName() const { return def->getValueAsString("className"); } +StringRef tblgen::EnumAttr::getCppNamespace() const { + return def->getValueAsString("cppNamespace"); +} + +StringRef tblgen::EnumAttr::getUnderlyingType() const { + return def->getValueAsString("underlyingType"); +} + +StringRef tblgen::EnumAttr::getStringToSymbolFnName() const { + return def->getValueAsString("stringToSymbolFnName"); +} + +StringRef tblgen::EnumAttr::getSymbolToStringFnName() const { + return def->getValueAsString("symbolToStringFnName"); +} + std::vector tblgen::EnumAttr::getAllCases() const { const auto *inits = def->getValueAsListInit("enumerants"); diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt index d341cab..65c5e91 100644 --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -4,6 +4,7 @@ set(LLVM_LINK_COMPONENTS ) add_tablegen(mlir-tblgen MLIR + EnumsGen.cpp LLVMIRConversionGen.cpp mlir-tblgen.cpp OpDefinitionsGen.cpp diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp new file mode 100644 index 0000000..ab86c9d --- /dev/null +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -0,0 +1,198 @@ +//===- EnumsGen.cpp - MLIR enum utility generator -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// EnumsGen generates common utility functions for enums. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using llvm::formatv; +using llvm::raw_ostream; +using llvm::Record; +using llvm::RecordKeeper; +using llvm::StringRef; +using mlir::tblgen::EnumAttr; +using mlir::tblgen::EnumAttrCase; + +static void emitEnumClass(const Record &enumDef, StringRef enumName, + StringRef underlyingType, StringRef description, + const std::vector &enumerants, + raw_ostream &os) { + os << "// " << description << "\n"; + os << "enum class " << enumName; + + if (!underlyingType.empty()) + os << " : " << underlyingType; + os << " {\n"; + + for (const auto &enumerant : enumerants) { + auto symbol = enumerant.getSymbol(); + auto value = enumerant.getValue(); + if (value < 0) { + llvm::PrintFatalError(enumDef.getLoc(), + "all enumerants must have a non-negative value"); + } + os << formatv(" {0} = {1},\n", symbol, value); + } + os << "};\n\n"; +} + +static void emitDenseMapInfo(StringRef enumName, std::string underlyingType, + StringRef cppNamespace, raw_ostream &os) { + std::string qualName = formatv("{0}::{1}", cppNamespace, enumName); + if (underlyingType.empty()) + underlyingType = formatv("std::underlying_type<{0}>::type", qualName); + + const char *const mapInfo = R"( +namespace llvm { +template<> struct DenseMapInfo<{0}> {{ + using StorageInfo = llvm::DenseMapInfo<{1}>; + + static inline {0} getEmptyKey() {{ + return static_cast<{0}>(StorageInfo::getEmptyKey()); + } + + static inline {0} getTombstoneKey() {{ + return static_cast<{0}>(StorageInfo::getTombstoneKey()); + } + + static unsigned getHashValue(const {0} &val) {{ + return StorageInfo::getHashValue(static_cast<{1}>(val)); + } + + static bool isEqual(const {0} &lhs, const {0} &rhs) {{ + return lhs == rhs; + } +}; +})"; + os << formatv(mapInfo, qualName, underlyingType); + os << "\n\n"; +} + +static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef cppNamespace = enumAttr.getCppNamespace(); + std::string underlyingType = enumAttr.getUnderlyingType(); + StringRef description = enumAttr.getDescription(); + StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); + StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + auto enumerants = enumAttr.getAllCases(); + + llvm::SmallVector namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + // Emit the enum class definition + emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os); + + // Emit coversion function declarations + os << formatv("llvm::StringRef {1}({0});\n", enumName, symToStrFnName); + os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName, + strToSymFnName); + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; + + // Emit DenseMapInfo for this enum class + emitDenseMapInfo(enumName, underlyingType, cppNamespace, os); +} + +static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("Enum Utility Declarations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr"); + for (const auto *def : defs) + emitEnumDecl(*def, os); + + return false; +} + +static void emitEnumDef(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef cppNamespace = enumAttr.getCppNamespace(); + StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); + StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + auto enumerants = enumAttr.getAllCases(); + + llvm::SmallVector namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName); + os << " switch (val) {\n"; + for (const auto &enumerant : enumerants) { + auto symbol = enumerant.getSymbol(); + os << formatv(" case {0}::{1}: return \"{1}\";\n", enumName, symbol); + } + os << " }\n"; + os << " return \"\";\n"; + os << "}\n\n"; + + os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName, + strToSymFnName); + os << formatv(" return llvm::StringSwitch>(str)\n", + enumName); + for (const auto &enumerant : enumerants) { + auto symbol = enumerant.getSymbol(); + os << formatv(" .Case(\"{1}\", {0}::{1})\n", enumName, symbol); + } + os << " .Default(llvm::None);\n"; + os << "}\n"; + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; + os << "\n"; +} + +static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("Enum Utility Definitions", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr"); + for (const auto *def : defs) + emitEnumDef(*def, os); + + return false; +} + +// Registers the enum utility generator to mlir-tblgen. +static mlir::GenRegistration + genEnumDecls("gen-enum-decls", "Generate enum utility declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitEnumDecls(records, os); + }); + +// Registers the enum utility generator to mlir-tblgen. +static mlir::GenRegistration + genEnumDefs("gen-enum-defs", "Generate enum utility definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitEnumDefs(records, os); + }); diff --git a/mlir/unittests/TableGen/CMakeLists.txt b/mlir/unittests/TableGen/CMakeLists.txt index e400590..aa55adb 100644 --- a/mlir/unittests/TableGen/CMakeLists.txt +++ b/mlir/unittests/TableGen/CMakeLists.txt @@ -1,5 +1,14 @@ +set(LLVM_TARGET_DEFINITIONS enums.td) +mlir_tablegen(EnumsGenTest.h.inc -gen-enum-decls) +mlir_tablegen(EnumsGenTest.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRTableGenEnumsIncGen) + add_mlir_unittest(MLIRTableGenTests + EnumsGenTest.cpp FormatTest.cpp ) + +add_dependencies(MLIRTableGenTests MLIRTableGenEnumsIncGen) + target_link_libraries(MLIRTableGenTests PRIVATE LLVMMLIRTableGen) diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp new file mode 100644 index 0000000..b9a98a4 --- /dev/null +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -0,0 +1,66 @@ +//===- EnumsGenTest.cpp - TableGen EnumsGen Tests -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringSwitch.h" +#include "gmock/gmock.h" +#include + +// Pull in generated enum utility declarations +#include "EnumsGenTest.h.inc" +// And definitions +#include "EnumsGenTest.cpp.inc" + +using ::testing::StrEq; + +// Test namespaces and enum class/utility names +using Outer::Inner::ConvertToEnum; +using Outer::Inner::ConvertToString; +using Outer::Inner::MyEnum; + +TEST(EnumsGenTest, GeneratedEnumDefinition) { + EXPECT_EQ(0u, static_cast(MyEnum::CaseA)); + EXPECT_EQ(10u, static_cast(MyEnum::CaseB)); +} + +TEST(EnumsGenTest, GeneratedDenseMapInfo) { + llvm::DenseMap myMap; + + myMap[MyEnum::CaseA] = "zero"; + myMap[MyEnum::CaseB] = "ten"; + + EXPECT_THAT(myMap[MyEnum::CaseA], StrEq("zero")); + EXPECT_THAT(myMap[MyEnum::CaseB], StrEq("ten")); +} + +TEST(EnumsGenTest, GeneratedSymbolToStringFn) { + EXPECT_THAT(ConvertToString(MyEnum::CaseA), StrEq("CaseA")); + EXPECT_THAT(ConvertToString(MyEnum::CaseB), StrEq("CaseB")); +} + +TEST(EnumsGenTest, GeneratedStringToSymbolFn) { + EXPECT_EQ(llvm::Optional(MyEnum::CaseA), ConvertToEnum("CaseA")); + EXPECT_EQ(llvm::Optional(MyEnum::CaseB), ConvertToEnum("CaseB")); + EXPECT_EQ(llvm::None, ConvertToEnum("X")); +} + +TEST(EnumsGenTest, GeneratedUnderlyingType) { + bool v = + std::is_same::type>::value; + EXPECT_TRUE(v); +} diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td new file mode 100644 index 0000000..2898295 --- /dev/null +++ b/mlir/unittests/TableGen/enums.td @@ -0,0 +1,31 @@ +//===-- enums.td - EnumsGen test definition file -----------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +include "mlir/IR/OpBase.td" + +def CaseA: EnumAttrCase<"CaseA", 0>; +def CaseB: EnumAttrCase<"CaseB", 10>; + +def MyEnum: EnumAttr<"MyEnum", "A test enum", [CaseA, CaseB]> { + let cppNamespace = "Outer::Inner"; + let stringToSymbolFnName = "ConvertToEnum"; + let symbolToStringFnName = "ConvertToString"; +} + +def Uint64Enum : EnumAttr<"Uint64Enum", "A test enum", [CaseA, CaseB]> { + let underlyingType = "uint64_t"; +} -- 2.7.4