From 6b1668397fd33440847f5a82675c5b83c4137018 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 12 Jul 2021 10:10:26 +0000 Subject: [PATCH] [mlir][Linalg] Improve comprehensive bufferization for scf.yield. Previously, comprehensive bufferization of scf.yield did not have enough information to detect whether an enclosing scf::for bbargs would bufferize to a buffer equivalent to that of the matching scf::yield operand. As a consequence a separate sanity check step would be required to determine whether bufferization occured properly. This late check would miss the case of calling a function in an loop. Instead, we now pass and update aliasInfo during bufferization and it is possible to imrpove bufferization of scf::yield and drop that post-pass check. Add an example use case that was failing previously. This slightly modifies the error conditions, which are also updated as part of this revision. Differential Revision: https://reviews.llvm.org/D105803 --- .../Linalg/Transforms/ComprehensiveBufferize.cpp | 66 +++++++--------------- .../comprehensive-module-bufferize-invalid.mlir | 29 ++++++++-- 2 files changed, 45 insertions(+), 50 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp index 8c37beb..be39eec 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -2075,14 +2075,24 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp, auto tensorType = operand.get().getType().dyn_cast(); if (!tensorType) continue; + OpOperand &forOperand = forOp.getOpOperandForResult( forOp->getResult(operand.getOperandNumber())); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - if (getInPlace(bbArg) == InPlaceSpec::True) - operand.set(bbArg); - else - operand.set( - b.create(yieldOp.getLoc(), lookup(bvm, bbArg))); + Value yieldedBuffer = lookup(bvm, operand.get()); + Value bbArgBuffer = lookup(bvm, bbArg); + if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) { + // TODO: this could get resolved with copies but it can also turn into + // swaps so we need to be careful about order of copies. + return yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand"; + } + + // Buffers are equivalent so the work is already done and we just yield the + // bbArg so that it later canonicalizes away. + operand.set(bbArg); } return success(); } @@ -2205,38 +2215,6 @@ bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result, return success(); } -/// Return `failure()` if either -/// scf::YieldOp are not explicitly bufferized and we need to perform a separate -/// sanity check for now. -static LogicalResult -bufferizationSanityCheck(scf::YieldOp yieldOp, - const BufferizationAliasInfo &aliasInfo) { - auto parentForOp = yieldOp->getParentOfType(); - if (!parentForOp) - return yieldOp->emitError() << "not nested under ForOp"; - - for (OpOperand &operand : yieldOp->getOpOperands()) { - OpResult matchingForOpResult = - parentForOp->getResult(operand.getOperandNumber()); - // Nothing to do if operand bufferizes out of place. - if (getInPlace(matchingForOpResult) != InPlaceSpec::True) - continue; - OpOperand &machingForOpOperand = - parentForOp.getOpOperandForResult(matchingForOpResult); - BlockArgument matchingForOpIterArg = - parentForOp.getRegionIterArgForOpOperand(machingForOpOperand); - if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg, - operand.get())) { - return yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand -> Fail the pass\n"; - } - } - - return success(); -} - /// Analyze the `funcOp` body to determine which OpResults are inplaceable. static LogicalResult inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, @@ -2275,13 +2253,14 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, return failure(); } - // Bufferize all ops except ExtractSliceOp and InsertSliceOp which are handled - // separately. + // Analyze all ops that return a tensors, except ExtractSliceOp and + // InsertSliceOp which are handled separately. // Walk other ops in reverse for better interference behavior. for (Operation *op : reverse(nonSliceOps)) for (OpOperand &opOperand : op->getOpOperands()) if (OpResult result = getInplaceableOpResult(opOperand)) - if (failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo, + if (result.getType().isa() && + failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo, domInfo))) return failure(); @@ -2292,14 +2271,9 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, if (failed(bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) return failure(); - // Sanity checks. - auto walkResult = funcOp.walk([&](scf::YieldOp yieldOp) -> WalkResult { - return bufferizationSanityCheck(yieldOp, aliasInfo); - }); - LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); - return success(!walkResult.wasInterrupted()); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir index 15be096..cdf35c0 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -18,7 +18,7 @@ func private @foo() -> tensor // expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}} func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor, %t2 : tensor) - -> (tensor, tensor) + -> (tensor, tensor) { cond_br %cond1, ^bb1, ^bb2 @@ -64,7 +64,7 @@ func @scf_for(%A : tensor, // Throw a wrench in the system by swapping yielded values: this result in a // ping-pong of values at each iteration on which we currently want to fail. - // expected-error @+1 {{Yield operand #1 does not bufferize to an equivalent buffer}} + // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}} scf.yield %ttB, %ttA : tensor, tensor } @@ -73,6 +73,27 @@ func @scf_for(%A : tensor, // ----- +func private @fun_with_side_effects(%A: tensor {linalg.inplaceable = true}) + +func @foo(%A: tensor {linalg.inplaceable = true}) -> (tensor) { + call @fun_with_side_effects(%A) : (tensor) -> () + return %A: tensor +} + +func @scf_yield_needs_copy(%A : tensor {linalg.inplaceable = true}, %iters : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %res = scf.for %arg0 = %c0 to %iters step %c1 iter_args(%bbarg = %A) -> (tensor) { + %r = call @foo(%A) : (tensor) -> (tensor) + // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}} + scf.yield %r : tensor + } + call @fun_with_side_effects(%res) : (tensor) -> () + return +} + +// ----- + func @extract_slice_fun(%A : tensor {linalg.inplaceable = true}) -> tensor<4xf32> { @@ -92,8 +113,8 @@ func @extract_slice_fun(%A : tensor {linalg.inplaceable = true}) func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32> { - %r = scf.if %b -> (tensor<4xf32>) { - // expected-error @+1 {{not nested under ForOp}} + // expected-error @+1 {{unsupported op with tensors}} + %r = scf.if %b -> (tensor<4xf32>) { scf.yield %A : tensor<4xf32> } else { scf.yield %B : tensor<4xf32> -- 2.7.4