// folding patterns.
if (extractResultRank < broadcastSrcRank)
return failure();
+
+ // Special case if broadcast src is a 0D vector.
+ if (extractResultRank == 0) {
+ assert(broadcastSrcRank == 0 && source.getType().isa<VectorType>());
+ rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+ return success();
+ }
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
return success();
// -----
+// CHECK-LABEL: fold_extract_broadcast_0dvec
+// CHECK-SAME: %[[A:.*]]: vector<f32>
+// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
+// CHECK: return %[[B]] : f32
+func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
+ %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
+ %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+ return %r : f32
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_broadcast_negative
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
// CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>