[mlir][vector] Fold extract(broadcast) of same rank
authorLei Zhang <antiagainst@google.com>
Thu, 7 Apr 2022 16:59:09 +0000 (12:59 -0400)
committerLei Zhang <antiagainst@google.com>
Thu, 7 Apr 2022 16:59:54 +0000 (12:59 -0400)
This case is handled in neither the folding or canonicalization
patterns. The folding pattern cannot generate new broadcast ops,
so it should be handled by the canonicalization pattern.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D123307

mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index eda7739..07546c0 100644 (file)
@@ -1496,6 +1496,7 @@ public:
     Operation *defOp = extractOp.getVector().getDefiningOp();
     if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
       return failure();
+
     Value source = defOp->getOperand(0);
     if (extractOp.getType() == source.getType())
       return failure();
@@ -1504,10 +1505,10 @@ public:
     };
     unsigned broadcastSrcRank = getRank(source.getType());
     unsigned extractResultRank = getRank(extractOp.getType());
-    // We only consider the case where the rank of the source is smaller than
-    // the rank of the extract dst. The other cases are handled in the folding
-    // patterns.
-    if (extractResultRank <= broadcastSrcRank)
+    // We only consider the case where the rank of the source is less than or
+    // equal to the rank of the extract dst. The other cases are handled in the
+    // folding patterns.
+    if (extractResultRank < broadcastSrcRank)
       return failure();
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
index a083851..8b6640b 100644 (file)
@@ -566,6 +566,18 @@ func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcast
+//  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
+//       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
+//       CHECK:   return %[[R]] : vector<8xf32>
+func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+  %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
+  %r = vector.extract %b[0] : vector<1x8xf32>
+  return %r : vector<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract_shapecast
 //  CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
 //       CHECK:   %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>