llvm::SmallDenseSet<unsigned> &dimsToProject);
/// Pattern to rewrite a subview op with constant arguments.
-template <typename OpType, typename CastOpFunc>
+template <typename OpType, typename ResultTypeFunc, typename CastOpFunc>
class OpWithOffsetSizesAndStridesConstantArgumentFolder final
: public OpRewritePattern<OpType> {
public:
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
// Create the new op in canonical form.
- auto newOp = rewriter.create<OpType>(op.getLoc(), op.source(), mixedOffsets,
- mixedSizes, mixedStrides);
+ ResultTypeFunc resultTypeFunc;
+ auto resultType =
+ resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
+ auto newOp =
+ rewriter.create<OpType>(op.getLoc(), resultType, op.source(),
+ mixedOffsets, mixedSizes, mixedStrides);
CastOpFunc func;
func(rewriter, op, newOp);
return res;
}
+/// Infer the canonical type of the result of a subview operation. Returns a
+/// type with rank `resultRank` that is either the rank of the rank-reduced
+/// type, or the non-rank-reduced type.
+static MemRefType
+getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
+ ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedSizes,
+ ArrayRef<OpFoldResult> mixedStrides) {
+ auto resultType =
+ SubViewOp::inferRankReducedResultType(
+ resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
+ .cast<MemRefType>();
+ if (resultType.getRank() != resultRank) {
+ resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
+ mixedSizes, mixedStrides)
+ .cast<MemRefType>();
+ }
+ return resultType;
+}
+
namespace {
/// Pattern to rewrite a subview op with MemRefCast arguments.
/// This essentially pushes memref.cast past its consuming subview when
/// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
- auto resultType = SubViewOp::inferRankReducedResultType(
+ auto resultType = getCanonicalSubViewResultType(
subViewOp.getType().getRank(),
castOp.source().getType().cast<MemRefType>(),
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
};
} // namespace
+/// Return the canonical type of the result of a subview.
+struct SubViewReturnTypeCanonicalizer {
+ MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedSizes,
+ ArrayRef<OpFoldResult> mixedStrides) {
+ return getCanonicalSubViewResultType(op.getType().getRank(),
+ op.getSourceType(), mixedOffsets,
+ mixedSizes, mixedStrides);
+ }
+};
+
/// A canonicalizer wrapper to replace SubViewOps.
struct SubViewCanonicalizer {
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
- SubViewOp, SubViewCanonicalizer>,
- SubViewOpMemRefCastFolder>(context);
+ results
+ .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
+ SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
+ SubViewOpMemRefCastFolder>(context);
}
OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
// the subview op with load even if the offsets have been canonicalized
// away.
SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
+ if (opRanges.size() != indices.size()) {
+ // For the rank-reduced cases, we can only handle the folding when the
+ // offset is zero, size is 1 and stride is 1.
+ return failure();
+ }
auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
- assert(opRanges.size() == indices.size() &&
- "expected as many indices as rank of subview op result type");
// New indices for the load are the current indices * subview_stride +
// subview_offset.
return produceSubTensorErrorMsg(result, op, expectedType);
}
+/// Infer the canonical type of the result of a subtensor operation. Returns a
+/// type with rank `resultRank` that is either the rank of the rank-reduced
+/// type, or the non-rank-reduced type.
+static RankedTensorType getCanonicalSubTensorResultType(
+ unsigned resultRank, RankedTensorType sourceType,
+ ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
+ ArrayRef<OpFoldResult> mixedStrides) {
+ auto resultType =
+ SubTensorOp::inferRankReducedResultType(
+ resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
+ .cast<RankedTensorType>();
+ if (resultType.getRank() != resultRank) {
+ resultType = SubTensorOp::inferResultType(sourceType, mixedOffsets,
+ mixedSizes, mixedStrides)
+ .cast<RankedTensorType>();
+ }
+ return resultType;
+}
+
namespace {
/// Pattern to rewrite a subtensor op with tensor::Cast arguments.
/// This essentially pushes memref_cast past its consuming subtensor when
if (!canFoldIntoConsumerOp(castOp))
return failure();
- /// Deduce the resultType of SubTensorOp with `inferRankReducedResultType`
- /// on the cast source operand type and the SubTensorOp static information.
- /// This is the resulting type if the tensor::CastOp were folded and
- /// rank-reduced to the desired result rank.
- auto resultType = SubTensorOp::inferRankReducedResultType(
- subTensorOp.getType().getRank(),
- castOp.source().getType().cast<RankedTensorType>(),
+ /// Deduce the type of the result to use for the canonicalized operation.
+ RankedTensorType resultType = getCanonicalSubTensorResultType(
+ subTensorOp.getType().getRank(), subTensorOp.getSourceType(),
subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(),
subTensorOp.getMixedStrides());
Value newSubTensor = rewriter.create<SubTensorOp>(
};
} // namespace
+/// Return the canonical type of the result of a subtensor.
+struct SubTensorReturnTypeCanonicalizer {
+ RankedTensorType operator()(SubTensorOp op,
+ ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedSizes,
+ ArrayRef<OpFoldResult> mixedStrides) {
+ return getCanonicalSubTensorResultType(op.getType().getRank(),
+ op.getSourceType(), mixedOffsets,
+ mixedSizes, mixedStrides);
+ }
+};
+
/// A canonicalizer wrapper to replace SubTensorOps.
struct SubTensorCanonicalizer {
void operator()(PatternRewriter &rewriter, SubTensorOp op,
void SubTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
- SubTensorOp, SubTensorCanonicalizer>,
+ SubTensorOp, SubTensorReturnTypeCanonicalizer,
+ SubTensorCanonicalizer>,
SubTensorOpCastFolder>(context);
}
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
// Create the new op in canonical form.
- Value source = subTensorInsertOp.source();
- RankedTensorType sourceType = source.getType().cast<RankedTensorType>();
- SmallVector<int64_t, 4> shape = llvm::to_vector<4>(
- llvm::map_range(mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t {
- if (auto attr = valueOrAttr.dyn_cast<Attribute>())
- return attr.cast<IntegerAttr>().getInt();
- return ShapedType::kDynamicSize;
- }));
- RankedTensorType newSourceType =
- RankedTensorType::get(shape, sourceType.getElementType());
- Location loc = subTensorInsertOp.getLoc();
- if (sourceType != newSourceType)
- source = rewriter.create<tensor::CastOp>(loc, newSourceType, source);
rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
- subTensorInsertOp, source, subTensorInsertOp.dest(), mixedOffsets,
- mixedSizes, mixedStrides);
+ subTensorInsertOp, subTensorInsertOp.source(), subTensorInsertOp.dest(),
+ mixedOffsets, mixedSizes, mixedStrides);
return success();
}
};
SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
SmallVectorImpl<Type> &caseOperandTypes,
DenseIntElementsAttr &caseOperandOffsets) {
-
if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
failed(parser.parseSuccessor(defaultDestination)))
return failure();
/// ]
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
PatternRewriter &rewriter) {
-
SmallVector<Block *> newCaseDests;
SmallVector<ValueRange> newCaseOperands;
SmallVector<SmallVector<Value>> argStorage;
%1 = memref.buffer_cast %0 : memref<?xf32, offset: ?, strides: [1]>
return %1 : memref<?xf32, offset: ?, strides: [1]>
}
+
+// -----
+
+// CHECK-LABEL: func @subview_of_memcast
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
+// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+// CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
+// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}>
+func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
+ memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
+ %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+ %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
+ memref<?x?x16x32xi8> to
+ memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+ return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+}
+
+// -----
+
+// CHECK-LABEL: func @subview_of_static_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
+// CHECK-NOT: memref.subview
+// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
+func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
+ %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
+ return %0 : memref<4x6x16x32xi8>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> memref<?x?x?xf32, #map0>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, #map0>
+ return %0 : memref<?x?x?xf32, #map0>
+}
+// CHECK-LABEL: func @subview_canonicalize
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x?xf32>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+// CHECK-SAME: : memref<?x?x?xf32> to memref<4x1x?xf32
+// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
+// CHEKC: return %[[RESULT]]
+
+// -----
+
+#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> memref<?x?xf32, #map0>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, #map0>
+ return %0 : memref<?x?xf32, #map0>
+}
+// CHECK-LABEL: func @rank_reducing_subview_canonicalize
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x?xf32>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+// CHECK-SAME: : memref<?x?x?xf32> to memref<4x?xf32
+// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
+// CHEKC: return %[[RESULT]]
// -----
-// CHECK-LABEL: func @subview_of_memcast
-// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
-// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
-// CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
-// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}>
-func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
- memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
- %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
- %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
- memref<?x?x16x32xi8> to
- memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
- return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+func @subtensor_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> tensor<?x?x?xf32>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
}
+// CHECK-LABEL: func @subtensor_canonicalize
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+// CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x1x?xf32>
+// CHECK: %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]]
+// CHEKC: return %[[RESULT]]
// -----
-// CHECK-LABEL: func @subview_of_static_full_size
-// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
-// CHECK-NOT: memref.subview
-// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
-func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
- %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
- return %0 : memref<4x6x16x32xi8>
+func @rank_reducing_subtensor_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index) -> tensor<?x?xf32>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
}
+// CHECK-LABEL: func @rank_reducing_subtensor_canonicalize
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+// CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x?xf32>
+// CHECK: %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]]
+// CHEKC: return %[[RESULT]]
// -----
// -----
-func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
+func @subtensor_insert_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @subtensor_insert_canonicalize
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]]
+// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+// CHECK-SAME: : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+// CHEKC: return %[[RESULT]]
+
+// -----
+
+func @subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+ %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @subtensor_to_subtensor_insert_canonicalize
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]]
+// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1]
+// CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x1x?xf32>
+// CHECK: %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]]
+// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+// CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32>
+// CHEKC: return %[[RESULT]]
+
+// -----
+
+func @rank_reducing_subtensor_insert_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @rank_reducing_subtensor_insert_canonicalize
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]]
+// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+// CHECK-SAME: : tensor<?x?xf32> into tensor<?x?x?xf32>
+// CHEKC: return %[[RESULT]]
+
+// -----
+
+func @rank_reducing_subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
+ %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @rank_reducing_subtensor_to_subtensor_insert_canonicalize
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]]
+// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+// CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x?xf32>
+// CHECK: %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]] into %[[ARG3]]
+// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+// CHECK-SAME: : tensor<4x?xf32> into tensor<?x?x?xf32>
+// CHEKC: return %[[RESULT]]
+
+// -----
+
+func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
%arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
return %3 : tensor<?x?xi32>
}
-// CHECK-LABEL: func @subtensor_canonicalize
+// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast
// CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
// CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32>
// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]