[vector] When trimming leading insertion dimensions, base the final result on the...
authorBenjamin Kramer <benny.kra@googlemail.com>
Tue, 18 Apr 2023 14:53:17 +0000 (16:53 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Tue, 18 Apr 2023 16:49:29 +0000 (18:49 +0200)
This was incorrect when the number of dropped source dims was smaller
than the number of dropped dst dims. We still need to insert zeros if
there is anything dropped from the src.

Differential Revision: https://reviews.llvm.org/D148636

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

index 58fe63687b4fdd37c95824f99f551a0a81c39de7..849e0442bc7e1d51f51766133dc38a5c4b89a438 100644 (file)
@@ -169,10 +169,8 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
     SmallVector<Attribute> newPositions = llvm::to_vector(
         insertOp.getPosition().getValue().take_back(newPosRank));
-    if (srcDropCount >= dstDropCount) {
-      auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
-      newPositions.resize(newPosRank + srcDropCount, zeroAttr);
-    }
+    newPositions.resize(newDstType.getRank() - newSrcRank,
+                        rewriter.getI64IntegerAttr(0));
 
     auto newInsertOp = rewriter.create<vector::InsertOp>(
         loc, newDstType, newSrcVector, newDstVector,
index 587633ac203491cf17a7ba5f62b1ff16d699545c..0ee006e3df632efd00420ad75fee8b8daa4389b3 100644 (file)
@@ -316,3 +316,15 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %
   %0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32>
   return %0: vector<8x1x4xf32>
 }
+
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
+//  CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
+//       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x8xi1>
+//       CHECK:   %[[EXTRACTV:.+]] = vector.extract %[[V]][0, 0] : vector<1x1x8x1x8xi1>
+//       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<8xi1> into vector<8x1x8xi1>
+//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1>
+//       CHECK:   return %[[BCAST]]
+func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v: vector<1x1x8x1x8xi1>) -> vector<1x1x8x1x8xi1> {
+  %0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1>
+  return %0: vector<1x1x8x1x8xi1>
+}