for (size_t i = 1; i < tripCounts.size(); ++i)
tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
+ // With large number of threads the value of creating many compute blocks
+ // is reduced because the problem typically becomes memory bound. For small
+ // number of threads it helps with stragglers.
+ float overshardingFactor = numWorkerThreads <= 4 ? 8.0
+ : numWorkerThreads <= 8 ? 4.0
+ : numWorkerThreads <= 16 ? 2.0
+ : numWorkerThreads <= 32 ? 1.0
+ : numWorkerThreads <= 64 ? 0.8
+ : 0.6;
+
// Do not overload worker threads with too many compute blocks.
- Value maxComputeBlocks =
- b.create<ConstantIndexOp>(numWorkerThreads * kMaxOversharding);
+ Value maxComputeBlocks = b.create<ConstantIndexOp>(
+ std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
// Target block size from the pass parameters.
Value targetComputeBlockSize = b.create<ConstantIndexOp>(targetBlockSize);
Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlockSize);
Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
- Value blockSize = b.create<SelectOp>(bs3, tripCount, bs2);
+ Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
+ Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);
+
+ // Compute balanced block size for the estimated block count.
+ Value blockSize = b.create<SignedCeilDivIOp>(tripCount, blockCount0);
Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);
// Create a parallel compute function that takes a block id and computes the