[TableGen] Generating enum definitions and utility functions
authorLei Zhang <antiagainst@google.com>
Sat, 8 Jun 2019 15:39:07 +0000 (08:39 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:24:08 +0000 (16:24 -0700)
Enum attributes can be defined using `EnumAttr`, which requires all its cases
to be defined with `EnumAttrCase`. To facilitate the interaction between
`EnumAttr`s and their C++ consumers, add a new EnumsGen TableGen backend
to generate a few common utilities, including an enum class, `llvm::DenseMapInfo`
for the enum class, conversion functions from/to strings.

This is controlled via the `-gen-enum-decls` and `-gen-enum-defs` command-line
options of `mlir-tblgen`.

PiperOrigin-RevId: 252209623

mlir/g3doc/OpDefinitions.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Attribute.h
mlir/lib/TableGen/Attribute.cpp
mlir/tools/mlir-tblgen/CMakeLists.txt
mlir/tools/mlir-tblgen/EnumsGen.cpp [new file with mode: 0644]
mlir/unittests/TableGen/CMakeLists.txt
mlir/unittests/TableGen/EnumsGenTest.cpp [new file with mode: 0644]
mlir/unittests/TableGen/enums.td [new file with mode: 0644]

index 810dd34..8146fce 100644 (file)
@@ -618,7 +618,94 @@ duplication, which is being worked on right now.
 
 ## Attribute Definition
 
-TODO: This section is outdated. Update it.
+### Enum attributes
+
+Enum attributes can be defined using `EnumAttr`, which requires all its cases to
+be defined with `EnumAttrCase`. To facilitate the interaction between
+`EnumAttr`s and their C++ consumers, the [`EnumsGen`][EnumsGen] TableGen backend
+can generate a few common utilities, including an enum class,
+`llvm::DenseMapInfo` for the enum class, conversion functions from/to strings.
+This is controlled via the `-gen-enum-decls` and `-gen-enum-defs` command-line
+options of `mlir-tblgen`.
+
+For example, given the following `EnumAttr`:
+
+```tablegen
+def CaseA: EnumAttrCase<"caseA", 0>;
+def CaseB: EnumAttrCase<"caseB", 10>;
+
+def MyEnum: EnumAttr<"MyEnum", "An example enum", [CaseA, CaseB]> {
+  let cppNamespace = "Outer::Inner";
+  let underlyingType = "uint64_t";
+  let stringToSymbolFnName = "ConvertToEnum";
+  let symbolToStringFnName = "ConvertToString";
+}
+```
+
+The following will be generated via `mlir-tblgen -gen-enum-decls`:
+
+```c++
+namespace Outer {
+namespace Inner {
+// An example enum
+enum class MyEnum : uint64_t {
+  caseA = 0,
+  caseB = 10,
+};
+
+llvm::StringRef ConvertToString(MyEnum);
+llvm::Optional<MyEnum> ConvertToEnum(llvm::StringRef);
+} // namespace Inner
+} // namespace Outer
+
+namespace llvm {
+template<> struct DenseMapInfo<Outer::Inner::MyEnum> {
+  using StorageInfo = llvm::DenseMapInfo<uint64_t>;
+
+  static inline Outer::Inner::MyEnum getEmptyKey() {
+    return static_cast<Outer::Inner::MyEnum>(StorageInfo::getEmptyKey());
+  }
+
+  static inline Outer::Inner::MyEnum getTombstoneKey() {
+    return static_cast<Outer::Inner::MyEnum>(StorageInfo::getTombstoneKey());
+  }
+
+  static unsigned getHashValue(const Outer::Inner::MyEnum &val) {
+    return StorageInfo::getHashValue(static_cast<uint64_t>(val));
+  }
+
+  static bool isEqual(const Outer::Inner::MyEnum &lhs,
+                      const Outer::Inner::MyEnum &rhs) {
+    return lhs == rhs;
+  }
+};
+}
+```
+
+The following will be generated via `mlir-tblgen -gen-enum-defs`:
+
+```c++
+namespace Outer {
+namespace Inner {
+llvm::StringRef ConvertToString(MyEnum val) {
+  switch (val) {
+    case MyEnum::caseA: return "caseA";
+    case MyEnum::caseB: return "caseB";
+    default: return "";
+  }
+}
+
+llvm::Optional<MyEnum> ConvertToEnum(llvm::StringRef str) {
+  return llvm::StringSwitch<llvm::Optional<MyEnum>>(str)
+      .Case("caseA", MyEnum::caseA)
+      .Case("caseB", MyEnum::caseB)
+      .Default(llvm::None);
+}
+} // namespace Inner
+} // namespace Outer
+```
+
+TODO(b/132506080): This following is outdated. Update it.
 
 An attribute is a compile time known constant of an operation. Attributes are
 required to be known to construct an operation (e.g., the padding behavior is
@@ -829,3 +916,4 @@ TODO: Describe the generation of benefit metric given pattern.
 [TableGenBackend]: https://llvm.org/docs/TableGen/BackEnds.html#introduction
 [OpBase]: https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/OpBase.td
 [OpDefinitionsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/OpDefinitionsGen.cpp
+[EnumsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/EnumsGen.cpp
index a170069..ac8c652 100644 (file)
@@ -634,11 +634,14 @@ class TypeAttrBase<string retType, string description> :
 def TypeAttr : TypeAttrBase<"Type", "any type attribute">;
 
 // An enum attribute case.
-class EnumAttrCase<string sym> : StringBasedAttr<
+class EnumAttrCase<string sym, int val = -1> : StringBasedAttr<
     CPred<"$_self.cast<StringAttr>().getValue() == \"" # sym # "\"">,
     "case " # sym> {
   // The C++ enumerant symbol
   string symbol = sym;
+  // The C++ enumerant value
+  // A non-negative value must be provided if to use EnumsGen backend.
+  int value = val;
 }
 
 // An enum attribute. Its value can only be one from the given list of `cases`.
@@ -651,8 +654,42 @@ class EnumAttr<string name, string description, list<EnumAttrCase> cases> :
                     description> {
   // The C++ enum class name
   string className = name;
+
   // List of all accepted cases
   list<EnumAttrCase> enumerants = cases;
+
+  // The following fields are only used by the EnumsGen backend to generate
+  // an enum class definition and conversion utility functions.
+
+  // The underlying type for the C++ enum class. An empty string mean the
+  // underlying type is not explicitly specified.
+  string underlyingType = "";
+
+  // The C++ namespaces that the enum class definition and utility functions
+  // should be placed into.
+  //
+  // Normally you want to place the full namespace path here. If it is nested,
+  // use "::" as the delimiter, e.g., given "A::B", generated code will be
+  // placed in `namespace A { namespace B { ... } }`. To avoid placing in any
+  // namespace, use "".
+  // TODO(b/134741431): use dialect to provide the namespace.
+  string cppNamespace = "";
+
+  // The name of the utility function that converts a string to the
+  // corresponding symbol. It will have the following signature:
+  //
+  // ```c++
+  // llvm::Optional<<qualified-enum-class-name>> <fn-name>(llvm::StringRef);
+  // ```
+  string stringToSymbolFnName = "symbolize" # name;
+
+  // The name of the utility function that converts a symbol to the
+  // corresponding string. It will have the following signature:
+  //
+  // ```c++
+  // llvm::StringRef <fn-name>(<qualified-enum-class-name>);
+  // ```
+  string symbolToStringFnName = "stringify" # name;
 }
 
 class ElementsAttrBase<Pred condition, string description> :
index f921605..f69961a 100644 (file)
@@ -129,6 +129,9 @@ public:
 
   // Returns the symbol of this enum attribute case.
   StringRef getSymbol() const;
+
+  // Returns the value of this enum attribute case.
+  int64_t getValue() const;
 };
 
 // Wrapper class providing helper methods for accessing enum attributes defined
@@ -137,11 +140,26 @@ public:
 class EnumAttr : public Attribute {
 public:
   explicit EnumAttr(const llvm::Record *record);
+  explicit EnumAttr(const llvm::Record &record);
   explicit EnumAttr(const llvm::DefInit *init);
 
   // Returns the enum class name.
   StringRef getEnumClassName() const;
 
+  // Returns the C++ namespaces this enum class should be placed in.
+  StringRef getCppNamespace() const;
+
+  // Returns the underlying type.
+  StringRef getUnderlyingType() const;
+
+  // Returns the name of the utility function that converts a string to the
+  // corresponding symbol.
+  StringRef getStringToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a symbol to the
+  // corresponding string.
+  StringRef getSymbolToStringFnName() const;
+
   // Returns all allowed cases for this enum attribute.
   std::vector<EnumAttrCase> getAllCases() const;
 };
index 9de95ef..29259be 100644 (file)
@@ -144,11 +144,17 @@ StringRef tblgen::EnumAttrCase::getSymbol() const {
   return def->getValueAsString("symbol");
 }
 
+int64_t tblgen::EnumAttrCase::getValue() const {
+  return def->getValueAsInt("value");
+}
+
 tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
   assert(def->isSubClassOf("EnumAttr") &&
          "must be subclass of TableGen 'EnumAttr' class");
 }
 
+tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
+
 tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
     : EnumAttr(init->getDef()) {}
 
@@ -156,6 +162,22 @@ StringRef tblgen::EnumAttr::getEnumClassName() const {
   return def->getValueAsString("className");
 }
 
+StringRef tblgen::EnumAttr::getCppNamespace() const {
+  return def->getValueAsString("cppNamespace");
+}
+
+StringRef tblgen::EnumAttr::getUnderlyingType() const {
+  return def->getValueAsString("underlyingType");
+}
+
+StringRef tblgen::EnumAttr::getStringToSymbolFnName() const {
+  return def->getValueAsString("stringToSymbolFnName");
+}
+
+StringRef tblgen::EnumAttr::getSymbolToStringFnName() const {
+  return def->getValueAsString("symbolToStringFnName");
+}
+
 std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const {
   const auto *inits = def->getValueAsListInit("enumerants");
 
index d341cab..65c5e91 100644 (file)
@@ -4,6 +4,7 @@ set(LLVM_LINK_COMPONENTS
   )
 
 add_tablegen(mlir-tblgen MLIR
+  EnumsGen.cpp
   LLVMIRConversionGen.cpp
   mlir-tblgen.cpp
   OpDefinitionsGen.cpp
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
new file mode 100644 (file)
index 0000000..ab86c9d
--- /dev/null
@@ -0,0 +1,198 @@
+//===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// EnumsGen generates common utility functions for enums.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using llvm::formatv;
+using llvm::raw_ostream;
+using llvm::Record;
+using llvm::RecordKeeper;
+using llvm::StringRef;
+using mlir::tblgen::EnumAttr;
+using mlir::tblgen::EnumAttrCase;
+
+static void emitEnumClass(const Record &enumDef, StringRef enumName,
+                          StringRef underlyingType, StringRef description,
+                          const std::vector<EnumAttrCase> &enumerants,
+                          raw_ostream &os) {
+  os << "// " << description << "\n";
+  os << "enum class " << enumName;
+
+  if (!underlyingType.empty())
+    os << " : " << underlyingType;
+  os << " {\n";
+
+  for (const auto &enumerant : enumerants) {
+    auto symbol = enumerant.getSymbol();
+    auto value = enumerant.getValue();
+    if (value < 0) {
+      llvm::PrintFatalError(enumDef.getLoc(),
+                            "all enumerants must have a non-negative value");
+    }
+    os << formatv("  {0} = {1},\n", symbol, value);
+  }
+  os << "};\n\n";
+}
+
+static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
+                             StringRef cppNamespace, raw_ostream &os) {
+  std::string qualName = formatv("{0}::{1}", cppNamespace, enumName);
+  if (underlyingType.empty())
+    underlyingType = formatv("std::underlying_type<{0}>::type", qualName);
+
+  const char *const mapInfo = R"(
+namespace llvm {
+template<> struct DenseMapInfo<{0}> {{
+  using StorageInfo = llvm::DenseMapInfo<{1}>;
+
+  static inline {0} getEmptyKey() {{
+    return static_cast<{0}>(StorageInfo::getEmptyKey());
+  }
+
+  static inline {0} getTombstoneKey() {{
+    return static_cast<{0}>(StorageInfo::getTombstoneKey());
+  }
+
+  static unsigned getHashValue(const {0} &val) {{
+    return StorageInfo::getHashValue(static_cast<{1}>(val));
+  }
+
+  static bool isEqual(const {0} &lhs, const {0} &rhs) {{
+    return lhs == rhs;
+  }
+};
+})";
+  os << formatv(mapInfo, qualName, underlyingType);
+  os << "\n\n";
+}
+
+static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
+  EnumAttr enumAttr(enumDef);
+  StringRef enumName = enumAttr.getEnumClassName();
+  StringRef cppNamespace = enumAttr.getCppNamespace();
+  std::string underlyingType = enumAttr.getUnderlyingType();
+  StringRef description = enumAttr.getDescription();
+  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
+  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
+  auto enumerants = enumAttr.getAllCases();
+
+  llvm::SmallVector<StringRef, 2> namespaces;
+  llvm::SplitString(cppNamespace, namespaces, "::");
+
+  for (auto ns : namespaces)
+    os << "namespace " << ns << " {\n";
+
+  // Emit the enum class definition
+  emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
+
+  // Emit coversion function declarations
+  os << formatv("llvm::StringRef {1}({0});\n", enumName, symToStrFnName);
+  os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
+                strToSymFnName);
+
+  for (auto ns : llvm::reverse(namespaces))
+    os << "} // namespace " << ns << "\n";
+
+  // Emit DenseMapInfo for this enum class
+  emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
+}
+
+static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  llvm::emitSourceFileHeader("Enum Utility Declarations", os);
+
+  auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr");
+  for (const auto *def : defs)
+    emitEnumDecl(*def, os);
+
+  return false;
+}
+
+static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
+  EnumAttr enumAttr(enumDef);
+  StringRef enumName = enumAttr.getEnumClassName();
+  StringRef cppNamespace = enumAttr.getCppNamespace();
+  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
+  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
+  auto enumerants = enumAttr.getAllCases();
+
+  llvm::SmallVector<StringRef, 2> namespaces;
+  llvm::SplitString(cppNamespace, namespaces, "::");
+
+  for (auto ns : namespaces)
+    os << "namespace " << ns << " {\n";
+
+  os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName);
+  os << "  switch (val) {\n";
+  for (const auto &enumerant : enumerants) {
+    auto symbol = enumerant.getSymbol();
+    os << formatv("    case {0}::{1}: return \"{1}\";\n", enumName, symbol);
+  }
+  os << "  }\n";
+  os << "  return \"\";\n";
+  os << "}\n\n";
+
+  os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
+                strToSymFnName);
+  os << formatv("  return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n",
+                enumName);
+  for (const auto &enumerant : enumerants) {
+    auto symbol = enumerant.getSymbol();
+    os << formatv("      .Case(\"{1}\", {0}::{1})\n", enumName, symbol);
+  }
+  os << "      .Default(llvm::None);\n";
+  os << "}\n";
+
+  for (auto ns : llvm::reverse(namespaces))
+    os << "} // namespace " << ns << "\n";
+  os << "\n";
+}
+
+static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  llvm::emitSourceFileHeader("Enum Utility Definitions", os);
+
+  auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr");
+  for (const auto *def : defs)
+    emitEnumDef(*def, os);
+
+  return false;
+}
+
+// Registers the enum utility generator to mlir-tblgen.
+static mlir::GenRegistration
+    genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
+                 [](const RecordKeeper &records, raw_ostream &os) {
+                   return emitEnumDecls(records, os);
+                 });
+
+// Registers the enum utility generator to mlir-tblgen.
+static mlir::GenRegistration
+    genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
+                [](const RecordKeeper &records, raw_ostream &os) {
+                  return emitEnumDefs(records, os);
+                });
index e400590..aa55adb 100644 (file)
@@ -1,5 +1,14 @@
+set(LLVM_TARGET_DEFINITIONS enums.td)
+mlir_tablegen(EnumsGenTest.h.inc -gen-enum-decls)
+mlir_tablegen(EnumsGenTest.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRTableGenEnumsIncGen)
+
 add_mlir_unittest(MLIRTableGenTests
+  EnumsGenTest.cpp
   FormatTest.cpp
 )
+
+add_dependencies(MLIRTableGenTests MLIRTableGenEnumsIncGen)
+
 target_link_libraries(MLIRTableGenTests
   PRIVATE LLVMMLIRTableGen)
diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
new file mode 100644 (file)
index 0000000..b9a98a4
--- /dev/null
@@ -0,0 +1,66 @@
+//===- EnumsGenTest.cpp - TableGen EnumsGen Tests -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "gmock/gmock.h"
+#include <type_traits>
+
+// Pull in generated enum utility declarations
+#include "EnumsGenTest.h.inc"
+// And definitions
+#include "EnumsGenTest.cpp.inc"
+
+using ::testing::StrEq;
+
+// Test namespaces and enum class/utility names
+using Outer::Inner::ConvertToEnum;
+using Outer::Inner::ConvertToString;
+using Outer::Inner::MyEnum;
+
+TEST(EnumsGenTest, GeneratedEnumDefinition) {
+  EXPECT_EQ(0u, static_cast<uint64_t>(MyEnum::CaseA));
+  EXPECT_EQ(10u, static_cast<uint64_t>(MyEnum::CaseB));
+}
+
+TEST(EnumsGenTest, GeneratedDenseMapInfo) {
+  llvm::DenseMap<MyEnum, std::string> myMap;
+
+  myMap[MyEnum::CaseA] = "zero";
+  myMap[MyEnum::CaseB] = "ten";
+
+  EXPECT_THAT(myMap[MyEnum::CaseA], StrEq("zero"));
+  EXPECT_THAT(myMap[MyEnum::CaseB], StrEq("ten"));
+}
+
+TEST(EnumsGenTest, GeneratedSymbolToStringFn) {
+  EXPECT_THAT(ConvertToString(MyEnum::CaseA), StrEq("CaseA"));
+  EXPECT_THAT(ConvertToString(MyEnum::CaseB), StrEq("CaseB"));
+}
+
+TEST(EnumsGenTest, GeneratedStringToSymbolFn) {
+  EXPECT_EQ(llvm::Optional<MyEnum>(MyEnum::CaseA), ConvertToEnum("CaseA"));
+  EXPECT_EQ(llvm::Optional<MyEnum>(MyEnum::CaseB), ConvertToEnum("CaseB"));
+  EXPECT_EQ(llvm::None, ConvertToEnum("X"));
+}
+
+TEST(EnumsGenTest, GeneratedUnderlyingType) {
+  bool v =
+      std::is_same<uint64_t, std::underlying_type<Uint64Enum>::type>::value;
+  EXPECT_TRUE(v);
+}
diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td
new file mode 100644 (file)
index 0000000..2898295
--- /dev/null
@@ -0,0 +1,31 @@
+//===-- enums.td - EnumsGen test definition file -----------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+include "mlir/IR/OpBase.td"
+
+def CaseA: EnumAttrCase<"CaseA", 0>;
+def CaseB: EnumAttrCase<"CaseB", 10>;
+
+def MyEnum: EnumAttr<"MyEnum", "A test enum", [CaseA, CaseB]> {
+  let cppNamespace = "Outer::Inner";
+  let stringToSymbolFnName = "ConvertToEnum";
+  let symbolToStringFnName = "ConvertToString";
+}
+
+def Uint64Enum : EnumAttr<"Uint64Enum", "A test enum", [CaseA, CaseB]> {
+  let underlyingType = "uint64_t";
+}