--- /dev/null
+//===- ParallelLoopTiling.cpp - Tiles loop.parallel ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop tiling on parallel loops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/LoopOps/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace mlir;
+using loop::ParallelOp;
+
+/// Tile a parallel loop of the form
+/// loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+/// step (%arg4, %arg5)
+///
+/// into
+/// loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+/// step (%arg4*tileSize[0],
+/// %arg5*tileSize[1])
+/// loop.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%j0)
+/// min(tileSize[1], %arg3-%j1))
+/// step (%arg4, %arg5)
+/// The old loop is replaced with the new one.
+static void tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
+ OpBuilder b(op);
+ auto zero = b.create<ConstantIndexOp>(op.getLoc(), 0);
+ SmallVector<Value, 2> tileSizeConstants;
+ tileSizeConstants.reserve(op.upperBound().size());
+ for (size_t i = 0, end = op.upperBound().size(); i != end; ++i) {
+ if (i < tileSizes.size())
+ tileSizeConstants.push_back(
+ b.create<ConstantIndexOp>(op.getLoc(), tileSizes[i]));
+ else
+ // Just pick 1 for the remaining dimensions.
+ tileSizeConstants.push_back(b.create<ConstantIndexOp>(op.getLoc(), 1));
+ }
+
+ // Create the outer loop with adjusted steps.
+ SmallVector<Value, 2> newSteps;
+ newSteps.reserve(op.step().size());
+ for (auto step : llvm::zip(op.step(), tileSizeConstants)) {
+ newSteps.push_back(
+ b.create<MulIOp>(op.getLoc(), std::get<0>(step), std::get<1>(step)));
+ }
+ auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.lowerBound(),
+ op.upperBound(), newSteps);
+ b.setInsertionPointToStart(outerLoop.getBody());
+
+ // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
+ // FIXME: Instead of using min, we want to replicate the tail. This would give
+ // the inner loop constant bounds for easy vectorization.
+ auto minMap = AffineMap::get(
+ /*dimCount=*/3, /*symbolCount=*/0,
+ {getAffineDimExpr(/*position=*/0, b.getContext()),
+ getAffineDimExpr(/*position=*/1, b.getContext()) -
+ getAffineDimExpr(/*position=*/2, b.getContext())});
+
+ // Create the inner loop with adjusted bounds.
+ SmallVector<Value, 2> newBounds;
+ newBounds.reserve(op.upperBound().size());
+ for (auto bounds : llvm::zip(tileSizeConstants, outerLoop.upperBound(),
+ outerLoop.getInductionVars())) {
+ newBounds.push_back(b.create<AffineMinOp>(
+ op.getLoc(), b.getIndexType(), minMap,
+ ValueRange{std::get<0>(bounds), std::get<1>(bounds),
+ std::get<2>(bounds)}));
+ }
+ auto innerLoop = b.create<ParallelOp>(
+ op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
+ op.step());
+
+ // Steal the body of the old parallel loop and erase it.
+ innerLoop.region().takeBody(op.region());
+ op.erase();
+}
+
+/// Get a list of most nested parallel loops. Assumes that ParallelOps are only
+/// directly nested.
+static bool getInnermostNestedLoops(Block *block,
+ SmallVectorImpl<ParallelOp> &loops) {
+ bool hasInnerLoop = false;
+ for (auto parallelOp : block->getOps<ParallelOp>()) {
+ hasInnerLoop = true;
+ if (!getInnermostNestedLoops(parallelOp.getBody(), loops))
+ loops.push_back(parallelOp);
+ }
+ return hasInnerLoop;
+}
+
+namespace {
+struct ParallelLoopTiling : public FunctionPass<ParallelLoopTiling> {
+ ParallelLoopTiling() = default;
+ ParallelLoopTiling(const ParallelLoopTiling &) {} // tileSize is non-copyable.
+ explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes) {
+ this->tileSizes = tileSizes;
+ }
+
+ void runOnFunction() override {
+ SmallVector<ParallelOp, 2> mostNestedParallelOps;
+ for (Block &block : getFunction()) {
+ getInnermostNestedLoops(&block, mostNestedParallelOps);
+ }
+ for (ParallelOp pLoop : mostNestedParallelOps) {
+ tileParallelLoop(pLoop, tileSizes);
+ }
+ }
+
+ ListOption<int64_t> tileSizes{
+ *this, "parallel-loop-tile-sizes",
+ llvm::cl::desc("factors to tile parallel loops by"), llvm::cl::ZeroOrMore,
+ llvm::cl::MiscFlags::CommaSeparated};
+};
+} // namespace
+
+std::unique_ptr<Pass>
+mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes) {
+ return std::make_unique<ParallelLoopTiling>(tileSizes);
+}
+
+static PassRegistration<ParallelLoopTiling> pass("parallel-loop-tiling",
+ "Tile parallel loops.");
--- /dev/null
+// RUN: mlir-opt %s -pass-pipeline='func(parallel-loop-tiling{parallel-loop-tile-sizes=1,4})' -split-input-file | FileCheck %s --dump-input-on-failure
+
+func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index, %arg5 : index,
+ %A: memref<?x?xf32>, %B: memref<?x?xf32>,
+ %C: memref<?x?xf32>, %result: memref<?x?xf32>) {
+ loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
+ %B_elem = load %B[%i0, %i1] : memref<?x?xf32>
+ %C_elem = load %C[%i0, %i1] : memref<?x?xf32>
+ %sum_elem = addf %B_elem, %C_elem : f32
+ store %sum_elem, %result[%i0, %i1] : memref<?x?xf32>
+ }
+ return
+}
+
+// CHECK: #map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
+// CHECK-LABEL: func @parallel_loop(
+// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index, [[VAL_5:%.*]]: index, [[VAL_6:%.*]]: memref<?x?xf32>, [[VAL_7:%.*]]: memref<?x?xf32>, [[VAL_8:%.*]]: memref<?x?xf32>, [[VAL_9:%.*]]: memref<?x?xf32>) {
+// CHECK: [[VAL_10:%.*]] = constant 0 : index
+// CHECK: [[VAL_11:%.*]] = constant 1 : index
+// CHECK: [[VAL_12:%.*]] = constant 4 : index
+// CHECK: [[VAL_13:%.*]] = muli [[VAL_4]], [[VAL_11]] : index
+// CHECK: [[VAL_14:%.*]] = muli [[VAL_5]], [[VAL_12]] : index
+// CHECK: loop.parallel ([[VAL_15:%.*]], [[VAL_16:%.*]]) = ([[VAL_0]], [[VAL_1]]) to ([[VAL_2]], [[VAL_3]]) step ([[VAL_13]], [[VAL_14]]) {
+// CHECK: [[VAL_17:%.*]] = affine.min #map0([[VAL_11]], [[VAL_2]], [[VAL_15]])
+// CHECK: [[VAL_18:%.*]] = affine.min #map0([[VAL_12]], [[VAL_3]], [[VAL_16]])
+// CHECK: loop.parallel ([[VAL_19:%.*]], [[VAL_20:%.*]]) = ([[VAL_10]], [[VAL_10]]) to ([[VAL_17]], [[VAL_18]]) step ([[VAL_4]], [[VAL_5]]) {
+// CHECK: [[VAL_21:%.*]] = load [[VAL_7]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
+// CHECK: [[VAL_22:%.*]] = load [[VAL_8]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
+// CHECK: [[VAL_23:%.*]] = addf [[VAL_21]], [[VAL_22]] : f32
+// CHECK: store [[VAL_23]], [[VAL_9]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: return
+
+// -----
+
+func @tile_nested_innermost() {
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ }
+ }
+ loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ }
+ return
+}
+
+// CHECK-LABEL: func @tile_nested_innermost() {
+// CHECK: [[VAL_24:%.*]] = constant 2 : index
+// CHECK: [[VAL_25:%.*]] = constant 0 : index
+// CHECK: [[VAL_26:%.*]] = constant 1 : index
+// CHECK: loop.parallel ([[VAL_27:%.*]], [[VAL_28:%.*]]) = ([[VAL_25]], [[VAL_25]]) to ([[VAL_24]], [[VAL_24]]) step ([[VAL_26]], [[VAL_26]]) {
+// CHECK: [[VAL_29:%.*]] = constant 0 : index
+// CHECK: [[VAL_30:%.*]] = constant 1 : index
+// CHECK: [[VAL_31:%.*]] = constant 4 : index
+// CHECK: [[VAL_32:%.*]] = muli [[VAL_26]], [[VAL_30]] : index
+// CHECK: [[VAL_33:%.*]] = muli [[VAL_26]], [[VAL_31]] : index
+// CHECK: loop.parallel ([[VAL_34:%.*]], [[VAL_35:%.*]]) = ([[VAL_25]], [[VAL_25]]) to ([[VAL_24]], [[VAL_24]]) step ([[VAL_32]], [[VAL_33]]) {
+// CHECK: [[VAL_36:%.*]] = affine.min #map0([[VAL_30]], [[VAL_24]], [[VAL_34]])
+// CHECK: [[VAL_37:%.*]] = affine.min #map0([[VAL_31]], [[VAL_24]], [[VAL_35]])
+// CHECK: loop.parallel ([[VAL_38:%.*]], [[VAL_39:%.*]]) = ([[VAL_29]], [[VAL_29]]) to ([[VAL_36]], [[VAL_37]]) step ([[VAL_26]], [[VAL_26]]) {
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: [[VAL_40:%.*]] = constant 0 : index
+// CHECK: [[VAL_41:%.*]] = constant 1 : index
+// CHECK: [[VAL_42:%.*]] = constant 4 : index
+// CHECK: [[VAL_43:%.*]] = muli [[VAL_26]], [[VAL_41]] : index
+// CHECK: [[VAL_44:%.*]] = muli [[VAL_26]], [[VAL_42]] : index
+// CHECK: loop.parallel ([[VAL_45:%.*]], [[VAL_46:%.*]]) = ([[VAL_25]], [[VAL_25]]) to ([[VAL_24]], [[VAL_24]]) step ([[VAL_43]], [[VAL_44]]) {
+// CHECK: [[VAL_47:%.*]] = affine.min #map0([[VAL_41]], [[VAL_24]], [[VAL_45]])
+// CHECK: [[VAL_48:%.*]] = affine.min #map0([[VAL_42]], [[VAL_24]], [[VAL_46]])
+// CHECK: loop.parallel ([[VAL_49:%.*]], [[VAL_50:%.*]]) = ([[VAL_40]], [[VAL_40]]) to ([[VAL_47]], [[VAL_48]]) step ([[VAL_26]], [[VAL_26]]) {
+// CHECK: }
+// CHECK: }
+// CHECK: return
+// CHECK: }