From 1e25109f93ffe5b28b28a2359e69143b7fb4aa5f Mon Sep 17 00:00:00 2001 From: Ahmed Taei Date: Tue, 7 Jan 2020 17:46:40 -0800 Subject: [PATCH] Canonicalize static alloc followed by memref_cast and std.view Summary: Rewrite alloc, memref_cast, std.view into allo, std.view by droping memref_cast. Reviewers: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72379 --- mlir/lib/Dialect/StandardOps/Ops.cpp | 22 +++++++++++++++++++++- mlir/test/Transforms/canonicalize.mlir | 11 ++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index b4fdcc4..f929619 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2527,11 +2527,31 @@ struct ViewOpShapeFolder : public OpRewritePattern { } }; +struct ViewOpMemrefCastFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ViewOp viewOp, + PatternRewriter &rewriter) const override { + Value memrefOperand = viewOp.getOperand(0); + MemRefCastOp memrefCastOp = + dyn_cast_or_null(memrefOperand.getDefiningOp()); + if (!memrefCastOp) + return matchFailure(); + Value allocOperand = memrefCastOp.getOperand(); + AllocOp allocOp = dyn_cast_or_null(allocOperand.getDefiningOp()); + if (!allocOp) + return matchFailure(); + rewriter.replaceOpWithNewOp(memrefOperand, viewOp, viewOp.getType(), + allocOperand, viewOp.operands()); + return matchSuccess(); + } +}; + } // end anonymous namespace void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 09db088..a6c2326 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -695,6 +695,7 @@ func @cast_values(%arg0: tensor<*xi32>, %arg1: memref) -> (tensor<2xi32>, // CHECK-LABEL: func @view func @view(%arg0 : index) { + // CHECK: %[[ALLOC_MEM:.*]] = alloc() : memref<2048xi8> %0 = alloc() : memref<2048xi8> %c0 = constant 0 : index %c7 = constant 7 : index @@ -730,11 +731,15 @@ func @view(%arg0 : index) { // Test: preserve an existing static dim size while folding a dynamic // dimension and offset. - // CHECK: std.view %0[][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]> - %5 = view %0[%c15][%c7] - : memref<2048xi8> to memref + // CHECK: std.view %[[ALLOC_MEM]][][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]> + %5 = view %0[%c15][%c7] : memref<2048xi8> to memref load %5[%c0, %c0] : memref + // Test: folding static alloc and memref_cast into a view. + // CHECK: std.view %0[][%c15, %c7] : memref<2048xi8> to memref + %6 = memref_cast %0 : memref<2048xi8> to memref + %7 = view %6[%c15][%c7] : memref to memref + load %7[%c0, %c0] : memref return } -- 2.7.4