auto targetShape = linalgOp.getStaticLoopRanges();
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
- // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
+ // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
if (inputShape.getShape().empty())
return VectorMemoryAccessKind::ScalarBroadcast;
+ // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
+ // gather load.
+ // TODO: Relax this condition.
+ if (linalgOp.hasDynamicShape())
+ return VectorMemoryAccessKind::Gather;
// 1. Assume that it's a gather load when reading _into_:
// * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
// * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
// TODO: Relax these conditions.
+ // FIXME: This condition assumes non-dynamic sizes.
if ((llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
targetShape.back() == 1)
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract } : !transform.any_op
}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @tensor_extract_dynamic_shape(%arg1: tensor<123x321xf32>, %arg2: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> {
+ %c0 = arith.constant 1 : index
+ %c1 = arith.constant 2 : index
+ %2 = linalg.generic {
+ indexing_maps = [#map1],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } outs(%arg2 : tensor<1x?x8xf32>)
+ {
+ ^bb0(%arg3: f32):
+ %idx_0 = linalg.index 0 : index
+ %idx_1 = linalg.index 1 : index
+ %idx = arith.addi %idx_0, %idx_1 : index
+ %7 = tensor.extract %arg1[%c0, %idx] : tensor<123x321xf32>
+ linalg.yield %7 : f32
+ } -> tensor<1x?x8xf32>
+ return %2 : tensor<1x?x8xf32>
+}
+
+// TODO: Make sure that this is vectorized as "scalar broadcast" when only
+// vectorising the 2nd dimension.
+// CHECK-LABEL: func.func @tensor_extract_dynamic_shape(
+// CHECK-SAME: %[[ARG_1:.*]]: tensor<123x321xf32>,
+// CHECK-SAME: %[[ARG_2:.*]]: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> {
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_2]], %[[C1_2]] : tensor<1x?x8xf32>
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1_1]], %[[DIM]], %[[C8]] : vector<1x3x8xi1>
+// CHECK: %[[MASK_2:.*]] = arith.constant dense<true> : vector<1x3x8xi1>
+// CHECK: %[[FALLTHROUGH:.*]] = arith.constant dense<0.000000e+00> : vector<1x3x8xf32>
+// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MASK]] { vector.gather %[[ARG_1]][%[[C0_1]], %[[C0_1]]] [%{{.*}}], %[[MASK_2]], %[[FALLTHROUGH]] : tensor<123x321xf32>, vector<1x3x8xindex>, vector<1x3x8xi1>, vector<1x3x8xf32> into vector<1x3x8xf32> } : vector<1x3x8xi1> -> vector<1x3x8xf32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.masked_vectorize %0 vector_sizes [1, 3, 8] { vectorize_nd_extract } : !transform.any_op
+}