[mlir][linalg] Add output tensor args folding for linalg.tiled_loop.
authorAlexander Belyaev <pifon@google.com>
Thu, 25 Mar 2021 17:08:30 +0000 (18:08 +0100)
committerAlexander Belyaev <pifon@google.com>
Thu, 25 Mar 2021 17:11:05 +0000 (18:11 +0100)
Folds away TiledLoopOp output tensors when the following conditions are met:
* result of `linalg.tiled_loop` has no uses
* output tensor is the argument of `linalg.yield`

Example:

```
%0 = linalg.tiled_loop ...  outs (%out, %out_buf:tensor<...>, memref<...>) {
  ...
  linalg.yield %out : tensor ...
}
```

Becomes

```
linalg.tiled_loop ...  outs (%out_buf:memref<...>) {
  ...
  linalg.yield
}
```

Differential Revision: https://reviews.llvm.org/D99333

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index d54efbe..fe67207 100644 (file)
@@ -584,6 +584,9 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
     }
     unsigned getNumLoops() { return step().size(); }
   }];
+
+  let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 
index fdb2e4f..744f027 100644 (file)
@@ -1943,6 +1943,87 @@ bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) {
 
 static LogicalResult verify(TiledLoopOp op) { return success(); }
 
+namespace {
+
+// Folds away TiledLoopOp output tensors when the following conditions are met:
+// * result of `linalg.tiled_loop` has no uses
+// * output tensor is the argument of `linalg.yield`
+//
+// Example:
+//
+// %0 = linalg.tiled_loop ...  outs (%out, %out_buf:tensor<...>, memref<...>) {
+//   ...
+//   linalg.yield %out : tensor ...
+// }
+//
+// Becomes
+//
+// linalg.tiled_loop ...  outs (%out_buf:memref<...>) {
+//   ...
+//   linalg.yield
+// }
+struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
+  using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
+                                PatternRewriter &rewriter) const final {
+    if (tiledLoop.getNumResults() == 0)
+      return failure();
+
+    Block *block = tiledLoop.getBody();
+    auto yieldOp = cast<linalg::YieldOp>(block->getTerminator());
+
+    // Match the pattern and collect output buffers that will replace the output
+    // tensors and also the ops that will be ignored when cloning the body.
+    SmallVector<Value, 2> newOutputOperands, newYieldArgs;
+    int resultId = 0;
+    for (Value out : tiledLoop.outputs()) {
+      if (!out.getType().isa<RankedTensorType>()) {
+        newOutputOperands.push_back(out);
+        continue;
+      }
+      Value result = tiledLoop.getResult(resultId);
+      Value yieldArg = yieldOp.getOperand(resultId);
+      if (yieldArg != out || !result.use_empty()) {
+        newOutputOperands.push_back(out);
+        newYieldArgs.push_back(yieldArg);
+      }
+      ++resultId;
+    }
+    if (newOutputOperands.size() == tiledLoop.outputs().size())
+      return failure();
+
+    Location loc = tiledLoop.getLoc();
+    auto newTiledLoop = rewriter.create<TiledLoopOp>(
+        loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
+        tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types());
+
+    // Clone the region ignoring the def-chain for linalg.yield args:
+    // unnecessary `subtensor_insert`, `tensor_load` and `cast` ops.
+    BlockAndValueMapping bvm;
+    bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+    OpBuilder innerBuilder =
+        OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
+    for (auto &op : tiledLoop.getBody()->without_terminator())
+      innerBuilder.clone(op, bvm);
+    innerBuilder.create<linalg::YieldOp>(loc, newYieldArgs);
+    rewriter.eraseOp(tiledLoop);
+
+    return success();
+  }
+};
+} // namespace
+
+void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
+  results.insert<TiledLoopResultsFolder>(context);
+}
+
+LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
+                                SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+
 /////// Operations corresponding to library calls defined with Tablegen ////////
 
 template <typename LinalgPoolingOp>
index 5ec93dd..44f9dbd 100644 (file)
@@ -818,3 +818,46 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
   // CHECK: return %[[FILL]] : tensor<6x4xf32>
   return %reshape : tensor<6x4xf32>
 }
+
+// -----
+
+#map0 = affine_map<(d0) -> (24, -d0 + 192)>
+#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
+#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+
+func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
+                  %C: memref<192x192xf32>) -> ()
+
+func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
+                              %C: memref<192x192xf32>,
+                              %C_tensor: tensor<192x192xf32>) {
+  %cst = constant 0.000000e+00 : f32
+  %c24 = constant 24 : index
+  %c16 = constant 16 : index
+  %c0 = constant 0 : index
+  %c192 = constant 192 : index
+  %useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
+      step (%c24, %c16)
+      ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>)
+      outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) {
+        call @foo(%A, %B, %C) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
+    linalg.yield %C_tensor : tensor<192x192xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @fold_tiled_loop_results(
+// CHECK-SAME:    %[[A:.*]]: [[TY:.*]], %[[B:.*]]: [[TY]], %[[C:.*]]: [[TY]],
+// CHECK-SAME:    %[[C_TENSOR:.*]]: tensor<{{.*}}>) {
+// CHECK:  %[[C24:.*]] = constant 24 : index
+// CHECK:  %[[C16:.*]] = constant 16 : index
+// CHECK:  %[[C0:.*]] = constant 0 : index
+// CHECK:  %[[C192:.*]] = constant 192 : index
+
+// CHECK-NOT: %{{.*}} = linalg.tiled_loop
+// CHECK:  linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]])
+// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]])
+// CHECK-SAME: ins (%[[A]], %[[B]]: memref<192x192xf32>, memref<192x192xf32>)
+// CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) {
+// CHECK-NEXT:   call @foo(%[[A]], %[[B]], %[[C]])
+// CHECK-NEXT:   linalg.yield