[mlir][Linalg] Avoid changing the rank of the result in canonicalizations of subtensor.
authorMaheshRavishankar <ravishankarm@google.com>
Wed, 28 Apr 2021 18:01:22 +0000 (11:01 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Wed, 28 Apr 2021 18:33:26 +0000 (11:33 -0700)
Canonicalizations for subtensor operations defaulted to use the
rank-reduced version of the operation, but the cast inserted to get
back the original type would be illegal if the rank was actually
reduced. Instead make the canonicalization not reduce the rank of the
operation.

Differential Revision: https://reviews.llvm.org/D101258

mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/Standard/canonicalize.mlir

index 79d8f55..c9a2f5a 100644 (file)
@@ -35,7 +35,7 @@ void getPositionsOfShapeOne(unsigned rank, ArrayRef<int64_t> shape,
                             llvm::SmallDenseSet<unsigned> &dimsToProject);
 
 /// Pattern to rewrite a subview op with constant arguments.
-template <typename OpType, typename CastOpFunc>
+template <typename OpType, typename ResultTypeFunc, typename CastOpFunc>
 class OpWithOffsetSizesAndStridesConstantArgumentFolder final
     : public OpRewritePattern<OpType> {
 public:
@@ -59,8 +59,12 @@ public:
     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
 
     // Create the new op in canonical form.
-    auto newOp = rewriter.create<OpType>(op.getLoc(), op.source(), mixedOffsets,
-                                         mixedSizes, mixedStrides);
+    ResultTypeFunc resultTypeFunc;
+    auto resultType =
+        resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
+    auto newOp =
+        rewriter.create<OpType>(op.getLoc(), resultType, op.source(),
+                                mixedOffsets, mixedSizes, mixedStrides);
     CastOpFunc func;
     func(rewriter, op, newOp);
 
index 1ac0002..57c1b15 100644 (file)
@@ -1859,6 +1859,26 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
   return res;
 }
 
+/// Infer the canonical type of the result of a subview operation. Returns a
+/// type with rank `resultRank` that is either the rank of the rank-reduced
+/// type, or the non-rank-reduced type.
+static MemRefType
+getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
+                              ArrayRef<OpFoldResult> mixedOffsets,
+                              ArrayRef<OpFoldResult> mixedSizes,
+                              ArrayRef<OpFoldResult> mixedStrides) {
+  auto resultType =
+      SubViewOp::inferRankReducedResultType(
+          resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
+          .cast<MemRefType>();
+  if (resultType.getRank() != resultRank) {
+    resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
+                                            mixedSizes, mixedStrides)
+                     .cast<MemRefType>();
+  }
+  return resultType;
+}
+
 namespace {
 /// Pattern to rewrite a subview op with MemRefCast arguments.
 /// This essentially pushes memref.cast past its consuming subview when
@@ -1898,7 +1918,7 @@ public:
     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
     /// the cast source operand type and the SubViewOp static information. This
     /// is the resulting type if the MemRefCastOp were folded.
-    auto resultType = SubViewOp::inferRankReducedResultType(
+    auto resultType = getCanonicalSubViewResultType(
         subViewOp.getType().getRank(),
         castOp.source().getType().cast<MemRefType>(),
         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
@@ -1914,6 +1934,17 @@ public:
 };
 } // namespace
 
+/// Return the canonical type of the result of a subview.
+struct SubViewReturnTypeCanonicalizer {
+  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
+                        ArrayRef<OpFoldResult> mixedSizes,
+                        ArrayRef<OpFoldResult> mixedStrides) {
+    return getCanonicalSubViewResultType(op.getType().getRank(),
+                                         op.getSourceType(), mixedOffsets,
+                                         mixedSizes, mixedStrides);
+  }
+};
+
 /// A canonicalizer wrapper to replace SubViewOps.
 struct SubViewCanonicalizer {
   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
@@ -1923,9 +1954,10 @@ struct SubViewCanonicalizer {
 
 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
-                  SubViewOp, SubViewCanonicalizer>,
-              SubViewOpMemRefCastFolder>(context);
+  results
+      .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
+               SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
+           SubViewOpMemRefCastFolder>(context);
 }
 
 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
index bd1fef0..ae76966 100644 (file)
@@ -45,10 +45,13 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
   // the subview op with load even if the offsets have been canonicalized
   // away.
   SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
+  if (opRanges.size() != indices.size()) {
+    // For the rank-reduced cases, we can only handle the folding when the
+    // offset is zero, size is 1 and stride is 1.
+    return failure();
+  }
   auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
   auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
-  assert(opRanges.size() == indices.size() &&
-         "expected as many indices as rank of subview op result type");
 
   // New indices for the load are the current indices * subview_stride +
   // subview_offset.
index b538cba..9260e27 100644 (file)
@@ -1917,6 +1917,25 @@ static LogicalResult verify(SubTensorOp op) {
   return produceSubTensorErrorMsg(result, op, expectedType);
 }
 
+/// Infer the canonical type of the result of a subtensor operation. Returns a
+/// type with rank `resultRank` that is either the rank of the rank-reduced
+/// type, or the non-rank-reduced type.
+static RankedTensorType getCanonicalSubTensorResultType(
+    unsigned resultRank, RankedTensorType sourceType,
+    ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
+    ArrayRef<OpFoldResult> mixedStrides) {
+  auto resultType =
+      SubTensorOp::inferRankReducedResultType(
+          resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
+          .cast<RankedTensorType>();
+  if (resultType.getRank() != resultRank) {
+    resultType = SubTensorOp::inferResultType(sourceType, mixedOffsets,
+                                              mixedSizes, mixedStrides)
+                     .cast<RankedTensorType>();
+  }
+  return resultType;
+}
+
 namespace {
 /// Pattern to rewrite a subtensor op with tensor::Cast arguments.
 /// This essentially pushes memref_cast past its consuming subtensor when
@@ -1951,13 +1970,9 @@ public:
     if (!canFoldIntoConsumerOp(castOp))
       return failure();
 
-    /// Deduce the resultType of SubTensorOp with `inferRankReducedResultType`
-    /// on the cast source operand type and the SubTensorOp static information.
-    /// This is the resulting type if the tensor::CastOp were folded and
-    /// rank-reduced to the desired result rank.
-    auto resultType = SubTensorOp::inferRankReducedResultType(
-        subTensorOp.getType().getRank(),
-        castOp.source().getType().cast<RankedTensorType>(),
+    /// Deduce the type of the result to use for the canonicalized operation.
+    RankedTensorType resultType = getCanonicalSubTensorResultType(
+        subTensorOp.getType().getRank(), subTensorOp.getSourceType(),
         subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(),
         subTensorOp.getMixedStrides());
     Value newSubTensor = rewriter.create<SubTensorOp>(
@@ -1972,6 +1987,18 @@ public:
 };
 } // namespace
 
+/// Return the canonical type of the result of a subtensor.
+struct SubTensorReturnTypeCanonicalizer {
+  RankedTensorType operator()(SubTensorOp op,
+                              ArrayRef<OpFoldResult> mixedOffsets,
+                              ArrayRef<OpFoldResult> mixedSizes,
+                              ArrayRef<OpFoldResult> mixedStrides) {
+    return getCanonicalSubTensorResultType(op.getType().getRank(),
+                                           op.getSourceType(), mixedOffsets,
+                                           mixedSizes, mixedStrides);
+  }
+};
+
 /// A canonicalizer wrapper to replace SubTensorOps.
 struct SubTensorCanonicalizer {
   void operator()(PatternRewriter &rewriter, SubTensorOp op,
@@ -1987,7 +2014,8 @@ struct SubTensorCanonicalizer {
 void SubTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
-                  SubTensorOp, SubTensorCanonicalizer>,
+                  SubTensorOp, SubTensorReturnTypeCanonicalizer,
+                  SubTensorCanonicalizer>,
               SubTensorOpCastFolder>(context);
 }
 
@@ -2093,22 +2121,9 @@ public:
     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
 
     // Create the new op in canonical form.
-    Value source = subTensorInsertOp.source();
-    RankedTensorType sourceType = source.getType().cast<RankedTensorType>();
-    SmallVector<int64_t, 4> shape = llvm::to_vector<4>(
-        llvm::map_range(mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t {
-          if (auto attr = valueOrAttr.dyn_cast<Attribute>())
-            return attr.cast<IntegerAttr>().getInt();
-          return ShapedType::kDynamicSize;
-        }));
-    RankedTensorType newSourceType =
-        RankedTensorType::get(shape, sourceType.getElementType());
-    Location loc = subTensorInsertOp.getLoc();
-    if (sourceType != newSourceType)
-      source = rewriter.create<tensor::CastOp>(loc, newSourceType, source);
     rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
-        subTensorInsertOp, source, subTensorInsertOp.dest(), mixedOffsets,
-        mixedSizes, mixedStrides);
+        subTensorInsertOp, subTensorInsertOp.source(), subTensorInsertOp.dest(),
+        mixedOffsets, mixedSizes, mixedStrides);
     return success();
   }
 };
@@ -2213,7 +2228,6 @@ parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
                    SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
                    SmallVectorImpl<Type> &caseOperandTypes,
                    DenseIntElementsAttr &caseOperandOffsets) {
-
   if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
       failed(parser.parseSuccessor(defaultDestination)))
     return failure();
@@ -2457,7 +2471,6 @@ static LogicalResult simplifyConstSwitchValue(SwitchOp op,
 /// ]
 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
                                                PatternRewriter &rewriter) {
-
   SmallVector<Block *> newCaseDests;
   SmallVector<ValueRange> newCaseOperands;
   SmallVector<SmallVector<Value>> argStorage;
index 7ff5c6f..0b0308f 100644 (file)
@@ -62,3 +62,70 @@ func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref<?xf32, offset: 3, st
   %1 = memref.buffer_cast %0 : memref<?xf32, offset: ?, strides: [1]>
   return %1 : memref<?xf32, offset: ?, strides: [1]>
 }
+
+// -----
+
+// CHECK-LABEL: func @subview_of_memcast
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
+//       CHECK:   %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   return %[[M]] : memref<16x32xi8, #{{.*}}>
+func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
+  memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
+  %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+  %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
+    memref<?x?x16x32xi8> to
+    memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+  return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+}
+
+// -----
+
+// CHECK-LABEL: func @subview_of_static_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
+// CHECK-NOT: memref.subview
+// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
+func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
+  %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
+  return %0 : memref<4x6x16x32xi8>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> memref<?x?x?xf32, #map0>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, #map0>
+  return %0 : memref<?x?x?xf32, #map0>
+}
+// CHECK-LABEL: func @subview_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+//  CHECK-SAME:      : memref<?x?x?xf32> to memref<4x1x?xf32
+//       CHECK:   %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
+//       CHEKC:   return %[[RESULT]]
+
+// -----
+
+#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> memref<?x?xf32, #map0>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, #map0>
+  return %0 : memref<?x?xf32, #map0>
+}
+// CHECK-LABEL: func @rank_reducing_subview_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+//  CHECK-SAME:      : memref<?x?x?xf32> to memref<4x?xf32
+//       CHECK:   %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
+//       CHEKC:   return %[[RESULT]]
index 908814b..e2b5e7b 100644 (file)
@@ -154,30 +154,41 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
 
 // -----
 
-// CHECK-LABEL: func @subview_of_memcast
-//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
-//       CHECK:   %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
-//       CHECK:   %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
-//       CHECK:   return %[[M]] : memref<16x32xi8, #{{.*}}>
-func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
-  memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
-  %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
-  %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
-    memref<?x?x16x32xi8> to
-    memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
-  return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+func @subtensor_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> tensor<?x?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
 }
+// CHECK-LABEL: func @subtensor_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x1x?xf32>
+//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]]
+//       CHEKC:   return %[[RESULT]]
 
 // -----
 
-// CHECK-LABEL: func @subview_of_static_full_size
-// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
-// CHECK-NOT: memref.subview
-// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
-func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
-  %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
-  return %0 : memref<4x6x16x32xi8>
+func @rank_reducing_subtensor_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> tensor<?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
 }
+// CHECK-LABEL: func @rank_reducing_subtensor_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x?xf32>
+//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]]
+//       CHEKC:   return %[[RESULT]]
 
 // -----
 
@@ -232,7 +243,89 @@ func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor<
 
 // -----
 
-func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
+func @subtensor_insert_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @subtensor_insert_canonicalize
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[ARG0]]
+//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+//       CHEKC:   return %[[RESULT]]
+
+// -----
+
+func @subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+  %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @subtensor_to_subtensor_insert_canonicalize
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[ARG0]]
+//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x1x?xf32>
+//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]]
+//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<4x1x?xf32> into tensor<?x?x?xf32>
+//       CHEKC:   return %[[RESULT]]
+
+// -----
+
+func @rank_reducing_subtensor_insert_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : index,
+    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @rank_reducing_subtensor_insert_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[ARG0]]
+//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?xf32> into tensor<?x?x?xf32>
+//       CHEKC:   return %[[RESULT]]
+
+// -----
+
+func @rank_reducing_subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
+  %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @rank_reducing_subtensor_to_subtensor_insert_canonicalize
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[ARG0]]
+//  CHECK-SAME:     [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:     : tensor<?x?x?xf32> to tensor<4x?xf32>
+//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]] into %[[ARG3]]
+//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<4x?xf32> into tensor<?x?x?xf32>
+//       CHEKC:   return %[[RESULT]]
+
+// -----
+
+func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
     %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -247,7 +340,7 @@ func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
   %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
   return %3 : tensor<?x?xi32>
 }
-// CHECK-LABEL: func @subtensor_canonicalize
+// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast
 //       CHECK:   %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
 //  CHECK-SAME:     tensor<2x?xi32> into tensor<?x8xi32>
 //       CHECK:   %[[CAST:.+]] = tensor.cast %[[UPDATED]]