[mlir][Linalg] Improve comprehensive bufferization for scf.yield.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 12 Jul 2021 10:10:26 +0000 (10:10 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 12 Jul 2021 10:36:25 +0000 (10:36 +0000)
Previously, comprehensive bufferization of scf.yield did not have enough information
to detect whether an enclosing scf::for bbargs would bufferize to a buffer equivalent
to that of the matching scf::yield operand.
As a consequence a separate sanity check step would be required to determine whether
bufferization occured properly.
This late check would miss the case of calling a function in an loop.

Instead, we now pass and update aliasInfo during bufferization and it is possible to
imrpove bufferization of scf::yield and drop that post-pass check.

Add an example use case that was failing previously.
This slightly modifies the error conditions, which are also updated as part of this
revision.

Differential Revision: https://reviews.llvm.org/D105803

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir

index 8c37beb..be39eec 100644 (file)
@@ -2075,14 +2075,24 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
     auto tensorType = operand.get().getType().dyn_cast<TensorType>();
     if (!tensorType)
       continue;
+
     OpOperand &forOperand = forOp.getOpOperandForResult(
         forOp->getResult(operand.getOperandNumber()));
     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-    if (getInPlace(bbArg) == InPlaceSpec::True)
-      operand.set(bbArg);
-    else
-      operand.set(
-          b.create<memref::TensorLoadOp>(yieldOp.getLoc(), lookup(bvm, bbArg)));
+    Value yieldedBuffer = lookup(bvm, operand.get());
+    Value bbArgBuffer = lookup(bvm, bbArg);
+    if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) {
+      // TODO: this could get resolved with copies but it can also turn into
+      // swaps so we need to be careful about order of copies.
+      return yieldOp->emitError()
+             << "Yield operand #" << operand.getOperandNumber()
+             << " does not bufferize to an equivalent buffer to the matching"
+             << " enclosing scf::for operand";
+    }
+
+    // Buffers are equivalent so the work is already done and we just yield the
+    // bbArg so that it later canonicalizes away.
+    operand.set(bbArg);
   }
   return success();
 }
@@ -2205,38 +2215,6 @@ bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
   return success();
 }
 
-/// Return `failure()` if either
-/// scf::YieldOp are not explicitly bufferized and we need to perform a separate
-/// sanity check for now.
-static LogicalResult
-bufferizationSanityCheck(scf::YieldOp yieldOp,
-                         const BufferizationAliasInfo &aliasInfo) {
-  auto parentForOp = yieldOp->getParentOfType<scf::ForOp>();
-  if (!parentForOp)
-    return yieldOp->emitError() << "not nested under ForOp";
-
-  for (OpOperand &operand : yieldOp->getOpOperands()) {
-    OpResult matchingForOpResult =
-        parentForOp->getResult(operand.getOperandNumber());
-    // Nothing to do if operand bufferizes out of place.
-    if (getInPlace(matchingForOpResult) != InPlaceSpec::True)
-      continue;
-    OpOperand &machingForOpOperand =
-        parentForOp.getOpOperandForResult(matchingForOpResult);
-    BlockArgument matchingForOpIterArg =
-        parentForOp.getRegionIterArgForOpOperand(machingForOpOperand);
-    if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg,
-                                                 operand.get())) {
-      return yieldOp->emitError()
-             << "Yield operand #" << operand.getOperandNumber()
-             << " does not bufferize to an equivalent buffer to the matching"
-             << " enclosing scf::for operand -> Fail the pass\n";
-    }
-  }
-
-  return success();
-}
-
 /// Analyze the `funcOp` body to determine which OpResults are inplaceable.
 static LogicalResult
 inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
@@ -2275,13 +2253,14 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
       return failure();
   }
 
-  // Bufferize all ops except ExtractSliceOp and InsertSliceOp which are handled
-  // separately.
+  // Analyze all ops that return a tensors, except ExtractSliceOp and
+  // InsertSliceOp which are handled separately.
   // Walk other ops in reverse for better interference behavior.
   for (Operation *op : reverse(nonSliceOps))
     for (OpOperand &opOperand : op->getOpOperands())
       if (OpResult result = getInplaceableOpResult(opOperand))
-        if (failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo,
+        if (result.getType().isa<TensorType>() &&
+            failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo,
                                                domInfo)))
           return failure();
 
@@ -2292,14 +2271,9 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
     if (failed(bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo)))
       return failure();
 
-  // Sanity checks.
-  auto walkResult = funcOp.walk([&](scf::YieldOp yieldOp) -> WalkResult {
-    return bufferizationSanityCheck(yieldOp, aliasInfo);
-  });
-
   LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');
 
-  return success(!walkResult.wasInterrupted());
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
index 15be096..cdf35c0 100644 (file)
@@ -18,7 +18,7 @@ func private @foo() -> tensor<?xf32>
 
 // expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
 func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
-    -> (tensor<f32>, tensor<f32>) 
+    -> (tensor<f32>, tensor<f32>)
 {
   cond_br %cond1, ^bb1, ^bb2
 
@@ -64,7 +64,7 @@ func @scf_for(%A : tensor<?xf32>,
     // Throw a wrench in the system by swapping yielded values: this result in a
     // ping-pong of values at each iteration on which we currently want to fail.
 
-    // expected-error @+1 {{Yield operand #1 does not bufferize to an equivalent buffer}}
+    // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}}
     scf.yield %ttB, %ttA : tensor<?xf32>, tensor<?xf32>
   }
 
@@ -73,6 +73,27 @@ func @scf_for(%A : tensor<?xf32>,
 
 // -----
 
+func private @fun_with_side_effects(%A: tensor<?xf32> {linalg.inplaceable = true})
+
+func @foo(%A: tensor<?xf32> {linalg.inplaceable = true}) -> (tensor<?xf32>) {
+  call @fun_with_side_effects(%A) : (tensor<?xf32>) -> ()
+  return %A: tensor<?xf32>
+}
+
+func @scf_yield_needs_copy(%A : tensor<?xf32> {linalg.inplaceable = true}, %iters : index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %res = scf.for %arg0 = %c0 to %iters step %c1 iter_args(%bbarg = %A) -> (tensor<?xf32>) {
+    %r = call @foo(%A) : (tensor<?xf32>) -> (tensor<?xf32>)
+    // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}}
+    scf.yield %r : tensor<?xf32>
+  }
+  call @fun_with_side_effects(%res) : (tensor<?xf32>) -> ()
+  return
+}
+
+// -----
+
 func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
   ->  tensor<4xf32>
 {
@@ -92,8 +113,8 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
 
 func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
 {
-  %r = scf.if %b -> (tensor<4xf32>) { 
-    // expected-error @+1 {{not nested under ForOp}}
+  // expected-error @+1 {{unsupported op with tensors}}
+  %r = scf.if %b -> (tensor<4xf32>) {
     scf.yield %A : tensor<4xf32>
   } else {
     scf.yield %B : tensor<4xf32>