[mlir] Expose a function to populate tensor constant bufferization patterns
authorBenjamin Kramer <benny.kra@googlemail.com>
Mon, 7 Jun 2021 19:57:55 +0000 (21:57 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Wed, 9 Jun 2021 11:47:33 +0000 (13:47 +0200)
This makes it easier to use it from other bufferization passes.

Differential Revision: https://reviews.llvm.org/D103838

mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp

index 2b7f3da..c7e331e 100644 (file)
@@ -19,6 +19,7 @@
 
 namespace mlir {
 
+class GlobalCreator;
 class RewritePatternSet;
 using OwningRewritePatternList = RewritePatternSet;
 
@@ -31,6 +32,12 @@ std::unique_ptr<Pass> createStdBufferizePass();
 /// Creates an instance of func bufferization pass.
 std::unique_ptr<Pass> createFuncBufferizePass();
 
+/// Add patterns to bufferize tensor constants into global memrefs to the given
+/// pattern list.
+void populateTensorConstantBufferizePatterns(
+    GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter,
+    RewritePatternSet &patterns);
+
 /// Creates an instance of tensor constant bufferization pass.
 std::unique_ptr<Pass> createTensorConstantBufferizePass();
 
index b40e47c..518405a 100644 (file)
@@ -81,6 +81,13 @@ public:
 };
 } // namespace
 
+void mlir::populateTensorConstantBufferizePatterns(
+    GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter,
+    RewritePatternSet &patterns) {
+  patterns.add<BufferizeTensorConstantOp>(globalCreator, typeConverter,
+                                          patterns.getContext());
+}
+
 namespace {
 struct TensorConstantBufferizePass
     : public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
@@ -94,7 +101,7 @@ struct TensorConstantBufferizePass
     ConversionTarget target(*context);
 
     target.addLegalDialect<memref::MemRefDialect>();
-    patterns.add<BufferizeTensorConstantOp>(globals, typeConverter, context);
+    populateTensorConstantBufferizePatterns(globals, typeConverter, patterns);
     target.addDynamicallyLegalOp<ConstantOp>(
         [&](ConstantOp op) { return typeConverter.isLegal(op.getType()); });
     if (failed(applyPartialConversion(module, target, std::move(patterns))))