[mlir][linalg][bufferize] Add bufferRelation to op interface
authorMatthias Springer <springerm@google.com>
Fri, 8 Oct 2021 03:14:05 +0000 (12:14 +0900)
committerMatthias Springer <springerm@google.com>
Fri, 8 Oct 2021 05:28:24 +0000 (14:28 +0900)
Currently supported are: BufferRelation::None, BufferRelation::Equivalent.

Differential Revision: https://reviews.llvm.org/D111376

mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

index 82ada9f40f97511fce5254bc022b7a50729b0799..b0249ff38a569fc24cba6ab4b85db5c1b2c8495e 100644 (file)
@@ -67,8 +67,7 @@ public:
 
   /// Set the inPlace bufferization spec to true.
   /// Merge result's and operand's aliasing sets and iterate to a fixed point.
-  void bufferizeInPlace(OpResult result, OpOperand &operand,
-                        BufferRelation bufferRelation = BufferRelation::None);
+  void bufferizeInPlace(OpResult result, OpOperand &operand);
 
   /// Set the inPlace bufferization spec to false.
   void bufferizeOutOfPlace(OpResult result);
index 0c4ded7bb9cf52f32616156d77350cc00b848657..685eb51928cc70105e26d8259e6323217e6bdca7 100644 (file)
@@ -136,6 +136,8 @@ using namespace mlir;
 using namespace linalg;
 using namespace tensor;
 
+using BufferRelation = BufferizationAliasInfo::BufferRelation;
+
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X)
 
@@ -421,8 +423,10 @@ static std::string printValueInfo(Value value, bool prefix) {
 //      buffers in memory.
 //   3. Whether an op operand, when bufferized inplace, aliases a return value.
 //   4. Whether an op return value, when bufferized inplace, aliases an operand.
-//   5. Wheher an op bufferizes to a memory read.
-//   6. Wheher an op bufferizes to a memory write.
+//   5. Whether an op bufferizes to a memory read.
+//   6. Whether an op bufferizes to a memory write.
+//   7. The buffer relationship between an operand and it corresponding result
+//      (in case of in-place bufferization).
 // These interfaces are necessary to distinguish between various cases and allow
 // special inplace behavior for (ExtractSliceOp, InsertSliceOp) pairs.
 //===----------------------------------------------------------------------===//
@@ -682,6 +686,16 @@ bufferizesToMemoryWrite(OpOperand &opOperand,
          getInPlace(opResult) == inPlaceSpec;
 }
 
+/// Returns the relationship between the operand and the its corresponding
+/// OpResult that it may alias with.
+static BufferRelation bufferRelation(OpOperand &operand) {
+  return TypeSwitch<Operation *, BufferRelation>(operand.getOwner())
+      // ExtractSliceOp returns a subview of the original tensor.
+      .Case([&](ExtractSliceOp op) { return BufferRelation::None; })
+      // All other ops: Buffers are equivalent.
+      .Default([&](Operation *op) { return BufferRelation::Equivalent; });
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific alias analysis.
 //===----------------------------------------------------------------------===//
@@ -787,13 +801,12 @@ bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const {
 
 /// Set the inPlace bufferization spec to true.
 void BufferizationAliasInfo::bufferizeInPlace(OpResult result,
-                                              OpOperand &operand,
-                                              BufferRelation bufferRelation) {
+                                              OpOperand &operand) {
   setInPlaceOpResult(result, InPlaceSpec::True);
   aliasInfo.unionSets(result, operand.get());
   // Dump the updated alias analysis.
   LLVM_DEBUG(dumpAliases());
-  if (bufferRelation == BufferRelation::Equivalent)
+  if (bufferRelation(operand) == BufferRelation::Equivalent)
     equivalentInfo.unionSets(result, operand.get());
   // Dump the updated equivalence analysis.
   LLVM_DEBUG(dumpEquivalences());
@@ -2293,8 +2306,7 @@ bufferizableInPlaceAnalysis(OpOperand &operand,
   else
     // TODO: Atm, all inplace bufferizations yield equivalent tensors. Support
     // more cases on a per-need basis.
-    aliasInfo.bufferizeInPlace(
-        result, operand, BufferizationAliasInfo::BufferRelation::Equivalent);
+    aliasInfo.bufferizeInPlace(result, operand);
 
   LDBG("Done inplace analysis for result #" << resultNumber << '\n');