[mlir][linalg] Replace "string" iterator_types attr with enums in LinalgInterface.
authorOleg Shyshkov <shyshkov@google.com>
Wed, 9 Nov 2022 14:42:40 +0000 (15:42 +0100)
committerOleg Shyshkov <shyshkov@google.com>
Wed, 9 Nov 2022 14:47:29 +0000 (15:47 +0100)
[RFC: EnumAttr for iterator types in Linalg](https://discourse.llvm.org/t/rfc-enumattr-for-iterator-types-in-linalg/64535)

This affect touches and probably breaks most of the code that creates `linalg.generic`. A fix would be to replace calls to `getParallelIteratorTypeName/getReductionIteratorTypeName` with `mlir::utils::IteratorType::parallel/reduction` and types from `StringRef` to `mlir::utils::IteratorType`.

Due to limitations of tablegen, shared C++ definition of IteratorType enum lives in StructuredOpsUtils.td, but each dialect should have it's own EnumAttr wrapper. To avoid conflict, all enums in a dialect are put into a separate file with a separate tablegen rule.

Test dialect td files are refactored a bit.

Printed format of `linalg.generic` temporarily remains unchanged to avoid breaking code and tests in the same change.

Differential Revision: https://reviews.llvm.org/D137658

33 files changed:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Linalg/conv-interface-invalid.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/transform-op-match.mlir
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestAttributes.h
mlir/test/lib/Dialect/Test/TestEnumDefs.td [new file with mode: 0644]
mlir/test/lib/Dialect/Test/TestOps.td
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

index 36cf41d..1ee9d2a 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef LINALG_BASE
 #define LINALG_BASE
 
+include "mlir/Dialect/Utils/StructuredOpsUtils.td"
 include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
@@ -71,4 +72,10 @@ def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
+def IteratorTypeEnum : EnumAttr<Linalg_Dialect, IteratorType, "iterator_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
+  "Iterator type should be an enum.">;
+
 #endif // LINALG_BASE
index 8e3df10..45d63d8 100644 (file)
@@ -25,6 +25,7 @@
 
 namespace mlir {
 namespace linalg {
+class IteratorTypeAttr;
 class LinalgOp;
 
 namespace detail {
index 533a52f..17f6bf4 100644 (file)
@@ -193,8 +193,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return getNumIterators(getParallelIteratorTypeName(),
-                                $_op.getIteratorTypesArray());
+        return llvm::count($_op.getIteratorTypesArray(),
+                           utils::IteratorType::parallel);
       }]
     >,
     InterfaceMethod<
@@ -207,7 +207,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return findPositionsOfType($_op.getIteratorTypesArray(),
-                                   getParallelIteratorTypeName(), res);
+                                   utils::IteratorType::parallel, res);
       }]
     >,
     InterfaceMethod<
@@ -219,8 +219,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return getNumIterators(getReductionIteratorTypeName(),
-                               $_op.getIteratorTypesArray());
+        return llvm::count($_op.getIteratorTypesArray(),
+                           utils::IteratorType::reduction);
       }]
     >,
     InterfaceMethod<
@@ -233,33 +233,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return findPositionsOfType($_op.getIteratorTypesArray(),
-                                   getReductionIteratorTypeName(), res);
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the number of window loops.
-      }],
-      /*retTy=*/"unsigned",
-      /*methodName=*/"getNumWindowLoops",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return getNumIterators(getWindowIteratorTypeName(),
-                               $_op.getIteratorTypesArray());
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the dims that are window loops.
-      }],
-      /*retTy=*/"void",
-      /*methodName=*/"getWindowDims",
-      /*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return findPositionsOfType($_op.getIteratorTypesArray(),
-                                   getWindowIteratorTypeName(), res);
+                                   utils::IteratorType::reduction, res);
       }]
     >,
     InterfaceMethod<
@@ -271,7 +245,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return getNumIterators($_op.getIteratorTypesArray());
+        return $_op.getIteratorTypesArray().size();
       }]
     >,
     InterfaceMethod<
@@ -286,7 +260,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*defaultImplementation=*/[{
         auto iters = $_op.getIteratorTypesArray();
         return iters.size() == 1 &&
-               getNumIterators(getReductionIteratorTypeName(), iters) == 1;
+               llvm::count(iters, utils::IteratorType::reduction) == 1;
       }]>,
     //===------------------------------------------------------------------===//
     // Input and Init arguments handling.
@@ -506,12 +480,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         can be infered from other parameters and in such cases default
         getIteratorTypesArray should be overriden.
       }],
-      /*retTy=*/"SmallVector<StringRef>",
+      /*retTy=*/"SmallVector<utils::IteratorType>",
       /*methodName=*/"getIteratorTypesArray",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        auto range = $_op.getIteratorTypes().template getAsValueRange<StringAttr>();
+        auto range = $_op.getIteratorTypes()
+                         .template getAsValueRange<IteratorTypeAttr,
+                                                   utils::IteratorType>();
         return {range.begin(), range.end()};
       }]
     >,
@@ -767,10 +743,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     LogicalResult reifyResultShapes(OpBuilder &b,
         ReifiedRankedShapedTypeDims &reifiedReturnShapes);
 
-    SmallVector<StringRef> getIteratorTypeNames() {
-      return getIteratorTypesArray();
-    }
-
     //========================================================================//
     // Forwarding functions to access interface methods from the
     // DestinationStyleOpInterface.
index 9866620..e822435 100644 (file)
@@ -163,7 +163,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
   let arguments = (ins Variadic<AnyType>:$inputs,
                        Variadic<AnyShaped>:$outputs,
                        AffineMapArrayAttr:$indexing_maps,
-                       ArrayAttr:$iterator_types,
+                       IteratorTypeArrayAttr:$iterator_types,
                        OptionalAttr<StrAttr>:$doc,
                        OptionalAttr<StrAttr>:$library_call);
   let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -178,22 +178,22 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
       "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
-      "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
+      "ArrayRef<utils::IteratorType>":$iteratorTypes, "StringRef":$doc,
       "StringRef":$libraryCall,
       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
-      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
+      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
       "StringRef":$doc, "StringRef":$libraryCall,
       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
       "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
-      "ArrayRef<StringRef>":$iteratorTypes,
+      "ArrayRef<utils::IteratorType>":$iteratorTypes,
       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
-      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
+      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
   ];
@@ -275,7 +275,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Implement functions necessary for LinalgStructuredInterface.
-    SmallVector<StringRef> getIteratorTypesArray();
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
     ArrayAttr getIndexingMaps();
     std::string getLibraryCallName() {
       return "op_has_no_registered_library_name";
@@ -356,7 +356,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Declare functions necessary for LinalgStructuredInterface.
-    SmallVector<StringRef> getIteratorTypesArray();
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
     ArrayAttr getIndexingMaps();
     std::string getLibraryCallName() {
       return "op_has_no_registered_library_name";
@@ -426,7 +426,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Declare functions necessary for LinalgStructuredInterface.
-    SmallVector<StringRef> getIteratorTypesArray();
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
     ArrayAttr getIndexingMaps();
     std::string getLibraryCallName() {
       return "op_has_no_registered_library_name";
@@ -502,7 +502,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Declare functions necessary for LinalgStructuredInterface.
-    SmallVector<StringRef> getIteratorTypesArray();
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
     ArrayAttr getIndexingMaps();
     std::string getLibraryCallName() {
       return "op_has_no_registered_library_name";
index 5fc7938..4f9dd71 100644 (file)
@@ -42,10 +42,10 @@ bool hasOnlyScalarElementwiseOp(Region &r);
 bool isElementwise(LinalgOp op);
 
 /// Check if iterator type has "parallel" semantics.
-bool isParallelIterator(StringRef iteratorType);
+bool isParallelIterator(utils::IteratorType iteratorType);
 
 /// Check if iterator type  has "reduction" semantics.
-bool isReductionIterator(StringRef iteratorType);
+bool isReductionIterator(utils::IteratorType iteratorType);
 
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
@@ -480,7 +480,8 @@ struct RegionMatcher {
 template <typename LoopTy>
 struct GenerateLoopNest {
   static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
-                   LinalgOp linalgOp, ArrayRef<StringRef> iteratorTypes,
+                   LinalgOp linalgOp,
+                   ArrayRef<utils::IteratorType> iteratorTypes,
                    function_ref<scf::ValueVector(OpBuilder &, Location,
                                                  ValueRange, ValueRange)>
                        bodyBuilderFn,
index b2b7b24..6b2104f 100644 (file)
@@ -22,7 +22,8 @@ namespace mlir {
 namespace tosa {
 
 // Creates a SmallVector of Stringrefs for N parallel loops
-SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);
+SmallVector<utils::IteratorType>
+getNParallelLoopsAttrs(unsigned nParallelLoops);
 
 // Takes a vector of values and condenses them to a vector with no gaps.
 SmallVector<Value> condenseValues(const SmallVector<Value> &values);
index 6fcfcb1..cb509fe 100644 (file)
@@ -21,7 +21,6 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/Support/LLVM.h"
-#include "llvm/ADT/StringRef.h"
 
 // Pull in all enum type definitions and utility function declarations.
 #include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc"
@@ -48,42 +47,9 @@ bool isColumnMajorMatmul(ArrayAttr indexingMaps);
 /// the reduction.
 bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
 
-/// Use to encode that a particular iterator type has parallel semantics.
-constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
-
-/// Use to encode that a particular iterator type has reduction semantics.
-constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
-
-/// Use to encode that a particular iterator type has window semantics.
-constexpr StringRef getWindowIteratorTypeName() { return "window"; }
-
-/// Use to encode that a particular iterator type has window semantics.
-inline ArrayRef<StringRef> getAllIteratorTypeNames() {
-  static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
-                                         getReductionIteratorTypeName(),
-                                         getWindowIteratorTypeName()};
-  return llvm::makeArrayRef(names);
-}
-
-/// Returns the iterator of a certain type.
-inline unsigned getNumIterators(StringRef name,
-                                ArrayRef<StringRef> iteratorTypes) {
-  auto names = getAllIteratorTypeNames();
-  (void)names;
-  assert(llvm::is_contained(names, name));
-  return llvm::count(iteratorTypes, name);
-}
-
-inline unsigned getNumIterators(ArrayRef<StringRef> iteratorTypes) {
-  unsigned res = 0;
-  for (auto n : getAllIteratorTypeNames())
-    res += getNumIterators(n, iteratorTypes);
-  return res;
-}
-
 /// Return positions in `iteratorTypes` that match `iteratorTypeName`.
-inline void findPositionsOfType(ArrayRef<StringRef> iteratorTypes,
-                                StringRef iteratorTypeName,
+inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
+                                utils::IteratorType iteratorTypeName,
                                 SmallVectorImpl<unsigned> &res) {
   for (const auto &en : llvm::enumerate(iteratorTypes)) {
     if (en.value() == iteratorTypeName)
@@ -94,29 +60,28 @@ inline void findPositionsOfType(ArrayRef<StringRef> iteratorTypes,
 /// Helper StructuredGenerator class to manipulate and rewrite ops with
 /// `StructuredOpInterface`. This is templated for now because VectorOps do not
 /// yet implement the StructuredOpInterface itself.
-template <typename StructuredOpInterface>
+template <typename StructuredOpInterface, typename IteratorTypeT>
 class StructuredGenerator {
 public:
   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
 
   struct IteratorType {
-    IteratorType(StringRef strRef) : strRef(strRef) {}
-    bool isOfType(StringRef typeName) const { return typeName == strRef; }
-    StringRef strRef;
+    IteratorType(IteratorTypeT iter) : iter(iter) {}
+    bool isOfType(IteratorTypeT expectedIter) const {
+      return expectedIter == iter;
+    }
+    IteratorTypeT iter;
   };
   struct Par : public IteratorType {
-    Par() : IteratorType(getParallelIteratorTypeName()) {}
+    Par() : IteratorType(IteratorTypeT::parallel) {}
   };
   struct Red : public IteratorType {
-    Red() : IteratorType(getReductionIteratorTypeName()) {}
-  };
-  struct Win : public IteratorType {
-    Win() : IteratorType(getWindowIteratorTypeName()) {}
+    Red() : IteratorType(IteratorTypeT::reduction) {}
   };
 
   StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
       : builder(builder), ctx(op.getContext()), loc(op.getLoc()),
-        iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()),
+        iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()),
         op(op) {}
 
   bool iters(ArrayRef<IteratorType> its) {
@@ -138,7 +103,7 @@ protected:
   OpBuilder &builder;
   MLIRContext *ctx;
   Location loc;
-  SmallVector<StringRef> iterators;
+  SmallVector<IteratorTypeT> iterators;
   SmallVector<AffineMap, 4> maps;
   Operation *op;
 };
index 758e7c1..5060d8c 100644 (file)
@@ -269,12 +269,11 @@ def Vector_ContractionOp :
       return CombiningKind::ADD;
     }
 
-    // Returns iterator types in string format.
-    SmallVector<StringRef> getIteratorTypeNames() {
-      return llvm::to_vector(
-          llvm::map_range(getIteratorTypes(), [](Attribute a) {
-            return stringifyIteratorType(a.cast<IteratorTypeAttr>().getValue());
-          }));
+    SmallVector<IteratorType> getIteratorTypesArray() {
+      auto range =
+          getIteratorTypes()
+              .template getAsValueRange<IteratorTypeAttr, IteratorType>();
+      return {range.begin(), range.end()};
     }
   }];
 
index 04cd00f..f56162b 100644 (file)
@@ -791,12 +791,12 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
 
   SmallVector<AffineExpr, 2> srcExprs;
   SmallVector<AffineExpr, 2> dstExprs;
-  SmallVector<StringRef, 4> iteratorTypes;
+  SmallVector<utils::IteratorType, 4> iteratorTypes;
   for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
     srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
 
-    iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
-                                      : getParallelIteratorTypeName());
+    iteratorTypes.push_back(axis == i ? utils::IteratorType::reduction
+                                      : utils::IteratorType::parallel);
     if (axis != i)
       dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
   }
@@ -1383,7 +1383,8 @@ public:
     auto inputMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
                                    inputExprs, builder.getContext());
     auto resultMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
-    SmallVector<StringRef> iterators(4, getParallelIteratorTypeName());
+    SmallVector<utils::IteratorType> iterators(4,
+                                               utils::IteratorType::parallel);
 
     Value empty = builder.create<tensor::EmptyOp>(
         resultTy.getShape(), resultTy.getElementType(), outputDynSize);
@@ -2083,9 +2084,9 @@ public:
 
     // We need to reduce along the arg-max axis, with parallel operations along
     // the rest.
-    SmallVector<StringRef, 4> iteratorTypes;
-    iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
-    iteratorTypes[axis] = getReductionIteratorTypeName();
+    SmallVector<utils::IteratorType, 4> iteratorTypes;
+    iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
+    iteratorTypes[axis] = utils::IteratorType::reduction;
 
     SmallVector<AffineExpr, 2> srcExprs;
     SmallVector<AffineExpr, 2> dstExprs;
index 78a29f4..f812621 100644 (file)
@@ -321,7 +321,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
         !filterDims.count(outputDim)) {
       // Batch dimension.
-      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
+      if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -329,7 +329,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.convolvedDims.count(outputDim) &&
         !filterDims.count(outputDim)) {
       // Output image Loop dimension.
-      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
+      if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -338,7 +338,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
         !inputExprWalker.unConvolvedDims.count(outputDim) &&
         filterDims.count(outputDim)) {
       // Output channel dimension.
-      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
+      if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -346,7 +346,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
         filterDims.count(outputDim)) {
       // Depth multiplier.
-      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
+      if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -364,7 +364,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.convolvedDims.count(filterDim) &&
         !outputDims.count(filterDim)) {
       // Filter loop dimension.
-      if (iteratorTypes[filterDim] != getReductionIteratorTypeName())
+      if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
         return MatchConvolutionResult::NonOutputDimNotReduction;
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
@@ -374,7 +374,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
         !outputDims.count(filterDim)) {
       // Input channel dimension.
-      if (iteratorTypes[filterDim] != getReductionIteratorTypeName())
+      if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
         return MatchConvolutionResult::NonOutputDimNotReduction;
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
@@ -619,15 +619,6 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
 
-  // Check all iterator types are known.
-  auto iteratorTypesRange = linalgOp.getIteratorTypesArray();
-  for (StringRef iteratorType : iteratorTypesRange) {
-    if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType) ||
-        !utils::symbolizeIteratorType(iteratorType).has_value())
-      return op->emitOpError("unexpected iterator_type (")
-             << iteratorType << ")";
-  }
-
   // Before checking indexing maps, we need to make sure the attributes
   // referenced by it are valid.
   if (linalgOp.hasDynamicIndexingMaps())
index 8ce1ad0..52ef33b 100644 (file)
@@ -705,12 +705,17 @@ void GenericOp::build(
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
+    ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
+    StringRef libraryCall,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
     ArrayRef<NamedAttribute> attributes) {
   build(builder, result, resultTensorTypes, inputs, outputs,
         builder.getAffineMapArrayAttr(indexingMaps),
-        builder.getStrArrayAttr(iteratorTypes),
+        builder.getArrayAttr(llvm::to_vector(llvm::map_range(
+            iteratorTypes,
+            [&](utils::IteratorType iter) -> mlir::Attribute {
+              return IteratorTypeAttr::get(builder.getContext(), iter);
+            }))),
         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
         libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
         bodyBuild, attributes);
@@ -719,7 +724,8 @@ void GenericOp::build(
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
+    ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
+    StringRef libraryCall,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
     ArrayRef<NamedAttribute> attributes) {
   build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
@@ -729,7 +735,7 @@ void GenericOp::build(
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<StringRef> iteratorTypes,
+    ArrayRef<utils::IteratorType> iteratorTypes,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
     ArrayRef<NamedAttribute> attributes) {
   build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
@@ -740,7 +746,7 @@ void GenericOp::build(
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<StringRef> iteratorTypes,
+    ArrayRef<utils::IteratorType> iteratorTypes,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
     ArrayRef<NamedAttribute> attributes) {
   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
@@ -758,9 +764,29 @@ void GenericOp::print(OpAsmPrinter &p) {
   llvm::StringSet<> genericAttrNamesSet;
   genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
   SmallVector<NamedAttribute, 8> genericAttrs;
-  for (auto attr : (*this)->getAttrs())
-    if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
+  for (auto attr : (*this)->getAttrs()) {
+    if (attr.getName() == getIteratorTypesAttrName()) {
+      auto iteratorTypes =
+          attr.getValue()
+              .cast<ArrayAttr>()
+              .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
+      // Convert IteratorType enums into the string representation. This is
+      // needed, because tests still use the old format when 'iterator_types'
+      // attribute is represented as an array of strings.
+      // TODO: Remove this conversion once tests are fixed.
+      SmallVector<Attribute> iteratorTypeNames =
+          llvm::to_vector(llvm::map_range(
+              iteratorTypes, [&](utils::IteratorType t) -> Attribute {
+                return StringAttr::get(getContext(), stringifyIteratorType(t));
+              }));
+
+      genericAttrs.emplace_back(
+          getIteratorTypesAttrName(),
+          ArrayAttr::get(getContext(), iteratorTypeNames));
+    } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
       genericAttrs.push_back(attr);
+    }
+  }
   if (!genericAttrs.empty()) {
     auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
     p << genericDictAttr;
@@ -805,6 +831,28 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
   result.attributes.assign(dictAttr.getValue().begin(),
                            dictAttr.getValue().end());
 
+  // Convert array of string into an array of IteratyType enums. This is needed,
+  // because tests still use the old format when 'iterator_types' attribute is
+  // represented as an array of strings.
+  // TODO: Remove this conversion once tests are fixed.
+  ArrayAttr iteratorTypes =
+      result.attributes.get(getIteratorTypesAttrName(result.name))
+          .cast<ArrayAttr>();
+
+  SmallVector<Attribute> iteratorTypeAttrs;
+
+  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
+    auto maybeIteratorType = utils::symbolizeIteratorType(s);
+    if (!maybeIteratorType.has_value())
+      return parser.emitError(parser.getCurrentLocation())
+             << "unexpected iterator_type (" << s << ")";
+
+    iteratorTypeAttrs.push_back(
+        IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
+  }
+  result.attributes.set(getIteratorTypesAttrName(result.name),
+                        parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
+
   // Parsing is shared with named ops, except for the region.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
@@ -1418,9 +1466,9 @@ LogicalResult MapOp::verify() {
   return success();
 }
 
-SmallVector<StringRef> MapOp::getIteratorTypesArray() {
+SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
   int64_t rank = getInit().getType().getRank();
-  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
 }
 
 ArrayAttr MapOp::getIndexingMaps() {
@@ -1476,12 +1524,12 @@ void ReduceOp::build(
                        inputs, inits, bodyBuild);
 }
 
-SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
+SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
   int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
-  SmallVector<StringRef> iteratorTypes(inputRank,
-                                       getParallelIteratorTypeName());
+  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+                                                 utils::IteratorType::parallel);
   for (int64_t reductionDim : getDimensions())
-    iteratorTypes[reductionDim] = getReductionIteratorTypeName();
+    iteratorTypes[reductionDim] = utils::IteratorType::reduction;
   return iteratorTypes;
 }
 
@@ -1753,9 +1801,9 @@ LogicalResult TransposeOp::verify() {
   return success();
 }
 
-SmallVector<StringRef> TransposeOp::getIteratorTypesArray() {
+SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
   int64_t rank = getInit().getType().getRank();
-  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
 }
 
 ArrayAttr TransposeOp::getIndexingMaps() {
@@ -1891,9 +1939,9 @@ LogicalResult BroadcastOp::verify() {
   return success();
 }
 
-SmallVector<StringRef> BroadcastOp::getIteratorTypesArray() {
+SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
   int64_t rank = getInit().getType().getRank();
-  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
 }
 
 ArrayAttr BroadcastOp::getIndexingMaps() {
index 6a9c4e3..7fd5a5e 100644 (file)
@@ -470,10 +470,9 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > 0 &&
-         llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) {
-           return it == getParallelIteratorTypeName();
-         });
+         genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
+             0 &&
+         llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
 }
 
 namespace {
@@ -783,8 +782,8 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   }
 
   // The iterator types of the expanded op are all parallel.
-  SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
-                                       getParallelIteratorTypeName());
+  SmallVector<utils::IteratorType> iteratorTypes(
+      expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp =
@@ -1083,7 +1082,8 @@ getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
       continue;
 
     // Check that all folded iterator types are all parallel or all reductions.
-    StringRef startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
+    utils::IteratorType startIteratorType =
+        iteratorTypes[foldedIterationSpaceDims[0]];
     if (!isParallelIterator(startIteratorType) &&
         !isReductionIterator(startIteratorType))
       continue;
@@ -1235,10 +1235,10 @@ private:
 
 /// Get the iterator types for the collapsed operation given the original
 /// iterator types and collapsed dimensions.
-static SmallVector<StringRef>
-getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
+static SmallVector<utils::IteratorType>
+getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
                             const CollapsingInfo &collapsingInfo) {
-  SmallVector<StringRef> collapsedIteratorTypes;
+  SmallVector<utils::IteratorType> collapsedIteratorTypes;
   for (ReassociationIndicesRef foldedIterDims :
        collapsingInfo.getCollapsedOpToOrigOpMapping()) {
     assert(!foldedIterDims.empty() &&
@@ -1246,8 +1246,7 @@ getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
     // Just pick the iterator type of the first folded dim. Pre-condition checks
     // expected to have checked that iterator types of all folded dimensions are
     // the same.
-    collapsedIteratorTypes.push_back(
-        iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
+    collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
   }
   return collapsedIteratorTypes;
 }
@@ -1406,8 +1405,8 @@ static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
   }
 
   // Get the iterator types for the operand.
-  SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
-      genericOp.getIteratorTypes().getValue(), collapsingInfo);
+  SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
+      genericOp.getIteratorTypesArray(), collapsingInfo);
 
   // Get the indexing maps.
   auto indexingMaps = llvm::to_vector(
index 3740633..52287f1 100644 (file)
@@ -91,8 +91,8 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
     SmallVector<AffineMap, 3> indexingMaps(
         op->getNumResults() + op->getNumOperands(),
         rewriter.getMultiDimIdentityMap(rank));
-    SmallVector<StringRef, 6> iteratorTypes(rank,
-                                            getParallelIteratorTypeName());
+    SmallVector<utils::IteratorType, 6> iteratorTypes(
+        rank, utils::IteratorType::parallel);
     auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
         op, /*resultTensorTypes=*/op->getResultTypes(),
index da43b49..4755fa3 100644 (file)
@@ -53,7 +53,7 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
   SmallVector<Value> inputs = linalgOp.getDpsInputOperands();
   SmallVector<Value> outputs = linalgOp.getDpsInitOperands();
   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
-  SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
+  SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
   SmallVector<Type> resultTypes = linalgOp.hasTensorSemantics()
                                       ? TypeRange(ValueRange(outputs))
                                       : TypeRange{};
index 0608c36..2fb550b 100644 (file)
@@ -162,13 +162,13 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
 
   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
                                    op.getContext()));
-  SmallVector<StringRef> newIteratorTypes;
+  SmallVector<utils::IteratorType> newIteratorTypes;
   for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) {
     if (insertSplitDimension == it.index() && !control.innerParallel)
-      newIteratorTypes.push_back(getParallelIteratorTypeName());
+      newIteratorTypes.push_back(utils::IteratorType::parallel);
     newIteratorTypes.push_back(it.value());
     if (insertSplitDimension == it.index() && control.innerParallel)
-      newIteratorTypes.push_back(getParallelIteratorTypeName());
+      newIteratorTypes.push_back(utils::IteratorType::parallel);
   }
   // Create the new op matching the original op with an extra parallel
   // dimension.
@@ -182,14 +182,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   // from the previous op.
   unsigned intermRank = newOutputShape.size();
   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
-  SmallVector<StringRef> reductionIteratorTypes;
+  SmallVector<utils::IteratorType> reductionIteratorTypes;
   SmallVector<AffineExpr> exprs;
   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
     if (insertSplitDimension == i) {
-      reductionIteratorTypes.push_back(getReductionIteratorTypeName());
+      reductionIteratorTypes.push_back(utils::IteratorType::reduction);
     } else {
       exprs.push_back(b.getAffineDimExpr(i));
-      reductionIteratorTypes.push_back(getParallelIteratorTypeName());
+      reductionIteratorTypes.push_back(utils::IteratorType::parallel);
     }
   }
   AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
@@ -367,7 +367,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
   // dimension.
   auto iteratorTypes = op.getIteratorTypesArray();
   iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
-                       getParallelIteratorTypeName());
+                       utils::IteratorType::parallel);
   GenericOp genericOp =
       b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
                           newOutputs, newMaps, iteratorTypes);
@@ -394,10 +394,10 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
     AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
     SmallVector<AffineMap> indexingMaps = {
         map, map.dropResult(insertSplitDimension)};
-    SmallVector<StringRef> reductionIteratorTypes(
-        originalOutputType.getRank() + 1, getParallelIteratorTypeName());
+    SmallVector<utils::IteratorType> reductionIteratorTypes(
+        originalOutputType.getRank() + 1, utils::IteratorType::parallel);
     reductionIteratorTypes[insertSplitDimension] =
-        getReductionIteratorTypeName();
+        utils::IteratorType::reduction;
 
     // clang-format off
     auto reductionOp = b.create<GenericOp>(
index 5937da3..e1d4616 100644 (file)
@@ -431,7 +431,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
   auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
 
-  SmallVector<StringRef, 4> iteratorTypes;
+  SmallVector<utils::IteratorType, 4> iteratorTypes;
   for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
     if (loopIndexToRangeIndex.count(attr.index()))
       iteratorTypes.push_back(attr.value());
index d1fcc01..02f4e9d 100644 (file)
@@ -88,10 +88,7 @@ struct LinalgOpTilingInterface
   /// Return the loop iterator type.
   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
-    return llvm::to_vector(llvm::map_range(
-        concreteOp.getIteratorTypesArray(), [](StringRef iteratorType) {
-          return utils::symbolizeIteratorType(iteratorType).value();
-        }));
+    return concreteOp.getIteratorTypesArray();
   }
 
   /// Return the iteration domain range.
@@ -339,8 +336,9 @@ struct LinalgOpPartialReductionInterface
 
     // Step3. create a generic op where the reduction dimension is replaced by a
     // parallel dimension of the size of reduction.
-    SmallVector<StringRef> newIteratorTypes = linalgOp.getIteratorTypesArray();
-    newIteratorTypes[reductionDims[0]] = getParallelIteratorTypeName();
+    SmallVector<utils::IteratorType> newIteratorTypes =
+        linalgOp.getIteratorTypesArray();
+    newIteratorTypes[reductionDims[0]] = utils::IteratorType::parallel;
     SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
     newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
                                     linalgOp.getContext());
@@ -366,14 +364,14 @@ struct LinalgOpPartialReductionInterface
     int64_t intermRank =
         partialReduce[0].getType().cast<ShapedType>().getRank();
     AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
-    SmallVector<StringRef> reductionIteratorTypes;
+    SmallVector<utils::IteratorType> reductionIteratorTypes;
     SmallVector<AffineExpr> exprs;
     for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
       if (dimToMerge == i) {
-        reductionIteratorTypes.push_back(getReductionIteratorTypeName());
+        reductionIteratorTypes.push_back(utils::IteratorType::reduction);
       } else {
         exprs.push_back(b.getAffineDimExpr(i));
-        reductionIteratorTypes.push_back(getParallelIteratorTypeName());
+        reductionIteratorTypes.push_back(utils::IteratorType::parallel);
       }
     }
     AffineMap outputMap =
index 1034e8e..11ee55c 100644 (file)
@@ -297,8 +297,10 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
   return vectorizeCopy(rewriter, copyOp);
 }
 
-static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
-  return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
+static SmallVector<utils::IteratorType>
+getNParallelLoopsAttrs(unsigned nParallelLoops) {
+  return SmallVector<utils::IteratorType>(nParallelLoops,
+                                          utils::IteratorType::parallel);
 }
 
 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to
index 2cf74a6..a643713 100644 (file)
@@ -1420,11 +1420,12 @@ namespace {
 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
 /// ```
 /// kw is unrolled, w is unrolled iff dilationW > 1.
-struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
+struct Conv1DGenerator
+    : public StructuredGenerator<LinalgOp, utils::IteratorType> {
   Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
                   int dilationW)
-      : StructuredGenerator<LinalgOp>(builder, linalgOp), strideW(strideW),
-        dilationW(dilationW) {
+      : StructuredGenerator<LinalgOp, utils::IteratorType>(builder, linalgOp),
+        strideW(strideW), dilationW(dilationW) {
     // Determine whether `linalgOp` can be generated with this generator
     if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
       return;
index ccf7cdc..fc34353 100644 (file)
@@ -186,12 +186,12 @@ bool isElementwise(LinalgOp op) {
   return hasOnlyScalarElementwiseOp(op->getRegion(0));
 }
 
-bool isParallelIterator(StringRef iteratorType) {
-  return iteratorType == getParallelIteratorTypeName();
+bool isParallelIterator(utils::IteratorType iteratorType) {
+  return iteratorType == utils::IteratorType::parallel;
 }
 
-bool isReductionIterator(StringRef iteratorType) {
-  return iteratorType == getReductionIteratorTypeName();
+bool isReductionIterator(utils::IteratorType iteratorType) {
+  return iteratorType == utils::IteratorType::reduction;
 }
 
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
@@ -422,15 +422,13 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
           b.getContext())),
       AffineMap::getMultiDimIdentityMap(transposeVector.size(),
                                         b.getContext())};
-  SmallVector<llvm::StringRef> iteratorTypes(transposeVector.size(),
-                                             getParallelIteratorTypeName());
+  SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(),
+                                                 utils::IteratorType::parallel);
 
   // Create a GenericOp to transpose `inputTensor` into `outputTensor`.
-  auto transposeOp = b.create<GenericOp>(
-      loc, resultTensorType, inputTensor, outputTensor,
-      b.getAffineMapArrayAttr(indexingMaps), b.getStrArrayAttr(iteratorTypes),
-      /*doc=*/nullptr,
-      /*library_call=*/nullptr);
+  auto transposeOp =
+      b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor,
+                          indexingMaps, iteratorTypes);
   Region &body = transposeOp.getRegion();
   body.push_back(new Block());
   body.front().addArguments({elementType, elementType}, {loc, loc});
@@ -452,8 +450,8 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
 
   AffineMap id =
       AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
-  SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
-                                       getParallelIteratorTypeName());
+  SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
+                                                 utils::IteratorType::parallel);
   return b.create<linalg::GenericOp>(
       loc,
       /*inputs=*/from,
@@ -469,7 +467,7 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(
     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
-    ArrayRef<StringRef> iteratorTypes,
+    ArrayRef<utils::IteratorType> iteratorTypes,
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
@@ -513,7 +511,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
 template <>
 void GenerateLoopNest<AffineForOp>::doit(
     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
-    ArrayRef<StringRef> iteratorTypes,
+    ArrayRef<utils::IteratorType> iteratorTypes,
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
@@ -564,7 +562,7 @@ void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
 // exceeds 10.
 static void generateParallelLoopNest(
     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
-    ValueRange steps, ArrayRef<StringRef> iteratorTypes,
+    ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
     ArrayRef<linalg::ProcInfo> procInfo,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
     SmallVectorImpl<Value> &ivStorage) {
@@ -679,7 +677,7 @@ static void generateParallelLoopNest(
 template <>
 void GenerateLoopNest<scf::ParallelOp>::doit(
     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
-    ArrayRef<StringRef> iteratorTypes,
+    ArrayRef<utils::IteratorType> iteratorTypes,
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
index 533d31f..3f2ee1b 100644 (file)
@@ -178,7 +178,8 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
 /// as we use adj matrix for the graph.
 /// The sorted result will put the first Reduction iterator to the
 /// latest possible index.
-static bool topSortOptimal(unsigned n, ArrayRef<StringRef> iteratorTypes,
+static bool topSortOptimal(unsigned n,
+                           ArrayRef<utils::IteratorType> iteratorTypes,
                            std::vector<unsigned> &topSort,
                            std::vector<unsigned> &inDegree,
                            std::vector<std::vector<bool>> &adjM) {
index 7f2e970..a9c77c6 100644 (file)
 using namespace mlir;
 using namespace mlir::tosa;
 
-SmallVector<StringRef>
+SmallVector<utils::IteratorType>
 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
-  return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
+  return SmallVector<utils::IteratorType>(nParallelLoops,
+                                          utils::IteratorType::parallel);
 }
 
 SmallVector<Value>
index 0bdaf7b..47aefa1 100644 (file)
@@ -1518,27 +1518,14 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
 }
 
 namespace {
-struct IteratorType {
-  IteratorType(StringRef strRef) : strRef(strRef) {}
-  bool isOfType(Attribute attr) const {
-    auto sAttr = attr.dyn_cast<StringAttr>();
-    return sAttr && sAttr.getValue() == strRef;
-  }
-  StringRef strRef;
-};
-struct Par : public IteratorType {
-  Par() : IteratorType(getParallelIteratorTypeName()) {}
-};
-struct Red : public IteratorType {
-  Red() : IteratorType(getReductionIteratorTypeName()) {}
-};
 
 /// Generate a vector implementation for matmat, matvec and tmatvec.
 /// This unrolls outer-products along the reduction dimension.
 struct UnrolledOuterProductGenerator
-    : public StructuredGenerator<vector::ContractionOp> {
+    : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
-      : StructuredGenerator<vector::ContractionOp>(builder, op),
+      : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(
+            builder, op),
         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
         res(op.getAcc()), lhsType(op.getLhsType()) {}
 
@@ -2719,8 +2706,10 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
     } else {
       MemRefLayoutAttrInterface updatedLayout;
       if (auto strided = layout.dyn_cast<StridedLayoutAttr>()) {
-        auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
-        updatedLayout = StridedLayoutAttr::get(strided.getContext(), strided.getOffset(), strides);
+        auto strides =
+            llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+        updatedLayout = StridedLayoutAttr::get(strided.getContext(),
+                                               strided.getOffset(), strides);
       } else {
         AffineMap map = srcType.getLayout().getAffineMap();
         int numSymbols = map.getNumSymbols();
index f9b20d3..e7d6dd9 100644 (file)
@@ -17,7 +17,7 @@ func.func @test_conv_op_wrong_num_operands(%arg0 : tensor<?xf32>,
   // expected-error @+1 {{expected op with 2 inputs and 1 output}}
   %0 = test.linalg_conv_op {
       indexing_maps = [#map, #map],
-      iterator_types = ["parallel"]}
+      iterator_types = [#test.iterator_type<parallel>]}
       ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<?xf32>) {
       ^bb0(%arg2 : f32, %arg3 : f32):
          linalg.yield  %arg3 : f32
@@ -34,7 +34,8 @@ func.func @test_conv_op_wrong_input_indexing_map1(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d0 * 2)>,
                        affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d0)>],
-      iterator_types = ["parallel", "reduction"]}
+      iterator_types = [#test.iterator_type<parallel>,
+                        #test.iterator_type<reduction>]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -52,7 +53,8 @@ func.func @test_conv_op_wrong_input_indexing_map2(%arg0 : tensor<?x?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d0 + d1, d0)>,
                        affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d0)>],
-      iterator_types = ["parallel", "reduction"]}
+      iterator_types = [#test.iterator_type<parallel>,
+                        #test.iterator_type<reduction>]}
       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -70,7 +72,8 @@ func.func @test_conv_op_filter_index_map_not_projection(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d1 + d0)>,
                        affine_map<(d0, d1) -> (d0)>],
-      iterator_types = ["parallel", "reduction"]}
+      iterator_types = [#test.iterator_type<parallel>,
+                        #test.iterator_type<reduction>]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -88,7 +91,8 @@ func.func @test_conv_op_output_index_map_not_projection(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d0)>,
                        affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d0 + d1)>],
-      iterator_types = ["parallel", "parallel"]}
+      iterator_types = [#test.iterator_type<parallel>,
+                        #test.iterator_type<parallel>]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -108,7 +112,8 @@ func.func @test_conv_op_output_filter_convolved(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>,
                        affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d0, d1)>],
-      iterator_types = ["parallel", "parallel"]}
+      iterator_types = [#test.iterator_type<parallel>,
+                        #test.iterator_type<parallel>]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?x?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -127,7 +132,9 @@ func.func @test_conv_op_output_only_dim(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>,
                        affine_map<(d0, d1, d2) -> (d1)>,
                        affine_map<(d0, d1, d2) -> (d0, d2)>],
-      iterator_types = ["parallel", "reduction", "parallel"]}
+      iterator_types = [#test.iterator_type<parallel>,
+                        #test.iterator_type<reduction>,
+                        #test.iterator_type<parallel>]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?x?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -146,7 +153,9 @@ func.func @test_conv_op_filter_only_dim(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>,
                        affine_map<(d0, d1, d2) -> (d1, d2)>,
                        affine_map<(d0, d1, d2) -> (d0)>],
-      iterator_types = ["parallel", "reduction", "reduction"]}
+      iterator_types = [#test.iterator_type<parallel>,
+                        #test.iterator_type<reduction>,
+                        #test.iterator_type<reduction>]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -165,7 +174,9 @@ func.func @test_conv_op_input_only_dim(%arg0 : tensor<?x?xf32>,
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1, d2)>,
                        affine_map<(d0, d1, d2) -> (d1)>,
                        affine_map<(d0, d1, d2) -> (d0)>],
-      iterator_types = ["parallel", "reduction", "reduction"]}
+      iterator_types = [#test.iterator_type<parallel>,
+                        #test.iterator_type<reduction>,
+                        #test.iterator_type<reduction>]}
       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -184,7 +195,8 @@ func.func @test_conv_op_non_output_access_loop_parallel(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>,
                        affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d0)>],
-      iterator_types = ["parallel", "parallel"]}
+      iterator_types = [#test.iterator_type<parallel>,
+                        #test.iterator_type<parallel>]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
index 5a1c2af..ebce71f 100644 (file)
@@ -96,7 +96,7 @@ func.func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
 // -----
 
 func.func @generic_wrong_iterator(%arg0: memref<1xi32>) {
-  // expected-error @+1 {{op unexpected iterator_type (random)}}
+  // expected-error @+4 {{unexpected iterator_type (random)}}
   linalg.generic {
     indexing_maps =  [ affine_map<(i) -> (i)> ],
     iterator_types = ["random"]}
index 4a92c07..b4ad820 100644 (file)
@@ -59,13 +59,19 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %match_attr = transform.structured.match
       ops{["linalg.generic"]}
-      attributes{iterator_types = ["parallel", "parallel", "parallel"]}
+      attributes{iterator_types = [
+        #linalg.iterator_type<parallel>,
+        #linalg.iterator_type<parallel>,
+        #linalg.iterator_type<parallel>]}
       in %arg1
   transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation
   transform.test_consume_operand %match_attr
 
   %no_match = transform.structured.match
-      attributes{iterator_types = ["parallel", "parallel", "reduction"]}
+      attributes{iterator_types = [
+        #linalg.iterator_type<parallel>,
+        #linalg.iterator_type<parallel>,
+        #linalg.iterator_type<reduction>]}
       in %arg1
 // expected-remark @below {{0}}
   transform.test_print_number_of_associated_payload_ir_ops %no_match
index 2c8719d..2d1a1df 100644 (file)
@@ -23,13 +23,16 @@ mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=test)
 mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=test)
 add_public_tablegen_target(MLIRTestTypeDefIncGen)
 
+set(LLVM_TARGET_DEFINITIONS TestEnumDefs.td)
+mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls)
+mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRTestEnumDefIncGen)
+
 set(LLVM_TARGET_DEFINITIONS TestOps.td)
 mlir_tablegen(TestOps.h.inc -gen-op-decls)
 mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
 mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test)
 mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test)
-mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls)
-mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
 mlir_tablegen(TestPatterns.inc -gen-rewriters)
 add_public_tablegen_target(MLIRTestOpsIncGen)
 
@@ -46,6 +49,7 @@ add_mlir_library(MLIRTestDialect
 
   DEPENDS
   MLIRTestAttrDefIncGen
+  MLIRTestEnumDefIncGen
   MLIRTestInterfaceIncGen
   MLIRTestTypeDefIncGen
   MLIRTestOpsIncGen
index 0c35f81..c4996ab 100644 (file)
@@ -15,6 +15,8 @@
 
 // To get the test dialect definition.
 include "TestDialect.td"
+include "TestEnumDefs.td"
+include "mlir/Dialect/Utils/StructuredOpsUtils.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/EnumAttr.td"
@@ -277,13 +279,6 @@ def TestArrayOfInts : ArrayOfAttr<Test_Dialect, "ArrayOfInts",
     "array_of_ints", "int32_t">;
 
 // An array of enum attributes.
-def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
-    I32EnumAttrCase<"a", 0>,
-    I32EnumAttrCase<"b", 1>
-  ]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::test";
-}
 def TestSimpleEnumAttr : EnumAttr<Test_Dialect, TestSimpleEnum, "simple_enum"> {
   let assemblyFormat = "`` $value";
 }
@@ -297,4 +292,14 @@ def TestCustomAnchor : Test_Attr<"TestCustomAnchor"> {
   let assemblyFormat = "`<` $a (`>`) : (`,` ` ` custom<TrueFalse>($b)^ `>`)?";
 }
 
+def Test_IteratorTypeEnum
+    : EnumAttr<Test_Dialect, IteratorType, "iterator_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def Test_IteratorTypeArrayAttr
+    : TypedArrayAttrBase<Test_IteratorTypeEnum,
+  "Iterator type should be an enum.">;
+
+
 #endif // TEST_ATTRDEFS
index 4cb4d61..cc73e07 100644 (file)
@@ -17,6 +17,7 @@
 #include <tuple>
 
 #include "TestTraits.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
diff --git a/mlir/test/lib/Dialect/Test/TestEnumDefs.td b/mlir/test/lib/Dialect/Test/TestEnumDefs.td
new file mode 100644 (file)
index 0000000..1ddfca0
--- /dev/null
@@ -0,0 +1,97 @@
+//===-- TestEnumDefs.td - Test dialect enum definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TableGen enum definitions for Test dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_ENUMDEFS
+#define TEST_ENUMDEFS
+
+include "mlir/IR/EnumAttr.td"
+
+def I32Case5:  I32EnumAttrCase<"case5", 5>;
+def I32Case10: I32EnumAttrCase<"case10", 10>;
+
+def SomeI32Enum: I32EnumAttr<
+  "SomeI32Enum", "", [I32Case5, I32Case10]>;
+
+def I64Case5:  I64EnumAttrCase<"case5", 5>;
+def I64Case10: I64EnumAttrCase<"case10", 10>;
+
+def SomeI64Enum: I64EnumAttr<
+  "SomeI64Enum", "", [I64Case5, I64Case10]>;
+
+//===----------------------------------------------------------------------===//
+// Test Enum
+//===----------------------------------------------------------------------===//
+
+// Define the C++ enum.
+def TestEnum
+    : I32EnumAttr<"TestEnum", "a test enum", [
+        I32EnumAttrCase<"First", 0, "first">,
+        I32EnumAttrCase<"Second", 1, "second">,
+        I32EnumAttrCase<"Third", 2, "third">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "test";
+}
+
+def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
+    I32EnumAttrCase<"a", 0>,
+    I32EnumAttrCase<"b", 1>
+  ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::test";
+}
+
+//===----------------------------------------------------------------------===//
+// Test Bit Enum
+//===----------------------------------------------------------------------===//
+
+// Define the C++ enum.
+def TestBitEnum
+    : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
+        I32BitEnumAttrCaseBit<"Read", 0, "read">,
+        I32BitEnumAttrCaseBit<"Write", 1, "write">,
+        I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "test";
+  let separator = ", ";
+}
+
+// Define an enum with a different separator
+def TestBitEnumVerticalBar
+    : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
+        I32BitEnumAttrCaseBit<"User", 0, "user">,
+        I32BitEnumAttrCaseBit<"Group", 1, "group">,
+        I32BitEnumAttrCaseBit<"Other", 2, "other">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "test";
+  let separator = " | ";
+}
+
+//===----------------------------------------------------------------------===//
+// Test Patterns (Multi-result Ops)
+
+def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
+def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
+def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
+def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
+def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
+def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;
+
+def MultiResultOpEnum: I64EnumAttr<
+  "MultiResultOpEnum", "Multi-result op kinds", [
+    MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
+    MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
+  ]>;
+
+#endif // TEST_ENUMDEFS
index 84bbe24..84dd37f 100644 (file)
@@ -202,23 +202,11 @@ def FloatAttrOp : TEST_Op<"float_attrs"> {
   );
 }
 
-def I32Case5:  I32EnumAttrCase<"case5", 5>;
-def I32Case10: I32EnumAttrCase<"case10", 10>;
-
-def SomeI32Enum: I32EnumAttr<
-  "SomeI32Enum", "", [I32Case5, I32Case10]>;
-
 def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> {
   let arguments = (ins SomeI32Enum:$attr);
   let results = (outs I32:$val);
 }
 
-def I64Case5:  I64EnumAttrCase<"case5", 5>;
-def I64Case10: I64EnumAttrCase<"case10", 10>;
-
-def SomeI64Enum: I64EnumAttr<
-  "SomeI64Enum", "", [I64Case5, I64Case10]>;
-
 def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> {
   let arguments = (ins SomeI64Enum:$attr);
   let results = (outs I32:$val);
@@ -319,17 +307,6 @@ def ConfinedDenseArrayAttrOp : TEST_Op<"confined_dense_array_attr"> {
 // Test Enum Attributes
 //===----------------------------------------------------------------------===//
 
-// Define the C++ enum.
-def TestEnum
-    : I32EnumAttr<"TestEnum", "a test enum", [
-        I32EnumAttrCase<"First", 0, "first">,
-        I32EnumAttrCase<"Second", 1, "second">,
-        I32EnumAttrCase<"Third", 2, "third">,
-      ]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "test";
-}
-
 // Define the enum attribute.
 def TestEnumAttr : EnumAttr<Test_Dialect, TestEnum, "enum">;
 
@@ -351,18 +328,6 @@ def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
 // Test Bit Enum Attributes
 //===----------------------------------------------------------------------===//
 
-// Define the C++ enum.
-def TestBitEnum
-    : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
-        I32BitEnumAttrCaseBit<"Read", 0, "read">,
-        I32BitEnumAttrCaseBit<"Write", 1, "write">,
-        I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
-      ]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "test";
-  let separator = ", ";
-}
-
 // Define the enum attribute.
 def TestBitEnumAttr : EnumAttr<Test_Dialect, TestBitEnum, "bit_enum"> {
   let assemblyFormat = "`<` $value `>`";
@@ -374,18 +339,6 @@ def OpWithBitEnum : TEST_Op<"op_with_bit_enum"> {
   let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
 }
 
-// Define an enum with a different separator
-def TestBitEnumVerticalBar
-    : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
-        I32BitEnumAttrCaseBit<"User", 0, "user">,
-        I32BitEnumAttrCaseBit<"Group", 1, "group">,
-        I32BitEnumAttrCaseBit<"Other", 2, "other">,
-      ]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "test";
-  let separator = " | ";
-}
-
 def TestBitEnumVerticalBarAttr
     : EnumAttr<Test_Dialect, TestBitEnumVerticalBar, "bit_enum_vbar"> {
   let assemblyFormat = "`<` $value `>`";
@@ -1392,22 +1345,6 @@ def : Pat<(OpC $input), (OpB $input, ConstantAttr<I32Attr, "17">:$attr)>;
 def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
 def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
 
-//===----------------------------------------------------------------------===//
-// Test Patterns (Multi-result Ops)
-
-def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
-def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
-def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
-def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
-def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
-def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;
-
-def MultiResultOpEnum: I64EnumAttr<
-  "MultiResultOpEnum", "Multi-result op kinds", [
-    MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
-    MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
-  ]>;
-
 def ThreeResultOp : TEST_Op<"three_result"> {
   let arguments = (ins MultiResultOpEnum:$kind);
   let results = (outs I32:$result1, F32:$result2, F32:$result3);
@@ -2824,8 +2761,10 @@ def TestLinalgConvOp :
       return &regionBuilder;
     }
 
-    mlir::ArrayAttr getIteratorTypes() {
-      return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+    llvm::SmallVector<mlir::utils::IteratorType> getIteratorTypesArray() {
+      auto attrs = getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+      auto range = attrs.getAsValueRange<IteratorTypeAttr, mlir::utils::IteratorType>();
+      return {range.begin(), range.end()};
     }
 
     mlir::ArrayAttr getIndexingMaps() {
@@ -2884,8 +2823,10 @@ def TestLinalgFillOp :
       return &regionBuilder;
     }
 
-    mlir::ArrayAttr getIteratorTypes() {
-      return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+    llvm::SmallVector<mlir::utils::IteratorType> getIteratorTypesArray() {
+      auto attrs = getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+      auto range = attrs.getAsValueRange<IteratorTypeAttr, mlir::utils::IteratorType>();
+      return {range.begin(), range.end()};
     }
 
     mlir::ArrayAttr getIndexingMaps() {
index 0a482cc..3595be8 100644 (file)
@@ -553,7 +553,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
 
     let extraClassDeclaration = structuredOpsBaseDecls # [{{
       // Auto-generated.
-      SmallVector<StringRef> getIteratorTypesArray();
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
       ArrayAttr getIndexingMaps();
       static void regionBuilder(ImplicitLocOpBuilder &b,
                                 Block &block, ArrayRef<NamedAttribute> attrs);
@@ -597,8 +597,8 @@ static const char structuredOpBuilderFormat[] = R"FMT(
 // {1}: Comma interleaved iterator type names.
 static const char structuredOpIteratorTypesFormat[] =
     R"FMT(
-SmallVector<StringRef> {0}::getIteratorTypesArray() {{
-  return SmallVector<StringRef>{{ {1} };
+SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
+  return SmallVector<utils::IteratorType>{{ {1} };
 }
 )FMT";
 
@@ -607,9 +607,9 @@ SmallVector<StringRef> {0}::getIteratorTypesArray() {{
 // {0}: Class name
 static const char rankPolyStructuredOpIteratorTypesFormat[] =
     R"FMT(
-SmallVector<StringRef> {0}::getIteratorTypesArray() {{
+SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
   int64_t rank = getRank(getDpsInitOperand(0));
-  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
 }
 )FMT";
 
@@ -812,10 +812,10 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
                           [&](LinalgIteratorTypeDef it) {
                             switch (it) {
                             case LinalgIteratorTypeDef::parallel:
-                              ss << "getParallelIteratorTypeName()";
+                              ss << "utils::IteratorType::parallel";
                               break;
                             case LinalgIteratorTypeDef::reduction:
-                              ss << "getReductionIteratorTypeName()";
+                              ss << "utils::IteratorType::reduction";
                               break;
                             }
                           });
index d58a07f..0a64646 100644 (file)
@@ -132,14 +132,6 @@ gentbl_cc_library(
             "lib/Dialect/Test/TestOpsDialect.cpp.inc",
         ),
         (
-            ["-gen-enum-decls"],
-            "lib/Dialect/Test/TestOpEnums.h.inc",
-        ),
-        (
-            ["-gen-enum-defs"],
-            "lib/Dialect/Test/TestOpEnums.cpp.inc",
-        ),
-        (
             ["-gen-rewriters"],
             "lib/Dialect/Test/TestPatterns.inc",
         ),
@@ -212,6 +204,27 @@ gentbl_cc_library(
 )
 
 gentbl_cc_library(
+    name = "TestEnumDefsIncGen",
+    strip_include_prefix = "lib/Dialect/Test",
+    tbl_outs = [
+        (
+            ["-gen-enum-decls"],
+            "lib/Dialect/Test/TestOpEnums.h.inc",
+        ),
+        (
+            ["-gen-enum-defs"],
+            "lib/Dialect/Test/TestOpEnums.cpp.inc",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "lib/Dialect/Test/TestEnumDefs.td",
+    test = True,
+    deps = [
+        ":TestOpTdFiles",
+    ],
+)
+
+gentbl_cc_library(
     name = "TestTypeDefsIncGen",
     strip_include_prefix = "lib/Dialect/Test",
     tbl_outs = [
@@ -318,6 +331,7 @@ cc_library(
     ],
     deps = [
         ":TestAttrDefsIncGen",
+        ":TestEnumDefsIncGen",
         ":TestInterfacesIncGen",
         ":TestOpsIncGen",
         ":TestTypeDefsIncGen",
@@ -330,6 +344,7 @@ cc_library(
         "//mlir:DerivedAttributeOpInterface",
         "//mlir:DestinationStyleOpInterface",
         "//mlir:Dialect",
+        "//mlir:DialectUtils",
         "//mlir:FuncDialect",
         "//mlir:FuncTransforms",
         "//mlir:IR",