## Attribute Definition
-TODO: This section is outdated. Update it.
+### Enum attributes
+
+Enum attributes can be defined using `EnumAttr`, which requires all its cases to
+be defined with `EnumAttrCase`. To facilitate the interaction between
+`EnumAttr`s and their C++ consumers, the [`EnumsGen`][EnumsGen] TableGen backend
+can generate a few common utilities, including an enum class,
+`llvm::DenseMapInfo` for the enum class, conversion functions from/to strings.
+This is controlled via the `-gen-enum-decls` and `-gen-enum-defs` command-line
+options of `mlir-tblgen`.
+
+For example, given the following `EnumAttr`:
+
+```tablegen
+def CaseA: EnumAttrCase<"caseA", 0>;
+def CaseB: EnumAttrCase<"caseB", 10>;
+
+def MyEnum: EnumAttr<"MyEnum", "An example enum", [CaseA, CaseB]> {
+ let cppNamespace = "Outer::Inner";
+ let underlyingType = "uint64_t";
+ let stringToSymbolFnName = "ConvertToEnum";
+ let symbolToStringFnName = "ConvertToString";
+}
+```
+
+The following will be generated via `mlir-tblgen -gen-enum-decls`:
+
+```c++
+namespace Outer {
+namespace Inner {
+// An example enum
+enum class MyEnum : uint64_t {
+ caseA = 0,
+ caseB = 10,
+};
+
+llvm::StringRef ConvertToString(MyEnum);
+llvm::Optional<MyEnum> ConvertToEnum(llvm::StringRef);
+} // namespace Inner
+} // namespace Outer
+
+namespace llvm {
+template<> struct DenseMapInfo<Outer::Inner::MyEnum> {
+ using StorageInfo = llvm::DenseMapInfo<uint64_t>;
+
+ static inline Outer::Inner::MyEnum getEmptyKey() {
+ return static_cast<Outer::Inner::MyEnum>(StorageInfo::getEmptyKey());
+ }
+
+ static inline Outer::Inner::MyEnum getTombstoneKey() {
+ return static_cast<Outer::Inner::MyEnum>(StorageInfo::getTombstoneKey());
+ }
+
+ static unsigned getHashValue(const Outer::Inner::MyEnum &val) {
+ return StorageInfo::getHashValue(static_cast<uint64_t>(val));
+ }
+
+ static bool isEqual(const Outer::Inner::MyEnum &lhs,
+ const Outer::Inner::MyEnum &rhs) {
+ return lhs == rhs;
+ }
+};
+}
+```
+
+The following will be generated via `mlir-tblgen -gen-enum-defs`:
+
+```c++
+namespace Outer {
+namespace Inner {
+llvm::StringRef ConvertToString(MyEnum val) {
+ switch (val) {
+ case MyEnum::caseA: return "caseA";
+ case MyEnum::caseB: return "caseB";
+ default: return "";
+ }
+}
+
+llvm::Optional<MyEnum> ConvertToEnum(llvm::StringRef str) {
+ return llvm::StringSwitch<llvm::Optional<MyEnum>>(str)
+ .Case("caseA", MyEnum::caseA)
+ .Case("caseB", MyEnum::caseB)
+ .Default(llvm::None);
+}
+} // namespace Inner
+} // namespace Outer
+```
+
+TODO(b/132506080): This following is outdated. Update it.
An attribute is a compile time known constant of an operation. Attributes are
required to be known to construct an operation (e.g., the padding behavior is
[TableGenBackend]: https://llvm.org/docs/TableGen/BackEnds.html#introduction
[OpBase]: https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/OpBase.td
[OpDefinitionsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/OpDefinitionsGen.cpp
+[EnumsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/EnumsGen.cpp
def TypeAttr : TypeAttrBase<"Type", "any type attribute">;
// An enum attribute case.
-class EnumAttrCase<string sym> : StringBasedAttr<
+class EnumAttrCase<string sym, int val = -1> : StringBasedAttr<
CPred<"$_self.cast<StringAttr>().getValue() == \"" # sym # "\"">,
"case " # sym> {
// The C++ enumerant symbol
string symbol = sym;
+ // The C++ enumerant value
+ // A non-negative value must be provided if to use EnumsGen backend.
+ int value = val;
}
// An enum attribute. Its value can only be one from the given list of `cases`.
description> {
// The C++ enum class name
string className = name;
+
// List of all accepted cases
list<EnumAttrCase> enumerants = cases;
+
+ // The following fields are only used by the EnumsGen backend to generate
+ // an enum class definition and conversion utility functions.
+
+ // The underlying type for the C++ enum class. An empty string mean the
+ // underlying type is not explicitly specified.
+ string underlyingType = "";
+
+ // The C++ namespaces that the enum class definition and utility functions
+ // should be placed into.
+ //
+ // Normally you want to place the full namespace path here. If it is nested,
+ // use "::" as the delimiter, e.g., given "A::B", generated code will be
+ // placed in `namespace A { namespace B { ... } }`. To avoid placing in any
+ // namespace, use "".
+ // TODO(b/134741431): use dialect to provide the namespace.
+ string cppNamespace = "";
+
+ // The name of the utility function that converts a string to the
+ // corresponding symbol. It will have the following signature:
+ //
+ // ```c++
+ // llvm::Optional<<qualified-enum-class-name>> <fn-name>(llvm::StringRef);
+ // ```
+ string stringToSymbolFnName = "symbolize" # name;
+
+ // The name of the utility function that converts a symbol to the
+ // corresponding string. It will have the following signature:
+ //
+ // ```c++
+ // llvm::StringRef <fn-name>(<qualified-enum-class-name>);
+ // ```
+ string symbolToStringFnName = "stringify" # name;
}
class ElementsAttrBase<Pred condition, string description> :
// Returns the symbol of this enum attribute case.
StringRef getSymbol() const;
+
+ // Returns the value of this enum attribute case.
+ int64_t getValue() const;
};
// Wrapper class providing helper methods for accessing enum attributes defined
class EnumAttr : public Attribute {
public:
explicit EnumAttr(const llvm::Record *record);
+ explicit EnumAttr(const llvm::Record &record);
explicit EnumAttr(const llvm::DefInit *init);
// Returns the enum class name.
StringRef getEnumClassName() const;
+ // Returns the C++ namespaces this enum class should be placed in.
+ StringRef getCppNamespace() const;
+
+ // Returns the underlying type.
+ StringRef getUnderlyingType() const;
+
+ // Returns the name of the utility function that converts a string to the
+ // corresponding symbol.
+ StringRef getStringToSymbolFnName() const;
+
+ // Returns the name of the utility function that converts a symbol to the
+ // corresponding string.
+ StringRef getSymbolToStringFnName() const;
+
// Returns all allowed cases for this enum attribute.
std::vector<EnumAttrCase> getAllCases() const;
};
return def->getValueAsString("symbol");
}
+int64_t tblgen::EnumAttrCase::getValue() const {
+ return def->getValueAsInt("value");
+}
+
tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
assert(def->isSubClassOf("EnumAttr") &&
"must be subclass of TableGen 'EnumAttr' class");
}
+tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
+
tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
: EnumAttr(init->getDef()) {}
return def->getValueAsString("className");
}
+StringRef tblgen::EnumAttr::getCppNamespace() const {
+ return def->getValueAsString("cppNamespace");
+}
+
+StringRef tblgen::EnumAttr::getUnderlyingType() const {
+ return def->getValueAsString("underlyingType");
+}
+
+StringRef tblgen::EnumAttr::getStringToSymbolFnName() const {
+ return def->getValueAsString("stringToSymbolFnName");
+}
+
+StringRef tblgen::EnumAttr::getSymbolToStringFnName() const {
+ return def->getValueAsString("symbolToStringFnName");
+}
+
std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const {
const auto *inits = def->getValueAsListInit("enumerants");
)
add_tablegen(mlir-tblgen MLIR
+ EnumsGen.cpp
LLVMIRConversionGen.cpp
mlir-tblgen.cpp
OpDefinitionsGen.cpp
--- /dev/null
+//===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// EnumsGen generates common utility functions for enums.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using llvm::formatv;
+using llvm::raw_ostream;
+using llvm::Record;
+using llvm::RecordKeeper;
+using llvm::StringRef;
+using mlir::tblgen::EnumAttr;
+using mlir::tblgen::EnumAttrCase;
+
+static void emitEnumClass(const Record &enumDef, StringRef enumName,
+ StringRef underlyingType, StringRef description,
+ const std::vector<EnumAttrCase> &enumerants,
+ raw_ostream &os) {
+ os << "// " << description << "\n";
+ os << "enum class " << enumName;
+
+ if (!underlyingType.empty())
+ os << " : " << underlyingType;
+ os << " {\n";
+
+ for (const auto &enumerant : enumerants) {
+ auto symbol = enumerant.getSymbol();
+ auto value = enumerant.getValue();
+ if (value < 0) {
+ llvm::PrintFatalError(enumDef.getLoc(),
+ "all enumerants must have a non-negative value");
+ }
+ os << formatv(" {0} = {1},\n", symbol, value);
+ }
+ os << "};\n\n";
+}
+
+static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
+ StringRef cppNamespace, raw_ostream &os) {
+ std::string qualName = formatv("{0}::{1}", cppNamespace, enumName);
+ if (underlyingType.empty())
+ underlyingType = formatv("std::underlying_type<{0}>::type", qualName);
+
+ const char *const mapInfo = R"(
+namespace llvm {
+template<> struct DenseMapInfo<{0}> {{
+ using StorageInfo = llvm::DenseMapInfo<{1}>;
+
+ static inline {0} getEmptyKey() {{
+ return static_cast<{0}>(StorageInfo::getEmptyKey());
+ }
+
+ static inline {0} getTombstoneKey() {{
+ return static_cast<{0}>(StorageInfo::getTombstoneKey());
+ }
+
+ static unsigned getHashValue(const {0} &val) {{
+ return StorageInfo::getHashValue(static_cast<{1}>(val));
+ }
+
+ static bool isEqual(const {0} &lhs, const {0} &rhs) {{
+ return lhs == rhs;
+ }
+};
+})";
+ os << formatv(mapInfo, qualName, underlyingType);
+ os << "\n\n";
+}
+
+static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
+ EnumAttr enumAttr(enumDef);
+ StringRef enumName = enumAttr.getEnumClassName();
+ StringRef cppNamespace = enumAttr.getCppNamespace();
+ std::string underlyingType = enumAttr.getUnderlyingType();
+ StringRef description = enumAttr.getDescription();
+ StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
+ StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
+ auto enumerants = enumAttr.getAllCases();
+
+ llvm::SmallVector<StringRef, 2> namespaces;
+ llvm::SplitString(cppNamespace, namespaces, "::");
+
+ for (auto ns : namespaces)
+ os << "namespace " << ns << " {\n";
+
+ // Emit the enum class definition
+ emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
+
+ // Emit coversion function declarations
+ os << formatv("llvm::StringRef {1}({0});\n", enumName, symToStrFnName);
+ os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
+ strToSymFnName);
+
+ for (auto ns : llvm::reverse(namespaces))
+ os << "} // namespace " << ns << "\n";
+
+ // Emit DenseMapInfo for this enum class
+ emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
+}
+
+static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ llvm::emitSourceFileHeader("Enum Utility Declarations", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr");
+ for (const auto *def : defs)
+ emitEnumDecl(*def, os);
+
+ return false;
+}
+
+static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
+ EnumAttr enumAttr(enumDef);
+ StringRef enumName = enumAttr.getEnumClassName();
+ StringRef cppNamespace = enumAttr.getCppNamespace();
+ StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
+ StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
+ auto enumerants = enumAttr.getAllCases();
+
+ llvm::SmallVector<StringRef, 2> namespaces;
+ llvm::SplitString(cppNamespace, namespaces, "::");
+
+ for (auto ns : namespaces)
+ os << "namespace " << ns << " {\n";
+
+ os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName);
+ os << " switch (val) {\n";
+ for (const auto &enumerant : enumerants) {
+ auto symbol = enumerant.getSymbol();
+ os << formatv(" case {0}::{1}: return \"{1}\";\n", enumName, symbol);
+ }
+ os << " }\n";
+ os << " return \"\";\n";
+ os << "}\n\n";
+
+ os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
+ strToSymFnName);
+ os << formatv(" return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n",
+ enumName);
+ for (const auto &enumerant : enumerants) {
+ auto symbol = enumerant.getSymbol();
+ os << formatv(" .Case(\"{1}\", {0}::{1})\n", enumName, symbol);
+ }
+ os << " .Default(llvm::None);\n";
+ os << "}\n";
+
+ for (auto ns : llvm::reverse(namespaces))
+ os << "} // namespace " << ns << "\n";
+ os << "\n";
+}
+
+static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
+ llvm::emitSourceFileHeader("Enum Utility Definitions", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr");
+ for (const auto *def : defs)
+ emitEnumDef(*def, os);
+
+ return false;
+}
+
+// Registers the enum utility generator to mlir-tblgen.
+static mlir::GenRegistration
+ genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitEnumDecls(records, os);
+ });
+
+// Registers the enum utility generator to mlir-tblgen.
+static mlir::GenRegistration
+ genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitEnumDefs(records, os);
+ });
+set(LLVM_TARGET_DEFINITIONS enums.td)
+mlir_tablegen(EnumsGenTest.h.inc -gen-enum-decls)
+mlir_tablegen(EnumsGenTest.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRTableGenEnumsIncGen)
+
add_mlir_unittest(MLIRTableGenTests
+ EnumsGenTest.cpp
FormatTest.cpp
)
+
+add_dependencies(MLIRTableGenTests MLIRTableGenEnumsIncGen)
+
target_link_libraries(MLIRTableGenTests
PRIVATE LLVMMLIRTableGen)
--- /dev/null
+//===- EnumsGenTest.cpp - TableGen EnumsGen Tests -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "gmock/gmock.h"
+#include <type_traits>
+
+// Pull in generated enum utility declarations
+#include "EnumsGenTest.h.inc"
+// And definitions
+#include "EnumsGenTest.cpp.inc"
+
+using ::testing::StrEq;
+
+// Test namespaces and enum class/utility names
+using Outer::Inner::ConvertToEnum;
+using Outer::Inner::ConvertToString;
+using Outer::Inner::MyEnum;
+
+TEST(EnumsGenTest, GeneratedEnumDefinition) {
+ EXPECT_EQ(0u, static_cast<uint64_t>(MyEnum::CaseA));
+ EXPECT_EQ(10u, static_cast<uint64_t>(MyEnum::CaseB));
+}
+
+TEST(EnumsGenTest, GeneratedDenseMapInfo) {
+ llvm::DenseMap<MyEnum, std::string> myMap;
+
+ myMap[MyEnum::CaseA] = "zero";
+ myMap[MyEnum::CaseB] = "ten";
+
+ EXPECT_THAT(myMap[MyEnum::CaseA], StrEq("zero"));
+ EXPECT_THAT(myMap[MyEnum::CaseB], StrEq("ten"));
+}
+
+TEST(EnumsGenTest, GeneratedSymbolToStringFn) {
+ EXPECT_THAT(ConvertToString(MyEnum::CaseA), StrEq("CaseA"));
+ EXPECT_THAT(ConvertToString(MyEnum::CaseB), StrEq("CaseB"));
+}
+
+TEST(EnumsGenTest, GeneratedStringToSymbolFn) {
+ EXPECT_EQ(llvm::Optional<MyEnum>(MyEnum::CaseA), ConvertToEnum("CaseA"));
+ EXPECT_EQ(llvm::Optional<MyEnum>(MyEnum::CaseB), ConvertToEnum("CaseB"));
+ EXPECT_EQ(llvm::None, ConvertToEnum("X"));
+}
+
+TEST(EnumsGenTest, GeneratedUnderlyingType) {
+ bool v =
+ std::is_same<uint64_t, std::underlying_type<Uint64Enum>::type>::value;
+ EXPECT_TRUE(v);
+}
--- /dev/null
+//===-- enums.td - EnumsGen test definition file -----------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+include "mlir/IR/OpBase.td"
+
+def CaseA: EnumAttrCase<"CaseA", 0>;
+def CaseB: EnumAttrCase<"CaseB", 10>;
+
+def MyEnum: EnumAttr<"MyEnum", "A test enum", [CaseA, CaseB]> {
+ let cppNamespace = "Outer::Inner";
+ let stringToSymbolFnName = "ConvertToEnum";
+ let symbolToStringFnName = "ConvertToString";
+}
+
+def Uint64Enum : EnumAttr<"Uint64Enum", "A test enum", [CaseA, CaseB]> {
+ let underlyingType = "uint64_t";
+}