[mlir:Async] Change async-parallel-for block size/count calculation
authorEugene Zhulenev <ezhulenev@google.com>
Tue, 29 Jun 2021 19:56:15 +0000 (12:56 -0700)
committerEugene Zhulenev <ezhulenev@google.com>
Tue, 29 Jun 2021 19:57:11 +0000 (12:57 -0700)
Depends On D105037

Avoid creating too many tasks when the number of workers is large.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D105126

mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

index a104fb7..373ee8b 100644 (file)
@@ -653,9 +653,19 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
   for (size_t i = 1; i < tripCounts.size(); ++i)
     tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
 
+  // With large number of threads the value of creating many compute blocks
+  // is reduced because the problem typically becomes memory bound. For small
+  // number of threads it helps with stragglers.
+  float overshardingFactor = numWorkerThreads <= 4    ? 8.0
+                             : numWorkerThreads <= 8  ? 4.0
+                             : numWorkerThreads <= 16 ? 2.0
+                             : numWorkerThreads <= 32 ? 1.0
+                             : numWorkerThreads <= 64 ? 0.8
+                                                      : 0.6;
+
   // Do not overload worker threads with too many compute blocks.
-  Value maxComputeBlocks =
-      b.create<ConstantIndexOp>(numWorkerThreads * kMaxOversharding);
+  Value maxComputeBlocks = b.create<ConstantIndexOp>(
+      std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
 
   // Target block size from the pass parameters.
   Value targetComputeBlockSize = b.create<ConstantIndexOp>(targetBlockSize);
@@ -668,7 +678,11 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
   Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
   Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlockSize);
   Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
-  Value blockSize = b.create<SelectOp>(bs3, tripCount, bs2);
+  Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
+  Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);
+
+  // Compute balanced block size for the estimated block count.
+  Value blockSize = b.create<SignedCeilDivIOp>(tripCount, blockCount0);
   Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);
 
   // Create a parallel compute function that takes a block id and computes the