[mlir][linalg] Add getIteratorTypesArray to LinalgInterface.
authorOleg Shyshkov <shyshkov@google.com>
Tue, 27 Sep 2022 14:26:56 +0000 (14:26 +0000)
committerOleg Shyshkov <shyshkov@google.com>
Tue, 27 Sep 2022 14:30:50 +0000 (14:30 +0000)
Summary:
Most of the code that gets `iterator_types` from LinalgInterface is forced to
extract values from an `Attribute`. As a result, the usage pattern looks like
this:

```
SmallVector<StringRef> iterators = llvm::to_vector<4>(linalgOp.iterator_types().getAsValueRange<StringAttr>());
```

It also forces all operations that implement LinalgOp interface to have
`iterator_types` attribute even when the information can be easily infered from
other parameters.

In perfect future, `getIteratorTypeArray` should be the only method to get
iterator types from the interface. The default implementation can rely on
`iterator_types` attribute though.

The name `getIteratorTypeArray` was picked to be consistent with existing
`getIndexingMapsArray`.

This patch add a few sample usages. More cleanups will follow.

Differential Revision: https://reviews.llvm.org/D134729

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp

index 64c2bd1..3803170 100644 (file)
@@ -493,6 +493,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
+        Return iterator types in the current operation.
+      }],
+      /*retTy=*/"SmallVector<StringRef>",
+      /*methodName=*/"getIteratorTypesArray",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto range = $_op.iterator_types().template getAsValueRange<StringAttr>();
+        return {range.begin(), range.end()};
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
         Return true if the indexing map is depending on the current op instance.
         This means that the indexing map is dynamically synthesized by using the
         op instance's concrete attributes, instead of being static for all
index 6e4c2fc..032578f 100644 (file)
@@ -297,8 +297,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       !indexingMaps.back().isProjectedPermutation())
     return MatchConvolutionResult::NotProjectedPermutations;
 
-  auto iteratorTypesRange =
-      linalgOp.iterator_types().getAsValueRange<StringAttr>();
+  auto iteratorTypesRange = linalgOp.getIteratorTypesArray();
 
   llvm::SmallDenseSet<unsigned> outputDims =
       getPreservedDims(indexingMaps.back());
index d58dc4a..15ba43c 100644 (file)
@@ -438,8 +438,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
       resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
     GenericOp replacementOp = rewriter.create<GenericOp>(
         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
-        llvm::to_vector<4>(genericOp.getIteratorTypes()
-                               .template getAsValueRange<StringAttr>()));
+        genericOp.getIteratorTypesArray());
     rewriter.inlineRegionBefore(genericOp.getRegion(),
                                 replacementOp.getRegion(),
                                 replacementOp.getRegion().begin());
index 97dd0c4..94526e1 100644 (file)
@@ -467,9 +467,8 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
                             .isProjectedPermutation();
                       }) &&
          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
-         llvm::all_of(genericOp.getIteratorTypes(), [](Attribute attr) {
-           return attr.cast<StringAttr>().getValue() ==
-                  getParallelIteratorTypeName();
+         llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) {
+           return it == getParallelIteratorTypeName();
          });
 }
 
index 4f384c3..97eee8b 100644 (file)
@@ -53,8 +53,7 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
   SmallVector<Value> inputOperands = linalgOp.getInputOperands();
   SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
-  SmallVector<StringRef> iterators = llvm::to_vector<4>(
-      linalgOp.iterator_types().getAsValueRange<StringAttr>());
+  SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
   SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes();
   SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
 
index e8f5372..a53e7c5 100644 (file)
@@ -61,9 +61,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
     SmallVector<Value> outputOperands = genericOp.getOutputOperands();
     auto newOp = rewriter.create<GenericOp>(
         loc, genericOp->getResultTypes(), newOperands, outputOperands,
-        newIndexingMaps,
-        llvm::to_vector<4>(genericOp.getIteratorTypes()
-                               .template getAsValueRange<StringAttr>()));
+        newIndexingMaps, genericOp.getIteratorTypesArray());
     rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(),
                                newOp.getRegion().begin());