static LogicalResult verify(TiledLoopOp op) { return success(); }
+namespace {
+
+// Folds away TiledLoopOp output tensors when the following conditions are met:
+// * result of `linalg.tiled_loop` has no uses
+// * output tensor is the argument of `linalg.yield`
+//
+// Example:
+//
+// %0 = linalg.tiled_loop ... outs (%out, %out_buf:tensor<...>, memref<...>) {
+// ...
+// linalg.yield %out : tensor ...
+// }
+//
+// Becomes
+//
+// linalg.tiled_loop ... outs (%out_buf:memref<...>) {
+// ...
+// linalg.yield
+// }
+struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
+ using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
+ PatternRewriter &rewriter) const final {
+ if (tiledLoop.getNumResults() == 0)
+ return failure();
+
+ Block *block = tiledLoop.getBody();
+ auto yieldOp = cast<linalg::YieldOp>(block->getTerminator());
+
+ // Match the pattern and collect output buffers that will replace the output
+ // tensors and also the ops that will be ignored when cloning the body.
+ SmallVector<Value, 2> newOutputOperands, newYieldArgs;
+ int resultId = 0;
+ for (Value out : tiledLoop.outputs()) {
+ if (!out.getType().isa<RankedTensorType>()) {
+ newOutputOperands.push_back(out);
+ continue;
+ }
+ Value result = tiledLoop.getResult(resultId);
+ Value yieldArg = yieldOp.getOperand(resultId);
+ if (yieldArg != out || !result.use_empty()) {
+ newOutputOperands.push_back(out);
+ newYieldArgs.push_back(yieldArg);
+ }
+ ++resultId;
+ }
+ if (newOutputOperands.size() == tiledLoop.outputs().size())
+ return failure();
+
+ Location loc = tiledLoop.getLoc();
+ auto newTiledLoop = rewriter.create<TiledLoopOp>(
+ loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
+ tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types());
+
+ // Clone the region ignoring the def-chain for linalg.yield args:
+ // unnecessary `subtensor_insert`, `tensor_load` and `cast` ops.
+ BlockAndValueMapping bvm;
+ bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+ OpBuilder innerBuilder =
+ OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
+ for (auto &op : tiledLoop.getBody()->without_terminator())
+ innerBuilder.clone(op, bvm);
+ innerBuilder.create<linalg::YieldOp>(loc, newYieldArgs);
+ rewriter.eraseOp(tiledLoop);
+
+ return success();
+ }
+};
+} // namespace
+
+void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<TiledLoopResultsFolder>(context);
+}
+
+LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+
/////// Operations corresponding to library calls defined with Tablegen ////////
template <typename LinalgPoolingOp>
// CHECK: return %[[FILL]] : tensor<6x4xf32>
return %reshape : tensor<6x4xf32>
}
+
+// -----
+
+#map0 = affine_map<(d0) -> (24, -d0 + 192)>
+#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
+#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+
+func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
+ %C: memref<192x192xf32>) -> ()
+
+func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
+ %C: memref<192x192xf32>,
+ %C_tensor: tensor<192x192xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ %c24 = constant 24 : index
+ %c16 = constant 16 : index
+ %c0 = constant 0 : index
+ %c192 = constant 192 : index
+ %useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
+ step (%c24, %c16)
+ ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>)
+ outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) {
+ call @foo(%A, %B, %C) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
+ linalg.yield %C_tensor : tensor<192x192xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @fold_tiled_loop_results(
+// CHECK-SAME: %[[A:.*]]: [[TY:.*]], %[[B:.*]]: [[TY]], %[[C:.*]]: [[TY]],
+// CHECK-SAME: %[[C_TENSOR:.*]]: tensor<{{.*}}>) {
+// CHECK: %[[C24:.*]] = constant 24 : index
+// CHECK: %[[C16:.*]] = constant 16 : index
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C192:.*]] = constant 192 : index
+
+// CHECK-NOT: %{{.*}} = linalg.tiled_loop
+// CHECK: linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]])
+// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]])
+// CHECK-SAME: ins (%[[A]], %[[B]]: memref<192x192xf32>, memref<192x192xf32>)
+// CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) {
+// CHECK-NEXT: call @foo(%[[A]], %[[B]], %[[C]])
+// CHECK-NEXT: linalg.yield