namespace bufferization {
class BufferizableOpInterface;
-struct BufferizationOptions;
class BufferizationState;
+struct DialectBufferizationState;
/// Options for ComprehensiveBufferize.
struct BufferizationOptions {
/// Memcpy function: Generate a memcpy between two buffers.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+ /// Initializer function for bufferization state.
+ using BufferizationStateInitFn = std::function<void(BufferizationState &)>;
+ /// Initializer function for dialect-specific bufferization state.
+ using DialectStateInitFn =
+ std::function<std::unique_ptr<DialectBufferizationState>()>;
/// An op filter entry. Filters can be used to specify which ops should be
/// processed by the bufferization.
/// DENY-filtered and have at least one matching ALLOW filter are processed.
SmallVector<OpFilterEntry> opFilter;
+ /// Initializer functions for bufferization state. These can be used to
+ /// initialize dialect-specific bufferization state.
+ SmallVector<BufferizationStateInitFn> stateInitializers;
+
+ /// Add a bufferization state initializer that initializes the specified
+ /// dialect-specific bufferization state.
+ void addDialectStateInitializer(StringRef name, DialectStateInitFn fn);
+
private:
/// Allow a dialect.
template <typename DialectT>
return static_cast<StateT &>(*dialectState[name]);
}
+ void insertDialectState(StringRef name,
+ std::unique_ptr<DialectBufferizationState> state) {
+ assert(!dialectState.count(name) && "dialect state already initialized");
+ dialectState[name] = std::move(state);
+ }
+
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const { return options; }
return nullptr;
}
+void BufferizationOptions::addDialectStateInitializer(StringRef name,
+ DialectStateInitFn fn) {
+ stateInitializers.push_back(
+ [=](BufferizationState &state) { state.insertDialectState(name, fn()); });
+}
+
//===----------------------------------------------------------------------===//
// Helper functions for BufferizableOpInterface
//===----------------------------------------------------------------------===//
}
BufferizationState::BufferizationState(const BufferizationOptions &options)
- : options(options) {}
+ : options(options) {
+ for (const BufferizationOptions::BufferizationStateInitFn &fn :
+ options.stateInitializers)
+ fn(*this);
+}
// bufferization.to_memref is not allowed to change the rank.
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {