AsyncParallelForPass() = default;
AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
- int32_t targetBlockSize) {
+ int32_t minTaskSize) {
this->asyncDispatch = asyncDispatch;
this->numWorkerThreads = numWorkerThreads;
- this->targetBlockSize = targetBlockSize;
+ this->minTaskSize = minTaskSize;
}
void runOnOperation() override;
struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
public:
AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch,
- int32_t numWorkerThreads, int32_t targetBlockSize)
+ int32_t numWorkerThreads, int32_t minTaskSize)
: OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
- numWorkerThreads(numWorkerThreads), targetBlockSize(targetBlockSize) {}
+ numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {}
LogicalResult matchAndRewrite(scf::ParallelOp op,
PatternRewriter &rewriter) const override;
private:
bool asyncDispatch;
int32_t numWorkerThreads;
- int32_t targetBlockSize;
+ int32_t minTaskSize;
};
struct ParallelComputeFunctionType {
// Dispatch parallel compute functions by submitting all async compute tasks
// from a simple for loop in the caller thread.
static void
-doSequantialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
+doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
ParallelComputeFunction ¶llelComputeFunction,
scf::ParallelOp op, Value blockSize, Value blockCount,
const SmallVector<Value> &tripCounts) {
std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
// Target block size from the pass parameters.
- Value targetComputeBlock = b.create<ConstantIndexOp>(targetBlockSize);
+ Value minTaskSizeCst = b.create<ConstantIndexOp>(minTaskSize);
// Compute parallel block size from the parallel problem size:
// blockSize = min(tripCount,
// max(ceil_div(tripCount, maxComputeBlocks),
- // targetComputeBlock))
+ // ceil_div(minTaskSize, bodySize)))
Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
- Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlock);
- Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlock);
+ Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, minTaskSizeCst);
+ Value bs2 = b.create<SelectOp>(bs1, bs0, minTaskSizeCst);
Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);
doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
blockCount, tripCounts);
} else {
- doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
+ doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
blockCount, tripCounts);
}
RewritePatternSet patterns(ctx);
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
- targetBlockSize);
+ minTaskSize);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
return std::make_unique<AsyncParallelForPass>();
}
-std::unique_ptr<Pass>
-mlir::createAsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
- int32_t targetBlockSize) {
+std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch,
+ int32_t numWorkerThreads,
+ int32_t minTaskSize) {
return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
- targetBlockSize);
+ minTaskSize);
}