[mlir][Vector] Add folding of memref_cast into vector_transfer ops
authorNicolas Vasilache <ntv@google.com>
Fri, 5 Jun 2020 17:20:59 +0000 (13:20 -0400)
committerNicolas Vasilache <ntv@google.com>
Fri, 5 Jun 2020 17:27:00 +0000 (13:27 -0400)
Summary:
This revision adds a common folding pattern that starts appearing on
vector_transfer ops.

Differential Revision: https://reviews.llvm.org/D81281

mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index 365795f..9ae1c74 100644 (file)
@@ -1061,6 +1061,8 @@ def Vector_TransferReadOp :
         return impl::getTransferMinorIdentityMap(memRefType, vectorType);
       }
   }];
+
+  let hasFolder = 1;
 }
 
 def Vector_TransferWriteOp :
@@ -1150,6 +1152,8 @@ def Vector_TransferWriteOp :
         return impl::getTransferMinorIdentityMap(memRefType, vectorType);
       }
   }];
+
+  let hasFolder = 1;
 }
 
 def Vector_ShapeCastOp :
index 21b62ce..019f5fd 100644 (file)
@@ -1498,6 +1498,30 @@ static LogicalResult verify(TransferReadOp op) {
                               [&op](Twine t) { return op.emitOpError(t); });
 }
 
+/// This is a common class used for patterns of the form
+/// ```
+///    someop(memrefcast) -> someop
+/// ```
+/// It folds the source of the memref_cast into the root operation directly.
+static LogicalResult foldMemRefCast(Operation *op) {
+  bool folded = false;
+  for (OpOperand &operand : op->getOpOperands()) {
+    auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
+    if (castOp && canFoldIntoConsumerOp(castOp)) {
+      operand.set(castOp.getOperand());
+      folded = true;
+    }
+  }
+  return success(folded);
+}
+
+OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
+  /// transfer_read(memrefcast) -> transfer_read
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
@@ -1583,6 +1607,11 @@ static LogicalResult verify(TransferWriteOp op) {
                               [&op](Twine t) { return op.emitOpError(t); });
 }
 
+LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
+                                    SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeCastOp
 //===----------------------------------------------------------------------===//
index db504bf..5e4ba39 100644 (file)
@@ -159,3 +159,19 @@ func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
   // CHECK-NEXT: return [[ADD]]
   return %7 : vector<4x3x2xf32>
 }
+
+// -----
+
+// CHECK-LABEL: cast_transfers
+func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) {
+  %c0 = constant 0 : index
+  %f0 = constant 0.0 : f32
+  %0 = memref_cast %A : memref<4x8xf32> to memref<?x?xf32>
+
+  // CHECK: vector.transfer_read %{{.*}} : memref<4x8xf32>, vector<4x8xf32>
+  %1 = vector.transfer_read %0[%c0, %c0], %f0 : memref<?x?xf32>, vector<4x8xf32>
+
+  // CHECK: vector.transfer_write %{{.*}} : vector<4x8xf32>, memref<4x8xf32>
+  vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref<?x?xf32>
+  return %1 : vector<4x8xf32>
+}