From: Benjamin Kramer Date: Mon, 7 Jun 2021 19:57:55 +0000 (+0200) Subject: [mlir] Expose a function to populate tensor constant bufferization patterns X-Git-Tag: llvmorg-14-init~4450 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c0db8d50ca3ceb1301b2ade2fb86c591a5b64e5c;p=platform%2Fupstream%2Fllvm.git [mlir] Expose a function to populate tensor constant bufferization patterns This makes it easier to use it from other bufferization passes. Differential Revision: https://reviews.llvm.org/D103838 --- diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h index 2b7f3da..c7e331e 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -19,6 +19,7 @@ namespace mlir { +class GlobalCreator; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; @@ -31,6 +32,12 @@ std::unique_ptr createStdBufferizePass(); /// Creates an instance of func bufferization pass. std::unique_ptr createFuncBufferizePass(); +/// Add patterns to bufferize tensor constants into global memrefs to the given +/// pattern list. +void populateTensorConstantBufferizePatterns( + GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns); + /// Creates an instance of tensor constant bufferization pass. std::unique_ptr createTensorConstantBufferizePass(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp index b40e47c..518405a 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -81,6 +81,13 @@ public: }; } // namespace +void mlir::populateTensorConstantBufferizePatterns( + GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add(globalCreator, typeConverter, + patterns.getContext()); +} + namespace { struct TensorConstantBufferizePass : public TensorConstantBufferizeBase { @@ -94,7 +101,7 @@ struct TensorConstantBufferizePass ConversionTarget target(*context); target.addLegalDialect(); - patterns.add(globals, typeConverter, context); + populateTensorConstantBufferizePatterns(globals, typeConverter, patterns); target.addDynamicallyLegalOp( [&](ConstantOp op) { return typeConverter.isLegal(op.getType()); }); if (failed(applyPartialConversion(module, target, std::move(patterns))))