From e312fc49ae1ec86999676edc9c02a4ac0bc39cec Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 13 Jul 2021 10:20:10 +0000 Subject: [PATCH] [mlir][Linalg] Add layout specification support to bufferization. Previously, linalg bufferization always had to be conservative at function boundaries and assume the most dynamic strided memref layout. This revision introduce the mechanism to specify a linalg.buffer_layout function argument attribute that carries an affine map used to set a less pessimistic layout. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D105859 --- mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td | 5 + .../Linalg/Transforms/ComprehensiveBufferize.cpp | 102 ++++++++++++++++++++- .../Linalg/comprehensive-module-bufferize.mlir | 40 ++++++++ 3 files changed, 144 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index 49ececc..ce36323 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -48,6 +48,11 @@ def Linalg_Dialect : Dialect { constexpr const static ::llvm::StringLiteral kInplaceableAttrName = "linalg.inplaceable"; + /// Attribute name used to mark the bufferization layout for region + // arguments during linalg comprehensive bufferization. + constexpr const static ::llvm::StringLiteral + kBufferLayoutAttrName = "linalg.buffer_layout"; + using RegionBuilderFunType = llvm::function_ref; RegionBuilderFunType getRegionBuilder(StringRef name) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp index be39eec..333a129 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -324,9 +324,11 @@ setInPlaceFuncArgument(BlockArgument bbArg, /// Remove the attribute that triggers inplace bufferization on a FuncOp /// argument `bbArg`. -static void removeInPlaceFuncArgument(BlockArgument bbArg) { +static void removeBufferizationFuncArguments(BlockArgument bbArg) { auto funcOp = cast(bbArg.getOwner()->getParentOp()); funcOp.removeArgAttr(bbArg.getArgNumber(), + LinalgDialect::kBufferLayoutAttrName); + funcOp.removeArgAttr(bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName); } @@ -2608,6 +2610,96 @@ static void applyEnablingTransformations(ModuleOp moduleOp) { (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); } +static void +foreachCaller(const DenseMap> &callerMap, + FuncOp callee, llvm::function_ref doit) { + auto itCallers = callerMap.find(callee); + if (itCallers == callerMap.end()) + return; + for (Operation *caller : itCallers->second) + doit(caller); +} + +/// Postprocess the linalg.buffer_layout annotation across function boundaries. +/// This is a purely mechanical process that may later become part of a +/// separate pass with its own layout assignment heuristic. +static void layoutPostProcessing(ModuleOp moduleOp) { + SmallVector orderedFuncOps; + DenseMap> callerMap; + auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap); + assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure"); + + for (FuncOp funcOp : orderedFuncOps) { + DenseMap> operandsPerCaller; + foreachCaller(callerMap, funcOp, [&](Operation *caller) { + operandsPerCaller.try_emplace(caller, SmallVector()); + }); + + SmallVector argumentTypes; + // Iterate on each function argument and check it it was marked with a + // desired layout. + for (auto it : llvm::enumerate(funcOp.getType().getInputs())) { + int argNumber = it.index(); + Type inputType = it.value(); + auto memrefType = inputType.dyn_cast(); + auto layoutAttr = funcOp.getArgAttrOfType( + argNumber, LinalgDialect::kBufferLayoutAttrName); + AffineMap desiredLayoutMap = + layoutAttr ? layoutAttr.getValue() : AffineMap(); + AffineMap currentLayoutMap = + memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap(); + if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) { + argumentTypes.push_back(inputType); + foreachCaller(callerMap, funcOp, [&](Operation *caller) { + operandsPerCaller.find(caller)->getSecond().push_back( + caller->getOperand(argNumber)); + }); + continue; + } + + // Compute the buffer type with desired layout and add to input argument + // types. + MemRefType desiredMemrefType = MemRefType::get( + memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap); + argumentTypes.push_back(desiredMemrefType); + + // If funcOp's body is not empty, change the bbArg type and propagate. + if (!funcOp.body().empty()) { + BlockArgument bbArg = funcOp.getArgument(argNumber); + bbArg.setType(desiredMemrefType); + OpBuilder b(bbArg.getContext()); + b.setInsertionPointToStart(bbArg.getOwner()); + // Cast back to the original memrefType and let it canonicalize. + Value cast = + b.create(funcOp.getLoc(), memrefType, bbArg); + bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp()); + } + + // Cast to desired buffer type on all callers to `funcOp`. + // TODO: on the callee side, this may even have to trigger a copy to + // change the layout. For now let the memref::CastOp fail to verify in + // such cases. + auto castArg = [&](Operation *caller) { + OpBuilder b(caller); + Value newOperand = b.create( + funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber)); + operandsPerCaller.find(caller)->getSecond().push_back(newOperand); + }; + foreachCaller(callerMap, funcOp, castArg); + } + + // Set operands with cast buffer on all callers to `funcOp`. + foreachCaller(callerMap, funcOp, [&](Operation *caller) { + caller->setOperands(operandsPerCaller.lookup(caller)); + }); + + // Finally set the funcOp type to update the arguments. + auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes, + funcOp.getType().getResults()); + funcOp.setType(newFuncType); + } +} + void LinalgComprehensiveModuleBufferize::runOnOperation() { ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); @@ -2672,12 +2764,16 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() { } } - // Post-pass cleanup of inplaceable attributes. + // Perform a post-processing pass of layout modification at function boundary + // according to the kBufferLayoutAttrName. + layoutPostProcessing(moduleOp); + + // Post-pass cleanup of inplaceable and buffer_layout attributes. moduleOp.walk( [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); }); moduleOp.walk([&](FuncOp op) { for (BlockArgument bbArg : op.getArguments()) - removeInPlaceFuncArgument(bbArg); + removeBufferizationFuncArguments(bbArg); }); OpPassManager cleanupPipeline(OpPassManager("module")); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir index b29cf6e..56278ef 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -555,3 +555,43 @@ func @tiled_dot(%A: tensor, %B: tensor, %c: tensor {linalg.in // CHECK-NOT: tensor return %1 : tensor } + +// ----- + +// CHECK: #[[$DYNAMIC:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK: func private @external_func(memref) +func private @external_func(tensor) + +// CHECK: func @callee( +// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref +func @callee(%A : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>}, + %B : tensor, + %C : tensor) { +// CHECK-NEXT: %[[CASTED:.*]] = memref.cast %[[A]] : memref to memref +// CHECK-NEXT: call @external_func(%[[CASTED]]) : (memref) -> () + call @external_func(%A) : (tensor) -> () + +// CHECK-NEXT: call @external_func(%[[B]]) : (memref) -> () + call @external_func(%B) : (tensor) -> () + +// CHECK-NEXT: call @external_func(%[[C]]) : (memref) -> () + call @external_func(%C) : (tensor) -> () + + return +} + +// CHECK: func @entry( +// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref +func @entry(%A : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>}, + %B : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>}, + %C : tensor) { +// CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref to memref +// CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]]) + call @callee(%A, %B, %C) : (tensor, tensor, tensor) -> () + return +} -- 2.7.4