return Value();
}
+// Fold extractOp with source coming from ShapeCast op.
+static Value foldExtractFromShapeCast(ExtractOp extractOp) {
+ auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
+ if (!shapeCastOp)
+ return Value();
+ // Get the nth dimension size starting from lowest dimension.
+ auto getDimReverse = [](VectorType type, int64_t n) {
+ return type.getDimSize(type.getRank() - n - 1);
+ };
+ int64_t destinationRank =
+ extractOp.getVectorType().getRank() - extractOp.position().size();
+ if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
+ return Value();
+ if (destinationRank > 0) {
+ auto destinationType = extractOp.getResult().getType().cast<VectorType>();
+ for (int64_t i = 0; i < destinationRank; i++) {
+ // The lowest dimension of of the destination must match the lowest
+ // dimension of the shapecast op source.
+ if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
+ getDimReverse(destinationType, i))
+ return Value();
+ }
+ }
+ // Extract the strides associated with the extract op vector source. Then use
+ // this to calculate a linearized position for the extract.
+ auto extractedPos = extractVector<int64_t>(extractOp.position());
+ std::reverse(extractedPos.begin(), extractedPos.end());
+ SmallVector<int64_t, 4> strides;
+ int64_t stride = 1;
+ for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
+ strides.push_back(stride);
+ stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
+ }
+
+ int64_t position = linearize(extractedPos, strides);
+ // Then extract the strides assoociated to the shapeCast op vector source and
+ // delinearize the position using those strides.
+ SmallVector<int64_t, 4> newStrides;
+ int64_t numDimension =
+ shapeCastOp.getSourceVectorType().getRank() - destinationRank;
+ stride = 1;
+ for (int64_t i = 0; i < numDimension; i++) {
+ newStrides.push_back(stride);
+ stride *=
+ getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
+ }
+ std::reverse(newStrides.begin(), newStrides.end());
+ SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
+ OpBuilder b(extractOp.getContext());
+ extractOp.setAttr(ExtractOp::getPositionAttrName(),
+ b.getI64ArrayAttr(newPosition));
+ extractOp.setOperand(shapeCastOp.source());
+ return extractOp.getResult();
+}
+
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
return val;
if (auto val = foldExtractFromBroadcast(*this))
return val;
+ if (auto val = foldExtractFromShapeCast(*this))
+ return val;
return OpFoldResult();
}
// -----
+// CHECK-LABEL: func @fold_extract_shapecast
+// CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
+// CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>
+// CHECK: %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32>
+// CHECK: %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32>
+// CHECK: return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32>
+func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
+ %arg1 : vector<8x4x2xf32>)
+ -> (f32, vector<2xf32>, vector<4x2xf32>) {
+ %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
+ %1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32>
+ %r1 = vector.extract %0[4, 1] : vector<15x2xf32>
+ %r2 = vector.extract %0[5] : vector<15x2xf32>
+ %r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32>
+ return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_shapecast_negative
+// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
+// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32>
+// CHECK: return %[[R]] : vector<4x2xf32>
+func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
+ %arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
+ %r = vector.extract %0[1] : vector<2x4x2xf32>
+ return %r : vector<4x2xf32>
+}
+
+
+// -----
+
// CHECK-LABEL: fold_vector_transfers
func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
%c0 = constant 0 : index