[mlir][gpu] Allow distributing to different level of IDs without failing
authorThomas Raoux <thomasraoux@google.com>
Fri, 3 Feb 2023 22:53:16 +0000 (22:53 +0000)
committerThomas Raoux <thomasraoux@google.com>
Sat, 4 Feb 2023 02:03:05 +0000 (02:03 +0000)
Change map_nested_foreach_to_threads to ignore foreach_thread not
mapping to threads, this will allow us to call
mapNestedForeachToThreadsImpl with different set of ids to lower
multiple levels. Also adds warpIds attributes.

Differential Revision: https://reviews.llvm.org/D143298

mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/test/Dialect/GPU/transform-gpu-failing.mlir
mlir/test/Dialect/GPU/transform-gpu.mlir

index 43ffcc3..c03bdb6 100644 (file)
@@ -43,6 +43,27 @@ def GPUThreadMappingAttr
   }];
 }
 
+def WarpsEnum : I64EnumAttr<"Warps", "threads for loop mapping", [
+    DimX, DimY, DimZ]> {
+  let cppNamespace = "::mlir::gpu";
+}
+
+def GPUWarpMappingAttr : GPU_Attr<"GPUWarpMapping", "warp", [
+  DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
+  let parameters = (ins
+    EnumParameter<WarpsEnum>:$warp
+  );
+  let assemblyFormat = "`<` params `>`";
+  let description = [{
+    An attribute that allows defining thread block parallelism for GPU devices.
+
+    Warp (aka subgroup) are grouped into a grid where grid may be
+    described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates
+    that thread block parallelism is desired. It can be consumed by lowering to
+    generate GPU code.
+  }];
+}
+
 def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [
     DimX, DimY, DimZ]> {
   let cppNamespace = "::mlir::gpu";
index 39eee80..802a915 100644 (file)
@@ -58,9 +58,13 @@ def MapNestedForeachToThreads :
       If any scf.foreach_thread with tensors is found, the transform definitely
       fails.
 
-      If all the scf.foreach_thread operations contained within the LaunchOp
-      referred to by the `target` PDLOperation lower to GPU properly, the
-      transform succeeds. Otherwise the transform definitely fails.
+      If all the scf.foreach_thread operations with gpu.thread mapping contained
+      within the LaunchOp referred to by the `target` PDLOperation lower to GPU
+      properly, the transform succeeds. Otherwise the transform definitely
+      fails.
+
+      scf.foreach_thread operations with mappings other than gpu.thread are
+      ignored.
 
       The returned handle points to the same LaunchOp operand, consuming it and
       producing a new SSA value to satisfy chaining and linearity of the IR
index 7e1f536..0ae8e16 100644 (file)
@@ -42,6 +42,10 @@ int64_t GPUBlockMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getBlock());
 }
 
+int64_t GPUWarpMappingAttr::getMappingId() const {
+  return static_cast<int64_t>(getWarp());
+}
+
 int64_t GPUThreadMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getThread());
 }
index 52b3957..ecad3aa 100644 (file)
@@ -509,6 +509,12 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
     const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
   target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
+    // Ignore cases with different attributes.
+    for (Attribute map : foreachThreadOp.getMapping()->getValue()) {
+      if (!llvm::is_contained(threadMappingAttributes, map)) {
+        return WalkResult::skip();
+      }
+    }
     diag = checkAttributeType(threadMappingAttributes,
                               foreachThreadOp.getMapping(), transformOp);
     if (diag.succeeded()) {
index e48b48a..8435ecc 100644 (file)
@@ -274,30 +274,4 @@ transform.sequence failures(propagate) {
   transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32, 32]}
 }
 
-// -----
-
-!type = memref<32x32xf32>
-func.func @saxpy2d_wrong_mapping(%x: !type, %y: !type, %stream : !gpu.async.token) -> !type {
-  %c32 = arith.constant 32 : index
-  %one = arith.constant 1 : index
-  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
-            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
-  {
-    scf.foreach_thread (%i, %j) in (%c32, %c32) {
-        %4 = memref.load %x[%i, %j] : !type
-        %5 = memref.load %y[%i, %j] : !type
-        %6 = arith.mulf %4, %5 : f32
-        memref.store %6, %y[%i, %j] : !type
-     }  { mapping = [#gpu.block<x>, #gpu.block<x>] }
-    gpu.terminator
-  }
-  return %y : !type
-}
-
-transform.sequence failures(propagate) {
-^bb1(%arg0: !pdl.operation):
-  %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{mapping must be one of #gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>}}
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32, 32]}
-}
 
index 24b474c..035a770 100644 (file)
@@ -230,3 +230,42 @@ transform.sequence failures(propagate) {
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false }
 }
+
+// -----
+
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-LABEL: func.func @map_multi_level(
+func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+  %one = arith.constant 1 : index
+  %c12 = arith.constant 12 : index
+  %c9 = arith.constant 9 : index
+  %c7 = arith.constant 7 : index
+// check that the thread level got distributed but not the warp level.
+//  CHECK-NOT:  {mapping = #gpu.thread
+//      CHECK:  {mapping = [#gpu.warp<x>]}
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+  {
+    scf.foreach_thread (%i, %j) in (%c7, %c9) {
+        %4 = memref.load %x[%i, %j] : !type
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
+     scf.foreach_thread (%i) in (%c12) {
+        %7 = memref.load %t[%i] : !type1d
+        %8 = arith.addf %alpha, %7 : f32
+        memref.store %8, %t[%i] : !type1d
+     }  {mapping = [#gpu.warp<x>] }
+    gpu.terminator
+  }
+  return %y : !type
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !pdl.operation):
+  %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
+  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] }
+}