return Value();
// Get the nth dimension size starting from lowest dimension.
auto getDimReverse = [](VectorType type, int64_t n) {
- return type.getDimSize(type.getRank() - n - 1);
+ return type.getShape().take_back(n+1).front();
};
int64_t destinationRank =
- extractOp.getVectorType().getRank() - extractOp.position().size();
+ extractOp.getType().isa<VectorType>()
+ ? extractOp.getType().cast<VectorType>().getRank()
+ : 0;
if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
return Value();
if (destinationRank > 0) {
for (int64_t i = 0; i < destinationRank; i++) {
// The lowest dimension of of the destination must match the lowest
// dimension of the shapecast op source.
+ // TODO: This case could be support in a canonicalization pattern.
if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
getDimReverse(destinationType, i))
return Value();
}
std::reverse(newStrides.begin(), newStrides.end());
SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
+ // OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setAttr(ExtractOp::getPositionAttrName(),
b.getI64ArrayAttr(newPosition));
}
// When the source of ExtractStrided comes from a chain of InsertStrided ops try
-// to use the source o the InsertStrided ops if we can detect that the extracted
-// vector is a subset of one of the vector inserted.
+// to use the source of the InsertStrided ops if we can detect that the
+// extracted vector is a subset of one of the vector inserted.
static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
// Helper to extract integer out of ArrayAttr.
// Case where we need to go through 2 level of insert element.
// CHECK-LABEL: extract_strided_fold_insert
-// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]]
-// CHECK-SAME: {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
+// CHECK-SAME: {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]}
// CHECK-SAME: : vector<1x4xf32> to vector<1x1xf32>
// CHECK-NEXT: return %[[EXT]] : vector<1x1xf32>
-func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>,
+func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
%c : vector<1x4xf32>) -> (vector<1x1xf32>) {
- %0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]}
- : vector<1x4xf32> into vector<2x4xf32>
+ %0 = vector.insert_strided_slice %b, %a {offsets = [0, 1], strides = [1, 1]}
+ : vector<1x4xf32> into vector<2x8xf32>
%1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]}
- : vector<1x4xf32> into vector<2x4xf32>
+ : vector<1x4xf32> into vector<2x8xf32>
%2 = vector.extract_strided_slice %1
{offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
- : vector<2x4xf32> to vector<1x1xf32>
+ : vector<2x8xf32> to vector<1x1xf32>
return %2 : vector<1x1xf32>
}