[mlir] Handle an edge case when folding reshapes with multiple trailing 1 dimensions
authorBenjamin Kramer <benny.kra@googlemail.com>
Mon, 29 Nov 2021 15:52:37 +0000 (16:52 +0100)
committerBenjamin Kramer <benny.kra@googlemail.com>
Mon, 29 Nov 2021 17:31:43 +0000 (18:31 +0100)
We would exit early and miss this case.

Differential Revision: https://reviews.llvm.org/D114711

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index 919415a..f1acfb3 100644 (file)
@@ -73,6 +73,8 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
     // definition is folding unit-dimensions with the result being scalar type.
     // So only append the `currIndices` if reassociation map is not empty.
     if (targetDim == targetShape.size()) {
+      while (sourceDim < sourceShape.size())
+        currIndices.push_back(sourceDim++);
       if (!reassociationMap.empty() && !currIndices.empty())
         reassociationMap.back().append(currIndices.begin(), currIndices.end());
       // Break out of the loops. We should be done here.
index e938f8c..a64bd64 100644 (file)
@@ -279,6 +279,20 @@ func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) ->
 
 // -----
 
+func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) -> tensor<12x42xf32>
+{
+  %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2], [3, 4]]
+      : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
+  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3, 4]]
+      : tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
+  return %1 : tensor<12x42xf32>
+}
+//       CHECK: func @fold_reshape_trailing_unit_dims
+//       CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
+//  CHECK-SAME:   tensor<12x42x1x1xf32> into tensor<12x42xf32>
+
+// -----
+
 func @no_fold_reshapes(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32>
 {
   %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]]