b.getI64ArrayAttr(extractPos));
return extractOp.getResult();
}
- // TODO: In case the rank of the broadcast source is greater than the rank of
- // the extract result this can be combined into a new broadcast op. This needs
- // to be added a canonicalization pattern if needed.
return Value();
}
namespace {
+// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
+class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern<ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ Operation *defOp = extractOp.vector().getDefiningOp();
+ if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+ return failure();
+ Value source = defOp->getOperand(0);
+ if (extractOp.getType() == source.getType())
+ return failure();
+ auto getRank = [](Type type) {
+ return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
+ };
+ unsigned broadcasrSrcRank = getRank(source.getType());
+ unsigned extractResultRank = getRank(extractOp.getType());
+ // We only consider the case where the rank of the source is smaller than
+ // the rank of the extract dst. The other cases are handled in the folding
+ // patterns.
+ if (extractResultRank <= broadcasrSrcRank)
+ return failure();
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ extractOp, extractOp.getType(), source);
+ return success();
+ }
+};
+
+// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
+class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern<ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // Return if 'extractStridedSliceOp' operand is not defined by a
+ // ConstantOp.
+ auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>();
+ if (!constantOp)
+ return failure();
+ auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
+ if (!dense)
+ return failure();
+ Attribute newAttr = dense.getSplatValue<Attribute>();
+ if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
+ newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // ExtractToShapeCast is not a default canonicalization, it is opt-in by
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
+ results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
}
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
return success();
}
+namespace {
+
+// If insertOp is only inserting unit dimensions it can be transformed to a
+// broadcast.
+class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
+ if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
+ srcVecType.getNumElements())
+ return failure();
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ insertOp, insertOp.getDestVectorType(), insertOp.source());
+ return success();
+ }
+};
+
+} // namespace
+
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // InsertToShapeCast is not a default canonicalization, it is opt-in by
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
+ results.add<InsertToBroadcast, BroadcastFolder>(context);
}
// Eliminates insert operations that produce values identical to their source
return VectorType::get(newShape, oldType.getElementType());
}
+/// Return a smallVector of size `rank` containing all zeros.
+static SmallVector<int64_t> splatZero(int64_t rank) {
+ return SmallVector<int64_t>(rank, 0);
+}
+
// Casts away leading one dimensions in vector.extract_strided_slice's vector
// input by inserting vector.shape_cast.
struct CastAwayExtractStridedSliceLeadingOneDim
Location loc = extractOp.getLoc();
- Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
- loc, newSrcType, extractOp.vector());
+ Value newSrcVector = rewriter.create<vector::ExtractOp>(
+ loc, extractOp.vector(), splatZero(dropCount));
// The offsets/sizes/strides attribute can have a less number of elements
// than the input vector's rank: it is meant for the leading dimensions.
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
newExtractOp);
return success();
VectorType oldDstType = insertOp.getDestVectorType();
VectorType newDstType = trimLeadingOneDims(oldDstType);
- if (newSrcType.getRank() == oldSrcType.getRank() &&
- newDstType.getRank() == oldDstType.getRank())
+ int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
+ int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
+ if (srcDropCount == 0 && dstDropCount == 0)
return failure();
// Trim leading one dimensions from both operands.
Location loc = insertOp.getLoc();
- Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
- loc, newSrcType, insertOp.source());
- Value newDstVector =
- rewriter.create<vector::ShapeCastOp>(loc, newDstType, insertOp.dest());
+ Value newSrcVector = rewriter.create<vector::ExtractOp>(
+ loc, insertOp.source(), splatZero(srcDropCount));
+ Value newDstVector = rewriter.create<vector::ExtractOp>(
+ loc, insertOp.dest(), splatZero(dstDropCount));
auto newOffsets = rewriter.getArrayAttr(
insertOp.offsets().getValue().take_back(newDstType.getRank()));
auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
newInsertOp);
return success();
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), newType, read.source(), read.indices(), newMap,
read.padding(), inBounds);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
return success();
}
VectorType oldType = write.getVectorType();
VectorType newType = trimLeadingOneDims(oldType);
-
if (newType == oldType)
return failure();
+ int64_t dropDim = oldType.getRank() - newType.getRank();
AffineMap oldMap = write.permutation_map();
ArrayRef<AffineExpr> newResults =
inBounds = rewriter.getArrayAttr(
write.in_boundsAttr().getValue().take_back(newType.getRank()));
- auto newVector = rewriter.create<vector::ShapeCastOp>(
- write.getLoc(), newType, write.vector());
+ auto newVector = rewriter.create<vector::ExtractOp>(
+ write.getLoc(), write.vector(), splatZero(dropDim));
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.source(), write.indices(), newMap, inBounds);
}
};
-template <typename BroadCastType>
-struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern<BroadCastType> {
- using OpRewritePattern<BroadCastType>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BroadCastType broadcastOp,
- PatternRewriter &rewriter) const override {
- VectorType dstType =
- broadcastOp.getResult().getType().template dyn_cast<VectorType>();
- if (!dstType)
- return failure();
- VectorType newDstType = trimLeadingOneDims(dstType);
- if (newDstType == dstType)
- return failure();
- Location loc = broadcastOp.getLoc();
- Value source = broadcastOp->getOperand(0);
- VectorType srcVecType = source.getType().template dyn_cast<VectorType>();
- if (srcVecType)
- srcVecType = trimLeadingOneDims(srcVecType);
- if (srcVecType && srcVecType != source.getType()) {
- source = rewriter.create<vector::ShapeCastOp>(loc, srcVecType, source);
- }
- Value newBroadcastOp =
- rewriter.create<BroadCastType>(loc, newDstType, source);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcastOp, dstType,
- newBroadcastOp);
- return success();
- }
-};
-
class CastAwayElementwiseLeadingOneDim : public RewritePattern {
public:
CastAwayElementwiseLeadingOneDim(MLIRContext *context)
VectorType newVecType = trimLeadingOneDims(vecType);
if (newVecType == vecType)
return failure();
-
+ int64_t dropDim = vecType.getRank() - newVecType.getRank();
SmallVector<Value, 4> newOperands;
for (Value operand : op->getOperands()) {
if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
- auto newType =
- VectorType::get(newVecType.getShape(), opVecType.getElementType());
- newOperands.push_back(rewriter.create<vector::ShapeCastOp>(
- op->getLoc(), newType, operand));
+ newOperands.push_back(rewriter.create<vector::ExtractOp>(
+ op->getLoc(), operand, splatZero(dropDim)));
} else {
newOperands.push_back(operand);
}
state.addOperands(newOperands);
state.addTypes(newVecType);
Operation *newOp = rewriter.createOperation(state);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, vecType,
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
newOp->getResult(0));
return success();
}
};
-// If extractOp is only removing unit dimensions it can be transformed to a
-// shapecast.
-class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern<ExtractOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
- if (!dstVecType || extractOp.getVectorType().getNumElements() !=
- dstVecType.getNumElements())
- return failure();
- rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
- extractOp.vector());
- return success();
- }
-};
-
-// If insertOp is only inserting unit dimensions it can be transformed to a
-// shapecast.
-class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
-public:
- using OpRewritePattern<InsertOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(InsertOp insertOp,
- PatternRewriter &rewriter) const override {
- auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
- if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
- srcVecType.getNumElements())
- return failure();
- rewriter.replaceOpWithNewOp<ShapeCastOp>(
- insertOp, insertOp.getDestVectorType(), insertOp.source());
- return success();
- }
-};
-
-// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
-// the degenerated case where the broadcast only adds dimensions of size 1 it
-// can be replaced by a ShapeCastOp. This canonicalization checks if the total
-// number of elements is the same before and after the broadcast to detect if
-// the only change in the vector type are new dimensions of size 1.
-class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
-public:
- using OpRewritePattern<BroadcastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
- PatternRewriter &rewriter) const override {
- auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
- if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
- srcVecType.getNumElements())
- return failure();
- rewriter.replaceOpWithNewOp<ShapeCastOp>(
- broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
- return success();
- }
-};
-
// Returns the values in `arrayAttr` as an integer vector.
static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
RewritePatternSet &patterns) {
- patterns.add<
- BroadcastToShapeCast, CastAwayExtractStridedSliceLeadingOneDim,
- CastAwayInsertStridedSliceLeadingOneDim,
- CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim,
- CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
- CastAwayBroadcastLeadingOneDim<SplatOp>, CastAwayElementwiseLeadingOneDim,
- ExtractToShapeCast, InsertToShapeCast>(patterns.getContext());
+ patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
+ CastAwayInsertStridedSliceLeadingOneDim,
+ CastAwayTransferReadLeadingOneDim,
+ CastAwayTransferWriteLeadingOneDim,
+ CastAwayElementwiseLeadingOneDim>(patterns.getContext());
populateShapeCastFoldingPatterns(patterns);
}
// -----
-// Negative test for extract_op folding when the type of broadcast source
-// doesn't match the type of vector.extract.
-// CHECK-LABEL: fold_extract_broadcast_negative
-// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<1x2x4xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[B]][0, 1] : vector<1x2x4xf32>
-// CHECK: return %[[R]] : vector<4xf32>
-func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
+// CHECK-LABEL: fold_extract_broadcast
+// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
+// CHECK: return %[[B]] : vector<4xf32>
+func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[0, 1] : vector<1x2x4xf32>
return %r : vector<4xf32>
vector<16x4xf16> to vector<2x4xf16>
return %1 : vector<2x4xf16>
}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1x4xf32>
+// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+ %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+ %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
+ %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+ return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_constant
+// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
+// CHECK: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32
+func @extract_constant() -> (vector<7xf32>, i32) {
+ %cst = arith.constant dense<2.000000e+00> : vector<29x7xf32>
+ %cst_1 = arith.constant dense<1> : vector<4x37x9xi32>
+ %0 = vector.extract %cst[2] : vector<29x7xf32>
+ %1 = vector.extract %cst_1[1, 4, 5] : vector<4x37x9xi32>
+ return %0, %1 : vector<7xf32>, i32
+}
+++ /dev/null
-// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
-
-// CHECK-LABEL: broadcast_to_shapecast
-// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
-// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16>
-func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
- %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
- return %0 : vector<1x4x4xf16>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_extract_to_shapecast
-// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
-// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
-// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>,
- %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
- %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
- %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
- return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
-}
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
- // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+ // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
%0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
- // CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
+ // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
// CHECK: return %[[RET]]
return %0: vector<1x1x8xf16>
}
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
- // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16>
- // CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+ // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8xf16>
+ // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
- // CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
+ // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
// CHECK: return %[[RET]]
return %0: vector<1x8x8xf16>
}
// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
// CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] : vector<1x1xf16> to vector<1x1x1xf16>
+ // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1x1xf16>
+ // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16>
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
- // CHECK: return %[[CAST]]
+ // CHECK: return %[[B]]
return %0: vector<1x1x1xf16>
}
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
%f0 = arith.constant 0. : f16
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
// CHECK: return %[[CAST]]
return %0: vector<1x4xf16>
func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0. : f16
- // CHECK: vector.shape_cast %{{.+}} : vector<1xf16> to vector<1x1xf16>
+ // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16>
return %0: vector<1x1xf16>
}
func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
- // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<1x4xf16>
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
%c0 = arith.constant 0 : index
- // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
+ // CHECK: vector.extract %{{.+}}[0] : vector<1x1xf16>
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16>
return
}
-// CHECK-LABEL: func @cast_away_broadcast_leading_one_dims
-func @cast_away_broadcast_leading_one_dims(
- %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
- (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) {
- // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
- %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
- // CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32>
- %1 = vector.broadcast %arg1 : f32 to vector<1x1x4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<3x4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<3x4xf32> to vector<1x3x4xf32>
- %2 = vector.broadcast %arg2 : vector<1x4xf32> to vector<1x3x4xf32>
- // CHECK: splat %{{.*}} : vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32>
- %3 = splat %arg1 : vector<1x1x4xf32>
- return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>
-}
-
// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
func @cast_away_elementwise_leading_one_dims(
%arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
%arg3: vector<1x4xf32>, %arg4: i1) ->
(vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) {
- // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32>
+ // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
+ // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
+ // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
%0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
// CHECK: arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xi1> to vector<1x4xi1>
+ // CHECK: vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1>
%1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
// CHECK: select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+ // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
%2 = select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
+ // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32>
// CHECK: select %arg4, %12, %{{.*}} : vector<4xf32>
- // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32>
+ // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
%3 = select %arg4, %arg3, %arg2 : vector<1x4xf32>
return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
}