Revert "[mlir][linalg] Replace "string" iterator_types attr with enums in LinalgInter...
authorOleg Shyshkov <shyshkov@google.com>
Wed, 9 Nov 2022 14:59:54 +0000 (15:59 +0100)
committerOleg Shyshkov <shyshkov@google.com>
Wed, 9 Nov 2022 14:59:54 +0000 (15:59 +0100)
Breaks linalg python tests. Would need to also update python/mlir/dialects/linalg/opdsl.

This reverts commit b809d73973bb5aeedeb6a18cac2a7b3111d0c8d2.

33 files changed:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Linalg/conv-interface-invalid.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/transform-op-match.mlir
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestAttributes.h
mlir/test/lib/Dialect/Test/TestEnumDefs.td [deleted file]
mlir/test/lib/Dialect/Test/TestOps.td
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

index 1ee9d2a..36cf41d 100644 (file)
@@ -13,7 +13,6 @@
 #ifndef LINALG_BASE
 #define LINALG_BASE
 
-include "mlir/Dialect/Utils/StructuredOpsUtils.td"
 include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
@@ -72,10 +71,4 @@ def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
-def IteratorTypeEnum : EnumAttr<Linalg_Dialect, IteratorType, "iterator_type"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
-  "Iterator type should be an enum.">;
-
 #endif // LINALG_BASE
index 45d63d8..8e3df10 100644 (file)
@@ -25,7 +25,6 @@
 
 namespace mlir {
 namespace linalg {
-class IteratorTypeAttr;
 class LinalgOp;
 
 namespace detail {
index 17f6bf4..533a52f 100644 (file)
@@ -193,8 +193,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return llvm::count($_op.getIteratorTypesArray(),
-                           utils::IteratorType::parallel);
+        return getNumIterators(getParallelIteratorTypeName(),
+                                $_op.getIteratorTypesArray());
       }]
     >,
     InterfaceMethod<
@@ -207,7 +207,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return findPositionsOfType($_op.getIteratorTypesArray(),
-                                   utils::IteratorType::parallel, res);
+                                   getParallelIteratorTypeName(), res);
       }]
     >,
     InterfaceMethod<
@@ -219,8 +219,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return llvm::count($_op.getIteratorTypesArray(),
-                           utils::IteratorType::reduction);
+        return getNumIterators(getReductionIteratorTypeName(),
+                               $_op.getIteratorTypesArray());
       }]
     >,
     InterfaceMethod<
@@ -233,7 +233,33 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return findPositionsOfType($_op.getIteratorTypesArray(),
-                                   utils::IteratorType::reduction, res);
+                                   getReductionIteratorTypeName(), res);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the number of window loops.
+      }],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getNumWindowLoops",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return getNumIterators(getWindowIteratorTypeName(),
+                               $_op.getIteratorTypesArray());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the dims that are window loops.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"getWindowDims",
+      /*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return findPositionsOfType($_op.getIteratorTypesArray(),
+                                   getWindowIteratorTypeName(), res);
       }]
     >,
     InterfaceMethod<
@@ -245,7 +271,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return $_op.getIteratorTypesArray().size();
+        return getNumIterators($_op.getIteratorTypesArray());
       }]
     >,
     InterfaceMethod<
@@ -260,7 +286,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*defaultImplementation=*/[{
         auto iters = $_op.getIteratorTypesArray();
         return iters.size() == 1 &&
-               llvm::count(iters, utils::IteratorType::reduction) == 1;
+               getNumIterators(getReductionIteratorTypeName(), iters) == 1;
       }]>,
     //===------------------------------------------------------------------===//
     // Input and Init arguments handling.
@@ -480,14 +506,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         can be infered from other parameters and in such cases default
         getIteratorTypesArray should be overriden.
       }],
-      /*retTy=*/"SmallVector<utils::IteratorType>",
+      /*retTy=*/"SmallVector<StringRef>",
       /*methodName=*/"getIteratorTypesArray",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        auto range = $_op.getIteratorTypes()
-                         .template getAsValueRange<IteratorTypeAttr,
-                                                   utils::IteratorType>();
+        auto range = $_op.getIteratorTypes().template getAsValueRange<StringAttr>();
         return {range.begin(), range.end()};
       }]
     >,
@@ -743,6 +767,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     LogicalResult reifyResultShapes(OpBuilder &b,
         ReifiedRankedShapedTypeDims &reifiedReturnShapes);
 
+    SmallVector<StringRef> getIteratorTypeNames() {
+      return getIteratorTypesArray();
+    }
+
     //========================================================================//
     // Forwarding functions to access interface methods from the
     // DestinationStyleOpInterface.
index e822435..9866620 100644 (file)
@@ -163,7 +163,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
   let arguments = (ins Variadic<AnyType>:$inputs,
                        Variadic<AnyShaped>:$outputs,
                        AffineMapArrayAttr:$indexing_maps,
-                       IteratorTypeArrayAttr:$iterator_types,
+                       ArrayAttr:$iterator_types,
                        OptionalAttr<StrAttr>:$doc,
                        OptionalAttr<StrAttr>:$library_call);
   let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -178,22 +178,22 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
       "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
-      "ArrayRef<utils::IteratorType>":$iteratorTypes, "StringRef":$doc,
+      "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
       "StringRef":$libraryCall,
       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
-      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
+      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
       "StringRef":$doc, "StringRef":$libraryCall,
       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
       "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
-      "ArrayRef<utils::IteratorType>":$iteratorTypes,
+      "ArrayRef<StringRef>":$iteratorTypes,
       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
-      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
+      "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
   ];
@@ -275,7 +275,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Implement functions necessary for LinalgStructuredInterface.
-    SmallVector<utils::IteratorType> getIteratorTypesArray();
+    SmallVector<StringRef> getIteratorTypesArray();
     ArrayAttr getIndexingMaps();
     std::string getLibraryCallName() {
       return "op_has_no_registered_library_name";
@@ -356,7 +356,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Declare functions necessary for LinalgStructuredInterface.
-    SmallVector<utils::IteratorType> getIteratorTypesArray();
+    SmallVector<StringRef> getIteratorTypesArray();
     ArrayAttr getIndexingMaps();
     std::string getLibraryCallName() {
       return "op_has_no_registered_library_name";
@@ -426,7 +426,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Declare functions necessary for LinalgStructuredInterface.
-    SmallVector<utils::IteratorType> getIteratorTypesArray();
+    SmallVector<StringRef> getIteratorTypesArray();
     ArrayAttr getIndexingMaps();
     std::string getLibraryCallName() {
       return "op_has_no_registered_library_name";
@@ -502,7 +502,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Declare functions necessary for LinalgStructuredInterface.
-    SmallVector<utils::IteratorType> getIteratorTypesArray();
+    SmallVector<StringRef> getIteratorTypesArray();
     ArrayAttr getIndexingMaps();
     std::string getLibraryCallName() {
       return "op_has_no_registered_library_name";
index 4f9dd71..5fc7938 100644 (file)
@@ -42,10 +42,10 @@ bool hasOnlyScalarElementwiseOp(Region &r);
 bool isElementwise(LinalgOp op);
 
 /// Check if iterator type has "parallel" semantics.
-bool isParallelIterator(utils::IteratorType iteratorType);
+bool isParallelIterator(StringRef iteratorType);
 
 /// Check if iterator type  has "reduction" semantics.
-bool isReductionIterator(utils::IteratorType iteratorType);
+bool isReductionIterator(StringRef iteratorType);
 
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
@@ -480,8 +480,7 @@ struct RegionMatcher {
 template <typename LoopTy>
 struct GenerateLoopNest {
   static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
-                   LinalgOp linalgOp,
-                   ArrayRef<utils::IteratorType> iteratorTypes,
+                   LinalgOp linalgOp, ArrayRef<StringRef> iteratorTypes,
                    function_ref<scf::ValueVector(OpBuilder &, Location,
                                                  ValueRange, ValueRange)>
                        bodyBuilderFn,
index 6b2104f..b2b7b24 100644 (file)
@@ -22,8 +22,7 @@ namespace mlir {
 namespace tosa {
 
 // Creates a SmallVector of Stringrefs for N parallel loops
-SmallVector<utils::IteratorType>
-getNParallelLoopsAttrs(unsigned nParallelLoops);
+SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);
 
 // Takes a vector of values and condenses them to a vector with no gaps.
 SmallVector<Value> condenseValues(const SmallVector<Value> &values);
index cb509fe..6fcfcb1 100644 (file)
@@ -21,6 +21,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
 
 // Pull in all enum type definitions and utility function declarations.
 #include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc"
@@ -47,9 +48,42 @@ bool isColumnMajorMatmul(ArrayAttr indexingMaps);
 /// the reduction.
 bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
 
+/// Use to encode that a particular iterator type has parallel semantics.
+constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
+
+/// Use to encode that a particular iterator type has reduction semantics.
+constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
+
+/// Use to encode that a particular iterator type has window semantics.
+constexpr StringRef getWindowIteratorTypeName() { return "window"; }
+
+/// Use to encode that a particular iterator type has window semantics.
+inline ArrayRef<StringRef> getAllIteratorTypeNames() {
+  static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
+                                         getReductionIteratorTypeName(),
+                                         getWindowIteratorTypeName()};
+  return llvm::makeArrayRef(names);
+}
+
+/// Returns the iterator of a certain type.
+inline unsigned getNumIterators(StringRef name,
+                                ArrayRef<StringRef> iteratorTypes) {
+  auto names = getAllIteratorTypeNames();
+  (void)names;
+  assert(llvm::is_contained(names, name));
+  return llvm::count(iteratorTypes, name);
+}
+
+inline unsigned getNumIterators(ArrayRef<StringRef> iteratorTypes) {
+  unsigned res = 0;
+  for (auto n : getAllIteratorTypeNames())
+    res += getNumIterators(n, iteratorTypes);
+  return res;
+}
+
 /// Return positions in `iteratorTypes` that match `iteratorTypeName`.
-inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
-                                utils::IteratorType iteratorTypeName,
+inline void findPositionsOfType(ArrayRef<StringRef> iteratorTypes,
+                                StringRef iteratorTypeName,
                                 SmallVectorImpl<unsigned> &res) {
   for (const auto &en : llvm::enumerate(iteratorTypes)) {
     if (en.value() == iteratorTypeName)
@@ -60,28 +94,29 @@ inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
 /// Helper StructuredGenerator class to manipulate and rewrite ops with
 /// `StructuredOpInterface`. This is templated for now because VectorOps do not
 /// yet implement the StructuredOpInterface itself.
-template <typename StructuredOpInterface, typename IteratorTypeT>
+template <typename StructuredOpInterface>
 class StructuredGenerator {
 public:
   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
 
   struct IteratorType {
-    IteratorType(IteratorTypeT iter) : iter(iter) {}
-    bool isOfType(IteratorTypeT expectedIter) const {
-      return expectedIter == iter;
-    }
-    IteratorTypeT iter;
+    IteratorType(StringRef strRef) : strRef(strRef) {}
+    bool isOfType(StringRef typeName) const { return typeName == strRef; }
+    StringRef strRef;
   };
   struct Par : public IteratorType {
-    Par() : IteratorType(IteratorTypeT::parallel) {}
+    Par() : IteratorType(getParallelIteratorTypeName()) {}
   };
   struct Red : public IteratorType {
-    Red() : IteratorType(IteratorTypeT::reduction) {}
+    Red() : IteratorType(getReductionIteratorTypeName()) {}
+  };
+  struct Win : public IteratorType {
+    Win() : IteratorType(getWindowIteratorTypeName()) {}
   };
 
   StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
       : builder(builder), ctx(op.getContext()), loc(op.getLoc()),
-        iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()),
+        iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()),
         op(op) {}
 
   bool iters(ArrayRef<IteratorType> its) {
@@ -103,7 +138,7 @@ protected:
   OpBuilder &builder;
   MLIRContext *ctx;
   Location loc;
-  SmallVector<IteratorTypeT> iterators;
+  SmallVector<StringRef> iterators;
   SmallVector<AffineMap, 4> maps;
   Operation *op;
 };
index 5060d8c..758e7c1 100644 (file)
@@ -269,11 +269,12 @@ def Vector_ContractionOp :
       return CombiningKind::ADD;
     }
 
-    SmallVector<IteratorType> getIteratorTypesArray() {
-      auto range =
-          getIteratorTypes()
-              .template getAsValueRange<IteratorTypeAttr, IteratorType>();
-      return {range.begin(), range.end()};
+    // Returns iterator types in string format.
+    SmallVector<StringRef> getIteratorTypeNames() {
+      return llvm::to_vector(
+          llvm::map_range(getIteratorTypes(), [](Attribute a) {
+            return stringifyIteratorType(a.cast<IteratorTypeAttr>().getValue());
+          }));
     }
   }];
 
index f56162b..04cd00f 100644 (file)
@@ -791,12 +791,12 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
 
   SmallVector<AffineExpr, 2> srcExprs;
   SmallVector<AffineExpr, 2> dstExprs;
-  SmallVector<utils::IteratorType, 4> iteratorTypes;
+  SmallVector<StringRef, 4> iteratorTypes;
   for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
     srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
 
-    iteratorTypes.push_back(axis == i ? utils::IteratorType::reduction
-                                      : utils::IteratorType::parallel);
+    iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
+                                      : getParallelIteratorTypeName());
     if (axis != i)
       dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
   }
@@ -1383,8 +1383,7 @@ public:
     auto inputMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
                                    inputExprs, builder.getContext());
     auto resultMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
-    SmallVector<utils::IteratorType> iterators(4,
-                                               utils::IteratorType::parallel);
+    SmallVector<StringRef> iterators(4, getParallelIteratorTypeName());
 
     Value empty = builder.create<tensor::EmptyOp>(
         resultTy.getShape(), resultTy.getElementType(), outputDynSize);
@@ -2084,9 +2083,9 @@ public:
 
     // We need to reduce along the arg-max axis, with parallel operations along
     // the rest.
-    SmallVector<utils::IteratorType, 4> iteratorTypes;
-    iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
-    iteratorTypes[axis] = utils::IteratorType::reduction;
+    SmallVector<StringRef, 4> iteratorTypes;
+    iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
+    iteratorTypes[axis] = getReductionIteratorTypeName();
 
     SmallVector<AffineExpr, 2> srcExprs;
     SmallVector<AffineExpr, 2> dstExprs;
index f812621..78a29f4 100644 (file)
@@ -321,7 +321,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
         !filterDims.count(outputDim)) {
       // Batch dimension.
-      if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
+      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -329,7 +329,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.convolvedDims.count(outputDim) &&
         !filterDims.count(outputDim)) {
       // Output image Loop dimension.
-      if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
+      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -338,7 +338,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
         !inputExprWalker.unConvolvedDims.count(outputDim) &&
         filterDims.count(outputDim)) {
       // Output channel dimension.
-      if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
+      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -346,7 +346,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
         filterDims.count(outputDim)) {
       // Depth multiplier.
-      if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
+      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -364,7 +364,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.convolvedDims.count(filterDim) &&
         !outputDims.count(filterDim)) {
       // Filter loop dimension.
-      if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
+      if (iteratorTypes[filterDim] != getReductionIteratorTypeName())
         return MatchConvolutionResult::NonOutputDimNotReduction;
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
@@ -374,7 +374,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
         !outputDims.count(filterDim)) {
       // Input channel dimension.
-      if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
+      if (iteratorTypes[filterDim] != getReductionIteratorTypeName())
         return MatchConvolutionResult::NonOutputDimNotReduction;
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
@@ -619,6 +619,15 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
 
+  // Check all iterator types are known.
+  auto iteratorTypesRange = linalgOp.getIteratorTypesArray();
+  for (StringRef iteratorType : iteratorTypesRange) {
+    if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType) ||
+        !utils::symbolizeIteratorType(iteratorType).has_value())
+      return op->emitOpError("unexpected iterator_type (")
+             << iteratorType << ")";
+  }
+
   // Before checking indexing maps, we need to make sure the attributes
   // referenced by it are valid.
   if (linalgOp.hasDynamicIndexingMaps())
index 52ef33b..8ce1ad0 100644 (file)
@@ -705,17 +705,12 @@ void GenericOp::build(
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
-    StringRef libraryCall,
+    ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
     ArrayRef<NamedAttribute> attributes) {
   build(builder, result, resultTensorTypes, inputs, outputs,
         builder.getAffineMapArrayAttr(indexingMaps),
-        builder.getArrayAttr(llvm::to_vector(llvm::map_range(
-            iteratorTypes,
-            [&](utils::IteratorType iter) -> mlir::Attribute {
-              return IteratorTypeAttr::get(builder.getContext(), iter);
-            }))),
+        builder.getStrArrayAttr(iteratorTypes),
         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
         libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
         bodyBuild, attributes);
@@ -724,8 +719,7 @@ void GenericOp::build(
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
-    StringRef libraryCall,
+    ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
     ArrayRef<NamedAttribute> attributes) {
   build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
@@ -735,7 +729,7 @@ void GenericOp::build(
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<utils::IteratorType> iteratorTypes,
+    ArrayRef<StringRef> iteratorTypes,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
     ArrayRef<NamedAttribute> attributes) {
   build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
@@ -746,7 +740,7 @@ void GenericOp::build(
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<utils::IteratorType> iteratorTypes,
+    ArrayRef<StringRef> iteratorTypes,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
     ArrayRef<NamedAttribute> attributes) {
   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
@@ -764,29 +758,9 @@ void GenericOp::print(OpAsmPrinter &p) {
   llvm::StringSet<> genericAttrNamesSet;
   genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
   SmallVector<NamedAttribute, 8> genericAttrs;
-  for (auto attr : (*this)->getAttrs()) {
-    if (attr.getName() == getIteratorTypesAttrName()) {
-      auto iteratorTypes =
-          attr.getValue()
-              .cast<ArrayAttr>()
-              .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
-      // Convert IteratorType enums into the string representation. This is
-      // needed, because tests still use the old format when 'iterator_types'
-      // attribute is represented as an array of strings.
-      // TODO: Remove this conversion once tests are fixed.
-      SmallVector<Attribute> iteratorTypeNames =
-          llvm::to_vector(llvm::map_range(
-              iteratorTypes, [&](utils::IteratorType t) -> Attribute {
-                return StringAttr::get(getContext(), stringifyIteratorType(t));
-              }));
-
-      genericAttrs.emplace_back(
-          getIteratorTypesAttrName(),
-          ArrayAttr::get(getContext(), iteratorTypeNames));
-    } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
+  for (auto attr : (*this)->getAttrs())
+    if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
       genericAttrs.push_back(attr);
-    }
-  }
   if (!genericAttrs.empty()) {
     auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
     p << genericDictAttr;
@@ -831,28 +805,6 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
   result.attributes.assign(dictAttr.getValue().begin(),
                            dictAttr.getValue().end());
 
-  // Convert array of string into an array of IteratyType enums. This is needed,
-  // because tests still use the old format when 'iterator_types' attribute is
-  // represented as an array of strings.
-  // TODO: Remove this conversion once tests are fixed.
-  ArrayAttr iteratorTypes =
-      result.attributes.get(getIteratorTypesAttrName(result.name))
-          .cast<ArrayAttr>();
-
-  SmallVector<Attribute> iteratorTypeAttrs;
-
-  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
-    auto maybeIteratorType = utils::symbolizeIteratorType(s);
-    if (!maybeIteratorType.has_value())
-      return parser.emitError(parser.getCurrentLocation())
-             << "unexpected iterator_type (" << s << ")";
-
-    iteratorTypeAttrs.push_back(
-        IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
-  }
-  result.attributes.set(getIteratorTypesAttrName(result.name),
-                        parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
-
   // Parsing is shared with named ops, except for the region.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
@@ -1466,9 +1418,9 @@ LogicalResult MapOp::verify() {
   return success();
 }
 
-SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
+SmallVector<StringRef> MapOp::getIteratorTypesArray() {
   int64_t rank = getInit().getType().getRank();
-  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
 }
 
 ArrayAttr MapOp::getIndexingMaps() {
@@ -1524,12 +1476,12 @@ void ReduceOp::build(
                        inputs, inits, bodyBuild);
 }
 
-SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
+SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
   int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
-  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
-                                                 utils::IteratorType::parallel);
+  SmallVector<StringRef> iteratorTypes(inputRank,
+                                       getParallelIteratorTypeName());
   for (int64_t reductionDim : getDimensions())
-    iteratorTypes[reductionDim] = utils::IteratorType::reduction;
+    iteratorTypes[reductionDim] = getReductionIteratorTypeName();
   return iteratorTypes;
 }
 
@@ -1801,9 +1753,9 @@ LogicalResult TransposeOp::verify() {
   return success();
 }
 
-SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
+SmallVector<StringRef> TransposeOp::getIteratorTypesArray() {
   int64_t rank = getInit().getType().getRank();
-  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
 }
 
 ArrayAttr TransposeOp::getIndexingMaps() {
@@ -1939,9 +1891,9 @@ LogicalResult BroadcastOp::verify() {
   return success();
 }
 
-SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
+SmallVector<StringRef> BroadcastOp::getIteratorTypesArray() {
   int64_t rank = getInit().getType().getRank();
-  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
 }
 
 ArrayAttr BroadcastOp::getIndexingMaps() {
index 7fd5a5e..6a9c4e3 100644 (file)
@@ -470,9 +470,10 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
-             0 &&
-         llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
+         genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > 0 &&
+         llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) {
+           return it == getParallelIteratorTypeName();
+         });
 }
 
 namespace {
@@ -782,8 +783,8 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   }
 
   // The iterator types of the expanded op are all parallel.
-  SmallVector<utils::IteratorType> iteratorTypes(
-      expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
+  SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
+                                       getParallelIteratorTypeName());
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp =
@@ -1082,8 +1083,7 @@ getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
       continue;
 
     // Check that all folded iterator types are all parallel or all reductions.
-    utils::IteratorType startIteratorType =
-        iteratorTypes[foldedIterationSpaceDims[0]];
+    StringRef startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
     if (!isParallelIterator(startIteratorType) &&
         !isReductionIterator(startIteratorType))
       continue;
@@ -1235,10 +1235,10 @@ private:
 
 /// Get the iterator types for the collapsed operation given the original
 /// iterator types and collapsed dimensions.
-static SmallVector<utils::IteratorType>
-getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
+static SmallVector<StringRef>
+getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
                             const CollapsingInfo &collapsingInfo) {
-  SmallVector<utils::IteratorType> collapsedIteratorTypes;
+  SmallVector<StringRef> collapsedIteratorTypes;
   for (ReassociationIndicesRef foldedIterDims :
        collapsingInfo.getCollapsedOpToOrigOpMapping()) {
     assert(!foldedIterDims.empty() &&
@@ -1246,7 +1246,8 @@ getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
     // Just pick the iterator type of the first folded dim. Pre-condition checks
     // expected to have checked that iterator types of all folded dimensions are
     // the same.
-    collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
+    collapsedIteratorTypes.push_back(
+        iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
   }
   return collapsedIteratorTypes;
 }
@@ -1405,8 +1406,8 @@ static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
   }
 
   // Get the iterator types for the operand.
-  SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
-      genericOp.getIteratorTypesArray(), collapsingInfo);
+  SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
+      genericOp.getIteratorTypes().getValue(), collapsingInfo);
 
   // Get the indexing maps.
   auto indexingMaps = llvm::to_vector(
index 52287f1..3740633 100644 (file)
@@ -91,8 +91,8 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
     SmallVector<AffineMap, 3> indexingMaps(
         op->getNumResults() + op->getNumOperands(),
         rewriter.getMultiDimIdentityMap(rank));
-    SmallVector<utils::IteratorType, 6> iteratorTypes(
-        rank, utils::IteratorType::parallel);
+    SmallVector<StringRef, 6> iteratorTypes(rank,
+                                            getParallelIteratorTypeName());
     auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
         op, /*resultTensorTypes=*/op->getResultTypes(),
index 4755fa3..da43b49 100644 (file)
@@ -53,7 +53,7 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
   SmallVector<Value> inputs = linalgOp.getDpsInputOperands();
   SmallVector<Value> outputs = linalgOp.getDpsInitOperands();
   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
-  SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
+  SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
   SmallVector<Type> resultTypes = linalgOp.hasTensorSemantics()
                                       ? TypeRange(ValueRange(outputs))
                                       : TypeRange{};
index 2fb550b..0608c36 100644 (file)
@@ -162,13 +162,13 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
 
   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
                                    op.getContext()));
-  SmallVector<utils::IteratorType> newIteratorTypes;
+  SmallVector<StringRef> newIteratorTypes;
   for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) {
     if (insertSplitDimension == it.index() && !control.innerParallel)
-      newIteratorTypes.push_back(utils::IteratorType::parallel);
+      newIteratorTypes.push_back(getParallelIteratorTypeName());
     newIteratorTypes.push_back(it.value());
     if (insertSplitDimension == it.index() && control.innerParallel)
-      newIteratorTypes.push_back(utils::IteratorType::parallel);
+      newIteratorTypes.push_back(getParallelIteratorTypeName());
   }
   // Create the new op matching the original op with an extra parallel
   // dimension.
@@ -182,14 +182,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   // from the previous op.
   unsigned intermRank = newOutputShape.size();
   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
-  SmallVector<utils::IteratorType> reductionIteratorTypes;
+  SmallVector<StringRef> reductionIteratorTypes;
   SmallVector<AffineExpr> exprs;
   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
     if (insertSplitDimension == i) {
-      reductionIteratorTypes.push_back(utils::IteratorType::reduction);
+      reductionIteratorTypes.push_back(getReductionIteratorTypeName());
     } else {
       exprs.push_back(b.getAffineDimExpr(i));
-      reductionIteratorTypes.push_back(utils::IteratorType::parallel);
+      reductionIteratorTypes.push_back(getParallelIteratorTypeName());
     }
   }
   AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
@@ -367,7 +367,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
   // dimension.
   auto iteratorTypes = op.getIteratorTypesArray();
   iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
-                       utils::IteratorType::parallel);
+                       getParallelIteratorTypeName());
   GenericOp genericOp =
       b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
                           newOutputs, newMaps, iteratorTypes);
@@ -394,10 +394,10 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
     AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
     SmallVector<AffineMap> indexingMaps = {
         map, map.dropResult(insertSplitDimension)};
-    SmallVector<utils::IteratorType> reductionIteratorTypes(
-        originalOutputType.getRank() + 1, utils::IteratorType::parallel);
+    SmallVector<StringRef> reductionIteratorTypes(
+        originalOutputType.getRank() + 1, getParallelIteratorTypeName());
     reductionIteratorTypes[insertSplitDimension] =
-        utils::IteratorType::reduction;
+        getReductionIteratorTypeName();
 
     // clang-format off
     auto reductionOp = b.create<GenericOp>(
index e1d4616..5937da3 100644 (file)
@@ -431,7 +431,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
   auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
 
-  SmallVector<utils::IteratorType, 4> iteratorTypes;
+  SmallVector<StringRef, 4> iteratorTypes;
   for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
     if (loopIndexToRangeIndex.count(attr.index()))
       iteratorTypes.push_back(attr.value());
index 02f4e9d..d1fcc01 100644 (file)
@@ -88,7 +88,10 @@ struct LinalgOpTilingInterface
   /// Return the loop iterator type.
   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
-    return concreteOp.getIteratorTypesArray();
+    return llvm::to_vector(llvm::map_range(
+        concreteOp.getIteratorTypesArray(), [](StringRef iteratorType) {
+          return utils::symbolizeIteratorType(iteratorType).value();
+        }));
   }
 
   /// Return the iteration domain range.
@@ -336,9 +339,8 @@ struct LinalgOpPartialReductionInterface
 
     // Step3. create a generic op where the reduction dimension is replaced by a
     // parallel dimension of the size of reduction.
-    SmallVector<utils::IteratorType> newIteratorTypes =
-        linalgOp.getIteratorTypesArray();
-    newIteratorTypes[reductionDims[0]] = utils::IteratorType::parallel;
+    SmallVector<StringRef> newIteratorTypes = linalgOp.getIteratorTypesArray();
+    newIteratorTypes[reductionDims[0]] = getParallelIteratorTypeName();
     SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
     newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
                                     linalgOp.getContext());
@@ -364,14 +366,14 @@ struct LinalgOpPartialReductionInterface
     int64_t intermRank =
         partialReduce[0].getType().cast<ShapedType>().getRank();
     AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
-    SmallVector<utils::IteratorType> reductionIteratorTypes;
+    SmallVector<StringRef> reductionIteratorTypes;
     SmallVector<AffineExpr> exprs;
     for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
       if (dimToMerge == i) {
-        reductionIteratorTypes.push_back(utils::IteratorType::reduction);
+        reductionIteratorTypes.push_back(getReductionIteratorTypeName());
       } else {
         exprs.push_back(b.getAffineDimExpr(i));
-        reductionIteratorTypes.push_back(utils::IteratorType::parallel);
+        reductionIteratorTypes.push_back(getParallelIteratorTypeName());
       }
     }
     AffineMap outputMap =
index 11ee55c..1034e8e 100644 (file)
@@ -297,10 +297,8 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
   return vectorizeCopy(rewriter, copyOp);
 }
 
-static SmallVector<utils::IteratorType>
-getNParallelLoopsAttrs(unsigned nParallelLoops) {
-  return SmallVector<utils::IteratorType>(nParallelLoops,
-                                          utils::IteratorType::parallel);
+static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
+  return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
 }
 
 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to
index a643713..2cf74a6 100644 (file)
@@ -1420,12 +1420,11 @@ namespace {
 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
 /// ```
 /// kw is unrolled, w is unrolled iff dilationW > 1.
-struct Conv1DGenerator
-    : public StructuredGenerator<LinalgOp, utils::IteratorType> {
+struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
   Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
                   int dilationW)
-      : StructuredGenerator<LinalgOp, utils::IteratorType>(builder, linalgOp),
-        strideW(strideW), dilationW(dilationW) {
+      : StructuredGenerator<LinalgOp>(builder, linalgOp), strideW(strideW),
+        dilationW(dilationW) {
     // Determine whether `linalgOp` can be generated with this generator
     if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
       return;
index fc34353..ccf7cdc 100644 (file)
@@ -186,12 +186,12 @@ bool isElementwise(LinalgOp op) {
   return hasOnlyScalarElementwiseOp(op->getRegion(0));
 }
 
-bool isParallelIterator(utils::IteratorType iteratorType) {
-  return iteratorType == utils::IteratorType::parallel;
+bool isParallelIterator(StringRef iteratorType) {
+  return iteratorType == getParallelIteratorTypeName();
 }
 
-bool isReductionIterator(utils::IteratorType iteratorType) {
-  return iteratorType == utils::IteratorType::reduction;
+bool isReductionIterator(StringRef iteratorType) {
+  return iteratorType == getReductionIteratorTypeName();
 }
 
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
@@ -422,13 +422,15 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
           b.getContext())),
       AffineMap::getMultiDimIdentityMap(transposeVector.size(),
                                         b.getContext())};
-  SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(),
-                                                 utils::IteratorType::parallel);
+  SmallVector<llvm::StringRef> iteratorTypes(transposeVector.size(),
+                                             getParallelIteratorTypeName());
 
   // Create a GenericOp to transpose `inputTensor` into `outputTensor`.
-  auto transposeOp =
-      b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor,
-                          indexingMaps, iteratorTypes);
+  auto transposeOp = b.create<GenericOp>(
+      loc, resultTensorType, inputTensor, outputTensor,
+      b.getAffineMapArrayAttr(indexingMaps), b.getStrArrayAttr(iteratorTypes),
+      /*doc=*/nullptr,
+      /*library_call=*/nullptr);
   Region &body = transposeOp.getRegion();
   body.push_back(new Block());
   body.front().addArguments({elementType, elementType}, {loc, loc});
@@ -450,8 +452,8 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
 
   AffineMap id =
       AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
-  SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
-                                                 utils::IteratorType::parallel);
+  SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
+                                       getParallelIteratorTypeName());
   return b.create<linalg::GenericOp>(
       loc,
       /*inputs=*/from,
@@ -467,7 +469,7 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(
     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
-    ArrayRef<utils::IteratorType> iteratorTypes,
+    ArrayRef<StringRef> iteratorTypes,
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
@@ -511,7 +513,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
 template <>
 void GenerateLoopNest<AffineForOp>::doit(
     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
-    ArrayRef<utils::IteratorType> iteratorTypes,
+    ArrayRef<StringRef> iteratorTypes,
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
@@ -562,7 +564,7 @@ void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
 // exceeds 10.
 static void generateParallelLoopNest(
     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
-    ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
+    ValueRange steps, ArrayRef<StringRef> iteratorTypes,
     ArrayRef<linalg::ProcInfo> procInfo,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
     SmallVectorImpl<Value> &ivStorage) {
@@ -677,7 +679,7 @@ static void generateParallelLoopNest(
 template <>
 void GenerateLoopNest<scf::ParallelOp>::doit(
     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
-    ArrayRef<utils::IteratorType> iteratorTypes,
+    ArrayRef<StringRef> iteratorTypes,
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
index 3f2ee1b..533d31f 100644 (file)
@@ -178,8 +178,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
 /// as we use adj matrix for the graph.
 /// The sorted result will put the first Reduction iterator to the
 /// latest possible index.
-static bool topSortOptimal(unsigned n,
-                           ArrayRef<utils::IteratorType> iteratorTypes,
+static bool topSortOptimal(unsigned n, ArrayRef<StringRef> iteratorTypes,
                            std::vector<unsigned> &topSort,
                            std::vector<unsigned> &inDegree,
                            std::vector<std::vector<bool>> &adjM) {
index a9c77c6..7f2e970 100644 (file)
 using namespace mlir;
 using namespace mlir::tosa;
 
-SmallVector<utils::IteratorType>
+SmallVector<StringRef>
 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
-  return SmallVector<utils::IteratorType>(nParallelLoops,
-                                          utils::IteratorType::parallel);
+  return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
 }
 
 SmallVector<Value>
index 47aefa1..0bdaf7b 100644 (file)
@@ -1518,14 +1518,27 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
 }
 
 namespace {
+struct IteratorType {
+  IteratorType(StringRef strRef) : strRef(strRef) {}
+  bool isOfType(Attribute attr) const {
+    auto sAttr = attr.dyn_cast<StringAttr>();
+    return sAttr && sAttr.getValue() == strRef;
+  }
+  StringRef strRef;
+};
+struct Par : public IteratorType {
+  Par() : IteratorType(getParallelIteratorTypeName()) {}
+};
+struct Red : public IteratorType {
+  Red() : IteratorType(getReductionIteratorTypeName()) {}
+};
 
 /// Generate a vector implementation for matmat, matvec and tmatvec.
 /// This unrolls outer-products along the reduction dimension.
 struct UnrolledOuterProductGenerator
-    : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
+    : public StructuredGenerator<vector::ContractionOp> {
   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
-      : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(
-            builder, op),
+      : StructuredGenerator<vector::ContractionOp>(builder, op),
         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
         res(op.getAcc()), lhsType(op.getLhsType()) {}
 
@@ -2706,10 +2719,8 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
     } else {
       MemRefLayoutAttrInterface updatedLayout;
       if (auto strided = layout.dyn_cast<StridedLayoutAttr>()) {
-        auto strides =
-            llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
-        updatedLayout = StridedLayoutAttr::get(strided.getContext(),
-                                               strided.getOffset(), strides);
+        auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+        updatedLayout = StridedLayoutAttr::get(strided.getContext(), strided.getOffset(), strides);
       } else {
         AffineMap map = srcType.getLayout().getAffineMap();
         int numSymbols = map.getNumSymbols();
index e7d6dd9..f9b20d3 100644 (file)
@@ -17,7 +17,7 @@ func.func @test_conv_op_wrong_num_operands(%arg0 : tensor<?xf32>,
   // expected-error @+1 {{expected op with 2 inputs and 1 output}}
   %0 = test.linalg_conv_op {
       indexing_maps = [#map, #map],
-      iterator_types = [#test.iterator_type<parallel>]}
+      iterator_types = ["parallel"]}
       ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<?xf32>) {
       ^bb0(%arg2 : f32, %arg3 : f32):
          linalg.yield  %arg3 : f32
@@ -34,8 +34,7 @@ func.func @test_conv_op_wrong_input_indexing_map1(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d0 * 2)>,
                        affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d0)>],
-      iterator_types = [#test.iterator_type<parallel>,
-                        #test.iterator_type<reduction>]}
+      iterator_types = ["parallel", "reduction"]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -53,8 +52,7 @@ func.func @test_conv_op_wrong_input_indexing_map2(%arg0 : tensor<?x?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d0 + d1, d0)>,
                        affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d0)>],
-      iterator_types = [#test.iterator_type<parallel>,
-                        #test.iterator_type<reduction>]}
+      iterator_types = ["parallel", "reduction"]}
       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -72,8 +70,7 @@ func.func @test_conv_op_filter_index_map_not_projection(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d1 + d0)>,
                        affine_map<(d0, d1) -> (d0)>],
-      iterator_types = [#test.iterator_type<parallel>,
-                        #test.iterator_type<reduction>]}
+      iterator_types = ["parallel", "reduction"]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -91,8 +88,7 @@ func.func @test_conv_op_output_index_map_not_projection(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d0)>,
                        affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d0 + d1)>],
-      iterator_types = [#test.iterator_type<parallel>,
-                        #test.iterator_type<parallel>]}
+      iterator_types = ["parallel", "parallel"]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -112,8 +108,7 @@ func.func @test_conv_op_output_filter_convolved(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>,
                        affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d0, d1)>],
-      iterator_types = [#test.iterator_type<parallel>,
-                        #test.iterator_type<parallel>]}
+      iterator_types = ["parallel", "parallel"]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?x?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -132,9 +127,7 @@ func.func @test_conv_op_output_only_dim(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>,
                        affine_map<(d0, d1, d2) -> (d1)>,
                        affine_map<(d0, d1, d2) -> (d0, d2)>],
-      iterator_types = [#test.iterator_type<parallel>,
-                        #test.iterator_type<reduction>,
-                        #test.iterator_type<parallel>]}
+      iterator_types = ["parallel", "reduction", "parallel"]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?x?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -153,9 +146,7 @@ func.func @test_conv_op_filter_only_dim(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>,
                        affine_map<(d0, d1, d2) -> (d1, d2)>,
                        affine_map<(d0, d1, d2) -> (d0)>],
-      iterator_types = [#test.iterator_type<parallel>,
-                        #test.iterator_type<reduction>,
-                        #test.iterator_type<reduction>]}
+      iterator_types = ["parallel", "reduction", "reduction"]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -174,9 +165,7 @@ func.func @test_conv_op_input_only_dim(%arg0 : tensor<?x?xf32>,
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1, d2)>,
                        affine_map<(d0, d1, d2) -> (d1)>,
                        affine_map<(d0, d1, d2) -> (d0)>],
-      iterator_types = [#test.iterator_type<parallel>,
-                        #test.iterator_type<reduction>,
-                        #test.iterator_type<reduction>]}
+      iterator_types = ["parallel", "reduction", "reduction"]}
       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
@@ -195,8 +184,7 @@ func.func @test_conv_op_non_output_access_loop_parallel(%arg0 : tensor<?xf32>,
       indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>,
                        affine_map<(d0, d1) -> (d1)>,
                        affine_map<(d0, d1) -> (d0)>],
-      iterator_types = [#test.iterator_type<parallel>,
-                        #test.iterator_type<parallel>]}
+      iterator_types = ["parallel", "parallel"]}
       ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
       outs(%arg2 : tensor<?xf32>) {
       ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
index ebce71f..5a1c2af 100644 (file)
@@ -96,7 +96,7 @@ func.func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
 // -----
 
 func.func @generic_wrong_iterator(%arg0: memref<1xi32>) {
-  // expected-error @+4 {{unexpected iterator_type (random)}}
+  // expected-error @+1 {{op unexpected iterator_type (random)}}
   linalg.generic {
     indexing_maps =  [ affine_map<(i) -> (i)> ],
     iterator_types = ["random"]}
index b4ad820..4a92c07 100644 (file)
@@ -59,19 +59,13 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %match_attr = transform.structured.match
       ops{["linalg.generic"]}
-      attributes{iterator_types = [
-        #linalg.iterator_type<parallel>,
-        #linalg.iterator_type<parallel>,
-        #linalg.iterator_type<parallel>]}
+      attributes{iterator_types = ["parallel", "parallel", "parallel"]}
       in %arg1
   transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation
   transform.test_consume_operand %match_attr
 
   %no_match = transform.structured.match
-      attributes{iterator_types = [
-        #linalg.iterator_type<parallel>,
-        #linalg.iterator_type<parallel>,
-        #linalg.iterator_type<reduction>]}
+      attributes{iterator_types = ["parallel", "parallel", "reduction"]}
       in %arg1
 // expected-remark @below {{0}}
   transform.test_print_number_of_associated_payload_ir_ops %no_match
index 2d1a1df..2c8719d 100644 (file)
@@ -23,16 +23,13 @@ mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=test)
 mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=test)
 add_public_tablegen_target(MLIRTestTypeDefIncGen)
 
-set(LLVM_TARGET_DEFINITIONS TestEnumDefs.td)
-mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls)
-mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRTestEnumDefIncGen)
-
 set(LLVM_TARGET_DEFINITIONS TestOps.td)
 mlir_tablegen(TestOps.h.inc -gen-op-decls)
 mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
 mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test)
 mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test)
+mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls)
+mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
 mlir_tablegen(TestPatterns.inc -gen-rewriters)
 add_public_tablegen_target(MLIRTestOpsIncGen)
 
@@ -49,7 +46,6 @@ add_mlir_library(MLIRTestDialect
 
   DEPENDS
   MLIRTestAttrDefIncGen
-  MLIRTestEnumDefIncGen
   MLIRTestInterfaceIncGen
   MLIRTestTypeDefIncGen
   MLIRTestOpsIncGen
index c4996ab..0c35f81 100644 (file)
@@ -15,8 +15,6 @@
 
 // To get the test dialect definition.
 include "TestDialect.td"
-include "TestEnumDefs.td"
-include "mlir/Dialect/Utils/StructuredOpsUtils.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/EnumAttr.td"
@@ -279,6 +277,13 @@ def TestArrayOfInts : ArrayOfAttr<Test_Dialect, "ArrayOfInts",
     "array_of_ints", "int32_t">;
 
 // An array of enum attributes.
+def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
+    I32EnumAttrCase<"a", 0>,
+    I32EnumAttrCase<"b", 1>
+  ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::test";
+}
 def TestSimpleEnumAttr : EnumAttr<Test_Dialect, TestSimpleEnum, "simple_enum"> {
   let assemblyFormat = "`` $value";
 }
@@ -292,14 +297,4 @@ def TestCustomAnchor : Test_Attr<"TestCustomAnchor"> {
   let assemblyFormat = "`<` $a (`>`) : (`,` ` ` custom<TrueFalse>($b)^ `>`)?";
 }
 
-def Test_IteratorTypeEnum
-    : EnumAttr<Test_Dialect, IteratorType, "iterator_type"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
-def Test_IteratorTypeArrayAttr
-    : TypedArrayAttrBase<Test_IteratorTypeEnum,
-  "Iterator type should be an enum.">;
-
-
 #endif // TEST_ATTRDEFS
index cc73e07..4cb4d61 100644 (file)
@@ -17,7 +17,6 @@
 #include <tuple>
 
 #include "TestTraits.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
diff --git a/mlir/test/lib/Dialect/Test/TestEnumDefs.td b/mlir/test/lib/Dialect/Test/TestEnumDefs.td
deleted file mode 100644 (file)
index 1ddfca0..0000000
+++ /dev/null
@@ -1,97 +0,0 @@
-//===-- TestEnumDefs.td - Test dialect enum definitions ----*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// TableGen enum definitions for Test dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef TEST_ENUMDEFS
-#define TEST_ENUMDEFS
-
-include "mlir/IR/EnumAttr.td"
-
-def I32Case5:  I32EnumAttrCase<"case5", 5>;
-def I32Case10: I32EnumAttrCase<"case10", 10>;
-
-def SomeI32Enum: I32EnumAttr<
-  "SomeI32Enum", "", [I32Case5, I32Case10]>;
-
-def I64Case5:  I64EnumAttrCase<"case5", 5>;
-def I64Case10: I64EnumAttrCase<"case10", 10>;
-
-def SomeI64Enum: I64EnumAttr<
-  "SomeI64Enum", "", [I64Case5, I64Case10]>;
-
-//===----------------------------------------------------------------------===//
-// Test Enum
-//===----------------------------------------------------------------------===//
-
-// Define the C++ enum.
-def TestEnum
-    : I32EnumAttr<"TestEnum", "a test enum", [
-        I32EnumAttrCase<"First", 0, "first">,
-        I32EnumAttrCase<"Second", 1, "second">,
-        I32EnumAttrCase<"Third", 2, "third">,
-      ]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "test";
-}
-
-def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
-    I32EnumAttrCase<"a", 0>,
-    I32EnumAttrCase<"b", 1>
-  ]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::test";
-}
-
-//===----------------------------------------------------------------------===//
-// Test Bit Enum
-//===----------------------------------------------------------------------===//
-
-// Define the C++ enum.
-def TestBitEnum
-    : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
-        I32BitEnumAttrCaseBit<"Read", 0, "read">,
-        I32BitEnumAttrCaseBit<"Write", 1, "write">,
-        I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
-      ]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "test";
-  let separator = ", ";
-}
-
-// Define an enum with a different separator
-def TestBitEnumVerticalBar
-    : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
-        I32BitEnumAttrCaseBit<"User", 0, "user">,
-        I32BitEnumAttrCaseBit<"Group", 1, "group">,
-        I32BitEnumAttrCaseBit<"Other", 2, "other">,
-      ]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "test";
-  let separator = " | ";
-}
-
-//===----------------------------------------------------------------------===//
-// Test Patterns (Multi-result Ops)
-
-def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
-def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
-def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
-def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
-def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
-def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;
-
-def MultiResultOpEnum: I64EnumAttr<
-  "MultiResultOpEnum", "Multi-result op kinds", [
-    MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
-    MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
-  ]>;
-
-#endif // TEST_ENUMDEFS
index 84dd37f..84bbe24 100644 (file)
@@ -202,11 +202,23 @@ def FloatAttrOp : TEST_Op<"float_attrs"> {
   );
 }
 
+def I32Case5:  I32EnumAttrCase<"case5", 5>;
+def I32Case10: I32EnumAttrCase<"case10", 10>;
+
+def SomeI32Enum: I32EnumAttr<
+  "SomeI32Enum", "", [I32Case5, I32Case10]>;
+
 def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> {
   let arguments = (ins SomeI32Enum:$attr);
   let results = (outs I32:$val);
 }
 
+def I64Case5:  I64EnumAttrCase<"case5", 5>;
+def I64Case10: I64EnumAttrCase<"case10", 10>;
+
+def SomeI64Enum: I64EnumAttr<
+  "SomeI64Enum", "", [I64Case5, I64Case10]>;
+
 def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> {
   let arguments = (ins SomeI64Enum:$attr);
   let results = (outs I32:$val);
@@ -307,6 +319,17 @@ def ConfinedDenseArrayAttrOp : TEST_Op<"confined_dense_array_attr"> {
 // Test Enum Attributes
 //===----------------------------------------------------------------------===//
 
+// Define the C++ enum.
+def TestEnum
+    : I32EnumAttr<"TestEnum", "a test enum", [
+        I32EnumAttrCase<"First", 0, "first">,
+        I32EnumAttrCase<"Second", 1, "second">,
+        I32EnumAttrCase<"Third", 2, "third">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "test";
+}
+
 // Define the enum attribute.
 def TestEnumAttr : EnumAttr<Test_Dialect, TestEnum, "enum">;
 
@@ -328,6 +351,18 @@ def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
 // Test Bit Enum Attributes
 //===----------------------------------------------------------------------===//
 
+// Define the C++ enum.
+def TestBitEnum
+    : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
+        I32BitEnumAttrCaseBit<"Read", 0, "read">,
+        I32BitEnumAttrCaseBit<"Write", 1, "write">,
+        I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "test";
+  let separator = ", ";
+}
+
 // Define the enum attribute.
 def TestBitEnumAttr : EnumAttr<Test_Dialect, TestBitEnum, "bit_enum"> {
   let assemblyFormat = "`<` $value `>`";
@@ -339,6 +374,18 @@ def OpWithBitEnum : TEST_Op<"op_with_bit_enum"> {
   let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
 }
 
+// Define an enum with a different separator
+def TestBitEnumVerticalBar
+    : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
+        I32BitEnumAttrCaseBit<"User", 0, "user">,
+        I32BitEnumAttrCaseBit<"Group", 1, "group">,
+        I32BitEnumAttrCaseBit<"Other", 2, "other">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "test";
+  let separator = " | ";
+}
+
 def TestBitEnumVerticalBarAttr
     : EnumAttr<Test_Dialect, TestBitEnumVerticalBar, "bit_enum_vbar"> {
   let assemblyFormat = "`<` $value `>`";
@@ -1345,6 +1392,22 @@ def : Pat<(OpC $input), (OpB $input, ConstantAttr<I32Attr, "17">:$attr)>;
 def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
 def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
 
+//===----------------------------------------------------------------------===//
+// Test Patterns (Multi-result Ops)
+
+def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
+def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
+def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
+def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
+def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
+def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;
+
+def MultiResultOpEnum: I64EnumAttr<
+  "MultiResultOpEnum", "Multi-result op kinds", [
+    MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
+    MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
+  ]>;
+
 def ThreeResultOp : TEST_Op<"three_result"> {
   let arguments = (ins MultiResultOpEnum:$kind);
   let results = (outs I32:$result1, F32:$result2, F32:$result3);
@@ -2761,10 +2824,8 @@ def TestLinalgConvOp :
       return &regionBuilder;
     }
 
-    llvm::SmallVector<mlir::utils::IteratorType> getIteratorTypesArray() {
-      auto attrs = getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
-      auto range = attrs.getAsValueRange<IteratorTypeAttr, mlir::utils::IteratorType>();
-      return {range.begin(), range.end()};
+    mlir::ArrayAttr getIteratorTypes() {
+      return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
     }
 
     mlir::ArrayAttr getIndexingMaps() {
@@ -2823,10 +2884,8 @@ def TestLinalgFillOp :
       return &regionBuilder;
     }
 
-    llvm::SmallVector<mlir::utils::IteratorType> getIteratorTypesArray() {
-      auto attrs = getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
-      auto range = attrs.getAsValueRange<IteratorTypeAttr, mlir::utils::IteratorType>();
-      return {range.begin(), range.end()};
+    mlir::ArrayAttr getIteratorTypes() {
+      return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
     }
 
     mlir::ArrayAttr getIndexingMaps() {
index 3595be8..0a482cc 100644 (file)
@@ -553,7 +553,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
 
     let extraClassDeclaration = structuredOpsBaseDecls # [{{
       // Auto-generated.
-      SmallVector<utils::IteratorType> getIteratorTypesArray();
+      SmallVector<StringRef> getIteratorTypesArray();
       ArrayAttr getIndexingMaps();
       static void regionBuilder(ImplicitLocOpBuilder &b,
                                 Block &block, ArrayRef<NamedAttribute> attrs);
@@ -597,8 +597,8 @@ static const char structuredOpBuilderFormat[] = R"FMT(
 // {1}: Comma interleaved iterator type names.
 static const char structuredOpIteratorTypesFormat[] =
     R"FMT(
-SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
-  return SmallVector<utils::IteratorType>{{ {1} };
+SmallVector<StringRef> {0}::getIteratorTypesArray() {{
+  return SmallVector<StringRef>{{ {1} };
 }
 )FMT";
 
@@ -607,9 +607,9 @@ SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
 // {0}: Class name
 static const char rankPolyStructuredOpIteratorTypesFormat[] =
     R"FMT(
-SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
+SmallVector<StringRef> {0}::getIteratorTypesArray() {{
   int64_t rank = getRank(getDpsInitOperand(0));
-  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
 }
 )FMT";
 
@@ -812,10 +812,10 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
                           [&](LinalgIteratorTypeDef it) {
                             switch (it) {
                             case LinalgIteratorTypeDef::parallel:
-                              ss << "utils::IteratorType::parallel";
+                              ss << "getParallelIteratorTypeName()";
                               break;
                             case LinalgIteratorTypeDef::reduction:
-                              ss << "utils::IteratorType::reduction";
+                              ss << "getReductionIteratorTypeName()";
                               break;
                             }
                           });
index 0a64646..d58a07f 100644 (file)
@@ -132,6 +132,14 @@ gentbl_cc_library(
             "lib/Dialect/Test/TestOpsDialect.cpp.inc",
         ),
         (
+            ["-gen-enum-decls"],
+            "lib/Dialect/Test/TestOpEnums.h.inc",
+        ),
+        (
+            ["-gen-enum-defs"],
+            "lib/Dialect/Test/TestOpEnums.cpp.inc",
+        ),
+        (
             ["-gen-rewriters"],
             "lib/Dialect/Test/TestPatterns.inc",
         ),
@@ -204,27 +212,6 @@ gentbl_cc_library(
 )
 
 gentbl_cc_library(
-    name = "TestEnumDefsIncGen",
-    strip_include_prefix = "lib/Dialect/Test",
-    tbl_outs = [
-        (
-            ["-gen-enum-decls"],
-            "lib/Dialect/Test/TestOpEnums.h.inc",
-        ),
-        (
-            ["-gen-enum-defs"],
-            "lib/Dialect/Test/TestOpEnums.cpp.inc",
-        ),
-    ],
-    tblgen = "//mlir:mlir-tblgen",
-    td_file = "lib/Dialect/Test/TestEnumDefs.td",
-    test = True,
-    deps = [
-        ":TestOpTdFiles",
-    ],
-)
-
-gentbl_cc_library(
     name = "TestTypeDefsIncGen",
     strip_include_prefix = "lib/Dialect/Test",
     tbl_outs = [
@@ -331,7 +318,6 @@ cc_library(
     ],
     deps = [
         ":TestAttrDefsIncGen",
-        ":TestEnumDefsIncGen",
         ":TestInterfacesIncGen",
         ":TestOpsIncGen",
         ":TestTypeDefsIncGen",
@@ -344,7 +330,6 @@ cc_library(
         "//mlir:DerivedAttributeOpInterface",
         "//mlir:DestinationStyleOpInterface",
         "//mlir:Dialect",
-        "//mlir:DialectUtils",
         "//mlir:FuncDialect",
         "//mlir:FuncTransforms",
         "//mlir:IR",