[mlir][linalg] Fixed issue generating reassociation map with Rank-0 types
authorRob Suderman <rob.suderman@gmail.com>
Tue, 11 May 2021 22:31:20 +0000 (15:31 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Wed, 12 May 2021 18:00:51 +0000 (11:00 -0700)
Rank-0 case causes a graph during linalg reshape operation.

Differential Revision: https://reviews.llvm.org/D102282

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index 5f635c7..1801f27 100644 (file)
@@ -1134,10 +1134,12 @@ mlir::linalg::getReassociationIndicesForReshape(ShapedType sourceType,
       return llvm::None;
 
     currIndices.push_back(sourceDim++);
-    // If there are no dimensions in the target to match, then append the
-    // `currIndices` to the last element of the reassociationMap.
+    // If the reassociation is empty but the currIndices is not, this by
+    // definition is folding unit-dimensions with the result being scalar type.
+    // So only append the `currIndices` if reassociation map is not empty.
     if (targetDim == targetShape.size()) {
-      reassociationMap.back().append(currIndices.begin(), currIndices.end());
+      if (!reassociationMap.empty() && !currIndices.empty())
+        reassociationMap.back().append(currIndices.begin(), currIndices.end());
       // Break out of the loops. We should be done here.
       break;
     }
index 536e361..cb5f54a 100644 (file)
@@ -43,6 +43,17 @@ func @memref_cast_into_tiled_loop(%arg0: memref<192xf32>)  {
 
 // -----
 
+// CHECK-LABEL: zero_rank_reshape_multi
+func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
+  // CHECK: return %arg0
+  %0 = linalg.tensor_reshape %arg0 [] : tensor<f32> into tensor<1xf32>
+  %1 = linalg.tensor_reshape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32>
+  %2 = linalg.tensor_reshape %1 [] : tensor<1x1xf32> into tensor<f32>
+  return %2 : tensor<f32>
+}
+
+// -----
+
 func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32>
 {
   %0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4]]