[mlir][Linalg] Add layout specification support to bufferization.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 13 Jul 2021 10:20:10 +0000 (10:20 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 13 Jul 2021 10:22:18 +0000 (10:22 +0000)
Previously, linalg bufferization always had to be conservative at function boundaries and assume the most dynamic strided memref layout.
This revision introduce the mechanism to specify a  linalg.buffer_layout function argument attribute that carries an affine map used to set a less pessimistic layout.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D105859

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

index 49ececc..ce36323 100644 (file)
@@ -48,6 +48,11 @@ def Linalg_Dialect : Dialect {
     constexpr const static ::llvm::StringLiteral
       kInplaceableAttrName = "linalg.inplaceable";
 
+    /// Attribute name used to mark the bufferization layout for region
+    // arguments during linalg comprehensive bufferization.
+    constexpr const static ::llvm::StringLiteral
+      kBufferLayoutAttrName = "linalg.buffer_layout";
+
     using RegionBuilderFunType =
       llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &)>;
     RegionBuilderFunType getRegionBuilder(StringRef name) {
index be39eec..333a129 100644 (file)
@@ -324,9 +324,11 @@ setInPlaceFuncArgument(BlockArgument bbArg,
 
 /// Remove the attribute that triggers inplace bufferization on a FuncOp
 /// argument `bbArg`.
-static void removeInPlaceFuncArgument(BlockArgument bbArg) {
+static void removeBufferizationFuncArguments(BlockArgument bbArg) {
   auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
   funcOp.removeArgAttr(bbArg.getArgNumber(),
+                       LinalgDialect::kBufferLayoutAttrName);
+  funcOp.removeArgAttr(bbArg.getArgNumber(),
                        LinalgDialect::kInplaceableAttrName);
 }
 
@@ -2608,6 +2610,96 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
   (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
 }
 
+static void
+foreachCaller(const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
+              FuncOp callee, llvm::function_ref<void(Operation *)> doit) {
+  auto itCallers = callerMap.find(callee);
+  if (itCallers == callerMap.end())
+    return;
+  for (Operation *caller : itCallers->second)
+    doit(caller);
+}
+
+/// Postprocess the linalg.buffer_layout annotation across function boundaries.
+/// This is a purely mechanical process that may later become part of a
+/// separate pass with its own layout assignment heuristic.
+static void layoutPostProcessing(ModuleOp moduleOp) {
+  SmallVector<FuncOp> orderedFuncOps;
+  DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
+  auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
+  assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
+
+  for (FuncOp funcOp : orderedFuncOps) {
+    DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
+    foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+      operandsPerCaller.try_emplace(caller, SmallVector<Value>());
+    });
+
+    SmallVector<Type> argumentTypes;
+    // Iterate on each function argument and check it it was marked with a
+    // desired layout.
+    for (auto it : llvm::enumerate(funcOp.getType().getInputs())) {
+      int argNumber = it.index();
+      Type inputType = it.value();
+      auto memrefType = inputType.dyn_cast<MemRefType>();
+      auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
+          argNumber, LinalgDialect::kBufferLayoutAttrName);
+      AffineMap desiredLayoutMap =
+          layoutAttr ? layoutAttr.getValue() : AffineMap();
+      AffineMap currentLayoutMap =
+          memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
+      if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
+        argumentTypes.push_back(inputType);
+        foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+          operandsPerCaller.find(caller)->getSecond().push_back(
+              caller->getOperand(argNumber));
+        });
+        continue;
+      }
+
+      // Compute the buffer type with desired layout and add to input argument
+      // types.
+      MemRefType desiredMemrefType = MemRefType::get(
+          memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
+      argumentTypes.push_back(desiredMemrefType);
+
+      // If funcOp's body is not empty, change the bbArg type and propagate.
+      if (!funcOp.body().empty()) {
+        BlockArgument bbArg = funcOp.getArgument(argNumber);
+        bbArg.setType(desiredMemrefType);
+        OpBuilder b(bbArg.getContext());
+        b.setInsertionPointToStart(bbArg.getOwner());
+        // Cast back to the original memrefType and let it canonicalize.
+        Value cast =
+            b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
+        bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
+      }
+
+      // Cast to desired buffer type on all callers to `funcOp`.
+      // TODO: on the callee side, this may even have to trigger a copy to
+      // change the layout. For now let the memref::CastOp fail to verify in
+      // such cases.
+      auto castArg = [&](Operation *caller) {
+        OpBuilder b(caller);
+        Value newOperand = b.create<memref::CastOp>(
+            funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
+        operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
+      };
+      foreachCaller(callerMap, funcOp, castArg);
+    }
+
+    // Set operands with cast buffer on all callers to `funcOp`.
+    foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+      caller->setOperands(operandsPerCaller.lookup(caller));
+    });
+
+    // Finally set the funcOp type to update the arguments.
+    auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
+                                         funcOp.getType().getResults());
+    funcOp.setType(newFuncType);
+  }
+}
+
 void LinalgComprehensiveModuleBufferize::runOnOperation() {
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);
@@ -2672,12 +2764,16 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     }
   }
 
-  // Post-pass cleanup of inplaceable attributes.
+  // Perform a post-processing pass of layout modification at function boundary
+  // according to the kBufferLayoutAttrName.
+  layoutPostProcessing(moduleOp);
+
+  // Post-pass cleanup of inplaceable and buffer_layout attributes.
   moduleOp.walk(
       [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); });
   moduleOp.walk([&](FuncOp op) {
     for (BlockArgument bbArg : op.getArguments())
-      removeInPlaceFuncArgument(bbArg);
+      removeBufferizationFuncArguments(bbArg);
   });
 
   OpPassManager cleanupPipeline(OpPassManager("module"));
index b29cf6e..56278ef 100644 (file)
@@ -555,3 +555,43 @@ func @tiled_dot(%A: tensor<?xf32>, %B: tensor<?xf32>, %c: tensor<f32> {linalg.in
   // CHECK-NOT: tensor
   return %1 : tensor<f32>
 }
+
+// -----
+
+// CHECK: #[[$DYNAMIC:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK: func private @external_func(memref<?xf32, #[[$DYNAMIC]]>)
+func private @external_func(tensor<?xf32>)
+
+//      CHECK: func @callee(
+// CHECK-SAME:   %[[A:[0-9a-zA-Z]*]]: memref<?xf32>
+// CHECK-SAME:   %[[B:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
+// CHECK-SAME:   %[[C:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
+func @callee(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
+             %B : tensor<?xf32>,
+             %C : tensor<?xf32>) {
+// CHECK-NEXT: %[[CASTED:.*]] = memref.cast %[[A]] : memref<?xf32> to memref<?xf32, #[[$DYNAMIC]]>
+// CHECK-NEXT: call @external_func(%[[CASTED]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
+  call @external_func(%A) : (tensor<?xf32>) -> ()
+
+// CHECK-NEXT: call @external_func(%[[B]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
+  call @external_func(%B) : (tensor<?xf32>) -> ()
+
+// CHECK-NEXT: call @external_func(%[[C]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
+  call @external_func(%C) : (tensor<?xf32>) -> ()
+
+  return
+}
+
+//      CHECK: func @entry(
+// CHECK-SAME:   %[[A:[0-9a-zA-Z]*]]: memref<?xf32>
+// CHECK-SAME:   %[[B:[0-9a-zA-Z]*]]: memref<?xf32>
+// CHECK-SAME:   %[[C:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
+func @entry(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
+            %B : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
+            %C : tensor<?xf32>) {
+// CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref<?xf32> to memref<?xf32, #[[$DYNAMIC]]>
+// CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]])
+  call @callee(%A, %B, %C) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> ()
+  return
+}