From 41280908e43d47903960c66237ab49caa5641b4d Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 9 Nov 2022 15:59:54 +0100 Subject: [PATCH] Revert "[mlir][linalg] Replace "string" iterator_types attr with enums in LinalgInterface." Breaks linalg python tests. Would need to also update python/mlir/dialects/linalg/opdsl. This reverts commit b809d73973bb5aeedeb6a18cac2a7b3111d0c8d2. --- mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td | 7 -- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 1 - .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 52 +++++++++--- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 18 ++-- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 7 +- .../mlir/Dialect/Tosa/Utils/CoversionUtils.h | 3 +- .../mlir/Dialect/Utils/StructuredOpsUtils.h | 59 ++++++++++--- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 11 +-- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 15 ++-- mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 21 +++-- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 82 ++++-------------- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 27 +++--- .../Linalg/Transforms/ElementwiseToLinalg.cpp | 4 +- .../Dialect/Linalg/Transforms/Generalization.cpp | 2 +- .../Dialect/Linalg/Transforms/SplitReduction.cpp | 20 ++--- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 16 ++-- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 6 +- .../Dialect/Linalg/Transforms/Vectorization.cpp | 7 +- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 32 +++---- .../SparseTensor/Transforms/Sparsification.cpp | 3 +- mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp | 5 +- .../Dialect/Vector/Transforms/VectorTransforms.cpp | 25 ++++-- .../Dialect/Linalg/conv-interface-invalid.mlir | 32 +++---- mlir/test/Dialect/Linalg/invalid.mlir | 2 +- mlir/test/Dialect/Linalg/transform-op-match.mlir | 10 +-- mlir/test/lib/Dialect/Test/CMakeLists.txt | 8 +- mlir/test/lib/Dialect/Test/TestAttrDefs.td | 19 ++--- mlir/test/lib/Dialect/Test/TestAttributes.h | 1 - mlir/test/lib/Dialect/Test/TestEnumDefs.td | 97 ---------------------- mlir/test/lib/Dialect/Test/TestOps.td | 75 +++++++++++++++-- .../mlir-linalg-ods-yaml-gen.cpp | 14 ++-- .../llvm-project-overlay/mlir/test/BUILD.bazel | 31 ++----- 33 files changed, 329 insertions(+), 385 deletions(-) delete mode 100644 mlir/test/lib/Dialect/Test/TestEnumDefs.td diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index 1ee9d2a..36cf41d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -13,7 +13,6 @@ #ifndef LINALG_BASE #define LINALG_BASE -include "mlir/Dialect/Utils/StructuredOpsUtils.td" include "mlir/Dialect/Linalg/IR/LinalgEnums.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" @@ -72,10 +71,4 @@ def TypeFnAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } -def IteratorTypeEnum : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} -def IteratorTypeArrayAttr : TypedArrayAttrBase; - #endif // LINALG_BASE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index 45d63d8..8e3df10 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -25,7 +25,6 @@ namespace mlir { namespace linalg { -class IteratorTypeAttr; class LinalgOp; namespace detail { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 17f6bf4..533a52f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -193,8 +193,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::count($_op.getIteratorTypesArray(), - utils::IteratorType::parallel); + return getNumIterators(getParallelIteratorTypeName(), + $_op.getIteratorTypesArray()); }] >, InterfaceMethod< @@ -207,7 +207,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*methodBody=*/"", /*defaultImplementation=*/[{ return findPositionsOfType($_op.getIteratorTypesArray(), - utils::IteratorType::parallel, res); + getParallelIteratorTypeName(), res); }] >, InterfaceMethod< @@ -219,8 +219,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::count($_op.getIteratorTypesArray(), - utils::IteratorType::reduction); + return getNumIterators(getReductionIteratorTypeName(), + $_op.getIteratorTypesArray()); }] >, InterfaceMethod< @@ -233,7 +233,33 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*methodBody=*/"", /*defaultImplementation=*/[{ return findPositionsOfType($_op.getIteratorTypesArray(), - utils::IteratorType::reduction, res); + getReductionIteratorTypeName(), res); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the number of window loops. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumWindowLoops", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getNumIterators(getWindowIteratorTypeName(), + $_op.getIteratorTypesArray()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dims that are window loops. + }], + /*retTy=*/"void", + /*methodName=*/"getWindowDims", + /*args=*/(ins "SmallVectorImpl &":$res), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return findPositionsOfType($_op.getIteratorTypesArray(), + getWindowIteratorTypeName(), res); }] >, InterfaceMethod< @@ -245,7 +271,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getIteratorTypesArray().size(); + return getNumIterators($_op.getIteratorTypesArray()); }] >, InterfaceMethod< @@ -260,7 +286,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*defaultImplementation=*/[{ auto iters = $_op.getIteratorTypesArray(); return iters.size() == 1 && - llvm::count(iters, utils::IteratorType::reduction) == 1; + getNumIterators(getReductionIteratorTypeName(), iters) == 1; }]>, //===------------------------------------------------------------------===// // Input and Init arguments handling. @@ -480,14 +506,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { can be infered from other parameters and in such cases default getIteratorTypesArray should be overriden. }], - /*retTy=*/"SmallVector", + /*retTy=*/"SmallVector", /*methodName=*/"getIteratorTypesArray", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto range = $_op.getIteratorTypes() - .template getAsValueRange(); + auto range = $_op.getIteratorTypes().template getAsValueRange(); return {range.begin(), range.end()}; }] >, @@ -743,6 +767,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes); + SmallVector getIteratorTypeNames() { + return getIteratorTypesArray(); + } + //========================================================================// // Forwarding functions to access interface methods from the // DestinationStyleOpInterface. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e822435..9866620 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -163,7 +163,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [ let arguments = (ins Variadic:$inputs, Variadic:$outputs, AffineMapArrayAttr:$indexing_maps, - IteratorTypeArrayAttr:$iterator_types, + ArrayAttr:$iterator_types, OptionalAttr:$doc, OptionalAttr:$library_call); let results = (outs Variadic:$result_tensors); @@ -178,22 +178,22 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [ CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayRef":$indexingMaps, - "ArrayRef":$iteratorTypes, "StringRef":$doc, + "ArrayRef":$iteratorTypes, "StringRef":$doc, "StringRef":$libraryCall, CArg<"function_ref", "nullptr">, CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, - "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, + "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, "StringRef":$doc, "StringRef":$libraryCall, CArg<"function_ref", "nullptr">, CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayRef":$indexingMaps, - "ArrayRef":$iteratorTypes, + "ArrayRef":$iteratorTypes, CArg<"function_ref", "nullptr">, CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, - "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, + "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, CArg<"function_ref", "nullptr">, CArg<"ArrayRef", "{}">:$attributes)> ]; @@ -275,7 +275,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Implement functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; @@ -356,7 +356,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; @@ -426,7 +426,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; @@ -502,7 +502,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 4f9dd71..5fc7938 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -42,10 +42,10 @@ bool hasOnlyScalarElementwiseOp(Region &r); bool isElementwise(LinalgOp op); /// Check if iterator type has "parallel" semantics. -bool isParallelIterator(utils::IteratorType iteratorType); +bool isParallelIterator(StringRef iteratorType); /// Check if iterator type has "reduction" semantics. -bool isReductionIterator(utils::IteratorType iteratorType); +bool isReductionIterator(StringRef iteratorType); /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. @@ -480,8 +480,7 @@ struct RegionMatcher { template struct GenerateLoopNest { static void doit(OpBuilder &b, Location loc, ArrayRef loopRanges, - LinalgOp linalgOp, - ArrayRef iteratorTypes, + LinalgOp linalgOp, ArrayRef iteratorTypes, function_ref bodyBuilderFn, diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h index 6b2104f..b2b7b24 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h @@ -22,8 +22,7 @@ namespace mlir { namespace tosa { // Creates a SmallVector of Stringrefs for N parallel loops -SmallVector -getNParallelLoopsAttrs(unsigned nParallelLoops); +SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops); // Takes a vector of values and condenses them to a vector with no gaps. SmallVector condenseValues(const SmallVector &values); diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index cb509fe..6fcfcb1 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -21,6 +21,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" // Pull in all enum type definitions and utility function declarations. #include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc" @@ -47,9 +48,42 @@ bool isColumnMajorMatmul(ArrayAttr indexingMaps); /// the reduction. bool isRowMajorBatchMatmul(ArrayAttr indexingMaps); +/// Use to encode that a particular iterator type has parallel semantics. +constexpr StringRef getParallelIteratorTypeName() { return "parallel"; } + +/// Use to encode that a particular iterator type has reduction semantics. +constexpr StringRef getReductionIteratorTypeName() { return "reduction"; } + +/// Use to encode that a particular iterator type has window semantics. +constexpr StringRef getWindowIteratorTypeName() { return "window"; } + +/// Use to encode that a particular iterator type has window semantics. +inline ArrayRef getAllIteratorTypeNames() { + static constexpr StringRef names[3] = {getParallelIteratorTypeName(), + getReductionIteratorTypeName(), + getWindowIteratorTypeName()}; + return llvm::makeArrayRef(names); +} + +/// Returns the iterator of a certain type. +inline unsigned getNumIterators(StringRef name, + ArrayRef iteratorTypes) { + auto names = getAllIteratorTypeNames(); + (void)names; + assert(llvm::is_contained(names, name)); + return llvm::count(iteratorTypes, name); +} + +inline unsigned getNumIterators(ArrayRef iteratorTypes) { + unsigned res = 0; + for (auto n : getAllIteratorTypeNames()) + res += getNumIterators(n, iteratorTypes); + return res; +} + /// Return positions in `iteratorTypes` that match `iteratorTypeName`. -inline void findPositionsOfType(ArrayRef iteratorTypes, - utils::IteratorType iteratorTypeName, +inline void findPositionsOfType(ArrayRef iteratorTypes, + StringRef iteratorTypeName, SmallVectorImpl &res) { for (const auto &en : llvm::enumerate(iteratorTypes)) { if (en.value() == iteratorTypeName) @@ -60,28 +94,29 @@ inline void findPositionsOfType(ArrayRef iteratorTypes, /// Helper StructuredGenerator class to manipulate and rewrite ops with /// `StructuredOpInterface`. This is templated for now because VectorOps do not /// yet implement the StructuredOpInterface itself. -template +template class StructuredGenerator { public: using MapList = ArrayRef>; struct IteratorType { - IteratorType(IteratorTypeT iter) : iter(iter) {} - bool isOfType(IteratorTypeT expectedIter) const { - return expectedIter == iter; - } - IteratorTypeT iter; + IteratorType(StringRef strRef) : strRef(strRef) {} + bool isOfType(StringRef typeName) const { return typeName == strRef; } + StringRef strRef; }; struct Par : public IteratorType { - Par() : IteratorType(IteratorTypeT::parallel) {} + Par() : IteratorType(getParallelIteratorTypeName()) {} }; struct Red : public IteratorType { - Red() : IteratorType(IteratorTypeT::reduction) {} + Red() : IteratorType(getReductionIteratorTypeName()) {} + }; + struct Win : public IteratorType { + Win() : IteratorType(getWindowIteratorTypeName()) {} }; StructuredGenerator(OpBuilder &builder, StructuredOpInterface op) : builder(builder), ctx(op.getContext()), loc(op.getLoc()), - iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()), + iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()), op(op) {} bool iters(ArrayRef its) { @@ -103,7 +138,7 @@ protected: OpBuilder &builder; MLIRContext *ctx; Location loc; - SmallVector iterators; + SmallVector iterators; SmallVector maps; Operation *op; }; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 5060d8c..758e7c1 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -269,11 +269,12 @@ def Vector_ContractionOp : return CombiningKind::ADD; } - SmallVector getIteratorTypesArray() { - auto range = - getIteratorTypes() - .template getAsValueRange(); - return {range.begin(), range.end()}; + // Returns iterator types in string format. + SmallVector getIteratorTypeNames() { + return llvm::to_vector( + llvm::map_range(getIteratorTypes(), [](Attribute a) { + return stringifyIteratorType(a.cast().getValue()); + })); } }]; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index f56162b..04cd00f 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -791,12 +791,12 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, SmallVector srcExprs; SmallVector dstExprs; - SmallVector iteratorTypes; + SmallVector iteratorTypes; for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) { srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); - iteratorTypes.push_back(axis == i ? utils::IteratorType::reduction - : utils::IteratorType::parallel); + iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName() + : getParallelIteratorTypeName()); if (axis != i) dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); } @@ -1383,8 +1383,7 @@ public: auto inputMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, inputExprs, builder.getContext()); auto resultMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); - SmallVector iterators(4, - utils::IteratorType::parallel); + SmallVector iterators(4, getParallelIteratorTypeName()); Value empty = builder.create( resultTy.getShape(), resultTy.getElementType(), outputDynSize); @@ -2084,9 +2083,9 @@ public: // We need to reduce along the arg-max axis, with parallel operations along // the rest. - SmallVector iteratorTypes; - iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel); - iteratorTypes[axis] = utils::IteratorType::reduction; + SmallVector iteratorTypes; + iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName()); + iteratorTypes[axis] = getReductionIteratorTypeName(); SmallVector srcExprs; SmallVector dstExprs; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index f812621..78a29f4 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -321,7 +321,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { if (inputExprWalker.unConvolvedDims.count(outputDim) && !filterDims.count(outputDim)) { // Batch dimension. - if (iteratorTypes[outputDim] != utils::IteratorType::parallel) + if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -329,7 +329,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { if (inputExprWalker.convolvedDims.count(outputDim) && !filterDims.count(outputDim)) { // Output image Loop dimension. - if (iteratorTypes[outputDim] != utils::IteratorType::parallel) + if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -338,7 +338,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { !inputExprWalker.unConvolvedDims.count(outputDim) && filterDims.count(outputDim)) { // Output channel dimension. - if (iteratorTypes[outputDim] != utils::IteratorType::parallel) + if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -346,7 +346,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { if (inputExprWalker.unConvolvedDims.count(outputDim) && filterDims.count(outputDim)) { // Depth multiplier. - if (iteratorTypes[outputDim] != utils::IteratorType::parallel) + if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -364,7 +364,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { if (inputExprWalker.convolvedDims.count(filterDim) && !outputDims.count(filterDim)) { // Filter loop dimension. - if (iteratorTypes[filterDim] != utils::IteratorType::reduction) + if (iteratorTypes[filterDim] != getReductionIteratorTypeName()) return MatchConvolutionResult::NonOutputDimNotReduction; if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; @@ -374,7 +374,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { if (inputExprWalker.unConvolvedDims.count(filterDim) && !outputDims.count(filterDim)) { // Input channel dimension. - if (iteratorTypes[filterDim] != utils::IteratorType::reduction) + if (iteratorTypes[filterDim] != getReductionIteratorTypeName()) return MatchConvolutionResult::NonOutputDimNotReduction; if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; @@ -619,6 +619,15 @@ LinalgOp::reifyResultShapes(OpBuilder &b, LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); + // Check all iterator types are known. + auto iteratorTypesRange = linalgOp.getIteratorTypesArray(); + for (StringRef iteratorType : iteratorTypesRange) { + if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType) || + !utils::symbolizeIteratorType(iteratorType).has_value()) + return op->emitOpError("unexpected iterator_type (") + << iteratorType << ")"; + } + // Before checking indexing maps, we need to make sure the attributes // referenced by it are valid. if (linalgOp.hasDynamicIndexingMaps()) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 52ef33b..8ce1ad0 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -705,17 +705,12 @@ void GenericOp::build( void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, - ArrayRef iteratorTypes, StringRef doc, - StringRef libraryCall, + ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, builder.getAffineMapArrayAttr(indexingMaps), - builder.getArrayAttr(llvm::to_vector(llvm::map_range( - iteratorTypes, - [&](utils::IteratorType iter) -> mlir::Attribute { - return IteratorTypeAttr::get(builder.getContext(), iter); - }))), + builder.getStrArrayAttr(iteratorTypes), doc.empty() ? StringAttr() : builder.getStringAttr(doc), libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), bodyBuild, attributes); @@ -724,8 +719,7 @@ void GenericOp::build( void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, - ArrayRef iteratorTypes, StringRef doc, - StringRef libraryCall, + ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, @@ -735,7 +729,7 @@ void GenericOp::build( void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, @@ -746,7 +740,7 @@ void GenericOp::build( void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, @@ -764,29 +758,9 @@ void GenericOp::print(OpAsmPrinter &p) { llvm::StringSet<> genericAttrNamesSet; genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); SmallVector genericAttrs; - for (auto attr : (*this)->getAttrs()) { - if (attr.getName() == getIteratorTypesAttrName()) { - auto iteratorTypes = - attr.getValue() - .cast() - .getAsValueRange(); - // Convert IteratorType enums into the string representation. This is - // needed, because tests still use the old format when 'iterator_types' - // attribute is represented as an array of strings. - // TODO: Remove this conversion once tests are fixed. - SmallVector iteratorTypeNames = - llvm::to_vector(llvm::map_range( - iteratorTypes, [&](utils::IteratorType t) -> Attribute { - return StringAttr::get(getContext(), stringifyIteratorType(t)); - })); - - genericAttrs.emplace_back( - getIteratorTypesAttrName(), - ArrayAttr::get(getContext(), iteratorTypeNames)); - } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) { + for (auto attr : (*this)->getAttrs()) + if (genericAttrNamesSet.count(attr.getName().strref()) > 0) genericAttrs.push_back(attr); - } - } if (!genericAttrs.empty()) { auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs); p << genericDictAttr; @@ -831,28 +805,6 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) { result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); - // Convert array of string into an array of IteratyType enums. This is needed, - // because tests still use the old format when 'iterator_types' attribute is - // represented as an array of strings. - // TODO: Remove this conversion once tests are fixed. - ArrayAttr iteratorTypes = - result.attributes.get(getIteratorTypesAttrName(result.name)) - .cast(); - - SmallVector iteratorTypeAttrs; - - for (StringRef s : iteratorTypes.getAsValueRange()) { - auto maybeIteratorType = utils::symbolizeIteratorType(s); - if (!maybeIteratorType.has_value()) - return parser.emitError(parser.getCurrentLocation()) - << "unexpected iterator_type (" << s << ")"; - - iteratorTypeAttrs.push_back( - IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value())); - } - result.attributes.set(getIteratorTypesAttrName(result.name), - parser.getBuilder().getArrayAttr(iteratorTypeAttrs)); - // Parsing is shared with named ops, except for the region. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) @@ -1466,9 +1418,9 @@ LogicalResult MapOp::verify() { return success(); } -SmallVector MapOp::getIteratorTypesArray() { +SmallVector MapOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); - return SmallVector(rank, utils::IteratorType::parallel); + return SmallVector(rank, getParallelIteratorTypeName()); } ArrayAttr MapOp::getIndexingMaps() { @@ -1524,12 +1476,12 @@ void ReduceOp::build( inputs, inits, bodyBuild); } -SmallVector ReduceOp::getIteratorTypesArray() { +SmallVector ReduceOp::getIteratorTypesArray() { int64_t inputRank = getInputs()[0].getType().cast().getRank(); - SmallVector iteratorTypes(inputRank, - utils::IteratorType::parallel); + SmallVector iteratorTypes(inputRank, + getParallelIteratorTypeName()); for (int64_t reductionDim : getDimensions()) - iteratorTypes[reductionDim] = utils::IteratorType::reduction; + iteratorTypes[reductionDim] = getReductionIteratorTypeName(); return iteratorTypes; } @@ -1801,9 +1753,9 @@ LogicalResult TransposeOp::verify() { return success(); } -SmallVector TransposeOp::getIteratorTypesArray() { +SmallVector TransposeOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); - return SmallVector(rank, utils::IteratorType::parallel); + return SmallVector(rank, getParallelIteratorTypeName()); } ArrayAttr TransposeOp::getIndexingMaps() { @@ -1939,9 +1891,9 @@ LogicalResult BroadcastOp::verify() { return success(); } -SmallVector BroadcastOp::getIteratorTypesArray() { +SmallVector BroadcastOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); - return SmallVector(rank, utils::IteratorType::parallel); + return SmallVector(rank, getParallelIteratorTypeName()); } ArrayAttr BroadcastOp::getIndexingMaps() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 7fd5a5e..6a9c4e3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -470,9 +470,10 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, .getValue() .isProjectedPermutation(); }) && - genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > - 0 && - llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator); + genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > 0 && + llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) { + return it == getParallelIteratorTypeName(); + }); } namespace { @@ -782,8 +783,8 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, } // The iterator types of the expanded op are all parallel. - SmallVector iteratorTypes( - expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel); + SmallVector iteratorTypes(expansionInfo.getExpandedOpNumDims(), + getParallelIteratorTypeName()); TypeRange resultTypes = ValueRange(outputs).getTypes(); auto fusedOp = @@ -1082,8 +1083,7 @@ getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, continue; // Check that all folded iterator types are all parallel or all reductions. - utils::IteratorType startIteratorType = - iteratorTypes[foldedIterationSpaceDims[0]]; + StringRef startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]]; if (!isParallelIterator(startIteratorType) && !isReductionIterator(startIteratorType)) continue; @@ -1235,10 +1235,10 @@ private: /// Get the iterator types for the collapsed operation given the original /// iterator types and collapsed dimensions. -static SmallVector -getCollapsedOpIteratorTypes(ArrayRef iteratorTypes, +static SmallVector +getCollapsedOpIteratorTypes(ArrayRef iteratorTypes, const CollapsingInfo &collapsingInfo) { - SmallVector collapsedIteratorTypes; + SmallVector collapsedIteratorTypes; for (ReassociationIndicesRef foldedIterDims : collapsingInfo.getCollapsedOpToOrigOpMapping()) { assert(!foldedIterDims.empty() && @@ -1246,7 +1246,8 @@ getCollapsedOpIteratorTypes(ArrayRef iteratorTypes, // Just pick the iterator type of the first folded dim. Pre-condition checks // expected to have checked that iterator types of all folded dimensions are // the same. - collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]); + collapsedIteratorTypes.push_back( + iteratorTypes[foldedIterDims[0]].cast().getValue()); } return collapsedIteratorTypes; } @@ -1405,8 +1406,8 @@ static FailureOr> collapseGenericOpIterationDims( } // Get the iterator types for the operand. - SmallVector iteratorTypes = getCollapsedOpIteratorTypes( - genericOp.getIteratorTypesArray(), collapsingInfo); + SmallVector iteratorTypes = getCollapsedOpIteratorTypes( + genericOp.getIteratorTypes().getValue(), collapsingInfo); // Get the indexing maps. auto indexingMaps = llvm::to_vector( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index 52287f1..3740633 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -91,8 +91,8 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { SmallVector indexingMaps( op->getNumResults() + op->getNumOperands(), rewriter.getMultiDimIdentityMap(rank)); - SmallVector iteratorTypes( - rank, utils::IteratorType::parallel); + SmallVector iteratorTypes(rank, + getParallelIteratorTypeName()); auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); rewriter.replaceOpWithNewOp( op, /*resultTensorTypes=*/op->getResultTypes(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 4755fa3..da43b49 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -53,7 +53,7 @@ FailureOr mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, SmallVector inputs = linalgOp.getDpsInputOperands(); SmallVector outputs = linalgOp.getDpsInitOperands(); SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); - SmallVector iterators = linalgOp.getIteratorTypesArray(); + SmallVector iterators = linalgOp.getIteratorTypesArray(); SmallVector resultTypes = linalgOp.hasTensorSemantics() ? TypeRange(ValueRange(outputs)) : TypeRange{}; diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp index 2fb550b..0608c36 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -162,13 +162,13 @@ FailureOr mlir::linalg::splitReduction( newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, op.getContext())); - SmallVector newIteratorTypes; + SmallVector newIteratorTypes; for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) { if (insertSplitDimension == it.index() && !control.innerParallel) - newIteratorTypes.push_back(utils::IteratorType::parallel); + newIteratorTypes.push_back(getParallelIteratorTypeName()); newIteratorTypes.push_back(it.value()); if (insertSplitDimension == it.index() && control.innerParallel) - newIteratorTypes.push_back(utils::IteratorType::parallel); + newIteratorTypes.push_back(getParallelIteratorTypeName()); } // Create the new op matching the original op with an extra parallel // dimension. @@ -182,14 +182,14 @@ FailureOr mlir::linalg::splitReduction( // from the previous op. unsigned intermRank = newOutputShape.size(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); - SmallVector reductionIteratorTypes; + SmallVector reductionIteratorTypes; SmallVector exprs; for (unsigned i : llvm::seq(0, intermRank)) { if (insertSplitDimension == i) { - reductionIteratorTypes.push_back(utils::IteratorType::reduction); + reductionIteratorTypes.push_back(getReductionIteratorTypeName()); } else { exprs.push_back(b.getAffineDimExpr(i)); - reductionIteratorTypes.push_back(utils::IteratorType::parallel); + reductionIteratorTypes.push_back(getParallelIteratorTypeName()); } } AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); @@ -367,7 +367,7 @@ FailureOr mlir::linalg::splitReductionByScaling( // dimension. auto iteratorTypes = op.getIteratorTypesArray(); iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, - utils::IteratorType::parallel); + getParallelIteratorTypeName()); GenericOp genericOp = b.create(loc, ValueRange(newOutputs).getTypes(), newInputs, newOutputs, newMaps, iteratorTypes); @@ -394,10 +394,10 @@ FailureOr mlir::linalg::splitReductionByScaling( AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); SmallVector indexingMaps = { map, map.dropResult(insertSplitDimension)}; - SmallVector reductionIteratorTypes( - originalOutputType.getRank() + 1, utils::IteratorType::parallel); + SmallVector reductionIteratorTypes( + originalOutputType.getRank() + 1, getParallelIteratorTypeName()); reductionIteratorTypes[insertSplitDimension] = - utils::IteratorType::reduction; + getReductionIteratorTypeName(); // clang-format off auto reductionOp = b.create( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index e1d4616..5937da3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -431,7 +431,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef tileSizes, auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges( b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); - SmallVector iteratorTypes; + SmallVector iteratorTypes; for (const auto &attr : enumerate(op.getIteratorTypesArray())) { if (loopIndexToRangeIndex.count(attr.index())) iteratorTypes.push_back(attr.value()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 02f4e9d..d1fcc01 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -88,7 +88,10 @@ struct LinalgOpTilingInterface /// Return the loop iterator type. SmallVector getLoopIteratorTypes(Operation *op) const { LinalgOpTy concreteOp = cast(op); - return concreteOp.getIteratorTypesArray(); + return llvm::to_vector(llvm::map_range( + concreteOp.getIteratorTypesArray(), [](StringRef iteratorType) { + return utils::symbolizeIteratorType(iteratorType).value(); + })); } /// Return the iteration domain range. @@ -336,9 +339,8 @@ struct LinalgOpPartialReductionInterface // Step3. create a generic op where the reduction dimension is replaced by a // parallel dimension of the size of reduction. - SmallVector newIteratorTypes = - linalgOp.getIteratorTypesArray(); - newIteratorTypes[reductionDims[0]] = utils::IteratorType::parallel; + SmallVector newIteratorTypes = linalgOp.getIteratorTypesArray(); + newIteratorTypes[reductionDims[0]] = getParallelIteratorTypeName(); SmallVector newMaps = linalgOp.getIndexingMapsArray(); newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr, linalgOp.getContext()); @@ -364,14 +366,14 @@ struct LinalgOpPartialReductionInterface int64_t intermRank = partialReduce[0].getType().cast().getRank(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); - SmallVector reductionIteratorTypes; + SmallVector reductionIteratorTypes; SmallVector exprs; for (int64_t i : llvm::seq(0, intermRank)) { if (dimToMerge == i) { - reductionIteratorTypes.push_back(utils::IteratorType::reduction); + reductionIteratorTypes.push_back(getReductionIteratorTypeName()); } else { exprs.push_back(b.getAffineDimExpr(i)); - reductionIteratorTypes.push_back(utils::IteratorType::parallel); + reductionIteratorTypes.push_back(getParallelIteratorTypeName()); } } AffineMap outputMap = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 11ee55c..1034e8e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -297,10 +297,8 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( return vectorizeCopy(rewriter, copyOp); } -static SmallVector -getNParallelLoopsAttrs(unsigned nParallelLoops) { - return SmallVector(nParallelLoops, - utils::IteratorType::parallel); +static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { + return SmallVector(nParallelLoops, getParallelIteratorTypeName()); } /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index a643713..2cf74a6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1420,12 +1420,11 @@ namespace { /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} /// ``` /// kw is unrolled, w is unrolled iff dilationW > 1. -struct Conv1DGenerator - : public StructuredGenerator { +struct Conv1DGenerator : public StructuredGenerator { Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW, int dilationW) - : StructuredGenerator(builder, linalgOp), - strideW(strideW), dilationW(dilationW) { + : StructuredGenerator(builder, linalgOp), strideW(strideW), + dilationW(dilationW) { // Determine whether `linalgOp` can be generated with this generator if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) return; diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index fc34353..ccf7cdc 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -186,12 +186,12 @@ bool isElementwise(LinalgOp op) { return hasOnlyScalarElementwiseOp(op->getRegion(0)); } -bool isParallelIterator(utils::IteratorType iteratorType) { - return iteratorType == utils::IteratorType::parallel; +bool isParallelIterator(StringRef iteratorType) { + return iteratorType == getParallelIteratorTypeName(); } -bool isReductionIterator(utils::IteratorType iteratorType) { - return iteratorType == utils::IteratorType::reduction; +bool isReductionIterator(StringRef iteratorType) { + return iteratorType == getReductionIteratorTypeName(); } /// Helper function that creates a memref::DimOp or tensor::DimOp depending on @@ -422,13 +422,15 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, b.getContext())), AffineMap::getMultiDimIdentityMap(transposeVector.size(), b.getContext())}; - SmallVector iteratorTypes(transposeVector.size(), - utils::IteratorType::parallel); + SmallVector iteratorTypes(transposeVector.size(), + getParallelIteratorTypeName()); // Create a GenericOp to transpose `inputTensor` into `outputTensor`. - auto transposeOp = - b.create(loc, resultTensorType, inputTensor, outputTensor, - indexingMaps, iteratorTypes); + auto transposeOp = b.create( + loc, resultTensorType, inputTensor, outputTensor, + b.getAffineMapArrayAttr(indexingMaps), b.getStrArrayAttr(iteratorTypes), + /*doc=*/nullptr, + /*library_call=*/nullptr); Region &body = transposeOp.getRegion(); body.push_back(new Block()); body.front().addArguments({elementType, elementType}, {loc, loc}); @@ -450,8 +452,8 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { AffineMap id = AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); - SmallVector iteratorTypes(memrefTypeTo.getRank(), - utils::IteratorType::parallel); + SmallVector iteratorTypes(memrefTypeTo.getRank(), + getParallelIteratorTypeName()); return b.create( loc, /*inputs=*/from, @@ -467,7 +469,7 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { template <> void GenerateLoopNest::doit( OpBuilder &b, Location loc, ArrayRef loopRanges, LinalgOp linalgOp, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuilderFn, @@ -511,7 +513,7 @@ void GenerateLoopNest::doit( template <> void GenerateLoopNest::doit( OpBuilder &b, Location loc, ArrayRef loopRanges, LinalgOp linalgOp, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuilderFn, @@ -562,7 +564,7 @@ void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, // exceeds 10. static void generateParallelLoopNest( OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, - ValueRange steps, ArrayRef iteratorTypes, + ValueRange steps, ArrayRef iteratorTypes, ArrayRef procInfo, function_ref bodyBuilderFn, SmallVectorImpl &ivStorage) { @@ -677,7 +679,7 @@ static void generateParallelLoopNest( template <> void GenerateLoopNest::doit( OpBuilder &b, Location loc, ArrayRef loopRanges, LinalgOp linalgOp, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuilderFn, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 3f2ee1b..533d31f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -178,8 +178,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { /// as we use adj matrix for the graph. /// The sorted result will put the first Reduction iterator to the /// latest possible index. -static bool topSortOptimal(unsigned n, - ArrayRef iteratorTypes, +static bool topSortOptimal(unsigned n, ArrayRef iteratorTypes, std::vector &topSort, std::vector &inDegree, std::vector> &adjM) { diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index a9c77c6..7f2e970 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -15,10 +15,9 @@ using namespace mlir; using namespace mlir::tosa; -SmallVector +SmallVector mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) { - return SmallVector(nParallelLoops, - utils::IteratorType::parallel); + return SmallVector(nParallelLoops, getParallelIteratorTypeName()); } SmallVector diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 47aefa1..0bdaf7b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1518,14 +1518,27 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, } namespace { +struct IteratorType { + IteratorType(StringRef strRef) : strRef(strRef) {} + bool isOfType(Attribute attr) const { + auto sAttr = attr.dyn_cast(); + return sAttr && sAttr.getValue() == strRef; + } + StringRef strRef; +}; +struct Par : public IteratorType { + Par() : IteratorType(getParallelIteratorTypeName()) {} +}; +struct Red : public IteratorType { + Red() : IteratorType(getReductionIteratorTypeName()) {} +}; /// Generate a vector implementation for matmat, matvec and tmatvec. /// This unrolls outer-products along the reduction dimension. struct UnrolledOuterProductGenerator - : public StructuredGenerator { + : public StructuredGenerator { UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) - : StructuredGenerator( - builder, op), + : StructuredGenerator(builder, op), kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), res(op.getAcc()), lhsType(op.getLhsType()) {} @@ -2706,10 +2719,8 @@ class DropInnerMostUnitDims : public OpRewritePattern { } else { MemRefLayoutAttrInterface updatedLayout; if (auto strided = layout.dyn_cast()) { - auto strides = - llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); - updatedLayout = StridedLayoutAttr::get(strided.getContext(), - strided.getOffset(), strides); + auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); + updatedLayout = StridedLayoutAttr::get(strided.getContext(), strided.getOffset(), strides); } else { AffineMap map = srcType.getLayout().getAffineMap(); int numSymbols = map.getNumSymbols(); diff --git a/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir b/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir index e7d6dd9..f9b20d3 100644 --- a/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir +++ b/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir @@ -17,7 +17,7 @@ func.func @test_conv_op_wrong_num_operands(%arg0 : tensor, // expected-error @+1 {{expected op with 2 inputs and 1 output}} %0 = test.linalg_conv_op { indexing_maps = [#map, #map], - iterator_types = [#test.iterator_type]} + iterator_types = ["parallel"]} ins(%arg0 : tensor) outs(%arg1 : tensor) { ^bb0(%arg2 : f32, %arg3 : f32): linalg.yield %arg3 : f32 @@ -34,8 +34,7 @@ func.func @test_conv_op_wrong_input_indexing_map1(%arg0 : tensor, indexing_maps = [affine_map<(d0, d1) -> (d0 * 2)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = [#test.iterator_type, - #test.iterator_type]} + iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -53,8 +52,7 @@ func.func @test_conv_op_wrong_input_indexing_map2(%arg0 : tensor, indexing_maps = [affine_map<(d0, d1) -> (d0 + d1, d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = [#test.iterator_type, - #test.iterator_type]} + iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -72,8 +70,7 @@ func.func @test_conv_op_filter_index_map_not_projection(%arg0 : tensor, indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1 + d0)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = [#test.iterator_type, - #test.iterator_type]} + iterator_types = ["parallel", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -91,8 +88,7 @@ func.func @test_conv_op_output_index_map_not_projection(%arg0 : tensor, indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0 + d1)>], - iterator_types = [#test.iterator_type, - #test.iterator_type]} + iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -112,8 +108,7 @@ func.func @test_conv_op_output_filter_convolved(%arg0 : tensor, indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = [#test.iterator_type, - #test.iterator_type]} + iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -132,9 +127,7 @@ func.func @test_conv_op_output_only_dim(%arg0 : tensor, indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>, affine_map<(d0, d1, d2) -> (d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], - iterator_types = [#test.iterator_type, - #test.iterator_type, - #test.iterator_type]} + iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -153,9 +146,7 @@ func.func @test_conv_op_filter_only_dim(%arg0 : tensor, indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], - iterator_types = [#test.iterator_type, - #test.iterator_type, - #test.iterator_type]} + iterator_types = ["parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -174,9 +165,7 @@ func.func @test_conv_op_input_only_dim(%arg0 : tensor, indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>, affine_map<(d0, d1, d2) -> (d0)>], - iterator_types = [#test.iterator_type, - #test.iterator_type, - #test.iterator_type]} + iterator_types = ["parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -195,8 +184,7 @@ func.func @test_conv_op_non_output_access_loop_parallel(%arg0 : tensor, indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = [#test.iterator_type, - #test.iterator_type]} + iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index ebce71f..5a1c2af 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -96,7 +96,7 @@ func.func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) { // ----- func.func @generic_wrong_iterator(%arg0: memref<1xi32>) { - // expected-error @+4 {{unexpected iterator_type (random)}} + // expected-error @+1 {{op unexpected iterator_type (random)}} linalg.generic { indexing_maps = [ affine_map<(i) -> (i)> ], iterator_types = ["random"]} diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir index b4ad820..4a92c07 100644 --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -59,19 +59,13 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %match_attr = transform.structured.match ops{["linalg.generic"]} - attributes{iterator_types = [ - #linalg.iterator_type, - #linalg.iterator_type, - #linalg.iterator_type]} + attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %arg1 transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation transform.test_consume_operand %match_attr %no_match = transform.structured.match - attributes{iterator_types = [ - #linalg.iterator_type, - #linalg.iterator_type, - #linalg.iterator_type]} + attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %arg1 // expected-remark @below {{0}} transform.test_print_number_of_associated_payload_ir_ops %no_match diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index 2d1a1df..2c8719d 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -23,16 +23,13 @@ mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=test) mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=test) add_public_tablegen_target(MLIRTestTypeDefIncGen) -set(LLVM_TARGET_DEFINITIONS TestEnumDefs.td) -mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls) -mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs) -add_public_tablegen_target(MLIRTestEnumDefIncGen) - set(LLVM_TARGET_DEFINITIONS TestOps.td) mlir_tablegen(TestOps.h.inc -gen-op-decls) mlir_tablegen(TestOps.cpp.inc -gen-op-defs) mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test) mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test) +mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls) +mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs) mlir_tablegen(TestPatterns.inc -gen-rewriters) add_public_tablegen_target(MLIRTestOpsIncGen) @@ -49,7 +46,6 @@ add_mlir_library(MLIRTestDialect DEPENDS MLIRTestAttrDefIncGen - MLIRTestEnumDefIncGen MLIRTestInterfaceIncGen MLIRTestTypeDefIncGen MLIRTestOpsIncGen diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index c4996ab..0c35f81 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -15,8 +15,6 @@ // To get the test dialect definition. include "TestDialect.td" -include "TestEnumDefs.td" -include "mlir/Dialect/Utils/StructuredOpsUtils.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/EnumAttr.td" @@ -279,6 +277,13 @@ def TestArrayOfInts : ArrayOfAttr; // An array of enum attributes. +def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [ + I32EnumAttrCase<"a", 0>, + I32EnumAttrCase<"b", 1> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::test"; +} def TestSimpleEnumAttr : EnumAttr { let assemblyFormat = "`` $value"; } @@ -292,14 +297,4 @@ def TestCustomAnchor : Test_Attr<"TestCustomAnchor"> { let assemblyFormat = "`<` $a (`>`) : (`,` ` ` custom($b)^ `>`)?"; } -def Test_IteratorTypeEnum - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def Test_IteratorTypeArrayAttr - : TypedArrayAttrBase; - - #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h index cc73e07..4cb4d61 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -17,7 +17,6 @@ #include #include "TestTraits.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/test/lib/Dialect/Test/TestEnumDefs.td b/mlir/test/lib/Dialect/Test/TestEnumDefs.td deleted file mode 100644 index 1ddfca0..0000000 --- a/mlir/test/lib/Dialect/Test/TestEnumDefs.td +++ /dev/null @@ -1,97 +0,0 @@ -//===-- TestEnumDefs.td - Test dialect enum definitions ----*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// TableGen enum definitions for Test dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef TEST_ENUMDEFS -#define TEST_ENUMDEFS - -include "mlir/IR/EnumAttr.td" - -def I32Case5: I32EnumAttrCase<"case5", 5>; -def I32Case10: I32EnumAttrCase<"case10", 10>; - -def SomeI32Enum: I32EnumAttr< - "SomeI32Enum", "", [I32Case5, I32Case10]>; - -def I64Case5: I64EnumAttrCase<"case5", 5>; -def I64Case10: I64EnumAttrCase<"case10", 10>; - -def SomeI64Enum: I64EnumAttr< - "SomeI64Enum", "", [I64Case5, I64Case10]>; - -//===----------------------------------------------------------------------===// -// Test Enum -//===----------------------------------------------------------------------===// - -// Define the C++ enum. -def TestEnum - : I32EnumAttr<"TestEnum", "a test enum", [ - I32EnumAttrCase<"First", 0, "first">, - I32EnumAttrCase<"Second", 1, "second">, - I32EnumAttrCase<"Third", 2, "third">, - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "test"; -} - -def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [ - I32EnumAttrCase<"a", 0>, - I32EnumAttrCase<"b", 1> - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::test"; -} - -//===----------------------------------------------------------------------===// -// Test Bit Enum -//===----------------------------------------------------------------------===// - -// Define the C++ enum. -def TestBitEnum - : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [ - I32BitEnumAttrCaseBit<"Read", 0, "read">, - I32BitEnumAttrCaseBit<"Write", 1, "write">, - I32BitEnumAttrCaseBit<"Execute", 2, "execute">, - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "test"; - let separator = ", "; -} - -// Define an enum with a different separator -def TestBitEnumVerticalBar - : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [ - I32BitEnumAttrCaseBit<"User", 0, "user">, - I32BitEnumAttrCaseBit<"Group", 1, "group">, - I32BitEnumAttrCaseBit<"Other", 2, "other">, - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "test"; - let separator = " | "; -} - -//===----------------------------------------------------------------------===// -// Test Patterns (Multi-result Ops) - -def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>; -def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>; -def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>; -def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>; -def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>; -def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>; - -def MultiResultOpEnum: I64EnumAttr< - "MultiResultOpEnum", "Multi-result op kinds", [ - MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3, - MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6 - ]>; - -#endif // TEST_ENUMDEFS diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 84dd37f..84bbe24 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -202,11 +202,23 @@ def FloatAttrOp : TEST_Op<"float_attrs"> { ); } +def I32Case5: I32EnumAttrCase<"case5", 5>; +def I32Case10: I32EnumAttrCase<"case10", 10>; + +def SomeI32Enum: I32EnumAttr< + "SomeI32Enum", "", [I32Case5, I32Case10]>; + def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> { let arguments = (ins SomeI32Enum:$attr); let results = (outs I32:$val); } +def I64Case5: I64EnumAttrCase<"case5", 5>; +def I64Case10: I64EnumAttrCase<"case10", 10>; + +def SomeI64Enum: I64EnumAttr< + "SomeI64Enum", "", [I64Case5, I64Case10]>; + def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> { let arguments = (ins SomeI64Enum:$attr); let results = (outs I32:$val); @@ -307,6 +319,17 @@ def ConfinedDenseArrayAttrOp : TEST_Op<"confined_dense_array_attr"> { // Test Enum Attributes //===----------------------------------------------------------------------===// +// Define the C++ enum. +def TestEnum + : I32EnumAttr<"TestEnum", "a test enum", [ + I32EnumAttrCase<"First", 0, "first">, + I32EnumAttrCase<"Second", 1, "second">, + I32EnumAttrCase<"Third", 2, "third">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "test"; +} + // Define the enum attribute. def TestEnumAttr : EnumAttr; @@ -328,6 +351,18 @@ def : Pat<(OpWithEnum ConstantAttr, + I32BitEnumAttrCaseBit<"Write", 1, "write">, + I32BitEnumAttrCaseBit<"Execute", 2, "execute">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "test"; + let separator = ", "; +} + // Define the enum attribute. def TestBitEnumAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; @@ -339,6 +374,18 @@ def OpWithBitEnum : TEST_Op<"op_with_bit_enum"> { let assemblyFormat = "$value (`tag` $tag^)? attr-dict"; } +// Define an enum with a different separator +def TestBitEnumVerticalBar + : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [ + I32BitEnumAttrCaseBit<"User", 0, "user">, + I32BitEnumAttrCaseBit<"Group", 1, "group">, + I32BitEnumAttrCaseBit<"Other", 2, "other">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "test"; + let separator = " | "; +} + def TestBitEnumVerticalBarAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; @@ -1345,6 +1392,22 @@ def : Pat<(OpC $input), (OpB $input, ConstantAttr:$attr)>; def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>; def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>; +//===----------------------------------------------------------------------===// +// Test Patterns (Multi-result Ops) + +def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>; +def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>; +def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>; +def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>; +def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>; +def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>; + +def MultiResultOpEnum: I64EnumAttr< + "MultiResultOpEnum", "Multi-result op kinds", [ + MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3, + MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6 + ]>; + def ThreeResultOp : TEST_Op<"three_result"> { let arguments = (ins MultiResultOpEnum:$kind); let results = (outs I32:$result1, F32:$result2, F32:$result3); @@ -2761,10 +2824,8 @@ def TestLinalgConvOp : return ®ionBuilder; } - llvm::SmallVector getIteratorTypesArray() { - auto attrs = getOperation()->getAttrOfType("iterator_types"); - auto range = attrs.getAsValueRange(); - return {range.begin(), range.end()}; + mlir::ArrayAttr getIteratorTypes() { + return getOperation()->getAttrOfType("iterator_types"); } mlir::ArrayAttr getIndexingMaps() { @@ -2823,10 +2884,8 @@ def TestLinalgFillOp : return ®ionBuilder; } - llvm::SmallVector getIteratorTypesArray() { - auto attrs = getOperation()->getAttrOfType("iterator_types"); - auto range = attrs.getAsValueRange(); - return {range.begin(), range.end()}; + mlir::ArrayAttr getIteratorTypes() { + return getOperation()->getAttrOfType("iterator_types"); } mlir::ArrayAttr getIndexingMaps() { diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 3595be8..0a482cc 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -553,7 +553,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. - SmallVector getIteratorTypesArray(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs); @@ -597,8 +597,8 @@ static const char structuredOpBuilderFormat[] = R"FMT( // {1}: Comma interleaved iterator type names. static const char structuredOpIteratorTypesFormat[] = R"FMT( -SmallVector {0}::getIteratorTypesArray() {{ - return SmallVector{{ {1} }; +SmallVector {0}::getIteratorTypesArray() {{ + return SmallVector{{ {1} }; } )FMT"; @@ -607,9 +607,9 @@ SmallVector {0}::getIteratorTypesArray() {{ // {0}: Class name static const char rankPolyStructuredOpIteratorTypesFormat[] = R"FMT( -SmallVector {0}::getIteratorTypesArray() {{ +SmallVector {0}::getIteratorTypesArray() {{ int64_t rank = getRank(getDpsInitOperand(0)); - return SmallVector(rank, utils::IteratorType::parallel); + return SmallVector(rank, getParallelIteratorTypeName()); } )FMT"; @@ -812,10 +812,10 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig, [&](LinalgIteratorTypeDef it) { switch (it) { case LinalgIteratorTypeDef::parallel: - ss << "utils::IteratorType::parallel"; + ss << "getParallelIteratorTypeName()"; break; case LinalgIteratorTypeDef::reduction: - ss << "utils::IteratorType::reduction"; + ss << "getReductionIteratorTypeName()"; break; } }); diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index 0a64646..d58a07f 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -132,6 +132,14 @@ gentbl_cc_library( "lib/Dialect/Test/TestOpsDialect.cpp.inc", ), ( + ["-gen-enum-decls"], + "lib/Dialect/Test/TestOpEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "lib/Dialect/Test/TestOpEnums.cpp.inc", + ), + ( ["-gen-rewriters"], "lib/Dialect/Test/TestPatterns.inc", ), @@ -204,27 +212,6 @@ gentbl_cc_library( ) gentbl_cc_library( - name = "TestEnumDefsIncGen", - strip_include_prefix = "lib/Dialect/Test", - tbl_outs = [ - ( - ["-gen-enum-decls"], - "lib/Dialect/Test/TestOpEnums.h.inc", - ), - ( - ["-gen-enum-defs"], - "lib/Dialect/Test/TestOpEnums.cpp.inc", - ), - ], - tblgen = "//mlir:mlir-tblgen", - td_file = "lib/Dialect/Test/TestEnumDefs.td", - test = True, - deps = [ - ":TestOpTdFiles", - ], -) - -gentbl_cc_library( name = "TestTypeDefsIncGen", strip_include_prefix = "lib/Dialect/Test", tbl_outs = [ @@ -331,7 +318,6 @@ cc_library( ], deps = [ ":TestAttrDefsIncGen", - ":TestEnumDefsIncGen", ":TestInterfacesIncGen", ":TestOpsIncGen", ":TestTypeDefsIncGen", @@ -344,7 +330,6 @@ cc_library( "//mlir:DerivedAttributeOpInterface", "//mlir:DestinationStyleOpInterface", "//mlir:Dialect", - "//mlir:DialectUtils", "//mlir:FuncDialect", "//mlir:FuncTransforms", "//mlir:IR", -- 2.7.4