}
};
+// Trims leading one dimensions from `oldType` and returns the result type.
+// Returns `vector<1xT>` if `oldType` only has one element.
+static VectorType trimLeadingOneDims(VectorType oldType) {
+ ArrayRef<int64_t> oldShape = oldType.getShape();
+ ArrayRef<int64_t> newShape =
+ oldShape.drop_while([](int64_t dim) { return dim == 1; });
+ // Make sure we have at least 1 dimension per vector type requirements.
+ if (newShape.empty())
+ newShape = oldShape.take_back();
+ return VectorType::get(newShape, oldType.getElementType());
+}
+
+// Casts away leading one dimensions in vector.extract_strided_slice's vector
+// input by inserting vector.shape_cast.
+struct CastAwayExtractStridedSliceLeadingOneDim
+ : public OpRewritePattern<vector::ExtractStridedSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // vector.extract_strided_slice requires the input and output vector to have
+ // the same rank. Here we drop leading one dimensions from the input vector
+ // type to make sure we don't cause mismatch.
+ VectorType oldSrcType = extractOp.getVectorType();
+ VectorType newSrcType = trimLeadingOneDims(oldSrcType);
+
+ if (newSrcType.getRank() == oldSrcType.getRank())
+ return failure();
+
+ int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
+
+ VectorType oldDstType = extractOp.getType();
+ VectorType newDstType =
+ VectorType::get(oldDstType.getShape().drop_front(dropCount),
+ oldDstType.getElementType());
+
+ Location loc = extractOp.getLoc();
+
+ Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
+ loc, newSrcType, extractOp.vector());
+
+ // The offsets/sizes/strides attribute can have a less number of elements
+ // than the input vector's rank: it is meant for the leading dimensions.
+ auto newOffsets = rewriter.getArrayAttr(
+ extractOp.offsets().getValue().drop_front(dropCount));
+ auto newSizes = rewriter.getArrayAttr(
+ extractOp.sizes().getValue().drop_front(dropCount));
+ auto newStrides = rewriter.getArrayAttr(
+ extractOp.strides().getValue().drop_front(dropCount));
+
+ auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
+ newExtractOp);
+
+ return success();
+ }
+};
+
+// Casts away leading one dimensions in vector.extract_strided_slice's vector
+// inputs by inserting vector.shape_cast.
+struct CastAwayInsertStridedSliceLeadingOneDim
+ : public OpRewritePattern<vector::InsertStridedSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
+ PatternRewriter &rewriter) const override {
+ VectorType oldSrcType = insertOp.getSourceVectorType();
+ VectorType newSrcType = trimLeadingOneDims(oldSrcType);
+ VectorType oldDstType = insertOp.getDestVectorType();
+ VectorType newDstType = trimLeadingOneDims(oldDstType);
+
+ if (newSrcType.getRank() == oldSrcType.getRank() &&
+ newDstType.getRank() == oldDstType.getRank())
+ return failure();
+
+ // Trim leading one dimensions from both operands.
+ Location loc = insertOp.getLoc();
+
+ Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
+ loc, newSrcType, insertOp.source());
+ Value newDstVector =
+ rewriter.create<vector::ShapeCastOp>(loc, newDstType, insertOp.dest());
+
+ auto newOffsets = rewriter.getArrayAttr(
+ insertOp.offsets().getValue().take_back(newDstType.getRank()));
+ auto newStrides = rewriter.getArrayAttr(
+ insertOp.strides().getValue().take_back(newSrcType.getRank()));
+
+ auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+ newInsertOp);
+
+ return success();
+ }
+};
+
+// Turns vector.transfer_read on vector with leading 1 dimensions into
+// vector.shape_cast followed by vector.transfer_read on vector without leading
+// 1 dimensions.
+struct CastAwayTransferReadLeadingOneDim
+ : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp read,
+ PatternRewriter &rewriter) const override {
+ auto shapedType = read.source().getType().cast<ShapedType>();
+ if (shapedType.getElementType() != read.getVectorType().getElementType())
+ return failure();
+
+ VectorType oldType = read.getVectorType();
+ VectorType newType = trimLeadingOneDims(oldType);
+
+ if (newType == oldType)
+ return failure();
+
+ AffineMap oldMap = read.permutation_map();
+ ArrayRef<AffineExpr> newResults =
+ oldMap.getResults().take_back(newType.getRank());
+ AffineMap newMap =
+ AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
+ rewriter.getContext());
+
+ ArrayAttr mask;
+ if (read.masked())
+ mask = rewriter.getArrayAttr(
+ read.maskedAttr().getValue().take_back(newType.getRank()));
+
+ auto newRead = rewriter.create<vector::TransferReadOp>(
+ read.getLoc(), newType, read.source(), read.indices(), newMap,
+ read.padding(), mask);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
+
+ return success();
+ }
+};
+
+// Turns vector.transfer_write on vector with leading 1 dimensions into
+// vector.shape_cast followed by vector.transfer_write on vector without leading
+// 1 dimensions.
+struct CastAwayTransferWriteLeadingOneDim
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+ PatternRewriter &rewriter) const override {
+ auto shapedType = write.source().getType().dyn_cast<ShapedType>();
+ if (shapedType.getElementType() != write.getVectorType().getElementType())
+ return failure();
+
+ VectorType oldType = write.getVectorType();
+ VectorType newType = trimLeadingOneDims(oldType);
+
+ if (newType == oldType)
+ return failure();
+
+ AffineMap oldMap = write.permutation_map();
+ ArrayRef<AffineExpr> newResults =
+ oldMap.getResults().take_back(newType.getRank());
+ AffineMap newMap =
+ AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
+ rewriter.getContext());
+
+ ArrayAttr mask;
+ if (write.masked())
+ mask = rewriter.getArrayAttr(
+ write.maskedAttr().getValue().take_back(newType.getRank()));
+
+ auto newVector = rewriter.create<vector::ShapeCastOp>(
+ write.getLoc(), newType, write.vector());
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ write, newVector, write.source(), write.indices(), newMap, mask);
+
+ return success();
+ }
+};
+
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
// clang-format on
}
+void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<CastAwayExtractStridedSliceLeadingOneDim,
+ CastAwayInsertStridedSliceLeadingOneDim,
+ CastAwayTransferReadLeadingOneDim,
+ CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
+ context);
+}
+
void mlir::vector::populateVectorSlicesLoweringPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
: vector<4x4xf32>, tensor<4x4xf32>
return %r : tensor<4x4xf32>
}
+
+// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
+func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
+ // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+ // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
+ %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
+ // CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
+ // CHECK: return %[[RET]]
+ return %0: vector<1x1x8xf16>
+}
+
+// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
+func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
+ // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16>
+ // CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+ // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
+ %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
+ // CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
+ // CHECK: return %[[RET]]
+ return %0: vector<1x8x8xf16>
+}
+
+// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
+func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
+ // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
+ // CHECK: vector.shape_cast %{{.+}} : vector<1x1x1xf16> to vector<1xf16>
+ %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
+ return %0: vector<1x1x1xf16>
+}
+
+// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims
+func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> {
+ // CHECK: %[[C0:.+]] = constant 0 : index
+ %c0 = constant 0 : index
+ // CHECK: %[[F0:.+]] = constant 0.000000e+00 : f16
+ %f0 = constant 0. : f16
+ // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {masked = [false]} : memref<1x4x8x16xf16>, vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+ %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x4x8x16xf16>, vector<1x4xf16>
+ // CHECK: return %[[CAST]]
+ return %0: vector<1x4xf16>
+}
+
+// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
+func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
+ %c0 = constant 0 : index
+ %f0 = constant 0. : f16
+ // CHECK: vector.shape_cast %{{.+}} : vector<1xf16> to vector<1x1xf16>
+ %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x1x1x1xf16>, vector<1x1xf16>
+ return %0: vector<1x1xf16>
+}
+
+// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
+func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
+ // CHECK: %[[C0:.+]] = constant 0 : index
+ %c0 = constant 0 : index
+ // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
+ // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {masked = [false]} : vector<4xf16>, memref<1x4x8x16xf16>
+
+ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf16>, memref<1x4x8x16xf16>
+ return
+}
+
+// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
+func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
+ %c0 = constant 0 : index
+ // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
+ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x1xf16>, memref<1x1x1x1xf16>
+ return
+}