Remove all function calls related to buffer equivalence from bufferize implementations.
Add a new PostAnalysisStep for scf.for that ensures that yielded values are equivalent to the corresponding BBArgs. (This was previously checked in `bufferize`.) This will be relaxed in a subsequent commit.
Note: This commit changes two test cases. These were broken by design
and should not have passed. With the new scf.for PostAnalysisStep, this
bug was fixed.
Differential Revision: https://reviews.llvm.org/D114927
namespace comprehensive_bufferize {
namespace scf_ext {
+/// Equivalence analysis for scf.for. Raise an error if iter_args are not
+/// equivalent to their corresponding loop yield values.
+struct AssertDestinationPassingStyle : public PostAnalysisStep {
+ LogicalResult run(FuncOp funcOp, BufferizationState &state,
+ SmallVector<Operation *> &newOps) override;
+};
+
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
} // namespace scf_ext
auto globalMemref = globalCreator.getGlobalFor(constantOp);
Value memref = b.create<memref::GetGlobalOp>(
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
- state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
state.mapBuffer(constantOp, memref);
return success();
/// Return `true` if a value was marked as in-place bufferized.
bool BufferizationAliasInfo::isInPlace(OpResult opResult) const {
- bool inplace = inplaceBufferized.contains(opResult);
-#ifndef NDEBUG
- if (inplace) {
- auto bufferizableOp =
- dyn_cast<BufferizableOpInterface>(opResult.getDefiningOp());
- assert(bufferizableOp &&
- "expected that in-place bufferized op is bufferizable");
- SmallVector<OpOperand *> operands =
- bufferizableOp.getAliasingOpOperand(opResult);
- for (OpOperand *operand : operands)
- assert(areAliasingBufferizedValues(operand->get(), opResult) &&
- "expected that in-place bufferized OpResult aliases with "
- "aliasing OpOperand");
- }
-#endif // NDEBUG
- return inplace;
+ return inplaceBufferized.contains(opResult);
}
/// Set the inPlace bufferization spec to true.
Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
- aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
}
// 2. Create memory deallocation.
return failure();
// Insert mapping and aliasing info.
- state.aliasInfo.createAliasInfoEntry(resultBuffer);
- state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
state.mapBuffer(opResult, resultBuffer);
// Insert new operand and bbArg.
body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
// Insert mapping and aliasing info.
- state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
- state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
- newBufferBBArg);
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
// Set operand of `linalg.yield` to the bbArg so it just canonicalizes
BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
// Insert mapping and aliasing info.
- state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
- state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
- newBufferBBArg);
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
// Increment indices.
BufferizationState &state) {
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
- BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// If nothing to do then we are done.
if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
auto castOp = b.create<memref::CastOp>(
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
toMemrefOp.memref().replaceAllUsesWith(castOp);
- aliasInfo.insertNewBufferEquivalence(castOp.dest(),
- toMemrefOp.memref());
}
}
// Replace all remaining uses by a to_tensor.
if (!bbArg.use_empty()) {
auto toTensorOp =
b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
- aliasInfo.insertNewBufferEquivalence(toTensorOp, bbArg);
bbArg.replaceAllUsesWith(toTensorOp);
}
frontBlock.eraseArgument(0);
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
// Add CallOp operand/result equivalence: this is interprocedural
// info.
- state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
state.mapBuffer(oldRes, buffer);
// Add a ToTensorOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(toTensorOp);
// Add new op equivalence info.
- state.aliasInfo.insertNewBufferEquivalence(toTensorOp, buffer);
state.mapBuffer(toTensorOp, buffer);
continue;
}
Value castBuffer =
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
// Add new op equivalence info.
- state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
state.mapBuffer(tensorOperand, castBuffer);
buffer = castBuffer;
}
Value returnTensor = b.create<bufferization::ToTensorOp>(
returnOp.getLoc(), v);
operand.set(returnTensor);
- state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
state.mapBuffer(returnTensor, v);
}
return success();
: getContiguousOrUnrankedMemRefType(tensorType);
Value bufferCast = b.create<bufferization::ToMemrefOp>(funcOp.getLoc(),
memRefType, bbArg);
- state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
state.mapBuffer(bbArg, bufferCast);
}
if (!resultBuffer)
return failure();
- state.aliasInfo.createAliasInfoEntry(resultBuffer);
state.mapBuffer(opResult, resultBuffer);
}
OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
- state.aliasInfo.createAliasInfoEntry(resultBuffer);
- state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
state.mapBuffer(bbArg, resultBuffer);
state.mapBuffer(opResult, resultBuffer);
}
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
- if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
- bbArg)) {
- // TODO: this could get resolved with copies but it can also turn into
- // swaps so we need to be careful about order of copies.
- return yieldOp->emitError()
- << "Yield operand #" << operand.getOperandNumber()
- << " does not bufferize to an equivalent buffer to the matching"
- << " enclosing scf::for operand";
- }
// Buffers are equivalent so the work is already done and we just yield
// the bbArg so that it later canonicalizes away.
}
};
+LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
+ AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state,
+ SmallVector<Operation *> &newOps) {
+ LogicalResult status = success();
+ funcOp->walk([&](scf::YieldOp yieldOp) {
+ auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
+ if (!forOp)
+ return WalkResult::advance();
+
+ for (OpOperand &operand : yieldOp->getOpOperands()) {
+ auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+ if (!tensorType)
+ continue;
+
+ OpOperand &forOperand = forOp.getOpOperandForResult(
+ forOp->getResult(operand.getOperandNumber()));
+ auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+ if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
+ bbArg)) {
+ // TODO: this could get resolved with copies but it can also turn into
+ // swaps so we need to be careful about order of copies.
+ status =
+ yieldOp->emitError()
+ << "Yield operand #" << operand.getOperandNumber()
+ << " does not bufferize to an equivalent buffer to the matching"
+ << " enclosing scf::for operand";
+ return WalkResult::interrupt();
+ }
+ }
+
+ return WalkResult::advance();
+ });
+ return status;
+}
+
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
scf::YieldOp> {
castOp.getResult().getType(), layout, memorySpace);
Value res =
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
- state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
state.mapBuffer(castOp.getResult(), res);
return success();
}
b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
insertOp.indices());
state.mapBuffer(insertOp, destMemref);
- state.aliasInfo.insertNewBufferAlias(insertOp, destMemref);
return success();
}
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
- // Insert new alias.
- state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
// Copy tensor.
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(),
// TODO: Find a way to enable this step automatically when bufferizing tensor
// dialect ops.
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
+ options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);
// Read from %t1 via alias %e.
%v2 = vector.transfer_read %e[%s], %cst : tensor<?xf32>, vector<5xf32>
- scf.yield %e, %v2 : tensor<?xf32>, vector<5xf32>
+ scf.yield %t2, %v2 : tensor<?xf32>, vector<5xf32>
}
// CHECK: __inplace_results_attr__ = ["true", "false"]
// This loop does not read from %t1. It only writes to it.
// CHECK: scf.for
%r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
- // CHECK: tensor.extract_slice
- // CHECK-SAME: __inplace_results_attr__ = ["true"]
- %e = tensor.extract_slice %t2[%s][%s][1] : tensor<?xf32> to tensor<?xf32>
-
- // Write to %t1 via alias. (Overwrite %t3.)
+ // Write to %t1 via %t2. (Overwrite %t3.)
// CHECK: linalg.generic
// CHECK-SAME: __inplace_results_attr__ = ["true"]
- %o2 = linalg.generic #trait outs (%e : tensor<?xf32>) {
+ %o2 = linalg.generic #trait outs (%t2 : tensor<?xf32>) {
^bb(%0: f32) :
linalg.yield %cst : f32
} -> (tensor<?xf32>)
}
// Use %t3 in some way without reading it, so that it does not get DCE'd.
- // CHECK: linalg.generic
- // CHECK-SAME: __inplace_results_attr__ = ["true"]
+ // CHECK: linalg.generic
+ // CHECK-SAME: __inplace_results_attr__ = ["true"]
%o = linalg.generic #trait outs (%t3 : tensor<?xf32>) {
^bb(%0: f32) :
linalg.yield %cst : f32