}
OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
- return impl::foldCastOp(*this);
+ if (Value folded = impl::foldCastOp(*this))
+ return folded;
+ return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
}
//===----------------------------------------------------------------------===//
return %1, %2 : f32, f32
}
+// CHECK-LABEL: @fold_memref_cast_in_memref_cast
+// CHECK-SAME: (%[[ARG0:.*]]: memref<42x42xf64>)
+func @fold_memref_cast_in_memref_cast(%0: memref<42x42xf64>) {
+ // CHECK: %[[folded:.*]] = memref_cast %[[ARG0]] : memref<42x42xf64> to memref<?x?xf64>
+ %4 = memref_cast %0 : memref<42x42xf64> to memref<?x42xf64>
+ // CHECK-NOT: memref_cast
+ %5 = memref_cast %4 : memref<?x42xf64> to memref<?x?xf64>
+ // CHECK: "test.user"(%[[folded]])
+ "test.user"(%5) : (memref<?x?xf64>) -> ()
+ return
+}
+
+// CHECK-LABEL: @fold_memref_cast_chain
+// CHECK-SAME: (%[[ARG0:.*]]: memref<42x42xf64>)
+func @fold_memref_cast_chain(%0: memref<42x42xf64>) {
+ // CHECK-NOT: memref_cast
+ %4 = memref_cast %0 : memref<42x42xf64> to memref<?x42xf64>
+ %5 = memref_cast %4 : memref<?x42xf64> to memref<42x42xf64>
+ // CHECK: "test.user"(%[[ARG0]])
+ "test.user"(%5) : (memref<42x42xf64>) -> ()
+ return
+}
+
// CHECK-LABEL: func @alloc_const_fold
func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: %0 = alloc() : memref<4xf32>