From c8f4005b0c6529c69a404b021faa0d962de03223 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 5 Nov 2021 11:29:46 +0900 Subject: [PATCH] [mlir][linalg][bufferize] Add isWritable to op interface By default, OpResult buffers are writable. But there are ops (e.g., ConstantOp) for which this is not the case. The purpose of this commit is to further decouple Comprehensive Bufferize from the Standard dialect. Differential Revision: https://reviews.llvm.org/D112908 --- .../Linalg/Transforms/BufferizableOpInterface.td | 15 +++++++++++++++ .../Linalg/Transforms/ComprehensiveBufferize.cpp | 17 +++++++++++------ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td index 9227e5f..4ab4f21 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td @@ -158,6 +158,21 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { return failure(); }] >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the given OpOperand can be written to in-place. This + is the case for most ops, but some ops such as ConstantOp may + bufferize to non-writable (read-only) memory locations. This method + will never be called on OpResults that do not have a tensor type. + }], + /*retType=*/"bool", + /*methodName=*/"isWritable", + /*args=*/(ins "OpResult":$opResult), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return true; + }] + > ]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp index a6b6e01..63145b1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -587,12 +587,12 @@ bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const { return true; } - if (Operation *op = v.getDefiningOp()) { - if (isa(op) || - !dyn_cast(op)) { - LDBG("-----------notWritable op\n"); - return true; - } + auto bufferizableOp = dyn_cast(v.getDefiningOp()); + if (!bufferizableOp || !bufferizableOp.isWritable(v.cast())) { + // Unknown ops are treated conservatively: Assume that it is illegal to + // write to their OpResults in-place. + LDBG("-----------notWritable op\n"); + return true; } } LDBG("---->value is writable\n"); @@ -2421,6 +2421,11 @@ struct ConstantOpInterface return success(); } + + bool isWritable(Operation *op, OpResult opResult) const { + // Memory locations returned by memref::GetGlobalOp may not be written to. + return false; + } }; } // namespace arith_ext -- 2.7.4