return getViewSource();
}
+ // Fold subview(subview(x)), where both subviews have the same size and the
+ // second subview's offsets are all zero. (I.e., the second subview is a
+ // no-op.)
+ if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
+ auto srcSizes = srcSubview.getMixedSizes();
+ auto sizes = getMixedSizes();
+ auto offsets = getMixedOffsets();
+ bool allOffsetsZero = llvm::all_of(
+ offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+ auto strides = getMixedStrides();
+ bool allStridesOne = llvm::all_of(
+ strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+ bool allSizesSame = llvm::equal(sizes, srcSizes);
+ if (allOffsetsZero && allStridesOne && allSizesSame &&
+ resultShapedType == sourceShapedType)
+ return getViewSource();
+ }
+
return {};
}
: memref<1x?xf32, 3> into memref<?xf32, 3>
return %1 : memref<?xf32, 3>
}
+
+// -----
+
+// CHECK-LABEL: func @fold_trivial_subviews(
+// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[?], offset: ?>>
+// CHECK: %[[subview:.*]] = memref.subview %[[m]][5]
+// CHECK: return %[[subview]]
+func.func @fold_trivial_subviews(%m: memref<?xf32, strided<[?], offset: ?>>,
+ %sz: index)
+ -> memref<?xf32, strided<[?], offset: ?>>
+{
+ %0 = memref.subview %m[5] [%sz] [1]
+ : memref<?xf32, strided<[?], offset: ?>>
+ to memref<?xf32, strided<[?], offset: ?>>
+ %1 = memref.subview %0[0] [%sz] [1]
+ : memref<?xf32, strided<[?], offset: ?>>
+ to memref<?xf32, strided<[?], offset: ?>>
+ return %1 : memref<?xf32, strided<[?], offset: ?>>
+}