OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs),
[{
build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
- $_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
+ predicate, lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
let printer = [{ printICmpOp(p, *this); }];
let llvmBuilder = [{
$res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}];
- let builders = [
- OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs,
- CArg<"FastmathFlags", "{}">:$fmf),
- [{
- build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
- $_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs,
- ::mlir::LLVM::FMFAttr::get($_builder.getContext(), fmf));
- }]>];
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
let printer = [{ printFCmpOp(p, *this); }];
}
#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVENUMS_H_
#define MLIR_DIALECT_SPIRV_IR_SPIRVENUMS_H_
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/StringRef.h"
let builders = [
OpBuilder<(ins "Value":$basePtr,
- CArg<"IntegerAttr", "{}">:$memory_access,
+ CArg<"MemoryAccessAttr", "{}">:$memory_access,
CArg<"IntegerAttr", "{}">:$alignment)>
];
}
COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR,
COMBINING_KIND_XOR]> {
let cppNamespace = "::mlir::vector";
+ let genSpecializedAttr = 0;
}
def Vector_CombiningKindAttr : DialectAttr<
}
// Additional information for an enum attribute.
-class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
+class EnumAttrInfo<
+ string name, list<EnumAttrCaseInfo> cases, Attr baseClass> :
+ Attr<baseClass.predicate, baseClass.summary> {
// The C++ enum class name
string className = name;
// static constexpr unsigned <fn-name>();
// ```
string maxEnumValFnName = "getMaxEnumValFor" # name;
+
+ // Generate specialized Attribute class
+ bit genSpecializedAttr = 1;
+ // The underlying Attribute class, which holds the enum value
+ Attr baseAttrClass = baseClass;
+ // The name of specialized Enum Attribute class
+ string specializedAttrClassName = name # Attr;
+
+ // Override Attr class fields for specialized class
+ let predicate = !if(genSpecializedAttr,
+ CPred<"$_self.isa<" # cppNamespace # "::" # specializedAttrClassName # ">()">,
+ baseAttrClass.predicate);
+ let storageType = !if(genSpecializedAttr,
+ cppNamespace # "::" # specializedAttrClassName,
+ baseAttrClass.storageType);
+ let returnType = !if(genSpecializedAttr,
+ cppNamespace # "::" # className,
+ baseAttrClass.returnType);
+ let constBuilderCall = !if(genSpecializedAttr,
+ cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)",
+ baseAttrClass.constBuilderCall);
+ let valueType = baseAttrClass.valueType;
}
// An enum attribute backed by StringAttr.
// Op attributes of this kind are stored as StringAttr. Extra verification will
// be generated on the string though: only the symbols of the allowed cases are
// permitted as the string value.
-class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases>
- : EnumAttrInfo<name, cases>,
+class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases> :
+ EnumAttrInfo<name, cases,
StringBasedAttr<
And<[StrAttr.predicate, Or<!foreach(case, cases, case.predicate)>]>,
!if(!empty(summary), "allowed string cases: " #
!interleave(!foreach(case, cases, "'" # case.symbol # "'"), ", "),
- summary)>;
+ summary)>> {
+ // Disable specialized Attribute class for `StringAttr` backend by default.
+ let genSpecializedAttr = 0;
+}
// An enum attribute backed by IntegerAttr.
//
// Op attributes of this kind are stored as IntegerAttr. Extra verification will
// be generated on the integer though: only the values of the allowed cases are
// permitted as the integer value.
-class IntEnumAttr<I intType, string name, string summary,
- list<IntEnumAttrCaseBase> cases> :
- EnumAttrInfo<name, cases>,
- SignlessIntegerAttrBase<intType,
- !if(!empty(summary), "allowed " # intType.summary # " cases: " #
- !interleave(!foreach(case, cases, case.value), ", "), summary)> {
+class IntEnumAttrBase<I intType, list<IntEnumAttrCaseBase> cases, string summary> :
+ SignlessIntegerAttrBase<intType, summary> {
let predicate = And<[
- SignlessIntegerAttrBase<intType, "">.predicate,
+ SignlessIntegerAttrBase<intType, summary>.predicate,
Or<!foreach(case, cases, case.predicate)>]>;
}
-class I32EnumAttr<string name, string summary,
- list<I32EnumAttrCase> cases> :
+class IntEnumAttr<I intType, string name, string summary,
+ list<IntEnumAttrCaseBase> cases> :
+ EnumAttrInfo<name, cases,
+ IntEnumAttrBase<intType, cases,
+ !if(!empty(summary), "allowed " # intType.summary # " cases: " #
+ !interleave(!foreach(case, cases, case.value), ", "),
+ summary)>>;
+
+class I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases> :
IntEnumAttr<I32, name, summary, cases> {
- let returnType = cppNamespace # "::" # name;
let underlyingType = "uint32_t";
- let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
- let constBuilderCall =
- "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
}
-class I64EnumAttr<string name, string summary,
- list<I64EnumAttrCase> cases> :
+class I64EnumAttr<string name, string summary, list<I64EnumAttrCase> cases> :
IntEnumAttr<I64, name, summary, cases> {
- let returnType = cppNamespace # "::" # name;
let underlyingType = "uint64_t";
- let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
- let constBuilderCall =
- "$_builder.getI64IntegerAttr(static_cast<int64_t>($0))";
}
// A bit enum stored with 32-bit IntegerAttr.
// be generated on the integer to make sure only allowed bit are set. Besides,
// helper methods are generated to parse a string separated with a specified
// delimiter to a symbol and vice versa.
-class BitEnumAttr<string name, string summary,
- list<BitEnumAttrCase> cases> :
- EnumAttrInfo<name, cases>, SignlessIntegerAttrBase<I32, summary> {
+class BitEnumAttrBase<list<BitEnumAttrCase> cases, string summary> :
+ SignlessIntegerAttrBase<I32, summary> {
let predicate = And<[
I32Attr.predicate,
// Make sure we don't have unknown bit set.
# !interleave(!foreach(case, cases, case.value # "u"), "|") #
")))">
]>;
+}
- let returnType = cppNamespace # "::" # name;
+class BitEnumAttr<string name, string summary, list<BitEnumAttrCase> cases> :
+ EnumAttrInfo<name, cases, BitEnumAttrBase<cases, summary>> {
let underlyingType = "uint32_t";
- let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
- let constBuilderCall =
- "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
// We need to return a string because we may concatenate symbols for multiple
// bits together.
// Returns all allowed cases for this enum attribute.
std::vector<EnumAttrCase> getAllCases() const;
+
+ bool genSpecializedAttr() const;
+ llvm::Record *getBaseAttrClass() const;
+ StringRef getSpecializedAttrClassName() const;
};
class StructFieldAttr {
// header to merge.
scf::ForOpAdaptor forOperands(operands);
auto loc = forOp.getLoc();
- auto loopControl = rewriter.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::LoopControl::None));
- auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
+ auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
loopOp.addEntryAndMergeBlock();
OpBuilder::InsertionGuard guard(rewriter);
scf::IfOpAdaptor ifOperands(operands);
auto loc = ifOp.getLoc();
- // Create `spv.mlir.selection` operation, selection header block and merge
- // block.
- auto selectionControl = rewriter.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::SelectionControl::None));
- auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
+ // Create `spv.selection` operation, selection header block and merge block.
+ auto selectionOp =
+ rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
auto *mergeBlock =
rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
rewriter.create<spirv::MergeOp>(loc);
return failure();
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
- operation, dstType,
- rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
- operation.operand1(), operation.operand2(),
- LLVM::FMFAttr::get(operation.getContext(), {}));
+ operation, dstType, predicate, operation.operand1(),
+ operation.operand2());
return success();
}
};
return failure();
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
- operation, dstType,
- rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
- operation.operand1(), operation.operand2());
+ operation, dstType, predicate, operation.operand1(),
+ operation.operand2());
return success();
}
};
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()),
- rewriter.getI64IntegerAttr(static_cast<int64_t>(
- convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
+ convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
transformed.lhs(), transformed.rhs());
return success();
ConversionPatternRewriter &rewriter) const override {
CmpFOpAdaptor transformed(operands);
- auto fmf = LLVM::FMFAttr::get(cmpfOp.getContext(), {});
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
- rewriter.getI64IntegerAttr(static_cast<int64_t>(
- convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
- transformed.lhs(), transformed.rhs(), fmf);
+ convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
+ transformed.lhs(), transformed.rhs());
return success();
}
srcBits, dstBits, rewriter);
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
loc, dstType, adjustedPtr,
- loadOp->getAttrOfType<IntegerAttr>(
+ loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
spirv::attributeName<spirv::MemoryAccess>()),
loadOp->getAttrOfType<IntegerAttr>("alignment"));
MLIRContext *context = map.getContext();
OpBuilder builder(context);
return ParallelLoopDimMapping::get(
- builder.getI64IntegerAttr(static_cast<int32_t>(processor)),
+ ProcessorAttr::get(builder.getContext(), processor),
AffineMapAttr::get(map), AffineMapAttr::get(bound), context);
}
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/IR/BuiltinTypes.h"
+
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
spirv::FuncOp function,
ArrayRef<Attribute> interfaceVars) {
build(builder, state,
- builder.getI32IntegerAttr(static_cast<int32_t>(executionModel)),
+ spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
builder.getSymbolRefAttr(function),
builder.getArrayAttr(interfaceVars));
}
spirv::ExecutionMode executionMode,
ArrayRef<int32_t> params) {
build(builder, state, builder.getSymbolRefAttr(function),
- builder.getI32IntegerAttr(static_cast<int32_t>(executionMode)),
+ spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
builder.getI32ArrayAttr(params));
}
//===----------------------------------------------------------------------===//
void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
- Value basePtr, IntegerAttr memory_access,
+ Value basePtr, MemoryAccessAttr memoryAccess,
IntegerAttr alignment) {
auto ptrType = basePtr.getType().cast<spirv::PointerType>();
- build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
+ build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
alignment);
}
spirv::SelectionOp spirv::SelectionOp::createIfThen(
Location loc, Value condition,
function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
- auto selectionControl = builder.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::SelectionControl::None));
- auto selectionOp = builder.create<spirv::SelectionOp>(loc, selectionControl);
+ auto selectionOp =
+ builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
selectionOp.addMergeBlock();
Block *mergeBlock = selectionOp.getMergeBlock();
return cases;
}
+bool EnumAttr::genSpecializedAttr() const {
+ return def->getValueAsBit("genSpecializedAttr");
+}
+
+llvm::Record *EnumAttr::getBaseAttrClass() const {
+ return def->getValueAsDef("baseAttrClass");
+}
+
+StringRef EnumAttr::getSpecializedAttrClassName() const {
+ return def->getValueAsString("specializedAttrClassName");
+}
+
StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
assert(def->isSubClassOf("StructFieldAttr") &&
"must be subclass of TableGen 'StructFieldAttr' class");
return emitError(unknownLoc,
"missing Execution Model specification in OpEntryPoint");
}
- auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+ auto execModel = spirv::ExecutionModelAttr::get(
+ context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
if (wordIndex >= words.size()) {
return emitError(unknownLoc, "missing <id> in OpEntryPoint");
}
if (wordIndex >= words.size()) {
return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
}
- auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+ auto execMode = spirv::ExecutionModeAttr::get(
+ context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
// Get the values
SmallVector<Attribute, 4> attrListElems;
argAttrs.push_back(argAttr);
}
- opBuilder.create<spirv::ControlBarrierOp>(unknownLoc, argAttrs[0],
- argAttrs[1], argAttrs[2]);
+ opBuilder.create<spirv::ControlBarrierOp>(
+ unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
+ argAttrs[1].cast<spirv::ScopeAttr>(),
+ argAttrs[2].cast<spirv::MemorySemanticsAttr>());
+
return success();
}
argAttrs.push_back(argAttr);
}
- opBuilder.create<spirv::MemoryBarrierOp>(unknownLoc, argAttrs[0],
- argAttrs[1]);
+ opBuilder.create<spirv::MemoryBarrierOp>(
+ unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
+ argAttrs[1].cast<spirv::MemorySemanticsAttr>());
return success();
}
// merge block so that the newly created SelectionOp will be inserted there.
OpBuilder builder(&mergeBlock->front());
- auto control = builder.getI32IntegerAttr(selectionControl);
+ auto control = static_cast<spirv::SelectionControl>(selectionControl);
auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
selectionOp.addMergeBlock();
// merge block so that the newly created LoopOp will be inserted there.
OpBuilder builder(&mergeBlock->front());
- auto control = builder.getI32IntegerAttr(loopControl);
+ auto control = static_cast<spirv::LoopControl>(loopControl);
auto loopOp = builder.create<spirv::LoopOp>(location, control);
loopOp.addEntryAndMergeBlock();
}
// Test using multi-result op as a whole
-def : Pat<(ThreeResultOp MultiResultOpKind1),
- (AnotherThreeResultOp MultiResultOpKind1)>;
+def : Pat<(ThreeResultOp MultiResultOpKind1:$kind),
+ (AnotherThreeResultOp $kind)>;
// Test using multi-result op as a whole for partial replacement
-def : Pattern<(ThreeResultOp MultiResultOpKind2),
- [(TwoResultOp MultiResultOpKind2),
- (OneResultOp1 MultiResultOpKind2)]>;
-def : Pattern<(ThreeResultOp MultiResultOpKind3),
- [(OneResultOp2 MultiResultOpKind3),
- (AnotherTwoResultOp MultiResultOpKind3)]>;
+def : Pattern<(ThreeResultOp MultiResultOpKind2:$kind),
+ [(TwoResultOp $kind),
+ (OneResultOp1 $kind)]>;
+def : Pattern<(ThreeResultOp MultiResultOpKind3:$kind),
+ [(OneResultOp2 $kind),
+ (AnotherTwoResultOp $kind)]>;
// Test using results separately in a multi-result op
-def : Pattern<(ThreeResultOp MultiResultOpKind4),
- [(TwoResultOp:$res1__0 MultiResultOpKind4),
- (OneResultOp1 MultiResultOpKind4),
- (TwoResultOp:$res2__1 MultiResultOpKind4)]>;
+def : Pattern<(ThreeResultOp MultiResultOpKind4:$kind),
+ [(TwoResultOp:$res1__0 $kind),
+ (OneResultOp1 $kind),
+ (TwoResultOp:$res2__1 $kind)]>;
// Test referencing a single value in the value pack
// This rule only matches TwoResultOp if its second result has no use.
-def : Pattern<(TwoResultOp:$res MultiResultOpKind5),
- [(OneResultOp2 MultiResultOpKind5),
- (OneResultOp1 MultiResultOpKind5)],
+def : Pattern<(TwoResultOp:$res MultiResultOpKind5:$kind),
+ [(OneResultOp2 $kind),
+ (OneResultOp1 $kind)],
[(HasNoUseOf:$res__1)]>;
// Test using auxiliary ops for replacing multi-result op
def : Pattern<
- (ThreeResultOp MultiResultOpKind6), [
+ (ThreeResultOp MultiResultOpKind6:$kind), [
// Auxiliary op generated to help building the final result but not
// directly used to replace the source op's results.
- (TwoResultOp:$interm MultiResultOpKind6),
+ (TwoResultOp:$interm $kind),
(OneResultOp3 $interm__1),
- (AnotherTwoResultOp MultiResultOpKind6)
+ (AnotherTwoResultOp $kind)
]>;
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
using llvm::formatv;
using llvm::isDigit;
+using llvm::PrintFatalError;
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::StringRef;
+using mlir::tblgen::Attribute;
using mlir::tblgen::EnumAttr;
using mlir::tblgen::EnumAttrCase;
+using mlir::tblgen::FmtContext;
+using mlir::tblgen::tgfmt;
static std::string makeIdentifier(StringRef str) {
if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
<< "}\n\n";
}
+static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
+ EnumAttr enumAttr(enumDef);
+ StringRef enumName = enumAttr.getEnumClassName();
+ StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
+ StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
+ StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
+ llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
+ Attribute baseAttr(baseAttrDef);
+
+ // Emit classof method
+
+ os << formatv("bool {0}::classof(::mlir::Attribute attr) {{\n",
+ attrClassName);
+
+ mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate();
+ if (baseAttrPred.isNull())
+ PrintFatalError("ERROR: baseAttrClass for EnumAttr has no Predicate\n");
+
+ std::string condition = baseAttrPred.getCondition();
+ FmtContext verifyCtx;
+ verifyCtx.withSelf("attr");
+ os << tgfmt(" return $0;\n", /*ctx=*/nullptr, tgfmt(condition, &verifyCtx));
+
+ os << "}\n";
+
+ // Emit get method
+
+ os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
+ attrClassName, enumName);
+
+ if (enumAttr.isSubClassOf("StrEnumAttr")) {
+ os << formatv(" ::mlir::StringAttr baseAttr = "
+ "::mlir::StringAttr::get(context, {0}(val));\n",
+ symToStrFnName);
+ } else {
+ StringRef underlyingType = enumAttr.getUnderlyingType();
+
+ // Assuming that it is IntegerAttr constraint
+ int64_t bitwidth = 64;
+ if (baseAttrDef->getValue("valueType")) {
+ auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
+ if (valueTypeDef->getValue("bitwidth"))
+ bitwidth = valueTypeDef->getValueAsInt("bitwidth");
+ }
+
+ os << formatv(" ::mlir::IntegerType intType = "
+ "::mlir::IntegerType::get(context, {0});\n",
+ bitwidth);
+ os << formatv(" ::mlir::IntegerAttr baseAttr = "
+ "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
+ underlyingType);
+ }
+ os << formatv(" return baseAttr.cast<{0}>();\n", attrClassName);
+
+ os << "}\n";
+
+ // Emit getValue method
+
+ os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
+
+ if (enumAttr.isSubClassOf("StrEnumAttr")) {
+ os << formatv(" const auto res = {0}(::mlir::StringAttr::getValue());\n",
+ strToSymFnName);
+ os << " return res.getValue();\n";
+ } else {
+ os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
+ enumName);
+ }
+
+ os << "}\n";
+}
+
static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
raw_ostream &os) {
EnumAttr enumAttr(enumDef);
)";
os << formatv(symbolizeEnumStr, enumName, strToSymFnName);
+ const char *const attrClassDecl = R"(
+class {1} : public ::mlir::{2} {
+public:
+ using ValueType = {0};
+ using ::mlir::{2}::{2};
+ static bool classof(::mlir::Attribute attr);
+ static {1} get(::mlir::MLIRContext *context, {0} val);
+ {0} getValue() const;
+};
+)";
+ if (enumAttr.genSpecializedAttr()) {
+ StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
+ StringRef baseAttrClassName =
+ enumAttr.isSubClassOf("StrEnumAttr") ? "StringAttr" : "IntegerAttr";
+ os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
+ }
+
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
emitUnderlyingToSymFnForIntEnum(enumDef, os);
}
+ if (enumAttr.genSpecializedAttr())
+ emitSpecializedAttrDef(enumDef, os);
+
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
os << "\n";
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LLVM.h"
+
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
+
#include "gmock/gmock.h"
+
#include <type_traits>
/// Pull in generated enum utility declarations and definitions.
#include "EnumsGenTest.h.inc"
+
#include "EnumsGenTest.cpp.inc"
/// Test namespaces and enum class/utility names.
using Outer::Inner::ConvertToEnum;
using Outer::Inner::ConvertToString;
using Outer::Inner::StrEnum;
+using Outer::Inner::StrEnumAttr;
TEST(EnumsGenTest, GeneratedStrEnumDefinition) {
EXPECT_EQ(0u, static_cast<uint64_t>(StrEnum::CaseA));
auto none = symbolizePrettyIntEnum("Case1");
EXPECT_FALSE(none);
}
+
+TEST(EnumsGenTest, GeneratedIntAttributeClass) {
+ mlir::MLIRContext ctx;
+ I32Enum rawVal = I32Enum::Case5;
+
+ I32EnumAttr enumAttr = I32EnumAttr::get(&ctx, rawVal);
+ EXPECT_NE(enumAttr, nullptr);
+ EXPECT_EQ(enumAttr.getValue(), rawVal);
+
+ mlir::Type intType = mlir::IntegerType::get(&ctx, 32);
+ mlir::Attribute intAttr = mlir::IntegerAttr::get(intType, 5);
+ EXPECT_TRUE(intAttr.isa<I32EnumAttr>());
+ EXPECT_EQ(intAttr, enumAttr);
+}
+
+TEST(EnumsGenTest, GeneratedStringAttributeClass) {
+ mlir::MLIRContext ctx;
+ StrEnum rawVal = StrEnum::CaseA;
+
+ StrEnumAttr enumAttr = StrEnumAttr::get(&ctx, rawVal);
+ EXPECT_NE(enumAttr, nullptr);
+ EXPECT_EQ(enumAttr.getValue(), rawVal);
+
+ mlir::Attribute strAttr = mlir::StringAttr::get(&ctx, "CaseA");
+ EXPECT_TRUE(strAttr.isa<StrEnumAttr>());
+ EXPECT_EQ(strAttr, enumAttr);
+}
+
+TEST(EnumsGenTest, GeneratedBitAttributeClass) {
+ mlir::MLIRContext ctx;
+
+ mlir::Type intType = mlir::IntegerType::get(&ctx, 32);
+ mlir::Attribute intAttr = mlir::IntegerAttr::get(
+ intType,
+ static_cast<uint32_t>(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3));
+ EXPECT_TRUE(intAttr.isa<BitEnumWithNoneAttr>());
+ EXPECT_TRUE(intAttr.isa<BitEnumWithoutNoneAttr>());
+}
let cppNamespace = "Outer::Inner";
let stringToSymbolFnName = "ConvertToEnum";
let symbolToStringFnName = "ConvertToString";
+ let genSpecializedAttr = 1;
}
def Case5: I32EnumAttrCase<"Case5", 5>;