return Value();
}
+/// Fold extractOp with scalar result coming from BroadcastOp.
+static Value foldExtractFromBroadcast(ExtractOp extractOp) {
+ auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcastOp)
+ return Value();
+ if (extractOp.getType() == broadcastOp.getSourceType())
+ return broadcastOp.source();
+ auto getRank = [](Type type) {
+ return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
+ };
+ unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
+ unsigned extractResultRank = getRank(extractOp.getType());
+ if (extractResultRank < broadcasrSrcRank) {
+ auto extractPos = extractVector<int64_t>(extractOp.position());
+ unsigned rankDiff = broadcasrSrcRank - extractResultRank;
+ extractPos.erase(
+ extractPos.begin(),
+ std::next(extractPos.begin(), extractPos.size() - rankDiff));
+ extractOp.setOperand(broadcastOp.source());
+ // OpBuilder is only used as a helper to build an I64ArrayAttr.
+ OpBuilder b(extractOp.getContext());
+ extractOp.setAttr(ExtractOp::getPositionAttrName(),
+ b.getI64ArrayAttr(extractPos));
+ return extractOp.getResult();
+ }
+ // TODO: In case the rank of the broadcast source is greater than the rank of
+ // the extract result this can be combined into a new broadcast op. This needs
+ // to be added a canonicalization pattern if needed.
+ return Value();
+}
+
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
return getResult();
if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
return val;
+ if (auto val = foldExtractFromBroadcast(*this))
+ return val;
return OpFoldResult();
}
// -----
+// CHECK-LABEL: fold_extract_broadcast
+// CHECK-SAME: %[[A:.*]]: f32
+// CHECK: return %[[A]] : f32
+func @fold_extract_broadcast(%a : f32) -> f32 {
+ %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
+ %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+ return %r : f32
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast_vector
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>
+// CHECK: return %[[A]] : vector<4xf32>
+func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
+ %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[0, 1] : vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>
+// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : vector<4xf32>
+// CHECK: return %[[R]] : f32
+func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
+ %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+ return %r : f32
+}
+
+// -----
+
+// Negative test for extract_op folding when the type of broadcast source
+// doesn't match the type of vector.extract.
+// CHECK-LABEL: fold_extract_broadcast_negative
+// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<1x2x4xf32>
+// CHECK: %[[R:.*]] = vector.extract %[[B]][0, 1] : vector<1x2x4xf32>
+// CHECK: return %[[R]] : vector<4xf32>
+func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
+ %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
+ %r = vector.extract %b[0, 1] : vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfers
func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
%c0 = constant 0 : index