From c0db8d50ca3ceb1301b2ade2fb86c591a5b64e5c Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 7 Jun 2021 21:57:55 +0200 Subject: [PATCH] [mlir] Expose a function to populate tensor constant bufferization patterns This makes it easier to use it from other bufferization passes. Differential Revision: https://reviews.llvm.org/D103838 --- mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h | 7 +++++++ .../Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp | 9 ++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h index 2b7f3da..c7e331e 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -19,6 +19,7 @@ namespace mlir { +class GlobalCreator; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; @@ -31,6 +32,12 @@ std::unique_ptr createStdBufferizePass(); /// Creates an instance of func bufferization pass. std::unique_ptr createFuncBufferizePass(); +/// Add patterns to bufferize tensor constants into global memrefs to the given +/// pattern list. +void populateTensorConstantBufferizePatterns( + GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns); + /// Creates an instance of tensor constant bufferization pass. std::unique_ptr createTensorConstantBufferizePass(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp index b40e47c..518405a 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -81,6 +81,13 @@ public: }; } // namespace +void mlir::populateTensorConstantBufferizePatterns( + GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add(globalCreator, typeConverter, + patterns.getContext()); +} + namespace { struct TensorConstantBufferizePass : public TensorConstantBufferizeBase { @@ -94,7 +101,7 @@ struct TensorConstantBufferizePass ConversionTarget target(*context); target.addLegalDialect(); - patterns.add(globals, typeConverter, context); + populateTensorConstantBufferizePatterns(globals, typeConverter, patterns); target.addDynamicallyLegalOp( [&](ConstantOp op) { return typeConverter.isLegal(op.getType()); }); if (failed(applyPartialConversion(module, target, std::move(patterns)))) -- 2.7.4