[SCF] Add thread_dim_mapping attribute to scf.foreach_thread
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 24 Jun 2022 09:26:22 +0000 (02:26 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 27 Jun 2022 11:58:36 +0000 (04:58 -0700)
An optional thread_dim_mapping index array attribute specifies for each
virtual thread dimension, how it remaps 1-1 to a set of concrete processing
element resources (e.g. a CUDA grid dimension or a level of concrete nested
async parallelism). At this time, the specification is backend-dependent and
is not verified by the op, beyond being an index array attribute.
It is the reponsibility of the lowering to interpret the index array in the
context of the concrete target the op is lowered to, or to ignore it when
the specification is ill-formed or unsupported for a particular target.

Differential Revision: https://reviews.llvm.org/D128633

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
mlir/test/Dialect/SCF/ops.mlir

index b36f6b7a1dba68a4c022202b13fe140ddbfee83a..cde966592212d026e547dfe6534a56678d24448c 100644 (file)
@@ -339,6 +339,15 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     application per thread. Further lowerings are responsible for specifying
     how this is materialized on concrete hardware resources.
 
+    An optional thread_dim_mapping index array attribute specifies for each
+    virtual thread dimension, how it remaps 1-1 to a set of concrete processing
+    element resources (e.g. a CUDA grid dimension or a level of concrete nested
+    async parallelism). At this time, the specification is backend-dependent and
+    is not verified by the op, beyond being an index array attribute.
+    It is the reponsibility of the lowering to interpret the index array in the
+    context of the concrete target the op is lowered to, or to ignore it when
+    the specification is ill-formed or unsupported for a particular target.
+
     The only allowed terminator is `scf.foreach_thread.perform_concurrently`,
     which dictates how the partial results of all parallel invocations should be
     reconciled into a full value.
@@ -398,8 +407,27 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     // Sequential context.
     //
     ```
+
+    Example with thread_dim_mapping attribute:
+    //
+    // Sequential context.
+    //
+    %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
+         (%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
+      //
+      // Parallel context, each thread with id = **(%thread_id_2, %thread_id_1)**
+      // runs its version of the code.
+      //
+       scf.foreach_thread.perform_concurrently {
+         ...
+      }
+    } { thread_dim_mapping = [1, 0] }
+    // Implicit synchronization point.
+    // Sequential context.
+    //
   }];
-  let arguments = (ins Variadic<Index>:$num_threads);
+  let arguments = (ins Variadic<Index>:$num_threads,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
 
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -411,11 +439,13 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
   let skipDefaultBuilders = 1;
   let builders = [
     // Bodyless builder, result types must be specified.
-    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads)>,
+    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads,
+                   CArg<"ArrayRef<int64_t>", "{}">:$thread_dim_mapping)>,
     // Builder that takes a bodyBuilder lambda, result types are inferred from
     // the terminator.
     OpBuilder<(ins "ValueRange":$num_threads,
-              "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
+                   "ArrayRef<int64_t>":$thread_dim_mapping,
+                   "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
   ];
   let extraClassDeclaration = [{
     int64_t getRank() { return getNumThreads().size(); }
index 73eca29ec9270f0547a3c19acb97e5f36364a81a..bd0f16dbd0e07eee091635e8db9873332ce191b1 100644 (file)
@@ -1135,8 +1135,12 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
 // Bodyless builder, result types must be specified.
 void ForeachThreadOp::build(mlir::OpBuilder &builder,
                             mlir::OperationState &result, TypeRange resultTypes,
-                            ValueRange numThreads) {
+                            ValueRange numThreads,
+                            ArrayRef<int64_t> threadDimMapping) {
   result.addOperands(numThreads);
+  result.addAttribute(
+      // TODO: getThreadDimMappingAttrName() but it is not a static member.
+      "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
 
   Region *bodyRegion = result.addRegion();
   OpBuilder::InsertionGuard g(builder);
@@ -1156,9 +1160,12 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder,
 // the terminator.
 void ForeachThreadOp::build(
     mlir::OpBuilder &builder, mlir::OperationState &result,
-    ValueRange numThreads,
+    ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
   result.addOperands(numThreads);
+  result.addAttribute(
+      // TODO: getThreadDimMappingAttrName() but it is not a static member.
+      "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
 
   OpBuilder::InsertionGuard g(builder);
   Region *bodyRegion = result.addRegion();
index 16ef7b5cad13037b1ad2b472868fdf0669fb39e4..36d0f0cefbae31871f9cfe3aa8a930e493751513 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -999,7 +1000,8 @@ struct ForeachThreadOpInterface
     TypeRange newResultTypes;
     auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
         foreachThreadOp.getLoc(), newResultTypes,
-        foreachThreadOp.getNumThreads());
+        foreachThreadOp.getNumThreads(),
+        extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
     newForeachThreadOp.getBody()->getTerminator()->erase();
 
     // Move over block contents of the old op.
index 63d5d88ba0317f4bc20fc516b12a00856b06814e..365195bc7896b2f90bba4546403f0b284f6c79ec 100644 (file)
@@ -130,6 +130,7 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>,
         scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
           tensor<1xf32> into tensor<100xf32>
       }
-  }
+  // CHECK: } {thread_dim_mapping = [5]}
+  } {thread_dim_mapping = [5]}
   return
 }
index 3ba3c04deb1524903b9b9672d478946daf2ea8f6..294017aef622506e4a66aae4fe0e2cb283144153 100644 (file)
@@ -338,11 +338,11 @@ func.func @elide_terminator() -> () {
   %num_threads = arith.constant 100 : index
 
   //      CHECK:    scf.foreach_thread
-  // CHECK-NEXT:  }
+  // CHECK-NEXT:  } {thread_dim_mapping = [42]}
   // CHECK-NEXT:  return
   scf.foreach_thread (%thread_idx) in (%num_threads) -> () {
     scf.foreach_thread.perform_concurrently {
     }
-  }
+  } {thread_dim_mapping = [42]}
   return
 }