[mlir] Add folding of memref_cast inside another memref_cast
authorAlex Zinenko <zinenko@google.com>
Fri, 6 Nov 2020 09:20:08 +0000 (10:20 +0100)
committerAlex Zinenko <zinenko@google.com>
Fri, 6 Nov 2020 09:42:40 +0000 (10:42 +0100)
There exists a generic folding facility that folds the operand of a memref_cast
into users of memref_cast that support this. However, it was not used for the
memref_cast itself. Fix it to enable elimination of memref_cast chains such as

  %1 = memref_cast %0 : A to B
  %2 = memref_cast %1 : B to A

that is achieved by combining the folding with the existing "A to A" cast
elimination.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D90910

mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Transforms/canonicalize.mlir

index 9b5875e..d333ddc 100644 (file)
@@ -2386,7 +2386,9 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
 }
 
 OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
-  return impl::foldCastOp(*this);
+  if (Value folded = impl::foldCastOp(*this))
+    return folded;
+  return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
 //===----------------------------------------------------------------------===//
index 7b8c45c..08f3ac7 100644 (file)
@@ -334,6 +334,29 @@ func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> (f32, f32) {
   return %1, %2 : f32, f32
 }
 
+// CHECK-LABEL: @fold_memref_cast_in_memref_cast
+// CHECK-SAME: (%[[ARG0:.*]]: memref<42x42xf64>)
+func @fold_memref_cast_in_memref_cast(%0: memref<42x42xf64>) {
+  // CHECK: %[[folded:.*]] = memref_cast %[[ARG0]] : memref<42x42xf64> to memref<?x?xf64>
+  %4 = memref_cast %0 : memref<42x42xf64> to memref<?x42xf64>
+  // CHECK-NOT: memref_cast
+  %5 = memref_cast %4 : memref<?x42xf64> to memref<?x?xf64>
+  // CHECK: "test.user"(%[[folded]])
+  "test.user"(%5) : (memref<?x?xf64>) -> ()
+  return
+}
+
+// CHECK-LABEL: @fold_memref_cast_chain
+// CHECK-SAME: (%[[ARG0:.*]]: memref<42x42xf64>)
+func @fold_memref_cast_chain(%0: memref<42x42xf64>) {
+  // CHECK-NOT: memref_cast
+  %4 = memref_cast %0 : memref<42x42xf64> to memref<?x42xf64>
+  %5 = memref_cast %4 : memref<?x42xf64> to memref<42x42xf64>
+  // CHECK: "test.user"(%[[ARG0]])
+  "test.user"(%5) : (memref<42x42xf64>) -> ()
+  return
+}
+
 // CHECK-LABEL: func @alloc_const_fold
 func @alloc_const_fold() -> memref<?xf32> {
   // CHECK-NEXT: %0 = alloc() : memref<4xf32>