[mlir][bufferization] Generalize and rename isMemoryWrite
authorMatthias Springer <springerm@google.com>
Mon, 30 Jan 2023 08:26:15 +0000 (09:26 +0100)
committerMatthias Springer <springerm@google.com>
Mon, 30 Jan 2023 08:34:04 +0000 (09:34 +0100)
The name of the method was confusing. It is bufferizesToMemoryWrite, but from the perspective of OpResults.

`bufferizesToMemoryWrite(OpResult)` now supports ops with regions that do not have aliasing OpOperands (such as `scf.if`). These ops no longer need to implement `isMemoryWrite`.

Differential Revision: https://reviews.llvm.org/D141684

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp

index 0ca2086..ccd2711 100644 (file)
@@ -331,6 +331,11 @@ public:
   /// the op is not bufferizable.
   bool bufferizesToMemoryWrite(OpOperand &opOperand) const;
 
+  /// Return true if the given `value` bufferizes to a memory write. Return
+  /// true if the value is a block argument. Return `true` if the defining op is
+  /// not bufferizable. Otherwise, consult the BufferizableOpInterface.
+  bool bufferizesToMemoryWrite(Value value) const;
+
   /// Return true if `opOperand` does neither read nor write but bufferizes to
   /// an alias. Return false if the op is not bufferizable.
   bool bufferizesToAliasOnly(OpOperand &opOperand) const;
@@ -545,6 +550,12 @@ defaultGetBufferType(Value value, const BufferizationOptions &options,
                      const DenseMap<Value, BaseMemRefType> &fixedTypes);
 
 /// This is the default implementation of
+/// BufferizableOpInterface::resultBufferizesToMemoryWrite. Should not be called
+/// from other places.
+bool defaultResultBufferizesToMemoryWrite(OpResult opResult,
+                                          const AnalysisState &state);
+
+/// This is the default implementation of
 /// BufferizableOpInterface::isRepetitiveRegion. Should not be called from other
 /// places.
 bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
index c8a917e..488165b 100644 (file)
@@ -92,35 +92,63 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
          }]
       >,
       InterfaceMethod<
-          /*desc=*/[{
-            Return `true` if the given OpResult is a memory write. This is the
-            case if in the following cases:
+        /*desc=*/[{
+          Return `true` if the given OpResult bufferizes to a memory write.
+          This is the same property as `bufferizesToMemoryWrite`, but from The
+          perspective of OpResults.
+
+          This method will never be called on OpResults that do not have a
+          tensor type.
+
+          This method has a default implementation. By default, it returns
+          `true` if any of the following three cases applies.
+
+          1. There is no corresponding aliasing OpOperand.
+
+             Example: `tensor.generate ... : tensor<10xf32>`
+             The op fills a newly allocated buffer and bufferizes to a memory
+             write.
+
+             Counter-example: bufferization.alloc_tensor
+             The op just allocates and does not specifiy the data of the tensor,
+             so resultBufferizesToMemoryWrite is overridden to return false.
+
+          2. At least one aliasing OpOperand bufferizes to a memory write.
 
-            * The corresponding aliasing OpOperand bufferizes to a memory write.
-            * Or: There is no corresponding aliasing OpOperand.
+             Example: `tensor.insert %f into %t[...] : tensor<?xf32>`
+             The destination OpOperand bufferizes to a memory write, so the
+             result also bufferizes to a memory write.
 
-            If the OpResult has multiple aliasing OpOperands, this method
-            returns `true` if at least one of them bufferizes to a memory write.
+          3. At least one aliasing OpOperand's value is defined inside the
+             defining op of the given OpResult and it is a memory write or the
+             reverse SSA use-def chain ends in the defining op.
+
+             According to this rule, an aliasing OpOperand value that is defined
+             inside this op and is bufferizing to a memory write makes the given
+             OpResult bufferize to a memory write.
+
+             Example:
+             ```
+             %r = scf.if ... -> tensor<?xf32> {
+               %1 = tensor.insert %f into %t[...] : tensor<?xf32>
+               scf.yield %1 : tensor<?xf32>
+             } else { ... }
+             ```
+             The scf.if result bufferizes to a memory write because %1 (an
+             OpResult defined inside the scf.if op) bufferizes to a memory
+             write.
           }],
-          /*retType=*/"bool",
-          /*methodName=*/"isMemoryWrite",
-          /*args=*/(ins "::mlir::OpResult":$opResult,
-                        "const ::mlir::bufferization::AnalysisState &":$state),
-          /*methodBody=*/"",
-          /*defaultImplementation=*/[{
-            auto bufferizableOp =
-                cast<BufferizableOpInterface>($_op.getOperation());
-            SmallVector<OpOperand*> opOperands =
-              bufferizableOp.getAliasingOpOperand(opResult, state);
-            if (opOperands.empty())
-              return true;
-            return llvm::any_of(
-                opOperands,
-                [&](OpOperand *operand) {
-                  return bufferizableOp.bufferizesToMemoryWrite(*operand,
-                                                                state);
-                });
-          }]
+        /*retType=*/"bool",
+        /*methodName=*/"resultBufferizesToMemoryWrite",
+        /*args=*/(ins "::mlir::OpResult":$opResult,
+                      "const ::mlir::bufferization::AnalysisState &":$state),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          assert(opResult.getDefiningOp() == $_op.getOperation() &&
+                 "invalid OpResult");
+          return bufferization::detail::defaultResultBufferizesToMemoryWrite(
+              opResult, state);
+        }]
       >,
       InterfaceMethod<
         /*desc=*/[{
index 279f597..909ac44 100644 (file)
@@ -89,7 +89,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     LogicalResult bufferize(RewriterBase &rewriter,
                             const BufferizationOptions &options);
 
-    bool isMemoryWrite(OpResult opResult, const AnalysisState &state);
+    bool resultBufferizesToMemoryWrite(OpResult opResult,
+                                       const AnalysisState &state);
 
     bool bufferizesToAllocation(OpResult opResult) { return true; }
 
index 9e7dbf5..8abb9c3 100644 (file)
@@ -405,6 +405,16 @@ bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
   return false;
 }
 
+bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
+  auto opResult = value.dyn_cast<OpResult>();
+  if (!opResult)
+    return true;
+  auto bufferizableOp = getOptions().dynCastBufferizableOp(value);
+  if (!bufferizableOp)
+    return true;
+  return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *this);
+}
+
 /// Return true if the given value is read by an op that bufferizes to a memory
 /// read. Also takes into account ops that create an alias but do not read by
 /// themselves (e.g., ExtractSliceOp).
@@ -473,15 +483,8 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
 // Find the Values of the last preceding write of a given Value.
 llvm::SetVector<Value>
 AnalysisState::findLastPrecedingWrite(Value value) const {
-  return findValueInReverseUseDefChain(value, [&](Value value) {
-    Operation *op = value.getDefiningOp();
-    if (!op)
-      return true;
-    auto bufferizableOp = options.dynCastBufferizableOp(op);
-    if (!bufferizableOp)
-      return true;
-    return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
-  });
+  return findValueInReverseUseDefChain(
+      value, [&](Value v) { return this->bufferizesToMemoryWrite(v); });
 }
 
 AnalysisState::AnalysisState(const BufferizationOptions &options)
@@ -618,6 +621,70 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
       .getResult();
 }
 
+bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
+    OpResult opResult, const AnalysisState &state) {
+  auto bufferizableOp = cast<BufferizableOpInterface>(opResult.getDefiningOp());
+  SmallVector<OpOperand *> opOperands =
+      bufferizableOp.getAliasingOpOperand(opResult, state);
+
+  // Case 1: OpResults that have no aliasing OpOperand usually bufferize to
+  // memory writes.
+  if (opOperands.empty())
+    return true;
+
+  // Case 2: If an aliasing OpOperand bufferizes to a memory write, the OpResult
+  // may bufferize to a memory write.
+  if (llvm::any_of(opOperands, [&](OpOperand *operand) {
+        return state.bufferizesToMemoryWrite(*operand);
+      }))
+    return true;
+
+  // Case 3: Check if a nested aliasing OpOperand value bufferizes to a memory
+  // write. (Or: The reverse SSA use-def chain ends inside the reigon.) In that
+  // case, the OpResult bufferizes to a memory write. E.g.:
+  //
+  // %0 = "some_writing_op" : tensor<?xf32>
+  // %r = scf.if ... -> tensor<?xf32> {
+  //   scf.yield %0 : tensor<?xf32>
+  // } else {
+  //   %1 = "another_writing_op"(%0) : tensor<?xf32>
+  //   scf.yield %1 : tensor<?xf32>
+  // }
+  // "some_reading_op"(%r)
+  //
+  // %r bufferizes to a memory write because an aliasing OpOperand value (%1)
+  // bufferizes to a memory write and the defining op is inside the scf.if.
+  //
+  // Note: This treatment of surrouding ops is useful for ops that have a
+  // region but no OpOperand such as scf.if or scf.execute_region. It simplifies
+  // the analysis considerably.
+  //
+  // "another_writing_op" in the above example should be able to bufferize
+  // inplace in the absence of another read of %0. However, if the scf.if op
+  // would not be considered a "write", the analysis would detect the
+  // following conflict:
+  //
+  // * read = some_reading_op
+  // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
+  // * conflictingWrite = %1
+  //
+  auto isMemoryWriteInsideOp = [&](Value v) {
+    Operation *op = getOwnerOfValue(v);
+    if (!opResult.getDefiningOp()->isAncestor(op))
+      return false;
+    return state.bufferizesToMemoryWrite(v);
+  };
+  for (OpOperand *operand : opOperands) {
+    if (!state
+             .findValueInReverseUseDefChain(operand->get(),
+                                            isMemoryWriteInsideOp,
+                                            /*followEquivalentOnly=*/false)
+             .empty())
+      return true;
+  }
+  return false;
+}
+
 FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
     const DenseMap<Value, BaseMemRefType> &fixedTypes) {
index e6f93d6..61fc3aa 100644 (file)
@@ -206,8 +206,8 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
   return success();
 }
 
-bool AllocTensorOp::isMemoryWrite(OpResult opResult,
-                                  const AnalysisState &state) {
+bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
+                                                  const AnalysisState &state) {
   // AllocTensorOps do not write unless they have a `copy` value.
   return static_cast<bool>(getCopy());
 }
index ef87fa9..2e48067 100644 (file)
@@ -277,10 +277,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
       // memory write or not.
       SetVector<Value> lastWrites = findLastPrecedingWrite(opResult);
       bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) {
-        if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite))
-          return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
-                                              *this);
-        return true;
+        return this->bufferizesToMemoryWrite(lastWrite);
       });
       if (isUndefined)
         for (OpOperand &use : opResult.getUses())
@@ -356,19 +353,6 @@ static bool happensBefore(Operation *a, Operation *b,
   return false;
 }
 
-/// Return `true` if the given tensor value is a memory write. Most values are
-/// tensor writes, but ops that define a tensor SSA value without specifying its
-/// contents (e.g., alloc_tensor) are not.
-static bool isMemoryWrite(Value value, const AnalysisState &state) {
-  auto opResult = value.dyn_cast<OpResult>();
-  if (!opResult)
-    return true;
-  auto bufferizableOp = state.getOptions().dynCastBufferizableOp(value);
-  if (!bufferizableOp)
-    return true;
-  return bufferizableOp.isMemoryWrite(opResult, state);
-}
-
 /// Return `true` if op dominance can be used to rule out read-after-write
 /// conflicts wrt. the given reads and writes.
 ///
@@ -471,7 +455,7 @@ bool canUseOpDominance(const DenseSet<OpOperand *> &usesRead,
   // In case of a read, take the region which the read value is defined.
   for (OpOperand *uRead : usesRead) {
     // Optimization: Skip reads of values that have no defined contents.
-    if (!isMemoryWrite(uRead->get(), state))
+    if (!state.bufferizesToMemoryWrite(uRead->get()))
       continue;
     Region *r = getEnclosingRepetitiveRegion(uRead->get(), options);
     if (!commonEnclosingRegion.has_value()) {
index 7cfb974..4a2991e 100644 (file)
@@ -123,18 +123,6 @@ struct ExecuteRegionOpInterface
     return {&yieldOp->getOpOperand(resultNum)};
   }
 
-  // TODO: For better bufferization results, this could return `true` only if
-  // there is a memory write in the region.
-  bool isMemoryWrite(Operation *op, OpResult opResult,
-                     const AnalysisState &state) const {
-    // Similar to scf.if, results of this op are always considered memory writes
-    // in the analysis. This is a useful pattern for all ops that have tensor
-    // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
-    // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
-    // ops without OpOperands.
-    return true;
-  }
-
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
@@ -190,37 +178,6 @@ struct IfOpInterface
             &ifOp.elseYield()->getOpOperand(resultNum)};
   }
 
-  // TODO: For better bufferization results, this could return `true` only if
-  // there is a memory write in one (or both) of the branches. Since this is not
-  // allowed at the moment, we should never encounter scf.ifs that yield
-  // unmodified tensors. Such scf.yield ops could just fold away.
-  bool isMemoryWrite(Operation *op, OpResult opResult,
-                     const AnalysisState &state) const {
-    // IfOp results are always considered memory writes in the analysis. This
-    // design decision simplifies the analysis considerably. E.g., consider the
-    // following test case:
-    //
-    // %0 = "some_writing_op" : tensor<?xf32>
-    // %r = scf.if %c -> (tensor<?xf32>) {
-    //   scf.yield %0
-    // } else {
-    //   %1 = "another_writing_op"(%0) : tensor<?xf32>
-    // }
-    // "some_reading_op"(%r)
-    //
-    // "another_writing_op" in the above example should be able to bufferize
-    // inplace in the absence of another read of %0. However, if the scf.if op
-    // would not be considered a "write", the analysis would detect the
-    // following conflict:
-    //
-    // * read = some_reading_op
-    // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
-    // * conflictingWrite = %1
-    //
-    // For more details, check the "scf.IfOp" section of the design document.
-    return true;
-  }
-
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     OpBuilder::InsertionGuard g(rewriter);
index 171c12d..8ea6075 100644 (file)
@@ -46,18 +46,6 @@ struct AssumingOpInterface
     return {&yieldOp->getOpOperand(resultNum)};
   }
 
-  // TODO: For better bufferization results, this could return `true` only if
-  // there is a memory write in the region.
-  bool isMemoryWrite(Operation *op, OpResult opResult,
-                     const AnalysisState &state) const {
-    // Similar to scf.if, results of this op are always considered memory writes
-    // in the analysis. This is a useful pattern for all ops that have tensor
-    // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
-    // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
-    // ops without OpOperands.
-    return true;
-  }
-
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto assumingOp = cast<shape::AssumingOp>(op);
index aaf1e86..f7ff03e 100644 (file)
@@ -111,8 +111,8 @@ struct LoadOpInterface
 struct NewOpInterface
     : public BufferizableOpInterface::ExternalModel<NewOpInterface,
                                                     sparse_tensor::NewOp> {
-  bool isMemoryWrite(Operation *op, OpResult opResult,
-                     const AnalysisState &state) const {
+  bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
+                                     const AnalysisState &state) const {
     // NewOps allocate but do not write.
     return false;
   }