From c38d9cf20e7468a2618dc23fcdc66e79c925aff5 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Thu, 13 Oct 2022 07:28:46 +0000 Subject: [PATCH] [mlir] Remove iterator_types() method from LinalgStructuredInterface. `getIteratorTypesArray` should be used instead. It's a better substitute for all the current usages of the interface. The current `ArrayAttr iterator_types()` has a few problems: * It creates an assumption operation has iterators types as an attribute, but it's not always the case. Sometime iterator types can be inferred from other attribute, or they're just static. * ArrayAttr is an obscure contained and required extracting values in the client code. * Makes it hard to migrate iterator types from strings to enums ([RFC](https://discourse.llvm.org/t/rfc-enumattr-for-iterator-types-in-linalg/64535/9)). Concrete ops, like `linalg.generic` will still have iterator types as an attribute if needed. As a side effect, this change helps a bit with migration to prefixed accessors. Differential Revision: https://reviews.llvm.org/D135765 --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 22 ++++++---------------- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 4 ++-- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 10 ++++------ .../Linalg/Transforms/TilingInterfaceImpl.cpp | 8 +++----- .../Transforms/SparseTensorRewriting.cpp | 2 +- .../SparseTensor/Transforms/Sparsification.cpp | 2 +- mlir/test/lib/Dialect/Test/TestOps.td | 4 ++-- .../test-linalg-ods-yaml-gen.yaml | 2 +- .../mlir-linalg-ods-yaml-gen.cpp | 16 ++++++++-------- 9 files changed, 28 insertions(+), 42 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 268587e..69871fa 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -499,26 +499,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { >, InterfaceMethod< /*desc=*/[{ - Return the iterator types attribute within the current operation. - }], - /*retTy=*/"ArrayAttr", - /*methodName=*/"iterator_types", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op.getIteratorTypes(); - }] - >, - InterfaceMethod< - /*desc=*/[{ Return iterator types in the current operation. + + Default implementation assumes that the operation has an attribute + `iterator_types`, but it's not always the case. Sometimes iterator types + can be infered from other parameters and in such cases default + getIteratorTypesArray should be overriden. }], /*retTy=*/"SmallVector", /*methodName=*/"getIteratorTypesArray", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto range = $_op.iterator_types().template getAsValueRange(); + auto range = $_op.getIteratorTypes().template getAsValueRange(); return {range.begin(), range.end()}; }] >, @@ -773,9 +766,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes); - // TODO: Remove once prefixing is flipped. - ArrayAttr getIteratorTypes() { return iterator_types(); } - SmallVector getIteratorTypeNames() { return getIteratorTypesArray(); } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 6bcf509..2619ad1 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -264,7 +264,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Implement functions necessary for LinalgStructuredInterface. - ArrayAttr getIteratorTypes(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; @@ -334,7 +334,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. - ArrayAttr getIteratorTypes(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 61b9386..c2705a3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1393,11 +1393,9 @@ LogicalResult MapOp::verify() { return success(); } -ArrayAttr MapOp::getIteratorTypes() { +SmallVector MapOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); - return Builder(getContext()) - .getStrArrayAttr( - SmallVector(rank, getParallelIteratorTypeName())); + return SmallVector(rank, getParallelIteratorTypeName()); } ArrayAttr MapOp::getIndexingMaps() { @@ -1435,13 +1433,13 @@ void ReduceOp::getAsmResultNames( setNameFn(getResults().front(), "reduced"); } -ArrayAttr ReduceOp::getIteratorTypes() { +SmallVector ReduceOp::getIteratorTypesArray() { int64_t inputRank = getInputs()[0].getType().cast().getRank(); SmallVector iteratorTypes(inputRank, getParallelIteratorTypeName()); for (int64_t reductionDim : getDimensions()) iteratorTypes[reductionDim] = getReductionIteratorTypeName(); - return Builder(getContext()).getStrArrayAttr(iteratorTypes); + return iteratorTypes; } ArrayAttr ReduceOp::getIndexingMaps() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index a9113f0..66d55dc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -92,11 +92,9 @@ struct LinalgOpTilingInterface /// Return the loop iterator type. SmallVector getLoopIteratorTypes(Operation *op) const { LinalgOpTy concreteOp = cast(op); - return llvm::to_vector( - llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) { - return utils::symbolizeIteratorType( - strAttr.cast().getValue()) - .value(); + return llvm::to_vector(llvm::map_range( + concreteOp.getIteratorTypesArray(), [](StringRef iteratorType) { + return utils::symbolizeIteratorType(iteratorType).value(); })); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index a28aa3b..2458dab 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -250,7 +250,7 @@ public: // Fuse producer and consumer into a new generic op. auto fusedOp = rewriter.create( loc, op.getResult(0).getType(), inputOps, outputOps, - rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(), + rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(), /*doc=*/nullptr, /*library_call=*/nullptr); Block &prodBlock = prod.getRegion().front(); Block &consBlock = op.getRegion().front(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 21ee7b5..1418ed4 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1857,7 +1857,7 @@ public: if (op.getNumOutputs() != 1) return failure(); unsigned numTensors = op.getNumInputsAndOutputs(); - unsigned numLoops = op.iterator_types().getValue().size(); + unsigned numLoops = op.getNumLoops(); Merger merger(numTensors, numLoops); if (!findSparseAnnotations(merger, op)) return failure(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 1fca43d..7f32912 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2816,7 +2816,7 @@ def TestLinalgConvOp : return ®ionBuilder; } - mlir::ArrayAttr iterator_types() { + mlir::ArrayAttr getIteratorTypes() { return getOperation()->getAttrOfType("iterator_types"); } @@ -2875,7 +2875,7 @@ def TestLinalgFillOp : return ®ionBuilder; } - mlir::ArrayAttr iterator_types() { + mlir::ArrayAttr getIteratorTypes() { return getOperation()->getAttrOfType("iterator_types"); } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml index 205cde2..51557f5 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -235,7 +235,7 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: value -# IMPL: Test3Op::iterator_types() { +# IMPL: Test3Op::getIteratorTypesArray() { # IMPL-NEXT: int64_t rank = getRank(getOutputOperand(0)); # IMPL: Test3Op::getIndexingMaps() { diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 1bea98f..8156bb9 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -553,7 +553,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. - ArrayAttr iterator_types(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs); @@ -587,24 +587,24 @@ static const char structuredOpBuilderFormat[] = R"FMT( }]> )FMT"; -// The iterator_types() method for structured ops. Parameters: +// The getIteratorTypesArray() method for structured ops. Parameters: // {0}: Class name // {1}: Comma interleaved iterator type names. static const char structuredOpIteratorTypesFormat[] = R"FMT( -ArrayAttr {0}::iterator_types() {{ - return Builder(getContext()).getStrArrayAttr(SmallVector{{ {1} }); +SmallVector {0}::getIteratorTypesArray() {{ + return SmallVector{{ {1} }; } )FMT"; -// The iterator_types() method for rank polymorphic structured ops. Parameters: +// The getIteratorTypesArray() method for rank polymorphic structured ops. +// Parameters: // {0}: Class name static const char rankPolyStructuredOpIteratorTypesFormat[] = R"FMT( -ArrayAttr {0}::iterator_types() {{ +SmallVector {0}::getIteratorTypesArray() {{ int64_t rank = getRank(getOutputOperand(0)); - return Builder(getContext()).getStrArrayAttr( - SmallVector(rank, getParallelIteratorTypeName())); + return SmallVector(rank, getParallelIteratorTypeName()); } )FMT"; -- 2.7.4