application per thread. Further lowerings are responsible for specifying
how this is materialized on concrete hardware resources.
+ An optional thread_dim_mapping index array attribute specifies for each
+ virtual thread dimension, how it remaps 1-1 to a set of concrete processing
+ element resources (e.g. a CUDA grid dimension or a level of concrete nested
+ async parallelism). At this time, the specification is backend-dependent and
+ is not verified by the op, beyond being an index array attribute.
+ It is the reponsibility of the lowering to interpret the index array in the
+ context of the concrete target the op is lowered to, or to ignore it when
+ the specification is ill-formed or unsupported for a particular target.
+
The only allowed terminator is `scf.foreach_thread.perform_concurrently`,
which dictates how the partial results of all parallel invocations should be
reconciled into a full value.
// Sequential context.
//
```
+
+ Example with thread_dim_mapping attribute:
+ //
+ // Sequential context.
+ //
+ %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
+ (%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
+ //
+ // Parallel context, each thread with id = **(%thread_id_2, %thread_id_1)**
+ // runs its version of the code.
+ //
+ scf.foreach_thread.perform_concurrently {
+ ...
+ }
+ } { thread_dim_mapping = [1, 0] }
+ // Implicit synchronization point.
+ // Sequential context.
+ //
}];
- let arguments = (ins Variadic<Index>:$num_threads);
+ let arguments = (ins Variadic<Index>:$num_threads,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
let skipDefaultBuilders = 1;
let builders = [
// Bodyless builder, result types must be specified.
- OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads)>,
+ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads,
+ CArg<"ArrayRef<int64_t>", "{}">:$thread_dim_mapping)>,
// Builder that takes a bodyBuilder lambda, result types are inferred from
// the terminator.
OpBuilder<(ins "ValueRange":$num_threads,
- "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
+ "ArrayRef<int64_t>":$thread_dim_mapping,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
];
let extraClassDeclaration = [{
int64_t getRank() { return getNumThreads().size(); }
// Bodyless builder, result types must be specified.
void ForeachThreadOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, TypeRange resultTypes,
- ValueRange numThreads) {
+ ValueRange numThreads,
+ ArrayRef<int64_t> threadDimMapping) {
result.addOperands(numThreads);
+ result.addAttribute(
+ // TODO: getThreadDimMappingAttrName() but it is not a static member.
+ "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
Region *bodyRegion = result.addRegion();
OpBuilder::InsertionGuard g(builder);
// the terminator.
void ForeachThreadOp::build(
mlir::OpBuilder &builder, mlir::OperationState &result,
- ValueRange numThreads,
+ ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
result.addOperands(numThreads);
+ result.addAttribute(
+ // TODO: getThreadDimMappingAttrName() but it is not a static member.
+ "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
OpBuilder::InsertionGuard g(builder);
Region *bodyRegion = result.addRegion();
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
TypeRange newResultTypes;
auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
foreachThreadOp.getLoc(), newResultTypes,
- foreachThreadOp.getNumThreads());
+ foreachThreadOp.getNumThreads(),
+ extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
newForeachThreadOp.getBody()->getTerminator()->erase();
// Move over block contents of the old op.