From 3f6c0fb2ff750c9246aee41eb8ad086518752edf Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 5 Nov 2021 10:42:59 +0900 Subject: [PATCH] [mlir][linalg][bufferize] Add MemCpyFn to AllocationCallbacks struct This in preparation of decoupling BufferizableOpInterface, Comprehensive Bufferize and dialects. The goal of this CL is to make `getResultBuffer` (and other `bufferize` functions) independent of `LinalgOps`. Differential Revision: https://reviews.llvm.org/D112907 --- .../Transforms/ComprehensiveBufferize.h | 20 +++++++++++++++---- .../Transforms/ComprehensiveBufferize.cpp | 17 +++++++++++----- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h index e3b59d5daa60..94cb52b4bca5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -172,16 +172,28 @@ Optional defaultAllocationFn(OpBuilder &b, Location loc, /// `defaultAllocationFn`. void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer); +/// Default memory copy function that is used by the comprehensive bufferization +/// pass. Creates a `linalg.copy` op. +void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to); + /// Callback functions that are used by the comprehensive bufferization pass to /// allocate/deallocate memory. These default to use the /// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the /// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned /// by the `allocationFn`. struct AllocationCallbacks { - std::function(OpBuilder &b, Location loc, Value shapedValue)> - allocationFn = defaultAllocationFn; - std::function deallocationFn = - defaultDeallocationFn; + using AllocationFn = + std::function(OpBuilder &, Location, Value)>; + using DeallocationFn = std::function; + using MemCpyFn = std::function; + + AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, + MemCpyFn copyFn) + : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {} + + AllocationFn allocationFn; + DeallocationFn deallocationFn; + MemCpyFn memCpyFn; }; /// Bufferize one particular op. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp index dad19a94080d..a6b6e0131d19 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1274,7 +1274,7 @@ static Value getResultBuffer(OpBuilder &b, OpResult result, if (!skipCopy) { // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(op); - b.create(loc, operandBuffer, resultBuffer); + allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer); } return resultBuffer; } @@ -1669,6 +1669,11 @@ void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc, b.create(loc, allocatedBuffer); } +void mlir::linalg::defaultMemCpyFn(OpBuilder &b, Location loc, Value from, + Value to) { + b.create(loc, from, to); +} + LogicalResult mlir::linalg::bufferizeOp( Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, AllocationCallbacks allocationFns, @@ -2258,11 +2263,13 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() { // command line option. So this is set up at the start of the pass. if (useAlloca) { AllocationCallbacks allocaAllocationFns = { - allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}}; + allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}, + defaultMemCpyFn}; allocationFns = std::make_unique(std::move(allocaAllocationFns)); } else { - allocationFns = std::make_unique(); + allocationFns = std::make_unique( + defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn); } } ModuleOp moduleOp = getOperation(); @@ -3222,7 +3229,7 @@ struct ExtractSliceOpInterface if (alloc) { // Do not copy if the copied data is never read. if (isValueRead(extractSliceOp.result())) - b.create(extractSliceOp.getLoc(), subView, alloc); + allocationFn.memCpyFn(b, extractSliceOp.getLoc(), subView, alloc); subView = alloc; } @@ -3344,7 +3351,7 @@ struct InsertSliceOpInterface insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); // Insert new alias. aliasInfo.insertNewBufferAlias(subView, dstMemref); - b.create(insertSliceOp.getLoc(), srcMemref, subView); + allocationFn.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView); } map(bvm, insertSliceOp.result(), dstMemref); -- 2.34.1