From b56bf30d3cc15896956061fdbeb6d078b63ec91f Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 5 Jun 2020 13:20:59 -0400 Subject: [PATCH] [mlir][Vector] Add folding of memref_cast into vector_transfer ops Summary: This revision adds a common folding pattern that starts appearing on vector_transfer ops. Differential Revision: https://reviews.llvm.org/D81281 --- mlir/include/mlir/Dialect/Vector/VectorOps.td | 4 ++++ mlir/lib/Dialect/Vector/VectorOps.cpp | 29 +++++++++++++++++++++++++++ mlir/test/Dialect/Vector/canonicalize.mlir | 16 +++++++++++++++ 3 files changed, 49 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 365795f..9ae1c74 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1061,6 +1061,8 @@ def Vector_TransferReadOp : return impl::getTransferMinorIdentityMap(memRefType, vectorType); } }]; + + let hasFolder = 1; } def Vector_TransferWriteOp : @@ -1150,6 +1152,8 @@ def Vector_TransferWriteOp : return impl::getTransferMinorIdentityMap(memRefType, vectorType); } }]; + + let hasFolder = 1; } def Vector_ShapeCastOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 21b62ce..019f5fd 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1498,6 +1498,30 @@ static LogicalResult verify(TransferReadOp op) { [&op](Twine t) { return op.emitOpError(t); }); } +/// This is a common class used for patterns of the form +/// ``` +/// someop(memrefcast) -> someop +/// ``` +/// It folds the source of the memref_cast into the root operation directly. +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto castOp = operand.get().getDefiningOp(); + if (castOp && canFoldIntoConsumerOp(castOp)) { + operand.set(castOp.getOperand()); + folded = true; + } + } + return success(folded); +} + +OpFoldResult TransferReadOp::fold(ArrayRef) { + /// transfer_read(memrefcast) -> transfer_read + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // TransferWriteOp //===----------------------------------------------------------------------===// @@ -1583,6 +1607,11 @@ static LogicalResult verify(TransferWriteOp op) { [&op](Twine t) { return op.emitOpError(t); }); } +LogicalResult TransferWriteOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} + //===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index db504bf..5e4ba39 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -159,3 +159,19 @@ func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> { // CHECK-NEXT: return [[ADD]] return %7 : vector<4x3x2xf32> } + +// ----- + +// CHECK-LABEL: cast_transfers +func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + %0 = memref_cast %A : memref<4x8xf32> to memref + + // CHECK: vector.transfer_read %{{.*}} : memref<4x8xf32>, vector<4x8xf32> + %1 = vector.transfer_read %0[%c0, %c0], %f0 : memref, vector<4x8xf32> + + // CHECK: vector.transfer_write %{{.*}} : vector<4x8xf32>, memref<4x8xf32> + vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref + return %1 : vector<4x8xf32> +} -- 2.7.4