[mlir] Plumb missing paramter to gpu transform op
authorThomas Raoux <thomasraoux@google.com>
Fri, 23 Sep 2022 16:49:23 +0000 (16:49 +0000)
committerThomas Raoux <thomasraoux@google.com>
Fri, 23 Sep 2022 16:58:44 +0000 (16:58 +0000)
rewriteMapNestedForeachThreadToGpuThreads was dropping the paramter to
skip inserting barrier

Differential Revision: https://reviews.llvm.org/D134500

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-gpu.mlir

index e60131b..328b0d2 100644 (file)
@@ -834,7 +834,8 @@ def MapNestedForeachThreadToGpuThreads :
     }];
 
   let arguments = (ins PDL_Operation:$target,
-                   DefaultValuedAttr<I64ArrayAttr, "{}">:$blockDim);
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$blockDim,
+                   DefaultValuedAttr<BoolAttr, "true">:$syncAfterDistribute);
   let results = (outs PDL_Operation:$result);
 
   let assemblyFormat = "$target attr-dict";
index 18df9ed..8f6a6b1 100644 (file)
@@ -1277,8 +1277,8 @@ mlir::WalkResult mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads(
     const SmallVector<int64_t> &blockDim, bool syncAfterDistribute) {
   auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
     rewriter.setInsertionPoint(foreachThreadOp);
-    if (failed(rewriteOneForeachThreadToGpuThreads(rewriter, foreachThreadOp,
-                                                   blockDim, true)))
+    if (failed(rewriteOneForeachThreadToGpuThreads(
+            rewriter, foreachThreadOp, blockDim, syncAfterDistribute)))
       return WalkResult::interrupt();
     return WalkResult::advance();
   });
@@ -1354,7 +1354,7 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
   SimpleRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   auto walkResult = mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads(
-      rewriter, target, blockDim, true);
+      rewriter, target, blockDim, getSyncAfterDistribute());
   if (walkResult.wasInterrupted())
     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
 
index fbd7bcb..c33b42f 100644 (file)
@@ -138,3 +138,39 @@ transform.with_pdl_patterns {
     transform.structured.map_nested_foreach_thread_to_gpu_threads %gpuLaunch { blockDim = [32, 4, 1] }
   }
 }
+
+// -----
+
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-LABEL: func.func @saxpy2d_no_barrier(
+func.func @saxpy2d_no_barrier(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+  %one = arith.constant 1 : index
+  %c12 = arith.constant 12 : index
+  %c9 = arith.constant 9 : index
+  %c7 = arith.constant 7 : index
+//  CHECK-NOT:   gpu.barrier
+//      CHECK:   return
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+  {
+    scf.foreach_thread (%i, %j) in (%c7, %c9) {
+        %4 = memref.load %x[%i, %j] : !type
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  {thread_dim_mapping = [1, 0, 2]}
+    gpu.terminator
+  }
+  return %y : !type
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0
+    transform.structured.map_nested_foreach_thread_to_gpu_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false }
+  }
+}