}
};
+struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
+ using OpRewritePattern<ViewOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(ViewOp viewOp,
+ PatternRewriter &rewriter) const override {
+ Value memrefOperand = viewOp.getOperand(0);
+ MemRefCastOp memrefCastOp =
+ dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
+ if (!memrefCastOp)
+ return matchFailure();
+ Value allocOperand = memrefCastOp.getOperand();
+ AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
+ if (!allocOp)
+ return matchFailure();
+ rewriter.replaceOpWithNewOp<ViewOp>(memrefOperand, viewOp, viewOp.getType(),
+ allocOperand, viewOp.operands());
+ return matchSuccess();
+ }
+};
+
} // end anonymous namespace
void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<ViewOpShapeFolder>(context);
+ results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
}
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @view
func @view(%arg0 : index) {
+ // CHECK: %[[ALLOC_MEM:.*]] = alloc() : memref<2048xi8>
%0 = alloc() : memref<2048xi8>
%c0 = constant 0 : index
%c7 = constant 7 : index
// Test: preserve an existing static dim size while folding a dynamic
// dimension and offset.
- // CHECK: std.view %0[][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]>
- %5 = view %0[%c15][%c7]
- : memref<2048xi8> to memref<?x4xf32, #TEST_VIEW_MAP2>
+ // CHECK: std.view %[[ALLOC_MEM]][][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]>
+ %5 = view %0[%c15][%c7] : memref<2048xi8> to memref<?x4xf32, #TEST_VIEW_MAP2>
load %5[%c0, %c0] : memref<?x4xf32, #TEST_VIEW_MAP2>
+ // Test: folding static alloc and memref_cast into a view.
+ // CHECK: std.view %0[][%c15, %c7] : memref<2048xi8> to memref<?x?xf32>
+ %6 = memref_cast %0 : memref<2048xi8> to memref<?xi8>
+ %7 = view %6[%c15][%c7] : memref<?xi8> to memref<?x?xf32>
+ load %7[%c0, %c0] : memref<?x?xf32>
return
}