struct AsyncParallelForPass
: public AsyncParallelForBase<AsyncParallelForPass> {
AsyncParallelForPass() = default;
+
+ AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
+ int32_t targetBlockSize) {
+ this->asyncDispatch = asyncDispatch;
+ this->numWorkerThreads = numWorkerThreads;
+ this->targetBlockSize = targetBlockSize;
+ }
+
void runOnOperation() override;
};
// Converts one-dimensional iteration index in the [0, tripCount) interval
// into multidimensional iteration coordinate.
static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
- const SmallVector<Value> &tripCounts) {
+ ArrayRef<Value> tripCounts) {
SmallVector<Value> coords(tripCounts.size());
assert(!tripCounts.empty() && "tripCounts must be not empty");
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
ModuleOp module = op->getParentOfType<ModuleOp>();
- b.setInsertionPointToStart(&module->getRegion(0).front());
ParallelComputeFunctionType computeFuncType =
getParallelComputeFunctionType(op, rewriter);
unsigned offset = 0; // argument offset for arguments decoding
- // Load multiple arguments into values vector.
- auto getArguments = [&](unsigned num_arguments) -> SmallVector<Value> {
- SmallVector<Value> values(num_arguments);
- for (unsigned i = 0; i < num_arguments; ++i)
- values[i] = block->getArgument(offset++);
- return values;
+ // Returns `numArguments` arguments starting from `offset` and updates offset
+ // by moving forward to the next argument.
+ auto getArguments = [&](unsigned numArguments) -> ArrayRef<Value> {
+ auto args = block->getArguments();
+ auto slice = args.drop_front(offset).take_front(numArguments);
+ offset += numArguments;
+ return {slice.begin(), slice.end()};
};
// Block iteration position defined by the block index and size.
Value blockSize = block->getArgument(offset++);
// Constants used below.
- Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
- Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
+ Value c0 = b.create<ConstantIndexOp>(0);
+ Value c1 = b.create<ConstantIndexOp>(1);
// Multi-dimensional parallel iteration space defined by the loop trip counts.
- SmallVector<Value> tripCounts = getArguments(op.getNumLoops());
+ ArrayRef<Value> tripCounts = getArguments(op.getNumLoops());
// Compute a product of trip counts to get the size of the flattened
// one-dimensional iteration space.
for (unsigned i = 1; i < tripCounts.size(); ++i)
tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
- // Parallel operation lower bound, upper bound and step.
- SmallVector<Value> lowerBound = getArguments(op.getNumLoops());
- SmallVector<Value> upperBound = getArguments(op.getNumLoops());
- SmallVector<Value> step = getArguments(op.getNumLoops());
+ // Parallel operation lower bound and step.
+ ArrayRef<Value> lowerBound = getArguments(op.getNumLoops());
+ offset += op.getNumLoops(); // skip upper bound arguments
+ ArrayRef<Value> step = getArguments(op.getNumLoops());
// Remaining arguments are implicit captures of the parallel operation.
- SmallVector<Value> captures = getArguments(block->getNumArguments() - offset);
+ ArrayRef<Value> captures = getArguments(block->getNumArguments() - offset);
// Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
// blockFirstIndex = blockIndex * blockSize
Value blockFirstIndex = b.create<MulIOp>(blockIndex, blockSize);
// The last one-dimensional index in the block defined by the `blockIndex`:
- // blockLastIndex = max((blockIndex + 1) * blockSize, tripCount) - 1
- Value blockEnd0 = b.create<AddIOp>(blockIndex, c1);
- Value blockEnd1 = b.create<MulIOp>(blockEnd0, blockSize);
- Value blockEnd2 = b.create<CmpIOp>(CmpIPredicate::sge, blockEnd1, tripCount);
- Value blockEnd3 = b.create<SelectOp>(blockEnd2, tripCount, blockEnd1);
- Value blockLastIndex = b.create<SubIOp>(blockEnd3, c1);
+ // blockLastIndex = max(blockFirstIndex + blockSize, tripCount) - 1
+ Value blockEnd0 = b.create<AddIOp>(blockFirstIndex, blockSize);
+ Value blockEnd1 = b.create<CmpIOp>(CmpIPredicate::sge, blockEnd0, tripCount);
+ Value blockEnd2 = b.create<SelectOp>(blockEnd1, tripCount, blockEnd0);
+ Value blockLastIndex = b.create<SubIOp>(blockEnd2, c1);
// Convert one-dimensional indices to multi-dimensional coordinates.
auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
- // Compute compute loops upper bounds from the block last coordinates:
+ // Compute loops upper bounds derived from the block last coordinates:
// blockEndCoord[i] = blockLastCoord[i] + 1
//
// Block first and last coordinates can be the same along the outer compute
- // dimension when inner compute dimension containts multple blocks.
+ // dimension when inner compute dimension contains multiple blocks.
SmallVector<Value> blockEndCoord(op.getNumLoops());
for (size_t i = 0; i < blockLastCoord.size(); ++i)
blockEndCoord[i] = b.create<AddIOp>(blockLastCoord[i], c1);
isBlockLastCoord[loopIdx] =
nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
- // Check if the previous loop is in its first of last iteration.
+ // Check if the previous loop is in its first or last iteration.
if (loopIdx > 0) {
isBlockFirstCoord[loopIdx] = nb.create<AndOp>(
isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
ImplicitLocOpBuilder b(loc, rewriter);
ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
- b.setInsertionPointToStart(&module->getRegion(0).front());
ArrayRef<Type> computeFuncInputTypes =
computeFunc.func.type().cast<FunctionType>().getInputs();
b.setInsertionPointToEnd(block);
Type indexTy = b.getIndexType();
- Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
- Value c2 = b.create<ConstantOp>(b.getIndexAttr(2));
+ Value c1 = b.create<ConstantIndexOp>(1);
+ Value c2 = b.create<ConstantIndexOp>(2);
// Get the async group that will track async dispatch completion.
Value group = block->getArgument(0);
}
// Setup the async dispatch loop body: recursively call dispatch function
- // for second the half of the original range and go to the next iteration.
+ // for the seconds half of the original range and go to the next iteration.
{
b.setInsertionPointToEnd(after);
Value start = after->getArgument(0);
Value end = after->getArgument(1);
Value distance = b.create<SubIOp>(end, start);
Value halfDistance = b.create<SignedDivIOp>(distance, c2);
- Value midIndex = b.create<AddIOp>(after->getArgument(0), halfDistance);
+ Value midIndex = b.create<AddIOp>(start, halfDistance);
// Call parallel compute function inside the async.execute region.
auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
executeBodyBuilder);
b.create<AddToGroupOp>(indexTy, execute.token(), group);
- b.create<scf::YieldOp>(ValueRange({after->getArgument(0), midIndex}));
+ b.create<scf::YieldOp>(ValueRange({start, midIndex}));
}
// After dispatching async operations to process the tail of the block range
FuncOp asyncDispatchFunction =
createAsyncDispatchFunction(parallelComputeFunction, rewriter);
- Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
- Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
+ Value c0 = b.create<ConstantIndexOp>(0);
+ Value c1 = b.create<ConstantIndexOp>(1);
// Create an async.group to wait on all async tokens from the concurrent
// execution of multiple parallel compute function. First block will be
FuncOp compute = parallelComputeFunction.func;
- Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
- Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
+ Value c0 = b.create<ConstantIndexOp>(0);
+ Value c1 = b.create<ConstantIndexOp>(1);
// Create an async.group to wait on all async tokens from the concurrent
// execution of multiple parallel compute function. First block will be
for (size_t i = 1; i < tripCounts.size(); ++i)
tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
- auto indexTy = b.getIndexType();
-
// Do not overload worker threads with too many compute blocks.
- Value maxComputeBlocks = b.create<ConstantOp>(
- indexTy, b.getIndexAttr(numWorkerThreads * kMaxOversharding));
+ Value maxComputeBlocks =
+ b.create<ConstantIndexOp>(numWorkerThreads * kMaxOversharding);
// Target block size from the pass parameters.
- Value targetComputeBlockSize =
- b.create<ConstantOp>(indexTy, b.getIndexAttr(targetBlockSize));
+ Value targetComputeBlockSize = b.create<ConstantIndexOp>(targetBlockSize);
// Compute parallel block size from the parallel problem size:
// blockSize = min(tripCount,
- // max(divup(tripCount, maxComputeBlocks),
+ // max(ceil_div(tripCount, maxComputeBlocks),
// targetComputeBlockSize))
Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
blockCount, tripCounts);
}
- // Parallel operation was replaces with a block iteration loop.
+ // Parallel operation was replaced with a block iteration loop.
rewriter.eraseOp(op);
return success();
std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
return std::make_unique<AsyncParallelForPass>();
}
+
+std::unique_ptr<Pass>
+mlir::createAsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
+ int32_t targetBlockSize) {
+ return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
+ targetBlockSize);
+}