[MLIR][MemRef] Only allow fold of cast for the pointer operand, not the value
authorWilliam S. Moses <gh@wsmoses.com>
Mon, 7 Jun 2021 17:44:07 +0000 (13:44 -0400)
committerWilliam S. Moses <gh@wsmoses.com>
Tue, 8 Jun 2021 15:43:09 +0000 (11:43 -0400)
Currently canonicalizations of a store and a cast try to fold all casts into the store.

In the case where the operand being stored is itself a cast, this is illegal as the type of the value being stored
will change. This PR fixes this by not checking the value for folding with a cast.

Depends on https://reviews.llvm.org/D103828

Differential Revision: https://reviews.llvm.org/D103829

mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/Affine/canonicalize.mlir
mlir/test/Dialect/MemRef/canonicalize.mlir

index ef990b70f3575900ae43d922d4ce33185ccd62d5..480b53811483fc8d0d2ba0451b0fe53087a161ed 100644 (file)
@@ -942,11 +942,12 @@ void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// This is a common class used for patterns of the form
 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
 /// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
+static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
   bool folded = false;
   for (OpOperand &operand : op->getOpOperands()) {
     auto cast = operand.get().getDefiningOp<memref::CastOp>();
-    if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
+    if (cast && operand.get() != ignore &&
+        !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
       operand.set(cast.getOperand());
       folded = true;
     }
@@ -2270,7 +2271,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
                                   SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
-  return foldMemRefCast(*this);
+  return foldMemRefCast(*this, getValueToStore());
 }
 
 //===----------------------------------------------------------------------===//
index a4ab6c1d0859f10c9830d69e2d885a4f6d0c4d69..f20234bd1d686c2832ed0e61f401bcc8f39c1672 100644 (file)
@@ -73,11 +73,12 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
 /// This is a common class used for patterns of the form
 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
 /// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
+static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
   bool folded = false;
   for (OpOperand &operand : op->getOpOperands()) {
     auto cast = operand.get().getDefiningOp<CastOp>();
-    if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
+    if (cast && operand.get() != inner &&
+        !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
       operand.set(cast.getOperand());
       folded = true;
     }
@@ -1425,7 +1426,7 @@ static LogicalResult verify(StoreOp op) {
 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
                             SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
-  return foldMemRefCast(*this);
+  return foldMemRefCast(*this, getValueToStore());
 }
 
 //===----------------------------------------------------------------------===//
index 0a47285e18c49e47e4b26e7da4be7dd41fd0d68a..3d6bd57c27ffc97ab89ac87a5da832f49ee136de 100644 (file)
@@ -924,3 +924,15 @@ func @compose_into_affine_vector_load_vector_store(%A : memref<1024xf32>, %u : i
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @no_fold_of_store
+//  CHECK:   %[[cst:.+]] = memref.cast %arg
+//  CHECK:   affine.store %[[cst]]
+func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
+  %0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
+  affine.store %0, %holder[] : memref<memref<?xi8>>
+  return
+}
+
index 354be2237ec3020aba45a877b62746a8d696a14b..140cd43ede147648994434d49e4fe64dfb35d471 100644 (file)
@@ -206,4 +206,14 @@ func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
   return %1 : index
 }
 
+// -----
+
+// CHECK-LABEL: func @no_fold_of_store
+//  CHECK:   %[[cst:.+]] = memref.cast %arg
+//  CHECK:   memref.store %[[cst]]
+func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
+  %0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
+  memref.store %0, %holder[] : memref<memref<?xi8>>
+  return
+}