if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
return failure();
- if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
- return failure();
+ // Check the index type, but only for non 0-d tensors (for which we do need
+ // access indices).
+ if (not extractOp.getIndices().empty()) {
+ if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
+ return failure();
+ }
if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
return !VectorType::isValidElementType(type);
LinalgOp &linalgOp) {
auto targetShape = linalgOp.getStaticLoopRanges();
+ auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
+
+ // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
+ if (inputShape.getShape().empty())
+ return VectorMemoryAccessKind::ScalarBroadcast;
+
// 1. Assume that it's a gather load when reading _into_:
// * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
targetShape.back() == 1)
return VectorMemoryAccessKind::Gather;
- auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
// 2. Assume that it's a gather load when reading _from_ a tensor for which
// the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
%1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_0d_tensor_extract(%arg0: tensor<f32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+ %2 = linalg.generic {
+ indexing_maps = [#map1],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } outs(%arg2 : tensor<1x1x3xf32>) {
+ ^bb0(%arg4: f32):
+ %7 = tensor.extract %arg0[] : tensor<f32>
+ linalg.yield %7 : f32
+ } -> tensor<1x1x3xf32>
+ return %2 : tensor<1x1x3xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_0d_tensor_extract(
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
+// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]][] : tensor<f32>
+// CHECK: vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+ }