From c1194c2ec35029f96ce75ab54555dccf2b7e8681 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 29 Jun 2021 12:56:15 -0700 Subject: [PATCH] [mlir:Async] Change async-parallel-for block size/count calculation Depends On D105037 Avoid creating too many tasks when the number of workers is large. Reviewed By: herhut Differential Revision: https://reviews.llvm.org/D105126 --- .../Dialect/Async/Transforms/AsyncParallelFor.cpp | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index a104fb7..373ee8b 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -653,9 +653,19 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, for (size_t i = 1; i < tripCounts.size(); ++i) tripCount = b.create(tripCount, tripCounts[i]); + // With large number of threads the value of creating many compute blocks + // is reduced because the problem typically becomes memory bound. For small + // number of threads it helps with stragglers. + float overshardingFactor = numWorkerThreads <= 4 ? 8.0 + : numWorkerThreads <= 8 ? 4.0 + : numWorkerThreads <= 16 ? 2.0 + : numWorkerThreads <= 32 ? 1.0 + : numWorkerThreads <= 64 ? 0.8 + : 0.6; + // Do not overload worker threads with too many compute blocks. - Value maxComputeBlocks = - b.create(numWorkerThreads * kMaxOversharding); + Value maxComputeBlocks = b.create( + std::max(1, static_cast(numWorkerThreads * overshardingFactor))); // Target block size from the pass parameters. Value targetComputeBlockSize = b.create(targetBlockSize); @@ -668,7 +678,11 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, Value bs1 = b.create(CmpIPredicate::sge, bs0, targetComputeBlockSize); Value bs2 = b.create(bs1, bs0, targetComputeBlockSize); Value bs3 = b.create(CmpIPredicate::sle, tripCount, bs2); - Value blockSize = b.create(bs3, tripCount, bs2); + Value blockSize0 = b.create(bs3, tripCount, bs2); + Value blockCount0 = b.create(tripCount, blockSize0); + + // Compute balanced block size for the estimated block count. + Value blockSize = b.create(tripCount, blockCount0); Value blockCount = b.create(tripCount, blockSize); // Create a parallel compute function that takes a block id and computes the -- 2.7.4