[mlir][linalg] Remove IndexedGenericOp support from DropUnitDims...
authorTobias Gysi <gysit@google.com>
Thu, 13 May 2021 13:14:47 +0000 (13:14 +0000)
committerTobias Gysi <gysit@google.com>
Thu, 13 May 2021 14:18:59 +0000 (14:18 +0000)
after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612).

Differential Revision: https://reviews.llvm.org/D102235

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

index 9c4d8af..623c824 100644 (file)
@@ -146,13 +146,13 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
 }
 
 /// Update the index accesses of linalg operations having index semantics.
-template <typename GenericOpTy>
-static void replaceUnitDimIndexOps(GenericOpTy op,
+static void replaceUnitDimIndexOps(GenericOp genericOp,
                                    const DenseSet<unsigned> &unitDims,
                                    PatternRewriter &rewriter) {
-  assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
+  assert(genericOp->getNumRegions() == 1 &&
+         genericOp->getRegion(0).getBlocks().size() == 1 &&
          "expected generic operation to have one block.");
-  Block &block = op->getRegion(0).front();
+  Block &block = genericOp->getRegion(0).front();
 
   for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) {
     OpBuilder::InsertionGuard guard(rewriter);
@@ -170,39 +170,13 @@ static void replaceUnitDimIndexOps(GenericOpTy op,
   }
 }
 
-/// Modify the region of indexed generic op to drop arguments corresponding to
-/// loops that are unit trip count.
-template <typename OpTy>
-static LogicalResult
-replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
-                               PatternRewriter &rewriterp) {
-  return success();
-}
-
-template <>
-LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
-    IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
-    PatternRewriter &rewriter) {
-  OpBuilder::InsertionGuard guard(rewriter);
-  Block *entryBlock = &op->getRegion(0).front();
-  rewriter.setInsertionPointToStart(entryBlock);
-  Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
-  for (unsigned unitDimLoop : unitDims) {
-    entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
-  }
-  SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end());
-  entryBlock->eraseArguments(unitDimsToErase);
-  return success();
-}
-
 namespace {
 /// Pattern to fold unit-trip count loops in GenericOps.
-template <typename GenericOpTy>
-struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
-  using OpRewritePattern<GenericOpTy>::OpRewritePattern;
-  LogicalResult matchAndRewrite(GenericOpTy op,
+struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
+    SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
     if (indexingMaps.empty())
       return failure();
 
@@ -213,7 +187,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
     if (!invertedMap)
       return failure();
     SmallVector<int64_t, 4> dims;
-    for (ShapedType shapedType : op.getShapedOperandTypes())
+    for (ShapedType shapedType : genericOp.getShapedOperandTypes())
       dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
 
     // Find all the reduction iterators. Those need some special consideration
@@ -221,7 +195,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
     auto getLoopDimsOfType =
         [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
       SmallVector<AffineExpr> dimExprs;
-      getDimsOfType(op, iteratorTypeName, dimExprs);
+      getDimsOfType(genericOp, iteratorTypeName, dimExprs);
       return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
         return expr.cast<AffineDimExpr>().getPosition();
       }));
@@ -230,7 +204,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
 
     DenseSet<unsigned> unitDims;
     SmallVector<unsigned, 4> unitDimsReductionLoops;
-    ArrayAttr iteratorTypes = op.iterator_types();
+    ArrayAttr iteratorTypes = genericOp.iterator_types();
     for (auto expr : enumerate(invertedMap.getResults())) {
       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
         if (dims[dimExpr.getPosition()] == 1) {
@@ -260,7 +234,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
     ArrayAttr newIndexingMapAttr =
         replaceUnitDims(unitDims, indexingMaps, context);
     if (!newIndexingMapAttr)
-      return op.emitError("unable to compute modified indexing_maps");
+      return genericOp.emitError("unable to compute modified indexing_maps");
 
     // Compute the iterator types of the modified op by dropping the one-trip
     // count loops.
@@ -270,12 +244,11 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
         newIteratorTypes.push_back(attr.value());
     }
 
-    rewriter.startRootUpdate(op);
-    op.indexing_mapsAttr(newIndexingMapAttr);
-    op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
-    (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
-    replaceUnitDimIndexOps(op, unitDims, rewriter);
-    rewriter.finalizeRootUpdate(op);
+    rewriter.startRootUpdate(genericOp);
+    genericOp.indexing_mapsAttr(newIndexingMapAttr);
+    genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
+    replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
+    rewriter.finalizeRootUpdate(genericOp);
     return success();
   }
 };
@@ -351,23 +324,22 @@ convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
 }
 
 /// Pattern to replace tensors operands/results that are unit extents.
-template <typename GenericOpTy>
-struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
-  using OpRewritePattern<GenericOpTy>::OpRewritePattern;
-  LogicalResult matchAndRewrite(GenericOpTy op,
+struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!op.hasTensorSemantics())
+    if (!genericOp.hasTensorSemantics())
       return failure();
 
     MLIRContext *context = rewriter.getContext();
-    Location loc = op.getLoc();
+    Location loc = genericOp.getLoc();
 
     SmallVector<AffineMap, 4> newIndexingMaps;
     SmallVector<ArrayAttr, 4> reassociationMaps;
     SmallVector<ShapedType, 4> newInputOutputTypes;
     bool doCanonicalization = false;
-    for (auto it :
-         llvm::zip(op.getIndexingMaps(), op.getShapedOperandTypes())) {
+    for (auto it : llvm::zip(genericOp.getIndexingMaps(),
+                             genericOp.getShapedOperandTypes())) {
       auto replacementInfo = replaceUnitExtents(
           std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
           context);
@@ -402,20 +374,20 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
       return res;
     };
 
-    SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
-    SmallVector<Value, 4> newOutputs = insertReshapes(op.outputs());
+    SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
+    SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
 
     // If any result type changes, insert a reshape to convert from the original
     // type to the new type.
     SmallVector<Type, 4> resultTypes;
-    resultTypes.reserve(op.getNumResults());
-    for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
-      resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
-    GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
+    resultTypes.reserve(genericOp.getNumResults());
+    for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+      resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
+    GenericOp replacementOp = rewriter.create<GenericOp>(
         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
         llvm::to_vector<4>(
-            op.iterator_types().template getAsValueRange<StringAttr>()));
-    rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
+            genericOp.iterator_types().template getAsValueRange<StringAttr>()));
+    rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
                                 replacementOp.region().begin());
 
     // If any result tensor has a modified shape, then add reshape to recover
@@ -423,7 +395,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
     SmallVector<Value, 4> resultReplacements;
     for (auto result : llvm::enumerate(replacementOp.getResults())) {
       unsigned index = result.index() + replacementOp.getNumInputs();
-      RankedTensorType origResultType = op.getResult(result.index())
+      RankedTensorType origResultType = genericOp.getResult(result.index())
                                             .getType()
                                             .template cast<RankedTensorType>();
       if (origResultType != result.value().getType())
@@ -433,7 +405,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
       else
         resultReplacements.push_back(result.value());
     }
-    rewriter.replaceOp(op, resultReplacements);
+    rewriter.replaceOp(genericOp, resultReplacements);
     return success();
   }
 };
@@ -528,9 +500,7 @@ struct UseRankReducedSubTensorInsertOp
 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
     RewritePatternSet &patterns) {
   auto *context = patterns.getContext();
-  patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
-               ReplaceUnitExtentTensors<GenericOp>,
-               ReplaceUnitExtentTensors<IndexedGenericOp>,
+  patterns.add<FoldUnitDimLoops, ReplaceUnitExtentTensors,
                UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
       context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
@@ -545,9 +515,7 @@ struct LinalgFoldUnitExtentDimsPass
     MLIRContext *context = funcOp.getContext();
     RewritePatternSet patterns(context);
     if (foldOneTripLoopsOnly)
-      patterns
-          .add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
-              context);
+      patterns.add<FoldUnitDimLoops>(context);
     else
       populateFoldUnitExtentDimsPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
index 808622b..5bc11f2 100644 (file)
@@ -42,48 +42,6 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf3
   library_call = "some_external_func"
 }
 
-func @drop_one_trip_loops_indexed_generic
-  (%arg0 : tensor<?x1x?xi32>, %shape: tensor<?x1x?x1x?xi32>) -> tensor<?x1x?x1x?xi32>
-{
-  %0 = linalg.indexed_generic #trait
-     ins(%arg0 : tensor<?x1x?xi32>)
-    outs(%shape: tensor<?x1x?x1x?xi32>) {
-       ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index,
-            %arg5 : index, %arg6 : i32, %arg7 : i32) :
-         %1 = addi %arg1, %arg2 : index
-         %2 = addi %1, %arg3 : index
-         %3 = addi %2, %arg4 : index
-         %4 = addi %3, %arg5 : index
-         %5 = index_cast %4 : index to i32
-         %6 = addi %5, %arg6 : i32
-         linalg.yield %6 : i32
-       } -> tensor<?x1x?x1x?xi32>
-  return %0 : tensor<?x1x?x1x?xi32>
-}
-// CHECK-LABEL: func @drop_one_trip_loops_indexed_generic
-//       CHECK:   linalg.indexed_generic
-//       CHECK:   ^{{.+}}(
-//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index
-//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32)
-//       CHECK:     %[[T3:.+]] = addi %[[ARG1]], %[[ARG2]]
-//       CHECK:     %[[T4:.+]] = addi %[[T3]], %[[ARG3]]
-//       CHECK:     %[[T5:.+]] = index_cast %[[T4]] : index to i32
-//       CHECK:     %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32
-//       CHECK:     linalg.yield %[[T6]] : i32
-
-// -----
-
-#accesses = [
-  affine_map<(i, j, k, l, m) -> (i, k, m)>,
-  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
-]
-
-#trait = {
-  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
-  indexing_maps = #accesses,
-  library_call = "some_external_func"
-}
-
 func @drop_one_trip_loops_indexed
   (%arg0 : tensor<?x1x?xi32>, %shape: tensor<?x1x?x1x?xi32>) -> tensor<?x1x?x1x?xi32>
 {
@@ -158,35 +116,6 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
   library_call = "some_external_func"
 }
 
-func @drop_all_loops_indexed_generic
-  (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>{
-  %0 = linalg.indexed_generic #trait
-     ins(%arg0 : tensor<1x1xi32>)
-    outs(%arg0 : tensor<1x1xi32>) {
-       ^bb0(%arg1 : index, %arg2 : index, %arg3: i32, %arg4: i32) :
-         %1 = addi %arg1, %arg2 : index
-         %2 = index_cast %1 : index to i32
-         %3 = addi %2, %arg3 : i32
-         linalg.yield %3 : i32
-       } -> tensor<1x1xi32>
-  return %0 : tensor<1x1xi32>
-}
-
-// CHECK-LABEL: func @drop_all_loops_indexed_generic
-//       CHECK:   linalg.indexed_generic
-//       CHECK:   ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
-//       CHECK:     linalg.yield %[[ARG1]] : i32
-
-// -----
-
-#map0 = affine_map<(i, j) -> (i, j)>
-#access = [#map0, #map0]
-#trait = {
-  iterator_types = ["parallel", "parallel"],
-  indexing_maps = #access,
-  library_call = "some_external_func"
-}
-
 func @drop_all_loops_indexed
   (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>{
   %0 = linalg.generic #trait