From: MaheshRavishankar Date: Wed, 28 Apr 2021 18:01:22 +0000 (-0700) Subject: [mlir][Linalg] Avoid changing the rank of the result in canonicalizations of subtensor. X-Git-Tag: llvmorg-14-init~8166 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=41849a91956755e15591240816d1d0c5ec402895;p=platform%2Fupstream%2Fllvm.git [mlir][Linalg] Avoid changing the rank of the result in canonicalizations of subtensor. Canonicalizations for subtensor operations defaulted to use the rank-reduced version of the operation, but the cast inserted to get back the original type would be illegal if the rank was actually reduced. Instead make the canonicalization not reduce the rank of the operation. Differential Revision: https://reviews.llvm.org/D101258 --- diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h index 79d8f55..c9a2f5a 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h @@ -35,7 +35,7 @@ void getPositionsOfShapeOne(unsigned rank, ArrayRef shape, llvm::SmallDenseSet &dimsToProject); /// Pattern to rewrite a subview op with constant arguments. -template +template class OpWithOffsetSizesAndStridesConstantArgumentFolder final : public OpRewritePattern { public: @@ -59,8 +59,12 @@ public: canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. - auto newOp = rewriter.create(op.getLoc(), op.source(), mixedOffsets, - mixedSizes, mixedStrides); + ResultTypeFunc resultTypeFunc; + auto resultType = + resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides); + auto newOp = + rewriter.create(op.getLoc(), resultType, op.source(), + mixedOffsets, mixedSizes, mixedStrides); CastOpFunc func; func(rewriter, op, newOp); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1ac0002..57c1b15 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1859,6 +1859,26 @@ SmallVector mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, return res; } +/// Infer the canonical type of the result of a subview operation. Returns a +/// type with rank `resultRank` that is either the rank of the rank-reduced +/// type, or the non-rank-reduced type. +static MemRefType +getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, + ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides) { + auto resultType = + SubViewOp::inferRankReducedResultType( + resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) + .cast(); + if (resultType.getRank() != resultRank) { + resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, + mixedSizes, mixedStrides) + .cast(); + } + return resultType; +} + namespace { /// Pattern to rewrite a subview op with MemRefCast arguments. /// This essentially pushes memref.cast past its consuming subview when @@ -1898,7 +1918,7 @@ public: /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on /// the cast source operand type and the SubViewOp static information. This /// is the resulting type if the MemRefCastOp were folded. - auto resultType = SubViewOp::inferRankReducedResultType( + auto resultType = getCanonicalSubViewResultType( subViewOp.getType().getRank(), castOp.source().getType().cast(), subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), @@ -1914,6 +1934,17 @@ public: }; } // namespace +/// Return the canonical type of the result of a subview. +struct SubViewReturnTypeCanonicalizer { + MemRefType operator()(SubViewOp op, ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides) { + return getCanonicalSubViewResultType(op.getType().getRank(), + op.getSourceType(), mixedOffsets, + mixedSizes, mixedStrides); + } +}; + /// A canonicalizer wrapper to replace SubViewOps. struct SubViewCanonicalizer { void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { @@ -1923,9 +1954,10 @@ struct SubViewCanonicalizer { void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - SubViewOpMemRefCastFolder>(context); + results + .add, + SubViewOpMemRefCastFolder>(context); } OpFoldResult SubViewOp::fold(ArrayRef operands) { diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp index bd1fef0..ae76966 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp @@ -45,10 +45,13 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter, // the subview op with load even if the offsets have been canonicalized // away. SmallVector opRanges = subViewOp.getOrCreateRanges(rewriter, loc); + if (opRanges.size() != indices.size()) { + // For the rank-reduced cases, we can only handle the folding when the + // offset is zero, size is 1 and stride is 1. + return failure(); + } auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; }); auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; }); - assert(opRanges.size() == indices.size() && - "expected as many indices as rank of subview op result type"); // New indices for the load are the current indices * subview_stride + // subview_offset. diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index b538cba..9260e27 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1917,6 +1917,25 @@ static LogicalResult verify(SubTensorOp op) { return produceSubTensorErrorMsg(result, op, expectedType); } +/// Infer the canonical type of the result of a subtensor operation. Returns a +/// type with rank `resultRank` that is either the rank of the rank-reduced +/// type, or the non-rank-reduced type. +static RankedTensorType getCanonicalSubTensorResultType( + unsigned resultRank, RankedTensorType sourceType, + ArrayRef mixedOffsets, ArrayRef mixedSizes, + ArrayRef mixedStrides) { + auto resultType = + SubTensorOp::inferRankReducedResultType( + resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) + .cast(); + if (resultType.getRank() != resultRank) { + resultType = SubTensorOp::inferResultType(sourceType, mixedOffsets, + mixedSizes, mixedStrides) + .cast(); + } + return resultType; +} + namespace { /// Pattern to rewrite a subtensor op with tensor::Cast arguments. /// This essentially pushes memref_cast past its consuming subtensor when @@ -1951,13 +1970,9 @@ public: if (!canFoldIntoConsumerOp(castOp)) return failure(); - /// Deduce the resultType of SubTensorOp with `inferRankReducedResultType` - /// on the cast source operand type and the SubTensorOp static information. - /// This is the resulting type if the tensor::CastOp were folded and - /// rank-reduced to the desired result rank. - auto resultType = SubTensorOp::inferRankReducedResultType( - subTensorOp.getType().getRank(), - castOp.source().getType().cast(), + /// Deduce the type of the result to use for the canonicalized operation. + RankedTensorType resultType = getCanonicalSubTensorResultType( + subTensorOp.getType().getRank(), subTensorOp.getSourceType(), subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides()); Value newSubTensor = rewriter.create( @@ -1972,6 +1987,18 @@ public: }; } // namespace +/// Return the canonical type of the result of a subtensor. +struct SubTensorReturnTypeCanonicalizer { + RankedTensorType operator()(SubTensorOp op, + ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides) { + return getCanonicalSubTensorResultType(op.getType().getRank(), + op.getSourceType(), mixedOffsets, + mixedSizes, mixedStrides); + } +}; + /// A canonicalizer wrapper to replace SubTensorOps. struct SubTensorCanonicalizer { void operator()(PatternRewriter &rewriter, SubTensorOp op, @@ -1987,7 +2014,8 @@ struct SubTensorCanonicalizer { void SubTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, + SubTensorOp, SubTensorReturnTypeCanonicalizer, + SubTensorCanonicalizer>, SubTensorOpCastFolder>(context); } @@ -2093,22 +2121,9 @@ public: canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. - Value source = subTensorInsertOp.source(); - RankedTensorType sourceType = source.getType().cast(); - SmallVector shape = llvm::to_vector<4>( - llvm::map_range(mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t { - if (auto attr = valueOrAttr.dyn_cast()) - return attr.cast().getInt(); - return ShapedType::kDynamicSize; - })); - RankedTensorType newSourceType = - RankedTensorType::get(shape, sourceType.getElementType()); - Location loc = subTensorInsertOp.getLoc(); - if (sourceType != newSourceType) - source = rewriter.create(loc, newSourceType, source); rewriter.replaceOpWithNewOp( - subTensorInsertOp, source, subTensorInsertOp.dest(), mixedOffsets, - mixedSizes, mixedStrides); + subTensorInsertOp, subTensorInsertOp.source(), subTensorInsertOp.dest(), + mixedOffsets, mixedSizes, mixedStrides); return success(); } }; @@ -2213,7 +2228,6 @@ parseSwitchOpCases(OpAsmParser &parser, Type &flagType, SmallVectorImpl &caseOperands, SmallVectorImpl &caseOperandTypes, DenseIntElementsAttr &caseOperandOffsets) { - if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) || failed(parser.parseSuccessor(defaultDestination))) return failure(); @@ -2457,7 +2471,6 @@ static LogicalResult simplifyConstSwitchValue(SwitchOp op, /// ] static LogicalResult simplifyPassThroughSwitch(SwitchOp op, PatternRewriter &rewriter) { - SmallVector newCaseDests; SmallVector newCaseOperands; SmallVector> argStorage; diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 7ff5c6f..0b0308f 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -62,3 +62,70 @@ func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref return %1 : memref } + +// ----- + +// CHECK-LABEL: func @subview_of_memcast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> +// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> +// CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> +// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> +func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ + %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref + %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : + memref to + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> + return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> +} + +// ----- + +// CHECK-LABEL: func @subview_of_static_full_size +// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> +// CHECK-NOT: memref.subview +// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8> +func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> { + %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8> + return %0 : memref<4x6x16x32xi8> +} + +// ----- + +#map0 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +func @subview_canonicalize(%arg0 : memref, %arg1 : index, + %arg2 : index) -> memref +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : memref to memref + return %0 : memref +} +// CHECK-LABEL: func @subview_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: memref +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : memref to memref<4x1x?xf32 +// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] +// CHEKC: return %[[RESULT]] + +// ----- + +#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +func @rank_reducing_subview_canonicalize(%arg0 : memref, %arg1 : index, + %arg2 : index) -> memref +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref to memref + return %0 : memref +} +// CHECK-LABEL: func @rank_reducing_subview_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: memref +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : memref to memref<4x?xf32 +// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] +// CHEKC: return %[[RESULT]] diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index 908814b..e2b5e7b 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -154,30 +154,41 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) -> // ----- -// CHECK-LABEL: func @subview_of_memcast -// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> -// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> -// CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> -// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> -func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> - memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ - %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref - %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : - memref to - memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> - return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> +func @subtensor_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor to tensor + return %0 : tensor } +// CHECK-LABEL: func @subtensor_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x1x?xf32> +// CHECK: %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]] +// CHEKC: return %[[RESULT]] // ----- -// CHECK-LABEL: func @subview_of_static_full_size -// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> -// CHECK-NOT: memref.subview -// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8> -func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> { - %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8> - return %0 : memref<4x6x16x32xi8> +func @rank_reducing_subtensor_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor to tensor + return %0 : tensor } +// CHECK-LABEL: func @rank_reducing_subtensor_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] +// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x?xf32> +// CHECK: %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]] +// CHEKC: return %[[RESULT]] // ----- @@ -232,7 +243,89 @@ func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor< // ----- -func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor, +func @subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %0 : tensor +} +// CHECK-LABEL: func @subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor to tensor + %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @subtensor_to_subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x1x?xf32> +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor<4x1x?xf32> into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @rank_reducing_subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %0 : tensor +} +// CHECK-LABEL: func @rank_reducing_subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @rank_reducing_subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor, %arg1 : index, + %arg2 : index, %arg3 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor to tensor + %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func @rank_reducing_subtensor_to_subtensor_insert_canonicalize +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[ARG0]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor to tensor<4x?xf32> +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]] into %[[ARG3]] +// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : tensor<4x?xf32> into tensor +// CHEKC: return %[[RESULT]] + +// ----- + +func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor, %arg2 : index, %arg3 : index) -> tensor { %c0 = constant 0 : index %c1 = constant 1 : index @@ -247,7 +340,7 @@ func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor, %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor return %3 : tensor } -// CHECK-LABEL: func @subtensor_canonicalize +// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast // CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1] // CHECK-SAME: tensor<2x?xi32> into tensor // CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]