namespace mlir {
+class GlobalCreator;
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
/// Creates an instance of func bufferization pass.
std::unique_ptr<Pass> createFuncBufferizePass();
+/// Add patterns to bufferize tensor constants into global memrefs to the given
+/// pattern list.
+void populateTensorConstantBufferizePatterns(
+ GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
/// Creates an instance of tensor constant bufferization pass.
std::unique_ptr<Pass> createTensorConstantBufferizePass();
};
} // namespace
+void mlir::populateTensorConstantBufferizePatterns(
+ GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ patterns.add<BufferizeTensorConstantOp>(globalCreator, typeConverter,
+ patterns.getContext());
+}
+
namespace {
struct TensorConstantBufferizePass
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
ConversionTarget target(*context);
target.addLegalDialect<memref::MemRefDialect>();
- patterns.add<BufferizeTensorConstantOp>(globals, typeConverter, context);
+ populateTensorConstantBufferizePatterns(globals, typeConverter, patterns);
target.addDynamicallyLegalOp<ConstantOp>(
[&](ConstantOp op) { return typeConverter.isLegal(op.getType()); });
if (failed(applyPartialConversion(module, target, std::move(patterns))))