[mlir][memref] Fold no-op subview(subview(x)) ops
authorMatthias Springer <springerm@google.com>
Wed, 14 Dec 2022 11:26:41 +0000 (12:26 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 14 Dec 2022 11:47:00 +0000 (12:47 +0100)
Differential Revision: https://reviews.llvm.org/D140008

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir

index 4b60eb9..de9690f 100644 (file)
@@ -3065,6 +3065,24 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
     return getViewSource();
   }
 
+  // Fold subview(subview(x)), where both subviews have the same size and the
+  // second subview's offsets are all zero. (I.e., the second subview is a
+  // no-op.)
+  if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
+    auto srcSizes = srcSubview.getMixedSizes();
+    auto sizes = getMixedSizes();
+    auto offsets = getMixedOffsets();
+    bool allOffsetsZero = llvm::all_of(
+        offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+    auto strides = getMixedStrides();
+    bool allStridesOne = llvm::all_of(
+        strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+    bool allSizesSame = llvm::equal(sizes, srcSizes);
+    if (allOffsetsZero && allStridesOne && allSizesSame &&
+        resultShapedType == sourceShapedType)
+      return getViewSource();
+  }
+
   return {};
 }
 
index d9710b7..88d9155 100644 (file)
@@ -875,3 +875,22 @@ func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>)
       : memref<1x?xf32, 3> into memref<?xf32, 3>
   return %1 : memref<?xf32, 3>
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_trivial_subviews(
+//  CHECK-SAME:     %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
+//       CHECK:   %[[subview:.*]] = memref.subview %[[m]][5]
+//       CHECK:   return %[[subview]]
+func.func @fold_trivial_subviews(%m: memref<?xf32, strided<[?], offset: ?>>,
+                                 %sz: index)
+    -> memref<?xf32, strided<[?], offset: ?>>
+{
+  %0 = memref.subview %m[5] [%sz] [1]
+      : memref<?xf32, strided<[?], offset: ?>>
+        to memref<?xf32, strided<[?], offset: ?>>
+  %1 = memref.subview %0[0] [%sz] [1]
+      : memref<?xf32, strided<[?], offset: ?>>
+        to memref<?xf32, strided<[?], offset: ?>>
+  return %1 : memref<?xf32, strided<[?], offset: ?>>
+}