From 2d45e332ba321a22996a1584af26d568b375b674 Mon Sep 17 00:00:00 2001 From: "tashuang.zk" Date: Mon, 16 Aug 2021 13:41:55 +0200 Subject: [PATCH] [MLIR][DISC] Revise ParallelLoopTilingPass with inbound_check mode Expand ParallelLoopTilingPass with an inbound_check mode. In default mode, the upper bound of the inner loop is from the min op; in inbound_check mode, the upper bound of the inner loop is the step of the outer loop and an additional inbound check will be emitted inside of the inner loop. This was 'FIXME' in the original codes and a typical usage is for GPU backends, thus the outer loop and inner loop can be mapped to blocks/threads in seperate. Differential Revision: https://reviews.llvm.org/D105455 --- mlir/include/mlir/Dialect/SCF/Passes.h | 7 +- mlir/include/mlir/Dialect/SCF/Passes.td | 6 +- mlir/include/mlir/Dialect/SCF/Transforms.h | 3 +- .../Dialect/SCF/Transforms/ParallelLoopTiling.cpp | 93 ++++++++++--- .../SCF/parallel-loop-tiling-inbound-check.mlir | 149 +++++++++++++++++++++ 5 files changed, 237 insertions(+), 21 deletions(-) create mode 100644 mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h index f8ed2c4..df6a272 100644 --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -36,8 +36,13 @@ std::unique_ptr createParallelLoopFusionPass(); std::unique_ptr createParallelLoopSpecializationPass(); /// Creates a pass which tiles innermost parallel loops. +/// If noMinMaxBounds, the upper bound of the inner loop will +/// be a same value among different outter loop iterations, and +/// an additional inbound check will be emitted inside the internal +/// loops. std::unique_ptr -createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}); +createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}, + bool noMinMaxBounds = false); /// Creates a pass which folds arith ops on induction variable into /// loop range. diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td index 5e2a3a8..44a7617 100644 --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -47,7 +47,11 @@ def SCFParallelLoopTiling : FunctionPass<"parallel-loop-tiling"> { let options = [ ListOption<"tileSizes", "parallel-loop-tile-sizes", "int64_t", "Factors to tile parallel loops by", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + Option<"noMinMaxBounds", "no-min-max-bounds", "bool", + /*default=*/"false", + "Perform tiling with fixed upper bound with inbound check " + "inside the internal loops"> ]; let dependentDialects = ["AffineDialect"]; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h index 5cb816c..603fd3f 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -87,7 +87,8 @@ LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, scf::IfOp &ifOp); /// The function returns the resulting ParallelOps, i.e. {outer_loop_op, /// inner_loop_op}. std::pair -tileParallelLoop(ParallelOp op, llvm::ArrayRef tileSizes); +tileParallelLoop(ParallelOp op, llvm::ArrayRef tileSizes, + bool noMinMaxBounds); /// Populates patterns for SCF structural type conversions and sets up the /// provided ConversionTarget with the appropriate legality configuration for diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp index 8282c07..af8dcaf 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -33,12 +33,25 @@ using namespace mlir::scf; /// min(%arg5*tileSize[1], %arg3-%i1)) /// step (%arg4, %arg5) /// +/// or, when no-min-max-bounds is true, into +/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) +/// step (%arg4*tileSize[0], +/// %arg5*tileSize[1]) +/// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0], +/// %arg5*tileSize[1]) +/// step (%arg4, %arg5) +/// %inbound = (%j0 * %arg4 + %i0 < %arg2) && +/// (%j1 * %arg5 + %i1 < %arg3) +/// scf.if (%inbound) +/// .... +/// /// where the uses of %i0 and %i1 in the loop body are replaced by /// %i0 + j0 and %i1 + %j1. // /// The old loop is replaced with the new one. std::pair -mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes) { +mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes, + bool noMinMaxBounds) { OpBuilder b(op); auto zero = b.create(op.getLoc(), 0); SmallVector tileSizeConstants; @@ -64,8 +77,6 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes) { b.setInsertionPointToStart(outerLoop.getBody()); // Compute min(size, dim - offset) to avoid out-of-bounds accesses. - // FIXME: Instead of using min, we want to replicate the tail. This would give - // the inner loop constant bounds for easy vectorization. auto minMap = AffineMap::get( /*dimCount=*/3, /*symbolCount=*/0, {getAffineDimExpr(/*position=*/0, b.getContext()), @@ -76,6 +87,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes) { // Create the inner loop with adjusted bounds. SmallVector newBounds; newBounds.reserve(op.upperBound().size()); + bool needInboundCheck = false; for (auto dim : llvm::zip(outerLoop.lowerBound(), outerLoop.upperBound(), outerLoop.step(), outerLoop.getInductionVars(), op.step(), tileSizeConstants)) { @@ -101,6 +113,14 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes) { continue; } } + + // For InboundCheck mode, just use the variable outer step + if (noMinMaxBounds) { + newBounds.push_back(newStep); + needInboundCheck = true; + continue; + } + // Otherwise, we dynamically compute the bound for // each iteration of the outer loop. newBounds.push_back( @@ -111,17 +131,51 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes) { op.getLoc(), SmallVector(newBounds.size(), zero), newBounds, op.step()); - // Steal the body of the old parallel loop and erase it. - innerLoop.region().takeBody(op.region()); - - // Insert computation for new index vectors and replace uses. - b.setInsertionPointToStart(innerLoop.getBody()); - for (auto ivs : - llvm::zip(innerLoop.getInductionVars(), outerLoop.getInductionVars())) { - Value inner_index = std::get<0>(ivs); - AddIOp newIndex = - b.create(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs)); - inner_index.replaceAllUsesExcept(newIndex, newIndex); + if (noMinMaxBounds && needInboundCheck) { + b.setInsertionPointToStart(innerLoop.getBody()); + // Insert in-bound check + Value inbound = + b.create(op.getLoc(), b.getIntegerType(1), + b.getIntegerAttr(b.getIntegerType(1), 1)); + for (auto dim : + llvm::zip(outerLoop.upperBound(), outerLoop.getInductionVars(), + innerLoop.getInductionVars(), innerLoop.step())) { + Value outerUpperBound, outerIV, innerIV, innerStep; + std::tie(outerUpperBound, outerIV, innerIV, innerStep) = dim; + // %in_bound = %in_bound && + // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound) + Value index = b.create( + op.getLoc(), b.create(op.getLoc(), innerIV, innerStep), + outerIV); + Value dimInbound = b.create(op.getLoc(), CmpIPredicate::ult, + index, outerUpperBound); + inbound = b.create(op.getLoc(), inbound, dimInbound); + } + auto ifInbound = b.create(op.getLoc(), + /*resultTypes*/ ArrayRef{}, inbound, + /*hasElseRegion*/ false); + ifInbound.thenRegion().takeBody(op.region()); + Block &thenBlock = ifInbound.thenRegion().front(); + b.setInsertionPointToStart(innerLoop.getBody()); + for (auto ivs : llvm::enumerate(llvm::zip(innerLoop.getInductionVars(), + outerLoop.getInductionVars()))) { + AddIOp newIndex = b.create(op.getLoc(), std::get<0>(ivs.value()), + std::get<1>(ivs.value())); + thenBlock.getArgument(ivs.index()) + .replaceAllUsesExcept(newIndex, newIndex); + } + thenBlock.eraseArguments(llvm::to_vector<4>( + llvm::seq((unsigned)0, thenBlock.getNumArguments()))); + } else { + innerLoop.region().takeBody(op.region()); + b.setInsertionPointToStart(innerLoop.getBody()); + for (auto ivs : llvm::zip(innerLoop.getInductionVars(), + outerLoop.getInductionVars())) { + Value innerIndex = std::get<0>(ivs); + AddIOp newIndex = + b.create(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs)); + innerIndex.replaceAllUsesExcept(newIndex, newIndex); + } } op.erase(); @@ -132,8 +186,10 @@ namespace { struct ParallelLoopTiling : public SCFParallelLoopTilingBase { ParallelLoopTiling() = default; - explicit ParallelLoopTiling(ArrayRef tileSizes) { + explicit ParallelLoopTiling(ArrayRef tileSizes, + bool noMinMaxBounds = false) { this->tileSizes = tileSizes; + this->noMinMaxBounds = noMinMaxBounds; } void runOnFunction() override { @@ -142,13 +198,14 @@ struct ParallelLoopTiling for (ParallelOp ploop : innermostPloops) { // FIXME: Add reduction support. if (ploop.getNumReductions() == 0) - tileParallelLoop(ploop, tileSizes); + tileParallelLoop(ploop, tileSizes, noMinMaxBounds); } } }; } // namespace std::unique_ptr -mlir::createParallelLoopTilingPass(ArrayRef tileSizes) { - return std::make_unique(tileSizes); +mlir::createParallelLoopTilingPass(ArrayRef tileSizes, + bool noMinMaxBounds) { + return std::make_unique(tileSizes, noMinMaxBounds); } diff --git a/mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir new file mode 100644 index 0000000..8f395c3 --- /dev/null +++ b/mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir @@ -0,0 +1,149 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.func(parallel-loop-tiling{parallel-loop-tile-sizes=1,4 no-min-max-bounds=true})' -split-input-file | FileCheck %s + +func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index, %arg5 : index, + %A: memref, %B: memref, + %C: memref, %result: memref) { + scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { + %B_elem = memref.load %B[%i0, %i1] : memref + %C_elem = memref.load %C[%i0, %i1] : memref + %sum_elem = addf %B_elem, %C_elem : f32 + memref.store %sum_elem, %result[%i0, %i1] : memref + } + return +} + +// CHECK-LABEL: func @parallel_loop( +// CHECK-SAME: [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index, [[ARG7:%.*]]: memref, [[ARG8:%.*]]: memref, [[ARG9:%.*]]: memref, [[ARG10:%.*]]: memref) { +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[V1:%.*]] = muli [[ARG5]], [[C1]] : index +// CHECK: [[V2:%.*]] = muli [[ARG6]], [[C4]] : index +// CHECK: scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[ARG1]], [[ARG2]]) to ([[ARG3]], [[ARG4]]) step ([[V1]], [[V2]]) { +// CHECK: scf.parallel ([[V7:%.*]], [[V8:%.*]]) = ([[C0]], [[C0]]) to ([[V1]], [[V2]]) step ([[ARG5]], [[ARG6]]) { +// CHECK: [[V9:%.*]] = addi [[V7]], [[V3]] : index +// CHECK: [[V10:%.*]] = addi [[V8]], [[V4]] : index +// CHECK: %true = constant true +// CHECK: [[V11:%.*]] = muli [[V7]], [[ARG5]] : index +// CHECK: [[V12:%.*]] = addi [[V11]], [[V3]] : index +// CHECK: [[V13:%.*]] = cmpi ult, [[V12]], [[ARG3]] : index +// CHECK: [[V14:%.*]] = and %true, [[V13]] : i1 +// CHECK: [[V15:%.*]] = muli [[V8]], [[ARG6]] : index +// CHECK: [[V16:%.*]] = addi [[V15]], [[V4]] : index +// CHECK: [[V17:%.*]] = cmpi ult, [[V16]], [[ARG4]] : index +// CHECK: [[V18:%.*]] = and [[V14]], [[V17]] : i1 +// CHECK: scf.if [[V18]] { +// CHECK: [[V19:%.*]] = memref.load [[ARG8]]{{\[}}[[V9]], [[V10]]] : memref +// CHECK: [[V20:%.*]] = memref.load [[ARG9]]{{\[}}[[V9]], [[V10]]] : memref +// CHECK: [[V21:%.*]] = addf [[V19]], [[V20]] : f32 +// CHECK: memref.store [[V21]], [[ARG10]]{{\[}}[[V9]], [[V10]]] : memref +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return + +// ----- + +func @static_loop_with_step() { + %c0 = constant 0 : index + %c3 = constant 3 : index + %c22 = constant 22 : index + %c24 = constant 24 : index + scf.parallel (%i0, %i1) = (%c0, %c0) to (%c22, %c24) step (%c3, %c3) { + } + return +} + +// CHECK-LABEL: func @static_loop_with_step() { +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C3:%.*]] = constant 3 : index +// CHECK: [[C22:%.*]] = constant 22 : index +// CHECK: [[C24:%.*]] = constant 24 : index +// CHECK: [[C0_1:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[V1:%.*]] = muli [[C3]], [[C1]] : index +// CHECK: [[V2:%.*]] = muli [[C3]], [[C4]] : index +// CHECK: scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[C0]], [[C0]]) to ([[C22]], [[C24]]) step ([[V1]], [[V2]]) { +// CHECK: scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V1]], [[V2]]) step ([[C3]], [[C3]]) { +// CHECK-NOT: scf.if +// CHECK: = addi [[V5]], [[V3]] : index +// CHECK: = addi [[V6]], [[V4]] : index +// CHECK: } +// CHECK: } +// CHECK: return + +// ----- + +func @tile_nested_innermost() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + } + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + return +} + +// CHECK-LABEL: func @tile_nested_innermost() { +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: scf.parallel ([[V1:%.*]], [[V2:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[C0_1:%.*]] = constant 0 : index +// CHECK: [[C1_1:%.*]] = constant 1 : index +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[V3:%.*]] = muli [[C1]], [[C1_1]] : index +// CHECK: [[V4:%.*]] = muli [[C1]], [[C4]] : index +// CHECK: scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V3]], [[V4]]) { +// CHECK: scf.parallel ([[V8:%.*]], [[V9:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V3]], [[V4]]) step ([[C1]], [[C1]]) { +// CHECK: = addi [[V8]], [[V5]] : index +// CHECK: = addi [[V9]], [[V6]] : index +// CHECK: scf.if +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: [[C0_2:%.*]] = constant 0 : index +// CHECK: [[C1_2:%.*]] = constant 1 : index +// CHECK: [[C4_1:%.*]] = constant 4 : index +// CHECK: [[V10:%.*]] = muli [[C1]], [[C1_2]] : index +// CHECK: [[V11:%.*]] = muli [[C1]], [[C4_1]] : index +// CHECK: scf.parallel ([[V12:%.*]], [[V13:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V10]], [[V11]]) { +// CHECK: scf.parallel ([[V15:%.*]], [[V16:%.*]]) = ([[C0_2]], [[C0_2]]) to ([[V10]], [[V11]]) step ([[C1]], [[C1]]) { +// CHECK: = addi [[V15]], [[V12]] : index +// CHECK: = addi [[V16]], [[V13]] : index +// CHECK: scf.if +// CHECK: } +// CHECK: } +// CHECK: return +// CHECK: } + +// ----- + +func @tile_nested_in_non_ploop() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + scf.for %i = %c0 to %c2 step %c1 { + scf.for %j = %c0 to %c2 step %c1 { + scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + } + } + return +} + +// CHECK-LABEL: func @tile_nested_in_non_ploop +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.parallel +// CHECK: scf.parallel +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } -- 2.7.4