From 1e7a6c0874f57014132b857b2d201d6aaa75feee Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 27 Sep 2022 14:26:56 +0000 Subject: [PATCH] [mlir][linalg] Add getIteratorTypesArray to LinalgInterface. Summary: Most of the code that gets `iterator_types` from LinalgInterface is forced to extract values from an `Attribute`. As a result, the usage pattern looks like this: ``` SmallVector iterators = llvm::to_vector<4>(linalgOp.iterator_types().getAsValueRange()); ``` It also forces all operations that implement LinalgOp interface to have `iterator_types` attribute even when the information can be easily infered from other parameters. In perfect future, `getIteratorTypeArray` should be the only method to get iterator types from the interface. The default implementation can rely on `iterator_types` attribute though. The name `getIteratorTypeArray` was picked to be consistent with existing `getIndexingMapsArray`. This patch add a few sample usages. More cleanups will follow. Differential Revision: https://reviews.llvm.org/D134729 --- mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 13 +++++++++++++ mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 3 +-- mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 3 +-- mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 5 ++--- mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp | 3 +-- mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp | 4 +--- 6 files changed, 19 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 64c2bd1..3803170 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -493,6 +493,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { >, InterfaceMethod< /*desc=*/[{ + Return iterator types in the current operation. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getIteratorTypesArray", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = $_op.iterator_types().template getAsValueRange(); + return {range.begin(), range.end()}; + }] + >, + InterfaceMethod< + /*desc=*/[{ Return true if the indexing map is depending on the current op instance. This means that the indexing map is dynamically synthesized by using the op instance's concrete attributes, instead of being static for all diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 6e4c2fc..032578f 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -297,8 +297,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { !indexingMaps.back().isProjectedPermutation()) return MatchConvolutionResult::NotProjectedPermutations; - auto iteratorTypesRange = - linalgOp.iterator_types().getAsValueRange(); + auto iteratorTypesRange = linalgOp.getIteratorTypesArray(); llvm::SmallDenseSet outputDims = getPreservedDims(indexingMaps.back()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index d58dc4a..15ba43c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -438,8 +438,7 @@ struct ReplaceUnitExtents : public OpRewritePattern { resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); GenericOp replacementOp = rewriter.create( loc, resultTypes, newInputs, newOutputs, newIndexingMaps, - llvm::to_vector<4>(genericOp.getIteratorTypes() - .template getAsValueRange())); + genericOp.getIteratorTypesArray()); rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), replacementOp.getRegion().begin()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 97dd0c4..94526e1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -467,9 +467,8 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, .isProjectedPermutation(); }) && genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && - llvm::all_of(genericOp.getIteratorTypes(), [](Attribute attr) { - return attr.cast().getValue() == - getParallelIteratorTypeName(); + llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) { + return it == getParallelIteratorTypeName(); }); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 4f384c3..97eee8b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -53,8 +53,7 @@ FailureOr mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, SmallVector inputOperands = linalgOp.getInputOperands(); SmallVector outputOperands = linalgOp.getOutputOperands(); SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); - SmallVector iterators = llvm::to_vector<4>( - linalgOp.iterator_types().getAsValueRange()); + SmallVector iterators = linalgOp.getIteratorTypesArray(); SmallVector resultTypes = linalgOp.getOutputTensorTypes(); SmallVector types(resultTypes.begin(), resultTypes.end()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp index e8f5372..a53e7c5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -61,9 +61,7 @@ struct InlineScalarOperands : public OpRewritePattern { SmallVector outputOperands = genericOp.getOutputOperands(); auto newOp = rewriter.create( loc, genericOp->getResultTypes(), newOperands, outputOperands, - newIndexingMaps, - llvm::to_vector<4>(genericOp.getIteratorTypes() - .template getAsValueRange())); + newIndexingMaps, genericOp.getIteratorTypesArray()); rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(), newOp.getRegion().begin()); -- 2.7.4