[mlir][ODS] Support specialized Attribute class for Enums
authorVladislav Vinogradov <vlad.vinogradov@intel.com>
Sat, 27 Feb 2021 12:21:00 +0000 (15:21 +0300)
committerVladislav Vinogradov <vlad.vinogradov@intel.com>
Wed, 17 Mar 2021 13:44:24 +0000 (16:44 +0300)
Add a feature to `EnumAttr` definition to generate
specialized Attribute class for the particular enumeration.

This class will inherit `StringAttr` or `IntegerAttr` and
will override `classof` and `getValue` methods.

With this class the enumeration predicate can be checked with simple
RTTI calls (`isa`, `dyn_cast`) and it will return the typed enumeration
directly instead of raw string/integer.

Based on the following discussion:
https://llvm.discourse.group/t/rfc-add-enum-attribute-decorator-class/2252

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D97836

20 files changed:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Attribute.h
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVEnums.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/TableGen/Attribute.cpp
mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/tools/mlir-tblgen/EnumsGen.cpp
mlir/unittests/TableGen/EnumsGenTest.cpp
mlir/unittests/TableGen/enums.td

index ddc0ed3..3623565 100644 (file)
@@ -200,7 +200,7 @@ def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> {
     OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs),
     [{
       build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
-            $_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
+            predicate, lhs, rhs);
     }]>];
   let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
   let printer = [{ printICmpOp(p, *this); }];
@@ -246,14 +246,6 @@ def LLVM_FCmpOp : LLVM_Op<"fcmp", [
   let llvmBuilder = [{
     $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
   }];
-  let builders = [
-    OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs,
-                  CArg<"FastmathFlags", "{}">:$fmf),
-    [{
-      build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
-            $_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs,
-            ::mlir::LLVM::FMFAttr::get($_builder.getContext(), fmf));
-    }]>];
   let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
   let printer = [{ printFCmpOp(p, *this); }];
 }
index e5838ef..ac128ac 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVENUMS_H_
 #define MLIR_DIALECT_SPIRV_IR_SPIRVENUMS_H_
 
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/StringRef.h"
index 0cffdb5..77fa63f 100644 (file)
@@ -184,7 +184,7 @@ def SPV_LoadOp : SPV_Op<"Load", []> {
 
   let builders = [
     OpBuilder<(ins "Value":$basePtr,
-      CArg<"IntegerAttr", "{}">:$memory_access,
+      CArg<"MemoryAccessAttr", "{}">:$memory_access,
       CArg<"IntegerAttr", "{}">:$alignment)>
   ];
 }
index 69fb073..f26a771 100644 (file)
@@ -53,6 +53,7 @@ def CombiningKind : BitEnumAttr<
      COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR,
      COMBINING_KIND_XOR]> {
   let cppNamespace = "::mlir::vector";
+  let genSpecializedAttr = 0;
 }
 
 def Vector_CombiningKindAttr : DialectAttr<
index 268056d..bdae05f 100644 (file)
@@ -1142,7 +1142,9 @@ class BitEnumAttrCase<string sym, int val, string str = sym> :
 }
 
 // Additional information for an enum attribute.
-class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
+class EnumAttrInfo<
+    string name, list<EnumAttrCaseInfo> cases, Attr baseClass> :
+      Attr<baseClass.predicate, baseClass.summary> {
   // The C++ enum class name
   string className = name;
 
@@ -1188,6 +1190,28 @@ class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
   // static constexpr unsigned <fn-name>();
   // ```
   string maxEnumValFnName = "getMaxEnumValFor" # name;
+
+  // Generate specialized Attribute class
+  bit genSpecializedAttr = 1;
+  // The underlying Attribute class, which holds the enum value
+  Attr baseAttrClass = baseClass;
+  // The name of specialized Enum Attribute class
+  string specializedAttrClassName = name # Attr;
+
+  // Override Attr class fields for specialized class
+  let predicate = !if(genSpecializedAttr,
+    CPred<"$_self.isa<" # cppNamespace # "::" # specializedAttrClassName # ">()">,
+    baseAttrClass.predicate);
+  let storageType = !if(genSpecializedAttr,
+    cppNamespace # "::" # specializedAttrClassName,
+    baseAttrClass.storageType);
+  let returnType = !if(genSpecializedAttr,
+    cppNamespace # "::" # className,
+    baseAttrClass.returnType);
+  let constBuilderCall = !if(genSpecializedAttr,
+    cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)",
+    baseAttrClass.constBuilderCall);
+  let valueType = baseAttrClass.valueType;
 }
 
 // An enum attribute backed by StringAttr.
@@ -1195,47 +1219,44 @@ class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
 // Op attributes of this kind are stored as StringAttr. Extra verification will
 // be generated on the string though: only the symbols of the allowed cases are
 // permitted as the string value.
-class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases>
-  : EnumAttrInfo<name, cases>,
+class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases> :
+  EnumAttrInfo<name, cases,
     StringBasedAttr<
       And<[StrAttr.predicate, Or<!foreach(case, cases, case.predicate)>]>,
       !if(!empty(summary), "allowed string cases: " #
           !interleave(!foreach(case, cases, "'" # case.symbol # "'"), ", "),
-          summary)>;
+          summary)>> {
+  // Disable specialized Attribute class for `StringAttr` backend by default.
+  let genSpecializedAttr = 0;
+}
 
 // An enum attribute backed by IntegerAttr.
 //
 // Op attributes of this kind are stored as IntegerAttr. Extra verification will
 // be generated on the integer though: only the values of the allowed cases are
 // permitted as the integer value.
-class IntEnumAttr<I intType, string name, string summary,
-                  list<IntEnumAttrCaseBase> cases> :
-    EnumAttrInfo<name, cases>,
-    SignlessIntegerAttrBase<intType,
-      !if(!empty(summary), "allowed " # intType.summary # " cases: " #
-          !interleave(!foreach(case, cases, case.value), ", "), summary)> {
+class IntEnumAttrBase<I intType, list<IntEnumAttrCaseBase> cases, string summary> :
+    SignlessIntegerAttrBase<intType, summary> {
   let predicate = And<[
-    SignlessIntegerAttrBase<intType, "">.predicate,
+    SignlessIntegerAttrBase<intType, summary>.predicate,
     Or<!foreach(case, cases, case.predicate)>]>;
 }
 
-class I32EnumAttr<string name, string summary,
-                  list<I32EnumAttrCase> cases> :
+class IntEnumAttr<I intType, string name, string summary,
+                  list<IntEnumAttrCaseBase> cases> :
+  EnumAttrInfo<name, cases,
+    IntEnumAttrBase<intType, cases,
+      !if(!empty(summary), "allowed " # intType.summary # " cases: " #
+          !interleave(!foreach(case, cases, case.value), ", "),
+          summary)>>;
+
+class I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases> :
     IntEnumAttr<I32, name, summary, cases> {
-  let returnType = cppNamespace # "::" # name;
   let underlyingType = "uint32_t";
-  let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
-  let constBuilderCall =
-          "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
 }
-class I64EnumAttr<string name, string summary,
-                  list<I64EnumAttrCase> cases> :
+class I64EnumAttr<string name, string summary, list<I64EnumAttrCase> cases> :
     IntEnumAttr<I64, name, summary, cases> {
-  let returnType = cppNamespace # "::" # name;
   let underlyingType = "uint64_t";
-  let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
-  let constBuilderCall =
-          "$_builder.getI64IntegerAttr(static_cast<int64_t>($0))";
 }
 
 // A bit enum stored with 32-bit IntegerAttr.
@@ -1244,9 +1265,8 @@ class I64EnumAttr<string name, string summary,
 // be generated on the integer to make sure only allowed bit are set. Besides,
 // helper methods are generated to parse a string separated with a specified
 // delimiter to a symbol and vice versa.
-class BitEnumAttr<string name, string summary,
-                  list<BitEnumAttrCase> cases> :
-    EnumAttrInfo<name, cases>, SignlessIntegerAttrBase<I32, summary> {
+class BitEnumAttrBase<list<BitEnumAttrCase> cases, string summary> :
+    SignlessIntegerAttrBase<I32, summary> {
   let predicate = And<[
     I32Attr.predicate,
     // Make sure we don't have unknown bit set.
@@ -1254,12 +1274,11 @@ class BitEnumAttr<string name, string summary,
           # !interleave(!foreach(case, cases, case.value # "u"), "|") #
           ")))">
   ]>;
+}
 
-  let returnType = cppNamespace # "::" # name;
+class BitEnumAttr<string name, string summary, list<BitEnumAttrCase> cases> :
+    EnumAttrInfo<name, cases, BitEnumAttrBase<cases, summary>> {
   let underlyingType = "uint32_t";
-  let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
-  let constBuilderCall =
-          "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
 
   // We need to return a string because we may concatenate symbols for multiple
   // bits together.
index dc6c969..a8292a9 100644 (file)
@@ -202,6 +202,10 @@ public:
 
   // Returns all allowed cases for this enum attribute.
   std::vector<EnumAttrCase> getAllCases() const;
+
+  bool genSpecializedAttr() const;
+  llvm::Record *getBaseAttrClass() const;
+  StringRef getSpecializedAttrClassName() const;
 };
 
 class StructFieldAttr {
index 8e7540f..19837fe 100644 (file)
@@ -155,9 +155,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
   // header to merge.
   scf::ForOpAdaptor forOperands(operands);
   auto loc = forOp.getLoc();
-  auto loopControl = rewriter.getI32IntegerAttr(
-      static_cast<uint32_t>(spirv::LoopControl::None));
-  auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
+  auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
   loopOp.addEntryAndMergeBlock();
 
   OpBuilder::InsertionGuard guard(rewriter);
@@ -238,11 +236,9 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
   scf::IfOpAdaptor ifOperands(operands);
   auto loc = ifOp.getLoc();
 
-  // Create `spv.mlir.selection` operation, selection header block and merge
-  // block.
-  auto selectionControl = rewriter.getI32IntegerAttr(
-      static_cast<uint32_t>(spirv::SelectionControl::None));
-  auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
+  // Create `spv.selection` operation, selection header block and merge block.
+  auto selectionOp =
+      rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
   auto *mergeBlock =
       rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
   rewriter.create<spirv::MergeOp>(loc);
index 871f54b..3a139b4 100644 (file)
@@ -826,10 +826,8 @@ public:
       return failure();
 
     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
-        operation, dstType,
-        rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
-        operation.operand1(), operation.operand2(),
-        LLVM::FMFAttr::get(operation.getContext(), {}));
+        operation, dstType, predicate, operation.operand1(),
+        operation.operand2());
     return success();
   }
 };
@@ -849,9 +847,8 @@ public:
       return failure();
 
     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
-        operation, dstType,
-        rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
-        operation.operand1(), operation.operand2());
+        operation, dstType, predicate, operation.operand1(),
+        operation.operand2());
     return success();
   }
 };
index 91e520e..2490f35 100644 (file)
@@ -3069,8 +3069,7 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
 
     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
         cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()),
-        rewriter.getI64IntegerAttr(static_cast<int64_t>(
-            convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
+        convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
         transformed.lhs(), transformed.rhs());
 
     return success();
@@ -3085,12 +3084,10 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
                   ConversionPatternRewriter &rewriter) const override {
     CmpFOpAdaptor transformed(operands);
 
-    auto fmf = LLVM::FMFAttr::get(cmpfOp.getContext(), {});
     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
         cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
-        rewriter.getI64IntegerAttr(static_cast<int64_t>(
-            convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
-        transformed.lhs(), transformed.rhs(), fmf);
+        convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
+        transformed.lhs(), transformed.rhs());
 
     return success();
   }
index ed1b72c..025029a 100644 (file)
@@ -1017,7 +1017,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
                                                    srcBits, dstBits, rewriter);
   Value spvLoadOp = rewriter.create<spirv::LoadOp>(
       loc, dstType, adjustedPtr,
-      loadOp->getAttrOfType<IntegerAttr>(
+      loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
           spirv::attributeName<spirv::MemoryAccess>()),
       loadOp->getAttrOfType<IntegerAttr>("alignment"));
 
index 6ccb59a..b032169 100644 (file)
@@ -36,7 +36,7 @@ ParallelLoopDimMapping getParallelLoopDimMappingAttr(Processor processor,
   MLIRContext *context = map.getContext();
   OpBuilder builder(context);
   return ParallelLoopDimMapping::get(
-      builder.getI64IntegerAttr(static_cast<int32_t>(processor)),
+      ProcessorAttr::get(builder.getContext(), processor),
       AffineMapAttr::get(map), AffineMapAttr::get(bound), context);
 }
 
index a289d9d..a851906 100644 (file)
@@ -12,6 +12,8 @@
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 
+#include "mlir/IR/BuiltinTypes.h"
+
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
index dd2dc3d..21bcfe4 100644 (file)
@@ -1659,7 +1659,7 @@ void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
                                 spirv::FuncOp function,
                                 ArrayRef<Attribute> interfaceVars) {
   build(builder, state,
-        builder.getI32IntegerAttr(static_cast<int32_t>(executionModel)),
+        spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
         builder.getSymbolRefAttr(function),
         builder.getArrayAttr(interfaceVars));
 }
@@ -1721,7 +1721,7 @@ void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
                                    spirv::ExecutionMode executionMode,
                                    ArrayRef<int32_t> params) {
   build(builder, state, builder.getSymbolRefAttr(function),
-        builder.getI32IntegerAttr(static_cast<int32_t>(executionMode)),
+        spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
         builder.getI32ArrayAttr(params));
 }
 
@@ -2243,10 +2243,10 @@ static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
 //===----------------------------------------------------------------------===//
 
 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
-                          Value basePtr, IntegerAttr memory_access,
+                          Value basePtr, MemoryAccessAttr memoryAccess,
                           IntegerAttr alignment) {
   auto ptrType = basePtr.getType().cast<spirv::PointerType>();
-  build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
+  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
         alignment);
 }
 
@@ -2784,9 +2784,8 @@ void spirv::SelectionOp::addMergeBlock() {
 spirv::SelectionOp spirv::SelectionOp::createIfThen(
     Location loc, Value condition,
     function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
-  auto selectionControl = builder.getI32IntegerAttr(
-      static_cast<uint32_t>(spirv::SelectionControl::None));
-  auto selectionOp = builder.create<spirv::SelectionOp>(loc, selectionControl);
+  auto selectionOp =
+      builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
 
   selectionOp.addMergeBlock();
   Block *mergeBlock = selectionOp.getMergeBlock();
index 99d9d8a..3b949b0 100644 (file)
@@ -231,6 +231,18 @@ std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
   return cases;
 }
 
+bool EnumAttr::genSpecializedAttr() const {
+  return def->getValueAsBit("genSpecializedAttr");
+}
+
+llvm::Record *EnumAttr::getBaseAttrClass() const {
+  return def->getValueAsDef("baseAttrClass");
+}
+
+StringRef EnumAttr::getSpecializedAttrClassName() const {
+  return def->getValueAsString("specializedAttrClassName");
+}
+
 StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
   assert(def->isSubClassOf("StructFieldAttr") &&
          "must be subclass of TableGen 'StructFieldAttr' class");
index 06e7f81..6137fee 100644 (file)
@@ -331,7 +331,8 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
     return emitError(unknownLoc,
                      "missing Execution Model specification in OpEntryPoint");
   }
-  auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+  auto execModel = spirv::ExecutionModelAttr::get(
+      context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
   if (wordIndex >= words.size()) {
     return emitError(unknownLoc, "missing <id> in OpEntryPoint");
   }
@@ -383,7 +384,8 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
   if (wordIndex >= words.size()) {
     return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
   }
-  auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+  auto execMode = spirv::ExecutionModeAttr::get(
+      context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
 
   // Get the values
   SmallVector<Attribute, 4> attrListElems;
@@ -417,8 +419,11 @@ Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
     argAttrs.push_back(argAttr);
   }
 
-  opBuilder.create<spirv::ControlBarrierOp>(unknownLoc, argAttrs[0],
-                                            argAttrs[1], argAttrs[2]);
+  opBuilder.create<spirv::ControlBarrierOp>(
+      unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
+      argAttrs[1].cast<spirv::ScopeAttr>(),
+      argAttrs[2].cast<spirv::MemorySemanticsAttr>());
+
   return success();
 }
 
@@ -483,8 +488,9 @@ Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
     argAttrs.push_back(argAttr);
   }
 
-  opBuilder.create<spirv::MemoryBarrierOp>(unknownLoc, argAttrs[0],
-                                           argAttrs[1]);
+  opBuilder.create<spirv::MemoryBarrierOp>(
+      unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
+      argAttrs[1].cast<spirv::MemorySemanticsAttr>());
   return success();
 }
 
index 171d9b7..c54c168 100644 (file)
@@ -1640,7 +1640,7 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
   // merge block so that the newly created SelectionOp will be inserted there.
   OpBuilder builder(&mergeBlock->front());
 
-  auto control = builder.getI32IntegerAttr(selectionControl);
+  auto control = static_cast<spirv::SelectionControl>(selectionControl);
   auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
   selectionOp.addMergeBlock();
 
@@ -1652,7 +1652,7 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
   // merge block so that the newly created LoopOp will be inserted there.
   OpBuilder builder(&mergeBlock->front());
 
-  auto control = builder.getI32IntegerAttr(loopControl);
+  auto control = static_cast<spirv::LoopControl>(loopControl);
   auto loopOp = builder.create<spirv::LoopOp>(location, control);
   loopOp.addEntryAndMergeBlock();
 
index 1968ebd..b8956e4 100644 (file)
@@ -1052,39 +1052,39 @@ def OneResultOp3 : TEST_Op<"one_result3"> {
 }
 
 // Test using multi-result op as a whole
-def : Pat<(ThreeResultOp MultiResultOpKind1),
-          (AnotherThreeResultOp MultiResultOpKind1)>;
+def : Pat<(ThreeResultOp MultiResultOpKind1:$kind),
+          (AnotherThreeResultOp $kind)>;
 
 // Test using multi-result op as a whole for partial replacement
-def : Pattern<(ThreeResultOp MultiResultOpKind2),
-              [(TwoResultOp MultiResultOpKind2),
-               (OneResultOp1 MultiResultOpKind2)]>;
-def : Pattern<(ThreeResultOp MultiResultOpKind3),
-              [(OneResultOp2 MultiResultOpKind3),
-               (AnotherTwoResultOp MultiResultOpKind3)]>;
+def : Pattern<(ThreeResultOp MultiResultOpKind2:$kind),
+              [(TwoResultOp $kind),
+               (OneResultOp1 $kind)]>;
+def : Pattern<(ThreeResultOp MultiResultOpKind3:$kind),
+              [(OneResultOp2 $kind),
+               (AnotherTwoResultOp $kind)]>;
 
 // Test using results separately in a multi-result op
-def : Pattern<(ThreeResultOp MultiResultOpKind4),
-              [(TwoResultOp:$res1__0 MultiResultOpKind4),
-               (OneResultOp1 MultiResultOpKind4),
-               (TwoResultOp:$res2__1 MultiResultOpKind4)]>;
+def : Pattern<(ThreeResultOp MultiResultOpKind4:$kind),
+              [(TwoResultOp:$res1__0 $kind),
+               (OneResultOp1 $kind),
+               (TwoResultOp:$res2__1 $kind)]>;
 
 // Test referencing a single value in the value pack
 // This rule only matches TwoResultOp if its second result has no use.
-def : Pattern<(TwoResultOp:$res MultiResultOpKind5),
-              [(OneResultOp2 MultiResultOpKind5),
-               (OneResultOp1 MultiResultOpKind5)],
+def : Pattern<(TwoResultOp:$res MultiResultOpKind5:$kind),
+              [(OneResultOp2 $kind),
+               (OneResultOp1 $kind)],
               [(HasNoUseOf:$res__1)]>;
 
 // Test using auxiliary ops for replacing multi-result op
 def : Pattern<
-    (ThreeResultOp MultiResultOpKind6), [
+    (ThreeResultOp MultiResultOpKind6:$kind), [
         // Auxiliary op generated to help building the final result but not
         // directly used to replace the source op's results.
-        (TwoResultOp:$interm MultiResultOpKind6),
+        (TwoResultOp:$interm $kind),
 
         (OneResultOp3 $interm__1),
-        (AnotherTwoResultOp MultiResultOpKind6)
+        (AnotherTwoResultOp $kind)
     ]>;
 
 //===----------------------------------------------------------------------===//
index e207e31..aa8841a 100644 (file)
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 
 using llvm::formatv;
 using llvm::isDigit;
+using llvm::PrintFatalError;
 using llvm::raw_ostream;
 using llvm::Record;
 using llvm::RecordKeeper;
 using llvm::StringRef;
+using mlir::tblgen::Attribute;
 using mlir::tblgen::EnumAttr;
 using mlir::tblgen::EnumAttrCase;
+using mlir::tblgen::FmtContext;
+using mlir::tblgen::tgfmt;
 
 static std::string makeIdentifier(StringRef str) {
   if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
@@ -303,6 +308,78 @@ static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
      << "}\n\n";
 }
 
+static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
+  EnumAttr enumAttr(enumDef);
+  StringRef enumName = enumAttr.getEnumClassName();
+  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
+  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
+  StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
+  llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
+  Attribute baseAttr(baseAttrDef);
+
+  // Emit classof method
+
+  os << formatv("bool {0}::classof(::mlir::Attribute attr) {{\n",
+                attrClassName);
+
+  mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate();
+  if (baseAttrPred.isNull())
+    PrintFatalError("ERROR: baseAttrClass for EnumAttr has no Predicate\n");
+
+  std::string condition = baseAttrPred.getCondition();
+  FmtContext verifyCtx;
+  verifyCtx.withSelf("attr");
+  os << tgfmt("  return $0;\n", /*ctx=*/nullptr, tgfmt(condition, &verifyCtx));
+
+  os << "}\n";
+
+  // Emit get method
+
+  os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
+                attrClassName, enumName);
+
+  if (enumAttr.isSubClassOf("StrEnumAttr")) {
+    os << formatv("  ::mlir::StringAttr baseAttr = "
+                  "::mlir::StringAttr::get(context, {0}(val));\n",
+                  symToStrFnName);
+  } else {
+    StringRef underlyingType = enumAttr.getUnderlyingType();
+
+    // Assuming that it is IntegerAttr constraint
+    int64_t bitwidth = 64;
+    if (baseAttrDef->getValue("valueType")) {
+      auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
+      if (valueTypeDef->getValue("bitwidth"))
+        bitwidth = valueTypeDef->getValueAsInt("bitwidth");
+    }
+
+    os << formatv("  ::mlir::IntegerType intType = "
+                  "::mlir::IntegerType::get(context, {0});\n",
+                  bitwidth);
+    os << formatv("  ::mlir::IntegerAttr baseAttr = "
+                  "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
+                  underlyingType);
+  }
+  os << formatv("  return baseAttr.cast<{0}>();\n", attrClassName);
+
+  os << "}\n";
+
+  // Emit getValue method
+
+  os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
+
+  if (enumAttr.isSubClassOf("StrEnumAttr")) {
+    os << formatv("  const auto res = {0}(::mlir::StringAttr::getValue());\n",
+                  strToSymFnName);
+    os << "  return res.getValue();\n";
+  } else {
+    os << formatv("  return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
+                  enumName);
+  }
+
+  os << "}\n";
+}
+
 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
                                             raw_ostream &os) {
   EnumAttr enumAttr(enumDef);
@@ -391,6 +468,23 @@ inline ::llvm::Optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) {
 )";
   os << formatv(symbolizeEnumStr, enumName, strToSymFnName);
 
+  const char *const attrClassDecl = R"(
+class {1} : public ::mlir::{2} {
+public:
+  using ValueType = {0};
+  using ::mlir::{2}::{2};
+  static bool classof(::mlir::Attribute attr);
+  static {1} get(::mlir::MLIRContext *context, {0} val);
+  {0} getValue() const;
+};
+)";
+  if (enumAttr.genSpecializedAttr()) {
+    StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
+    StringRef baseAttrClassName =
+        enumAttr.isSubClassOf("StrEnumAttr") ? "StringAttr" : "IntegerAttr";
+    os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
+  }
+
   for (auto ns : llvm::reverse(namespaces))
     os << "} // namespace " << ns << "\n";
 
@@ -428,6 +522,9 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
     emitUnderlyingToSymFnForIntEnum(enumDef, os);
   }
 
+  if (enumAttr.genSpecializedAttr())
+    emitSpecializedAttrDef(enumDef, os);
+
   for (auto ns : llvm::reverse(namespaces))
     os << "} // namespace " << ns << "\n";
   os << "\n";
index a558019..a873658 100644 (file)
@@ -6,21 +6,29 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/Support/LLVM.h"
+
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
+
 #include "gmock/gmock.h"
+
 #include <type_traits>
 
 /// Pull in generated enum utility declarations and definitions.
 #include "EnumsGenTest.h.inc"
+
 #include "EnumsGenTest.cpp.inc"
 
 /// Test namespaces and enum class/utility names.
 using Outer::Inner::ConvertToEnum;
 using Outer::Inner::ConvertToString;
 using Outer::Inner::StrEnum;
+using Outer::Inner::StrEnumAttr;
 
 TEST(EnumsGenTest, GeneratedStrEnumDefinition) {
   EXPECT_EQ(0u, static_cast<uint64_t>(StrEnum::CaseA));
@@ -110,3 +118,41 @@ TEST(EnumsGenTest, GeneratedCustomStringToSymbolFn) {
   auto none = symbolizePrettyIntEnum("Case1");
   EXPECT_FALSE(none);
 }
+
+TEST(EnumsGenTest, GeneratedIntAttributeClass) {
+  mlir::MLIRContext ctx;
+  I32Enum rawVal = I32Enum::Case5;
+
+  I32EnumAttr enumAttr = I32EnumAttr::get(&ctx, rawVal);
+  EXPECT_NE(enumAttr, nullptr);
+  EXPECT_EQ(enumAttr.getValue(), rawVal);
+
+  mlir::Type intType = mlir::IntegerType::get(&ctx, 32);
+  mlir::Attribute intAttr = mlir::IntegerAttr::get(intType, 5);
+  EXPECT_TRUE(intAttr.isa<I32EnumAttr>());
+  EXPECT_EQ(intAttr, enumAttr);
+}
+
+TEST(EnumsGenTest, GeneratedStringAttributeClass) {
+  mlir::MLIRContext ctx;
+  StrEnum rawVal = StrEnum::CaseA;
+
+  StrEnumAttr enumAttr = StrEnumAttr::get(&ctx, rawVal);
+  EXPECT_NE(enumAttr, nullptr);
+  EXPECT_EQ(enumAttr.getValue(), rawVal);
+
+  mlir::Attribute strAttr = mlir::StringAttr::get(&ctx, "CaseA");
+  EXPECT_TRUE(strAttr.isa<StrEnumAttr>());
+  EXPECT_EQ(strAttr, enumAttr);
+}
+
+TEST(EnumsGenTest, GeneratedBitAttributeClass) {
+  mlir::MLIRContext ctx;
+
+  mlir::Type intType = mlir::IntegerType::get(&ctx, 32);
+  mlir::Attribute intAttr = mlir::IntegerAttr::get(
+      intType,
+      static_cast<uint32_t>(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3));
+  EXPECT_TRUE(intAttr.isa<BitEnumWithNoneAttr>());
+  EXPECT_TRUE(intAttr.isa<BitEnumWithoutNoneAttr>());
+}
index b2c8f6f..cdcc182 100644 (file)
@@ -15,6 +15,7 @@ def StrEnum: StrEnumAttr<"StrEnum", "A test enum", [CaseA, CaseB]> {
   let cppNamespace = "Outer::Inner";
   let stringToSymbolFnName = "ConvertToEnum";
   let symbolToStringFnName = "ConvertToString";
+  let genSpecializedAttr = 1;
 }
 
 def Case5: I32EnumAttrCase<"Case5", 5>;