#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
+using namespace mlir;
+
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
using tensor::ExtractSliceOp;
using tensor::InsertSliceOp;
+namespace {
+/// Extra bufferization state that is required for bufferization of tensor ops.
+struct TensorBufferizationState : public DialectBufferizationState {
+ /// InsertSliceOps that bufferize inplace and do not require a copy.
+ DenseSet<Operation *> insertSliceOpsWithoutCopy;
+};
+} // namespace
+
+static TensorBufferizationState &
+getTensorBufferizationState(BufferizationState &state) {
+ return state.getDialectState<TensorBufferizationState>(
+ tensor::TensorDialect::getDialectNamespace());
+}
+
struct CastOpInterface
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
tensor::CastOp> {
// catastrophically bad scheduling decision.
// TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ TensorBufferizationState &tensorState = getTensorBufferizationState(state);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
if (!dstMemref)
return failure();
- // A copy of the source buffer is needed if either:
- // - The producer of `source` is not inplace. This is the case where a
- // slice is computed out of place into the inplace full tensor.
- // - The result is not inplace. This is the case where the whole tensor is
- // cloned and the clone needs to be updated.
- // TODO: Is this necessary?
- bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
- state.aliasInfo, insertSliceOp) ||
- !state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
+ bool needCopy =
+ !tensorState.insertSliceOpsWithoutCopy.contains(insertSliceOp);
if (needCopy) {
// Take a subview of the dst.
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
} // namespace linalg
} // namespace mlir
+LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext::
+ InplaceInsertSliceOpAnalysis::run(FuncOp funcOp, BufferizationState &state,
+ SmallVector<Operation *> &newOps) {
+ auto &tensorState = getTensorBufferizationState(state);
+ funcOp.walk([&](InsertSliceOp insertSliceOp) {
+ // A copy of the source buffer is needed if either:
+ // - The producer of `source` is not inplace. This is the case where a
+ // slice is computed out of place into the inplace full tensor.
+ // - The result is not inplace. This is the case where the whole tensor is
+ // cloned and the clone needs to be updated.
+ if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(state.aliasInfo,
+ insertSliceOp) &&
+ state.aliasInfo.isInPlace(insertSliceOp->getResult(0)))
+ tensorState.insertSliceOpsWithoutCopy.insert(insertSliceOp);
+ });
+ return success();
+}
+
void mlir::linalg::comprehensive_bufferize::tensor_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();