[mlir][Linalg] Generalize the logic to compute reassociation maps
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 29 Sep 2020 23:14:49 +0000 (16:14 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Wed, 30 Sep 2020 14:58:06 +0000 (07:58 -0700)
while folding tensor_reshape op.

While folding reshapes that introduce unit extent dims, the logic to
compute the reassociation maps can be generalized to handle some
corner cases, for example, when the folded shape still has unit-extent
dims but corresponds to folded unit extent dims of the expanded shape.

Differential Revision: https://reviews.llvm.org/D88521

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

index 08e7e35..611c938 100644 (file)
@@ -403,61 +403,58 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
         srcType.getRank() < dstType.getRank() ||
         parentSrcType.getRank() == dstType.getRank())
       return failure();
+
     // Check if the result tensor_reshape after folding the reshapeOp and
     // parentReshapeOp are combined.
     // If the final tensor_reshape is folding, the parentReshapeOp is
     // introducing unit-dims, and the reshapeOp does an actual reshape.
-    // If the final tensor_reshape op is expanding, the reshapeOp is introducing
-    // unit-dims, and the parentReshapeOp does an actual reshape.
+    // If the final tensor_reshape op is expanding, the reshapeOp is
+    // introducing unit-dims, and the parentReshapeOp does an actual reshape.
     bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
-    auto reassociationMaps = isFoldingPattern
-                                 ? reshapeOp.getReassociationMaps()
-                                 : parentReshapeOp.getReassociationMaps();
-    DenseSet<unsigned> conservedDimensions;
-    for (auto &map : reassociationMaps) {
-      if (map.getNumResults() == 1) {
-        conservedDimensions.insert(
-            map.getResult(0).cast<AffineDimExpr>().getPosition());
-      }
-    }
-
-    // Find positions at which the unit-dims exist.
-    int64_t nonUnitDimPos = 0;
-    DenseMap<unsigned, unsigned> nonUnitSrcDims;
-    ArrayRef<int64_t> nonUnitShape =
+    ArrayRef<int64_t> expandedShape =
         isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
-    for (auto shape : enumerate(srcType.getShape())) {
-      // Case 1 : It is a conserved dimension.
-      if (conservedDimensions.count(shape.index())) {
-        nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
-        continue;
+    ArrayRef<int64_t> foldedShape =
+        isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();
+
+    unsigned expandedDim = 0, foldedDim = 0;
+    SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
+        foldedShape.size());
+    while (expandedDim < expandedShape.size() &&
+           foldedDim < foldedShape.size()) {
+      int64_t dstSize = foldedShape[foldedDim];
+      int64_t srcSize = expandedShape[expandedDim];
+      while (srcSize < dstSize && expandedDim < expandedShape.size()) {
+        reassociationExprs[foldedDim].push_back(
+            rewriter.getAffineDimExpr(expandedDim++));
+        srcSize *= expandedShape[expandedDim];
       }
-      // Case 2 : Dimensions dont match but the intermediate tensor is unit-dim.
-      if (shape.value() == 1)
-        continue;
-      // Case 3 : Dimensions match, treat it as a non-unit src dim.
-      if (nonUnitDimPos < static_cast<int64_t>(nonUnitShape.size()) &&
-          nonUnitShape[nonUnitDimPos] == shape.value()) {
-        nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
-        continue;
+      if (srcSize == dstSize) {
+        reassociationExprs[foldedDim].push_back(
+            rewriter.getAffineDimExpr(expandedDim++));
+        // If the next dim in foldedShape is not 1, treat subsequent dims in
+        // expandedShape which are 1 to be collapsed.
+        if (foldedDim == foldedShape.size() - 1 ||
+            foldedShape[foldedDim + 1] != 1) {
+          while (expandedDim < expandedShape.size() &&
+                 expandedShape[expandedDim] == 1) {
+            reassociationExprs[foldedDim].push_back(
+                rewriter.getAffineDimExpr(expandedDim++));
+          }
+        }
+      } else {
+        return failure();
       }
-      return failure();
+      foldedDim++;
     }
+    if (expandedDim != expandedShape.size())
+      return failure();
 
-    // Compute reassociation maps for the final operation. Use the reassociation
-    // maps that is actually doing a reshape (and not just introducing
-    // unit-dims). From these maps, prune the unit-extent dimensions.
-    for (AffineMap &map : reassociationMaps) {
-      SmallVector<AffineExpr, 4> exprs;
-      exprs.reserve(nonUnitSrcDims.size());
-      for (auto result : map.getResults()) {
-        unsigned dim = result.cast<AffineDimExpr>().getPosition();
-        if (nonUnitSrcDims.count(dim))
-          exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim]));
-      }
-      map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs,
-                           rewriter.getContext());
-    }
+    SmallVector<AffineMap, 4> reassociationMaps =
+        llvm::to_vector<4>(llvm::map_range(
+            reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
+              return AffineMap::get(expandedShape.size(), 0, exprs,
+                                    rewriter.getContext());
+            }));
     rewriter.replaceOpWithNewOp<TensorReshapeOp>(
         reshapeOp, dstType, parentReshapeOp.src(),
         rewriter.getAffineMapArrayAttr(reassociationMaps));
index 06e56c5..1793d2b 100644 (file)
@@ -240,3 +240,19 @@ func @fold_reshape(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32>
     : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
   return %1 : tensor<4x512x1x512x4xf32>
 }
+
+// -----
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//       CHECK: func @fold_reshape
+//       CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]
+//  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
+func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
+{
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32>
+  %1 = linalg.tensor_reshape %0
+  [affine_map<(d0, d1, d2) -> (d0)>,
+   affine_map<(d0, d1, d2) -> (d1, d2)>
+  ] : tensor<2x1x1xf32> into tensor<2x1xf32>
+  return %1 : tensor<2x1xf32>
+}