#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/None.h"
SmallVectorImpl<Value> &)>
blockIdGenerator,
SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp) {
+ // Step 0. Target-specific verifications. There is no good place to anchor
+ // those right now: the ForeachThreadOp is target-independent and the
+ // transform op does not apply to individual ForeachThreadOp.
+ MLIRContext *ctx = foreachThreadOp->getContext();
+ Location loc = foreachThreadOp->getLoc();
+ Attribute bX = GPUBlockMappingAttr::get(ctx, Blocks::DimX);
+ Attribute bY = GPUBlockMappingAttr::get(ctx, Blocks::DimY);
+ Attribute bZ = GPUBlockMappingAttr::get(ctx, Blocks::DimZ);
if (foreachThreadOp.getNumResults() > 0)
return transformOp.emitSilenceableError()
- << "only bufferized scf.foreach_thread lowers to gpu.block_id";
+ << "only bufferized scf.foreach_thread lowers to "
+ "gpu.block_id";
if (foreachThreadOp.getNumThreads().size() > 3)
return transformOp.emitSilenceableError()
- << "scf.foreach_thread with rank > 3 does not lower to gpu.block_id";
-
- // Step 0. Outline the compute workload region and set up the workload
- // operands.
- SmallVector<int64_t> mapping;
+ << "scf.foreach_thread with rank > 3 does not lower to "
+ "gpu.block_id";
+ if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
+ return !v.getDefiningOp<arith::ConstantIndexOp>();
+ })) {
+ return transformOp.emitSilenceableError()
+ << "unsupported dynamic griddim size";
+ }
if (!foreachThreadOp.getMapping().has_value())
return transformOp.emitSilenceableError() << "mapping must be present";
- for (DeviceMappingAttrInterface map :
- foreachThreadOp.getMapping()->getValue()) {
- if (auto blockMap = map.dyn_cast<GPUBlockMappingAttr>()) {
- mapping.push_back((int64_t)blockMap.getBlock());
- } else {
- return transformOp.emitSilenceableError()
- << "mapping must be #gpu.block<x/y/z/>";
- }
+ SmallVector<Attribute> blockMapping =
+ llvm::to_vector(foreachThreadOp.getMapping()->getValue());
+ if (llvm::any_of(blockMapping, [](DeviceMappingAttrInterface map) {
+ return !map.isa<GPUBlockMappingAttr>();
+ })) {
+ return transformOp.emitSilenceableError()
+ << "mapping must be #gpu.block<x/y/z/>";
}
- FailureOr<SmallVector<OpFoldResult>> potentialGridDim =
- foreachThreadOp.getPermutedNumThreads(rewriter, mapping);
-
- if (failed(potentialGridDim) ||
- llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) {
- return !getConstantIntValue(ofr).has_value();
- })) {
- return transformOp.emitSilenceableError() << "unsupported dynamic gridDim";
+ // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
+ SmallVector<Value> numBlocks =
+ llvm::to_vector(foreachThreadOp.getNumThreads());
+ // Ensure we have 3 block sizes, one for each id.
+ Value one;
+ for (auto attr : {bX, bY, bZ}) {
+ if (std::find(blockMapping.begin(), blockMapping.end(), attr) ==
+ blockMapping.end()) {
+ blockMapping.push_back(attr);
+ one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ numBlocks.push_back(one);
+ }
}
- for (OpFoldResult ofr : *potentialGridDim)
- gridDims.push_back(getConstantIntValue(ofr).value());
+ // Step 2. sort the values by the corresponding GPUBlockMappingAttr.
+ auto comparator = [](Attribute a, Attribute b) -> bool {
+ return static_cast<int64_t>(a.cast<GPUBlockMappingAttr>().getBlock()) <
+ static_cast<int64_t>(b.cast<GPUBlockMappingAttr>().getBlock());
+ };
+ SmallVector<Value> gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey(
+ blockMapping, numBlocks, comparator);
+ for (Value v : gridDimValues)
+ gridDims.push_back(v.getDefiningOp<arith::ConstantIndexOp>().value());
+ // Step 3. Generate the blockIds using the provided generator and map the
+ // induction variables to the newly created ops.
SmallVector<Value> blockOps;
blockIdGenerator(rewriter, foreachThreadOp, blockOps);
+ BlockAndValueMapping bvm;
+ for (auto [blockIdx, blockDim] :
+ llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) {
+ bvm.map(blockIdx, blockOps[static_cast<int64_t>(
+ blockDim.cast<GPUBlockMappingAttr>().getBlock())]);
+ }
- // Step 1. Move the body of foreachThreadOp.
+ // Step 4. Move the body of foreachThreadOp.
// Erase the terminator first, it will not be used since we are on buffers.
rewriter.eraseOp(foreachThreadOp.getTerminator());
Block *targetBlock = foreachThreadOp->getBlock();
targetBlock->getOperations().splice(insertionPoint,
sourceBlock.getOperations());
- // Step 2. RAUW thread indices to thread ops.
- SmallVector<Value> threadIndices =
- *foreachThreadOp.getPermutedThreadIndices(mapping);
- assert(blockOps.size() == 3 && "3 block id ops are required");
- for (auto [blockIdx, blockOp] : llvm::zip(threadIndices, blockOps)) {
- Value val = blockIdx;
- Value blkOp = blockOp;
- if (!val)
- continue;
- for (Operation *user : llvm::make_early_inc_range(val.getUsers()))
- user->replaceUsesOfWith(val, blkOp);
+ // Step 5. RAUW thread indices to thread ops.
+ for (Value blockIdx : foreachThreadOp.getThreadIndices()) {
+ for (Operation *user : llvm::make_early_inc_range(blockIdx.getUsers())) {
+ rewriter.updateRootInPlace(user, [&]() {
+ user->replaceUsesOfWith(blockIdx, bvm.lookup(blockIdx));
+ });
+ }
}
- // Step 3. Erase old op.
+ // Step 6. Erase old op.
rewriter.eraseOp(foreachThreadOp);
return DiagnosedSilenceableFailure::success();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(foreachOp);
IndexType indexType = rewriter.getIndexType();
- SmallVector<Dimension> gpuDims{Dimension::x, Dimension::y, Dimension::z};
- for (int64_t idx : llvm::seq<int64_t>(0, gpuDims.size())) {
- blockOps.push_back(
- rewriter.create<BlockIdOp>(loc, indexType, gpuDims[idx]));
- }
+ blockOps = SmallVector<Value>{
+ rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
+ rewriter.create<BlockIdOp>(loc, indexType, Dimension::y),
+ rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)};
}
DiagnosedSilenceableFailure
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
const SmallVectorImpl<int64_t> &globalBlockDims, bool syncAfterDistribute,
llvm::Optional<TransformOpInterface> transformOp) {
+ // Step 0. Target-specific verifications. There is no good place to anchor
+ // those right now: the ForeachThreadOp is target-independent and the
+ // transform op does not apply to individual ForeachThreadOp.
auto failureHelper =
[&](const Twine &message) -> DiagnosedSilenceableFailure {
if (transformOp.has_value()) {
}
return emitDefiniteFailure(foreachThreadOp, message);
};
-
+ MLIRContext *ctx = foreachThreadOp->getContext();
+ Location loc = foreachThreadOp->getLoc();
+ Attribute tX = GPUThreadMappingAttr::get(ctx, Threads::DimX);
+ Attribute tY = GPUThreadMappingAttr::get(ctx, Threads::DimY);
+ Attribute tZ = GPUThreadMappingAttr::get(ctx, Threads::DimZ);
if (foreachThreadOp.getNumResults() > 0)
return failureHelper(
"only bufferized scf.foreach_thread lowers to gpu.thread_id");
-
if (foreachThreadOp.getNumThreads().size() > 3)
return failureHelper(
"scf.foreach_thread with rank > 3 does not lower to gpu.thread_id");
-
- SmallVector<int64_t> mapping;
+ if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
+ return !v.getDefiningOp<arith::ConstantIndexOp>();
+ })) {
+ return failureHelper("unsupported dynamic blockdim size");
+ }
if (!foreachThreadOp.getMapping().has_value())
return failureHelper("mapping must be present");
- for (DeviceMappingAttrInterface map :
- foreachThreadOp.getMapping()->getValue()) {
- if (auto threadMap = map.dyn_cast<GPUThreadMappingAttr>()) {
- mapping.push_back((int64_t)threadMap.getThread());
- } else {
- return failureHelper("mapping must be #gpu.thread<x/y/z/>");
- }
- }
- FailureOr<SmallVector<OpFoldResult>> potentialBlockDim =
- foreachThreadOp.getPermutedNumThreads(rewriter, mapping);
- if (failed(potentialBlockDim) ||
- llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) {
- return !getConstantIntValue(ofr).has_value();
+ SmallVector<Attribute> threadMapping =
+ llvm::to_vector(foreachThreadOp.getMapping()->getValue());
+ if (llvm::any_of(threadMapping, [](DeviceMappingAttrInterface map) {
+ return !map.isa<GPUThreadMappingAttr>();
})) {
- return failureHelper("unsupported dynamic blockdim size");
+ return transformOp->emitSilenceableError()
+ << "mapping must be #gpu.thread<x/y/z/>";
}
- SmallVector<int64_t> blockDim =
- llvm::to_vector(llvm::map_range(*potentialBlockDim, [](OpFoldResult ofr) {
- return getConstantIntValue(ofr).value();
+ // Step 1. Complete the threadMapping to a full mapping (with 1s) if
+ // necessary.
+ SmallVector<Value> numThreads =
+ llvm::to_vector(foreachThreadOp.getNumThreads());
+ // Ensure we have 3 block sizes, one for each id.
+ Value one;
+ for (auto attr : {tX, tY, tZ}) {
+ if (std::find(threadMapping.begin(), threadMapping.end(), attr) ==
+ threadMapping.end()) {
+ threadMapping.push_back(attr);
+ one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ numThreads.push_back(one);
+ }
+ }
+
+ // Step 2. sort the values by the corresponding GPUThreadMappingAttr.
+ auto comparator = [](Attribute a, Attribute b) -> bool {
+ return static_cast<int64_t>(a.cast<GPUThreadMappingAttr>().getThread()) <
+ static_cast<int64_t>(b.cast<GPUThreadMappingAttr>().getThread());
+ };
+ SmallVector<Value> blockDimValues =
+ scf::ForeachThreadOp::getValuesSortedByKey(threadMapping, numThreads,
+ comparator);
+ SmallVector<int64_t> blockDims =
+ llvm::to_vector(llvm::map_range(blockDimValues, [](Value v) {
+ return v.getDefiningOp<arith::ConstantIndexOp>().value();
}));
- // Step 1. Create the gpu.thread ops
- Location loc = foreachThreadOp.getLoc();
+ // Step 3. Create the gpu.thread ops and map the induction variables to the
+ // newly created ops.
IndexType indexType = rewriter.getIndexType();
-
- SmallVector<Dimension> gpuDims{Dimension::x, Dimension::y, Dimension::z};
- SmallVector<Value> threadOps;
- for (int64_t idx : llvm::seq<int64_t>(0, blockDim.size())) {
- threadOps.push_back(
- rewriter.create<ThreadIdOp>(loc, indexType, gpuDims[idx]));
+ SmallVector<Value> threadOps{
+ rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
+ rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
+ rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)};
+ BlockAndValueMapping bvm;
+ for (auto [blockIdx, blockDim] :
+ llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
+ bvm.map(blockIdx, threadOps[static_cast<int64_t>(
+ blockDim.cast<GPUThreadMappingAttr>().getThread())]);
}
- // Step 2. Maybe create conditionals to predicate the region.
+
+ // Step 4. Maybe create conditionals to predicate the region.
Value predicate;
for (auto [threadId, blockDim, globalBlockDim] :
- llvm::zip(threadOps, blockDim, globalBlockDims)) {
+ llvm::zip(threadOps, blockDims, globalBlockDims)) {
if (blockDim > globalBlockDim) {
return failureHelper(
"The requested GPU threads are fewer than the number of loop trip "
: tmpPredicate;
}
- // Step 3. Move the body of foreachThreadOp.
+ // Step 5. Move the body of foreachThreadOp.
// Erase the terminator first, it will not be used.
rewriter.eraseOp(foreachThreadOp.getTerminator());
Block *targetBlock;
Block::iterator insertionPoint;
if (predicate) {
- // Step 3.a. If predicated, move at the beginning.
+ // Step 5.a. If predicated, move at the beginning.
auto ifOp =
rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
targetBlock = ifOp.thenBlock();
insertionPoint = ifOp.thenBlock()->begin();
} else {
- // Step 3.a. Otherwise, move inline just before foreachThreadOp.
+ // Step 5.b. Otherwise, move inline just before foreachThreadOp.
targetBlock = foreachThreadOp->getBlock();
insertionPoint = Block::iterator(foreachThreadOp);
}
targetBlock->getOperations().splice(insertionPoint,
sourceBlock.getOperations());
- // Step 4. RAUW thread indices to thread ops.
- SmallVector<Value> threadIndices =
- *foreachThreadOp.getPermutedThreadIndices(mapping);
- for (auto [threadIdx, threadOp] : llvm::zip(threadIndices, threadOps)) {
- Value val = threadIdx;
- Value op = threadOp;
- if (!val)
- continue;
- for (Operation *user : llvm::make_early_inc_range(val.getUsers())) {
- user->replaceUsesOfWith(val, op);
+ // Step 6. RAUW thread indices to thread ops.
+ for (Value threadIdx : foreachThreadOp.getThreadIndices()) {
+ for (Operation *user : llvm::make_early_inc_range(threadIdx.getUsers())) {
+ rewriter.updateRootInPlace(user, [&]() {
+ user->replaceUsesOfWith(threadIdx, bvm.lookup(threadIdx));
+ });
}
}
- // Step 5. syncthreads.
+ // Step 7. syncthreads.
// TODO: Need warpsync
if (syncAfterDistribute)
rewriter.create<BarrierOp>(loc);
- // Step 6. Erase old op.
+ // Step 8. Erase old op.
rewriter.eraseOp(foreachThreadOp);
return DiagnosedSilenceableFailure::success();
if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType())
return emitOpError("type mismatch between ")
<< i << "-th output and corresponding block argument";
- if (getMapping().has_value())
+ if (getMapping().has_value() && !getMapping()->empty()) {
+ if (static_cast<int64_t>(getMapping()->size()) != getRank())
+ return emitOpError() << "mapping attribute size must match op rank";
for (auto map : getMapping()->getValue()) {
if (!isa<DeviceMappingAttrInterface>(map))
return emitOpError()
<< getMappingAttrName() << " is not device mapping attribute";
}
+ }
return success();
}
return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
}
-template <typename T>
-static FailureOr<SmallVector<T>> permute(const SmallVector<T> &vals,
- ArrayRef<int64_t> perm) {
- if (vals.size() != perm.size())
- return failure();
- SmallVector<T> result(vals.size());
- SmallVector<bool> seen(vals.size());
- for (auto [idx, val] : llvm::zip(perm, vals)) {
- // Already seen, invalid mapping.
- if (seen[idx])
- return failure();
- result[idx] = val;
- seen[idx] = true;
- }
- // Some not seen, invalid mapping.
- if (!llvm::all_of(seen, [](bool b) { return b; }))
- return failure();
- return result;
-}
-
-/// Helper to get apply the `mapping` permutation of a
-/// `foreachThreadOp` to `values`.
-template <typename T>
-static FailureOr<SmallVector<T>>
-getValuesPermutedByThreadMapping(scf::ForeachThreadOp foreachThreadOp,
- const SmallVector<T> &values,
- ArrayRef<int64_t> mapping) {
- // Apply mapping permutation if specified.
- FailureOr<SmallVector<T>> maybePermuted = permute(values, mapping);
- if (failed(maybePermuted))
- return foreachThreadOp->emitError("invalid permutation");
- return *maybePermuted;
- return values;
-}
-
-/// Return the thread indices in the order specified by the mapping
-/// attribute. Return failure is mapping is not a valid permutation.
-FailureOr<SmallVector<Value>>
-ForeachThreadOp::getPermutedThreadIndices(ArrayRef<int64_t> mapping) {
- SmallVector<Value> threadCountValues = this->getThreadIndices();
- threadCountValues.resize(3, Value());
- return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping);
-}
-
-/// Return the number of threads in the order specified by the
-/// mapping attribute.
-/// Return failure is mapping is not a valid permutation.
-FailureOr<SmallVector<OpFoldResult>>
-ForeachThreadOp::getPermutedNumThreads(OpBuilder &b,
- ArrayRef<int64_t> mapping) {
- SmallVector<OpFoldResult> threadCountValues = this->getNumThreads();
- threadCountValues.resize(3, b.getIndexAttr(1));
- return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping);
+/// Helper to sort `values` according to matching `keys`.
+SmallVector<Value> ForeachThreadOp::getValuesSortedByKey(
+ ArrayRef<Attribute> keys, ValueRange values,
+ llvm::function_ref<bool(Attribute, Attribute)> compare) {
+ if (keys.empty())
+ return values;
+ assert(keys.size() == values.size() && "unexpected mismatching sizes");
+ auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
+ std::sort(indices.begin(), indices.end(),
+ [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
+ SmallVector<Value> res;
+ res.reserve(values.size());
+ for (int64_t i = 0, e = indices.size(); i < e; ++i)
+ res.push_back(values[indices[i]]);
+ return res;
}
ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
%5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
%6 = math.fma %alpha, %4, %5 : f32
memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
- } { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+ } { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
gpu.terminator
}
%5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
%6 = math.fma %alpha, %4, %5 : f32
memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
- } { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+ } { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
gpu.terminator
}
%5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
%6 = math.fma %alpha, %4, %5 : f32
memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
- } { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+ } { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
gpu.terminator
}
return %y : memref<2 x 32 x f32>
scf.foreach_thread (%i, %j, %k, %l) in (%c2, %c32,%c32,%c32) {
%4 = memref.load %x[%i, %j, %k, %l] : memref<2x32x32x32xf32>
memref.store %4, %y[%i, %j, %k, %l] : memref<2x32x32x32xf32>
- } { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+ } { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>, #gpu.thread<z>] }
gpu.terminator
}
return %y : memref<2x32x32x32xf32>
%5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
%6 = math.fma %alpha, %4, %5 : f32
memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
- } { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+ } { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
scf.foreach_thread (%i, %j) in (%c7, %c9) {
%4 = memref.load %x[%i, %j] : memref<2 x 32 x f32>
%5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
%6 = math.fma %alpha, %4, %5 : f32
memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
- } { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+ } { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
gpu.terminator
}
%5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
%6 = math.fma %alpha, %4, %5 : f32
memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
- } { mapping = [#gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>] }
+ } { mapping = [#gpu.thread<x>, #gpu.thread<y>] }
scf.foreach_thread (%i, %j) in (%c7, %c9) {
%4 = memref.load %x[%i, %j] : memref<2 x 32 x f32>
%5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
%6 = math.fma %alpha, %4, %5 : f32
memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
- } { mapping = [#gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z>] }
+ } { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
return %y : memref<2 x 32 x f32>
}
%5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
%6 = math.fma %alpha, %4, %5 : f32
memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
- } { mapping = [#gpu.block<x>, #gpu.block<y>, #gpu.block<z>] }
+ } { mapping = [#gpu.block<x>, #gpu.block<y>] }
return %y : memref<2 x 32 x f32>
}