Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return failure();
+
Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return failure();
};
unsigned broadcastSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
- // We only consider the case where the rank of the source is smaller than
- // the rank of the extract dst. The other cases are handled in the folding
- // patterns.
- if (extractResultRank <= broadcastSrcRank)
+ // We only consider the case where the rank of the source is less than or
+ // equal to the rank of the extract dst. The other cases are handled in the
+ // folding patterns.
+ if (extractResultRank < broadcastSrcRank)
return failure();
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
// -----
+// CHECK-LABEL: fold_extract_broadcast
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>
+// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
+// CHECK: return %[[R]] : vector<8xf32>
+func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+ %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
+ %r = vector.extract %b[0] : vector<1x8xf32>
+ return %r : vector<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_shapecast
// CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
// CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>