const size_t numIndices = extractOp.getIndices().size();
for (size_t i = 1; i < numIndices; i++) {
- auto dimSizeBcast = b.create<vector::BroadcastOp>(
- loc, indexVecType,
+ auto dimSize = broadcastIfNeeded(
+ b,
b.create<arith::ConstantIndexOp>(
loc,
- extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)));
- offset = b.create<arith::MulIOp>(loc, offset, dimSizeBcast);
-
- auto originalIndexBcast = bvm.lookup(extractOp.getIndices()[i]);
- if (i == numIndices - 1) {
- // We only need an additional broadcast for the trailing index. All other
- // indices have already been broadcast by `vectorizeLinalgIndex` to match
- // the output size.
- originalIndexBcast = b.create<vector::BroadcastOp>(
- loc, indexVecType, bvm.lookup(extractOp.getIndices()[i]));
- }
+ extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
+ indexVecType.getShape());
+
+ offset = b.create<arith::MulIOp>(loc, offset, dimSize);
+
+ auto extractOpIndex = broadcastIfNeeded(
+ b, bvm.lookup(extractOp.getIndices()[i]), indexVecType.getShape());
- offset = b.create<arith::AddIOp>(loc, originalIndexBcast, offset);
+ offset = b.create<arith::AddIOp>(loc, extractOpIndex, offset);
}
return offset;
// -----
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+ %c0 = arith.constant 1 : index
+ %c1 = arith.constant 2 : index
+ %2 = linalg.generic {
+ indexing_maps = [#map1],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } outs(%arg2 : tensor<1x1x3xf32>) {
+ ^bb0(%arg4: f32):
+ %3 = linalg.index 2 : index
+ %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32>
+ linalg.yield %7 : f32
+ } -> tensor<1x1x3xf32>
+ return %2 : tensor<1x1x3xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx
+// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
+// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// Magic "5" below comes from (1 * 3 + 2) (1: index into dim 1, 2: index into dim 2)
+// CHECK: %[[IDX:.*]] = arith.constant dense<5> : vector<1x1x3xindex>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[IDX]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
+// CHECK: vector.transfer_write %[[GATHER]]
+// CHECK: }
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+ %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+ }
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+ %1 = linalg.generic {
+ indexing_maps = [#map1],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } outs(%arg2 : tensor<1x1x3xf32>) {
+ ^bb0(%arg4: f32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = linalg.index 2 : index
+ %5 = tensor.extract %arg0[%2, %3, %4] : tensor<3x3x3xf32>
+ linalg.yield %5 : f32
+ } -> tensor<1x1x3xf32>
+ return %1 : tensor<1x1x3xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index
+// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
+// CHECK: %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
+// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
+// CHECK: vector.transfer_write %[[GATHER]]
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+ %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+}
+
+// -----
+
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
+func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
%2 = linalg.generic {
indexing_maps = [#map0, #map0, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} -> tensor<4x7x3x2xf32>
return %2 : tensor<4x7x3x2xf32>
}
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_index_from_tensor
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32>
// CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32>
// CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32>