From ba916c0cf6d0149f81bf1137e88f7d6fd3b0cc76 Mon Sep 17 00:00:00 2001 From: Quentin Colombet Date: Sat, 27 Aug 2022 01:14:23 +0000 Subject: [PATCH] [mlir][MemRef] Canonicalize reinterpret_cast(extract_strided_metadata) Add a canonicalizetion step for reinterpret_cast(extract_strided_metadata). This step replaces this sequence of operations by either: - A noop, i.e., the original memref is directly used, or - A plain cast of the original memref The choice is ultimately made based on whether the original memref type is equal to what the reinterpret_cast iss producing. For instance, the reinterpret_cast could be changing some dimensions from static to dynamic and in such case, we need to keep a cast. The transformation is currently only performed when the reinterpret_cast uses exactly the same arguments as what the extract_strided_metadata produces. It may be possible to be more aggressive here but I wanted to start with a relatively simple MLIR patch for my first one! Differential Revision: https://reviews.llvm.org/D132776 --- mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td | 1 + mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 59 ++++++++++++++++++++++++ mlir/test/Dialect/MemRef/canonicalize.mlir | 57 +++++++++++++++++++++++ 3 files changed, 117 insertions(+) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 5ef8d8f..a9b9d54 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1142,6 +1142,7 @@ def MemRef_ReinterpretCastOp }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 616b228..a2c49db 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1600,6 +1600,65 @@ OpFoldResult ReinterpretCastOp::fold(ArrayRef /*operands*/) { return nullptr; } +namespace { +/// Replace reinterpret_cast(extract_strided_metadata memref) -> memref. +struct ReinterpretCastOpExtractStridedMetadataFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReinterpretCastOp op, + PatternRewriter &rewriter) const override { + auto extractStridedMetadata = + op.getSource().getDefiningOp(); + if (!extractStridedMetadata) + return failure(); + // Check if the reinterpret cast reconstructs a memref with the exact same + // properties as the extract strided metadata. + + // First, check that the strides are the same. + if (extractStridedMetadata.getStrides().size() != op.getStrides().size()) + return failure(); + for (auto [extractStride, reinterpretStride] : + llvm::zip(extractStridedMetadata.getStrides(), op.getStrides())) + if (extractStride != reinterpretStride) + return failure(); + + // Second, check the sizes. + if (extractStridedMetadata.getSizes().size() != op.getSizes().size()) + return failure(); + for (auto [extractSize, reinterpretSize] : + llvm::zip(extractStridedMetadata.getSizes(), op.getSizes())) + if (extractSize != reinterpretSize) + return failure(); + + // Finally, check the offset. + if (op.getOffsets().size() != 1 && + extractStridedMetadata.getOffset() != *op.getOffsets().begin()) + return failure(); + + // At this point, we know that the back and forth between extract strided + // metadata and reinterpret cast is a noop. However, the final type of the + // reinterpret cast may not be exactly the same as the original memref. + // E.g., it could be changing a dimension from static to dynamic. Check that + // here and add a cast if necessary. + Type srcTy = extractStridedMetadata.getSource().getType(); + if (srcTy == op.getResult().getType()) + rewriter.replaceOp(op, extractStridedMetadata.getSource()); + else + rewriter.replaceOpWithNewOp(op, op.getType(), + extractStridedMetadata.getSource()); + + return success(); + } +}; +} // namespace + +void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // Reassociative reshape ops //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 138a0f4..45427759 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -740,6 +740,63 @@ func.func @reinterpret_of_subview(%arg : memref, %size1: index, %size2: in // ----- +// Check that a reinterpret cast of an equivalent extract strided metadata +// is canonicalized to a plain cast when the destination type is different +// than the type of the original memref. +// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach +// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) +// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref) -> memref { + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index + %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref + return %m2 : memref +} + +// ----- + +// Check that a reinterpret cast of an equivalent extract strided metadata +// is completely removed when the original memref has the same type. +// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type +// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) +// CHECK: return %[[ARG]] +func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<8x2xf32>) -> memref<8x2xf32> { + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index + %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref<8x2xf32> + return %m2 : memref<8x2xf32> +} + +// ----- + +// Check that we don't simplify reinterpret cast of extract strided metadata +// when the strides don't match. +// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride +// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [4, 2, 2], strides: [1, 1, %[[STRIDES]]#1] +// CHECK: return %[[RES]] +func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref { + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index + %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref to memref + return %m2 : memref +} +// ----- + +// Check that we don't simplify reinterpret cast of extract strided metadata +// when the offset doesn't match. +// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset +// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[SIZES]]#0, %[[SIZES]]#1], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1] +// CHECK: return %[[RES]] +func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref { + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index + %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref + return %m2 : memref +} + +// ----- + func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>, %arg1 : index) -> memref { %c0 = arith.constant 0 : index -- 2.7.4