/// Remove the attribute that triggers inplace bufferization on a FuncOp
/// argument `bbArg`.
-static void removeInPlaceFuncArgument(BlockArgument bbArg) {
+static void removeBufferizationFuncArguments(BlockArgument bbArg) {
auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
funcOp.removeArgAttr(bbArg.getArgNumber(),
+ LinalgDialect::kBufferLayoutAttrName);
+ funcOp.removeArgAttr(bbArg.getArgNumber(),
LinalgDialect::kInplaceableAttrName);
}
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
}
+static void
+foreachCaller(const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
+ FuncOp callee, llvm::function_ref<void(Operation *)> doit) {
+ auto itCallers = callerMap.find(callee);
+ if (itCallers == callerMap.end())
+ return;
+ for (Operation *caller : itCallers->second)
+ doit(caller);
+}
+
+/// Postprocess the linalg.buffer_layout annotation across function boundaries.
+/// This is a purely mechanical process that may later become part of a
+/// separate pass with its own layout assignment heuristic.
+static void layoutPostProcessing(ModuleOp moduleOp) {
+ SmallVector<FuncOp> orderedFuncOps;
+ DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
+ auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
+ assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
+
+ for (FuncOp funcOp : orderedFuncOps) {
+ DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
+ foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+ operandsPerCaller.try_emplace(caller, SmallVector<Value>());
+ });
+
+ SmallVector<Type> argumentTypes;
+ // Iterate on each function argument and check it it was marked with a
+ // desired layout.
+ for (auto it : llvm::enumerate(funcOp.getType().getInputs())) {
+ int argNumber = it.index();
+ Type inputType = it.value();
+ auto memrefType = inputType.dyn_cast<MemRefType>();
+ auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
+ argNumber, LinalgDialect::kBufferLayoutAttrName);
+ AffineMap desiredLayoutMap =
+ layoutAttr ? layoutAttr.getValue() : AffineMap();
+ AffineMap currentLayoutMap =
+ memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
+ if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
+ argumentTypes.push_back(inputType);
+ foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+ operandsPerCaller.find(caller)->getSecond().push_back(
+ caller->getOperand(argNumber));
+ });
+ continue;
+ }
+
+ // Compute the buffer type with desired layout and add to input argument
+ // types.
+ MemRefType desiredMemrefType = MemRefType::get(
+ memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
+ argumentTypes.push_back(desiredMemrefType);
+
+ // If funcOp's body is not empty, change the bbArg type and propagate.
+ if (!funcOp.body().empty()) {
+ BlockArgument bbArg = funcOp.getArgument(argNumber);
+ bbArg.setType(desiredMemrefType);
+ OpBuilder b(bbArg.getContext());
+ b.setInsertionPointToStart(bbArg.getOwner());
+ // Cast back to the original memrefType and let it canonicalize.
+ Value cast =
+ b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
+ bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
+ }
+
+ // Cast to desired buffer type on all callers to `funcOp`.
+ // TODO: on the callee side, this may even have to trigger a copy to
+ // change the layout. For now let the memref::CastOp fail to verify in
+ // such cases.
+ auto castArg = [&](Operation *caller) {
+ OpBuilder b(caller);
+ Value newOperand = b.create<memref::CastOp>(
+ funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
+ operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
+ };
+ foreachCaller(callerMap, funcOp, castArg);
+ }
+
+ // Set operands with cast buffer on all callers to `funcOp`.
+ foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+ caller->setOperands(operandsPerCaller.lookup(caller));
+ });
+
+ // Finally set the funcOp type to update the arguments.
+ auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
+ funcOp.getType().getResults());
+ funcOp.setType(newFuncType);
+ }
+}
+
void LinalgComprehensiveModuleBufferize::runOnOperation() {
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);
}
}
- // Post-pass cleanup of inplaceable attributes.
+ // Perform a post-processing pass of layout modification at function boundary
+ // according to the kBufferLayoutAttrName.
+ layoutPostProcessing(moduleOp);
+
+ // Post-pass cleanup of inplaceable and buffer_layout attributes.
moduleOp.walk(
[&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); });
moduleOp.walk([&](FuncOp op) {
for (BlockArgument bbArg : op.getArguments())
- removeInPlaceFuncArgument(bbArg);
+ removeBufferizationFuncArguments(bbArg);
});
OpPassManager cleanupPipeline(OpPassManager("module"));
// CHECK-NOT: tensor
return %1 : tensor<f32>
}
+
+// -----
+
+// CHECK: #[[$DYNAMIC:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK: func private @external_func(memref<?xf32, #[[$DYNAMIC]]>)
+func private @external_func(tensor<?xf32>)
+
+// CHECK: func @callee(
+// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref<?xf32>
+// CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
+// CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
+func @callee(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
+ %B : tensor<?xf32>,
+ %C : tensor<?xf32>) {
+// CHECK-NEXT: %[[CASTED:.*]] = memref.cast %[[A]] : memref<?xf32> to memref<?xf32, #[[$DYNAMIC]]>
+// CHECK-NEXT: call @external_func(%[[CASTED]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
+ call @external_func(%A) : (tensor<?xf32>) -> ()
+
+// CHECK-NEXT: call @external_func(%[[B]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
+ call @external_func(%B) : (tensor<?xf32>) -> ()
+
+// CHECK-NEXT: call @external_func(%[[C]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
+ call @external_func(%C) : (tensor<?xf32>) -> ()
+
+ return
+}
+
+// CHECK: func @entry(
+// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref<?xf32>
+// CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref<?xf32>
+// CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
+func @entry(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
+ %B : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
+ %C : tensor<?xf32>) {
+// CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref<?xf32> to memref<?xf32, #[[$DYNAMIC]]>
+// CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]])
+ call @callee(%A, %B, %C) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> ()
+ return
+}