[mlir] Add map_nested_foreach_thread_to_gpu_threads op to transform dialect
authorGuray Ozen <guray.ozen@gmail.com>
Mon, 19 Sep 2022 10:19:21 +0000 (12:19 +0200)
committerGuray Ozen <guray.ozen@gmail.com>
Mon, 19 Sep 2022 14:27:30 +0000 (16:27 +0200)
This revision adds a new op `map_nested_foreach_thread_to_gpu_threads` to transform dialect. The op searches `scf.foreach_threads` inside the `gpu_launch` and distributes them with `gpu.thread_id` attribute.

Loop mapping is explicit and given by the `map_nested_foreach_thread_to_gpu_threads` op. Mapping is done one-to-one, therefore the loops dissappear.

The dynamic trip count or trip count that are larger than thread size are not supported for the time being. However, we can indeed support them by generating a loop inside with cyclic scheduling.

For the time being, trip counts that are dynamic or bigger than thread sizes are not supported. However, in the future the compiler can indeed generate a loop with static cyclic scheduling to support these cases.

Current mechanism allows `scf.foreach_threads` to be siblings or nested. There cannot be interleaving code between the loops when they are nested.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D133950

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/Linalg/transform-gpu.mlir [new file with mode: 0644]

index 4f845bc..d5fafca 100644 (file)
@@ -748,6 +748,104 @@ def TileToForeachThreadOp :
   }];
 }
 
+def MapNestedForeachThreadToGpuThreads : 
+  Op<Transform_Dialect, "structured.map_nested_foreach_thread_to_gpu_threads",
+    [FunctionalStyleTransformOpTrait, 
+     MemoryEffectsOpInterface,
+     TransformEachOpTrait,
+     TransformOpInterface]> {
+  let description = [{
+      Target the gpu_launch op and rewrite all scf.foreach_thread
+      to distributed gpu.thread_id attribute.
+
+      The operation searches `scf.foreach_thread` ops nested under `target` 
+      and maps each such op to GPU threads. Mapping is one-to-one and the 
+      induction variables of `scf.foreach_thread` are rewritten to 
+      gpu.thread_id according to the thread_dim_apping attribute. 
+      
+      Sibling `scf.foreach_thread` are supported in which case, the union of 
+      the number of threads is computed and may result in predication.
+
+      Multiple scf.foreach_thread are supported per function in which case, the
+      max of all the threads is computed and taken for the global gpu.thread_id.
+      If necessary, scf.foreach_thread that do not use the whole thread range 
+      result in predicated computations.
+
+      Dynamic, `scf.foreach_thread` trip counts are currently not supported. 
+      Dynamic block dim sizes are currently not supported.
+
+      Only **bufferized** scf.foreach_thread are currently supported.
+      Only scf.foreach_thread distributed to **at most 3 dimensions** are 
+      currently supported.
+    
+      Barriers are inserted after each scf.foreach_thread op for now.
+
+      The operation alters the block size of the given gpu_launch using 
+      blockDim argument.
+
+      Return modes:
+      =============
+      This operation ignores non-gpu_launch ops and drops them in the return.
+
+      If any scf.foreach_thread with tensors is found, the transform definitely 
+      fails.    
+
+      If all the scf.foreach_thread operations contained within the LaunchOp 
+      referred to by the `target` PDLOperation lower to GPU properly, the 
+      transform succeeds. Otherwise the transform definitely fails.
+      
+      The returned handle points to the same LaunchOp operand, consuming it and
+      producing a new SSA value to satisfy chaining and linearity of the IR 
+      properties.
+
+      Example:
+      ========
+
+      ```
+      gpu.launch blocks(%bx, %by, %bz) in (%x = %0, %y = %1, %z = %2)
+                 threads(%tx, %ty, %tz) in (%tx = %3, %ty = %4, %tz = %5) {
+        scf.foreach_thread (%i, %j) in (7, 9) {
+          ... // body 1
+        } {thread_dim_mapping = [1, 0, 2]}
+        scf.foreach_thread (%i) in (12) {
+          ... // body 2
+        }
+        gpu.terminator
+      }
+      ```
+      is translated to:
+
+      ```
+      %bdimX = arith.constant 12 : index
+      %bdimY = arith.constant 9 : index
+      gpu.launch blocks(%bx, %by, %bz) in (%x = %0, %y = %1, %z = %2)
+             threads(%tx, %ty, %tz) in (%tx = %bdimX, %ty = %bdimY, %tz = %5) {
+        if (threadIdx.x < 9 && threadIdx.y < 7) {
+          ... // body 1
+        }
+        gpu.barrier
+        if (threadIdx.y < 1) {
+          ... // body 2
+        }
+        gpu.barrier
+        gpu.terminator
+      }      
+      ```
+    }];
+
+  let arguments = (ins PDL_Operation:$target,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$blockDim);
+  let results = (outs PDL_Operation:$result);
+
+  let assemblyFormat = "$target attr-dict";
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::Operation *target, 
+        ::llvm::SmallVectorImpl<::mlir::Operation *> &results, 
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformEachOpTrait, TransformOpInterface]> {
index 43185a2..2382188 100644 (file)
@@ -121,6 +121,17 @@ bool areElementwiseOpsFusable(OpOperand *fusedOperand);
 FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
                                           OpOperand *fusedOperand);
 
+/// Searches `scf.foreach_thread` ops nested under `target` and maps each such
+/// op to GPU threads. Mapping is one-to-one and the induction variables of
+/// `scf.foreach_thread` are rewritten to gpu.thread_id according to the
+/// thread_dim_apping attribute. Sibling `scf.foreach_thread` are supported in
+/// which case, the union of the number of threads is computed and may result in
+/// predication. Dynamic, `scf.foreach_thread` trip counts are currently not
+/// supported. Dynamic block dim sizes are currently not supported.
+mlir::WalkResult rewriteMapNestedForeachThreadToGpuThreads(
+    RewriterBase &rewriter, Operation *target,
+    const SmallVector<int64_t> &blockDim, bool syncAfterDistribute);
+
 /// Split the given `op` into two parts along the given iteration space
 /// `dimension` at the specified `splitPoint`, and return the two parts.
 ///
index d542093..9ed0ae8 100644 (file)
@@ -501,6 +501,16 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
       return getBody()->getArguments().drop_front(getRank());
     }
 
+    /// Return the thread indices in the order specified by the 
+    /// thread_dim_mapping attribute. Return failure is 
+    /// thread_dim_mapping is not a valid permutation.
+    FailureOr<SmallVector<Value>> getPermutedThreadIndices();
+
+    /// Return the number of threads in the order specified by the
+    /// thread_dim_mapping attribute.
+    /// Return failure is thread_dim_mapping is not a valid permutation.
+    FailureOr<SmallVector<OpFoldResult>> getPermutedNumThreads(OpBuilder &b);
+
     // The ensureTerminator method generated by SingleBlockImplicitTerminator is
     // unaware of the fact that our terminator also needs a region to be
     // well-formed. We override it here to ensure that we do the right thing.
index bf4396b..108570f 100644 (file)
@@ -11,6 +11,7 @@
 #include "mlir/AsmParser/AsmParser.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
@@ -1166,6 +1167,175 @@ void transform::TileOp::getEffects(
 }
 
 //===----------------------------------------------------------------------===//
+// MapNestedForeachThreadToGpuThreads
+//===----------------------------------------------------------------------===//
+
+/// Searches `scf.foreach_thread` ops nested under `target` and maps each such
+/// op to GPU threads. Mapping is one-to-one and the induction variables of
+/// `scf.foreach_thread` are rewritten to gpu.thread_id according to the
+/// thread_dim_apping attribute. Sibling `scf.foreach_thread` are supported in
+/// which case, the union of the number of threads is computed and may result in
+/// predication. Dynamic, `scf.foreach_thread` trip counts are currently not
+/// supported. Dynamic block dim sizes are currently not supported.
+static FailureOr<SmallVector<OpFoldResult>> rewriteOneForeachThreadToGpuThreads(
+    RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
+    const SmallVector<int64_t> &globalBlockDims, bool syncAfterDistribute) {
+  if (foreachThreadOp.getNumResults() > 0)
+    return foreachThreadOp->emitError(
+        "only bufferized scf.foreach_thread lowers to gpu.thread");
+  if (foreachThreadOp.getNumThreads().size() > 3)
+    return foreachThreadOp->emitError(
+        "scf.foreach_thread with rank > 3 does not lower to gpu.thread");
+
+  auto potentialBlockDim = foreachThreadOp.getPermutedNumThreads(rewriter);
+  if (failed(potentialBlockDim) ||
+      llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) {
+        return !getConstantIntValue(ofr).has_value();
+      }))
+    return foreachThreadOp->emitError("unsupported dynamic blockdim size");
+
+  SmallVector<int64_t> blockDim =
+      llvm::to_vector(llvm::map_range(*potentialBlockDim, [](OpFoldResult ofr) {
+        return getConstantIntValue(ofr).value();
+      }));
+
+  // Step 1. Create the gpu.thread ops
+  Location loc = foreachThreadOp.getLoc();
+  IndexType indexType = rewriter.getIndexType();
+
+  SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
+                                      gpu::Dimension::z};
+  SmallVector<Value> threadOps;
+  for (int64_t idx : llvm::seq<int64_t>(0, blockDim.size())) {
+    threadOps.push_back(
+        rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpuDims[idx]));
+  }
+  // Step 2. Maybe create conditionals to predicate the region.
+  Value predicate;
+  for (auto [threadId, blockDim, globalBlockDim] :
+       llvm::zip(threadOps, blockDim, globalBlockDims)) {
+    if (blockDim > globalBlockDim) {
+      return foreachThreadOp.emitOpError("blockDim size overflow: ")
+             << blockDim << " > " << globalBlockDim;
+    }
+    if (blockDim == globalBlockDim)
+      continue;
+    Value tmpPredicate = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, threadId,
+        rewriter.create<arith::ConstantIndexOp>(loc, blockDim));
+    predicate =
+        predicate ? rewriter.create<arith::AndIOp>(loc, predicate, tmpPredicate)
+                  : tmpPredicate;
+  }
+
+  // Step 3. Move the body of foreachThreadOp.
+  // Erase the terminator first, it will not be used.
+  rewriter.eraseOp(foreachThreadOp.getTerminator());
+  Block *targetBlock;
+  Block::iterator insertionPoint;
+  if (predicate) {
+    // Step 3.a. If predicated, move at the beginning.
+    auto ifOp =
+        rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
+    targetBlock = ifOp.thenBlock();
+    insertionPoint = ifOp.thenBlock()->begin();
+  } else {
+    // Step 3.a. Otherwise, move inline just before foreachThreadOp.
+    targetBlock = foreachThreadOp->getBlock();
+    insertionPoint = Block::iterator(foreachThreadOp);
+  }
+  Block &sourceBlock = foreachThreadOp.getRegion().front();
+  targetBlock->getOperations().splice(insertionPoint,
+                                      sourceBlock.getOperations());
+
+  // Step 4. RAUW thread indices to thread ops.
+  SmallVector<Value> threadIndices =
+      *foreachThreadOp.getPermutedThreadIndices();
+  for (auto it : llvm::zip(threadIndices, threadOps)) {
+    Value val = std::get<0>(it);
+    if (!val)
+      continue;
+    for (Operation *user : llvm::make_early_inc_range(val.getUsers())) {
+      rewriter.updateRootInPlace(
+          user, [&]() { user->replaceUsesOfWith(val, std::get<1>(it)); });
+    }
+  }
+
+  // Step 5. syncthreads.
+  // TODO: Need warpsync
+  if (syncAfterDistribute)
+    rewriter.create<gpu::BarrierOp>(loc);
+
+  // Step 6. Erase old op.
+  rewriter.eraseOp(foreachThreadOp);
+
+  return *potentialBlockDim;
+}
+
+mlir::WalkResult mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads(
+    RewriterBase &rewriter, Operation *target,
+    const SmallVector<int64_t> &blockDim, bool syncAfterDistribute) {
+  auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
+    rewriter.setInsertionPoint(foreachThreadOp);
+    if (failed(rewriteOneForeachThreadToGpuThreads(rewriter, foreachThreadOp,
+                                                   blockDim, true)))
+      return WalkResult::interrupt();
+    return WalkResult::advance();
+  });
+  return walkResult;
+}
+
+// Alter blockDim of the given kernel
+static LogicalResult alterGpuLaunchBlockDim(SimpleRewriter &rewriter,
+                                            gpu::LaunchOp gpuLaunch,
+                                            SmallVector<int64_t> blockDim) {
+  gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
+  if (blockDim[0] < 1 || blockDim[1] < 1 || blockDim[2] < 1) {
+    gpuLaunch->emitError() << "Given blockDim(" << blockDim[0] << ","
+                           << blockDim[1] << "," << blockDim[2]
+                           << ") is invalid";
+    return failure();
+  }
+  rewriter.setInsertionPointAfterValue(currentBlockdim.x);
+  auto createBlockDimValue = [&](int64_t dim) {
+    return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(),
+                                                   dim);
+  };
+  gpuLaunch.blockSizeXMutable().assign(createBlockDimValue(blockDim[0]));
+  gpuLaunch.blockSizeYMutable().assign(createBlockDimValue(blockDim[1]));
+  gpuLaunch.blockSizeZMutable().assign(createBlockDimValue(blockDim[2]));
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform::MapNestedForeachThreadToGpuThreads::applyToOne(
+    Operation *target, SmallVectorImpl<Operation *> &results,
+    transform::TransformState &state) {
+
+  gpu::LaunchOp gpuLaunch = dyn_cast<gpu::LaunchOp>(target);
+  if (!gpuLaunch) {
+    target->emitError("Given target is not gpu.launch");
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+
+  SmallVector<int64_t> blockDim = extractFromI64ArrayAttr(getBlockDim());
+  blockDim.resize(/*size=*/3, /*value=*/1);
+  SimpleRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(target);
+  auto walkResult = mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads(
+      rewriter, target, blockDim, true);
+  if (walkResult.wasInterrupted())
+    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+
+  LogicalResult result = alterGpuLaunchBlockDim(rewriter, gpuLaunch, blockDim);
+  if (failed(result))
+    return DiagnosedSilenceableFailure::definiteFailure();
+
+  results.assign({target});
+  return DiagnosedSilenceableFailure(success());
+}
+
+//===----------------------------------------------------------------------===//
 // TileToForeachThreadOp
 //===----------------------------------------------------------------------===//
 
index 548bee8..83a2381 100644 (file)
@@ -1244,6 +1244,61 @@ PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
   return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
 }
 
+template <typename T>
+static FailureOr<SmallVector<T>> permute(const SmallVector<T> &vals,
+                                         ArrayRef<int64_t> perm) {
+  if (vals.size() != perm.size())
+    return failure();
+  SmallVector<T> result(vals.size());
+  SmallVector<bool> seen(vals.size());
+  for (auto [idx, val] : llvm::zip(perm, vals)) {
+    // Already seen, invalid thread_dim_mapping.
+    if (seen[idx])
+      return failure();
+    result[idx] = val;
+    seen[idx] = true;
+  }
+  // Some not seen, invalid thread_dim_mapping.
+  if (!llvm::all_of(seen, [](bool b) { return b; }))
+    return failure();
+  return result;
+}
+
+/// Helper to get apply the `thread_dim_mapping` permutation of a
+/// `foreachThreadOp` to `values`.
+template <typename T>
+static FailureOr<SmallVector<T>>
+getValuesPermutedByThreadMapping(scf::ForeachThreadOp foreachThreadOp,
+                                 const SmallVector<T> &values) {
+  // Apply mapping permutation if specified.
+  auto mapping = foreachThreadOp.getThreadDimMapping();
+  if (mapping && !mapping.empty()) {
+    auto maybePermuted = permute(values, extractFromI64ArrayAttr(mapping));
+    if (failed(maybePermuted))
+      return foreachThreadOp->emitError("invalid permutation");
+    return *maybePermuted;
+  }
+  return values;
+}
+
+/// Return the thread indices in the order specified by the thread_dim_mapping
+/// attribute. Return failure is thread_dim_mapping is not a valid permutation.
+FailureOr<SmallVector<Value>> ForeachThreadOp::getPermutedThreadIndices() {
+  SmallVector<Value> threadCountValues = this->getThreadIndices();
+  threadCountValues.resize(3, Value());
+  return getValuesPermutedByThreadMapping(*this, threadCountValues);
+}
+
+/// Return the number of threads in the order specified by the
+/// thread_dim_mapping attribute.
+/// Return failure is thread_dim_mapping is not a valid permutation.
+FailureOr<SmallVector<OpFoldResult>>
+ForeachThreadOp::getPermutedNumThreads(OpBuilder &b) {
+  SmallVector<OpFoldResult> threadCountValues = this->getNumThreads();
+  threadCountValues.resize(3, b.getIndexAttr(1));
+  return getValuesPermutedByThreadMapping(*this, threadCountValues);
+}
+
 ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
   auto tidxArg = val.dyn_cast<BlockArgument>();
   if (!tidxArg)
diff --git a/mlir/test/Dialect/Linalg/transform-gpu.mlir b/mlir/test/Dialect/Linalg/transform-gpu.mlir
new file mode 100644 (file)
index 0000000..00b750e
--- /dev/null
@@ -0,0 +1,58 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
+
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-LABEL: func.func @saxpy2d(
+// CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGT:[0-9a-z]+]]: memref<32xf32>
+func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+  %one = arith.constant 1 : index
+  %c12 = arith.constant 12 : index
+  %c9 = arith.constant 9 : index
+  %c7 = arith.constant 7 : index
+//      CHECK:   gpu.launch
+//      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
+//      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
+//      CHECK:   %[[C9:.*]] = arith.constant 9 : index
+//      CHECK:   arith.cmpi ult, %[[TIDX]], %[[C9]] : index
+//      CHECK:   %[[C7:.*]] = arith.constant 7 : index
+//      CHECK:   arith.cmpi ult, %[[TIDY]], %[[C7]] : index
+//      CHECK:   memref.load %[[ARGX]][%[[TIDY]], %[[TIDX]]]
+//      CHECK:   memref.load %[[ARGY]][%[[TIDY]], %[[TIDX]]]
+//      CHECK:   gpu.barrier
+//      CHECK:   %[[TIDX2:.*]] = gpu.thread_id  x
+//      CHECK:   %[[TIDY2:.*]] = gpu.thread_id  y
+//      CHECK:   %[[C1:.*]] = arith.constant 1 : index
+//      CHECK:   arith.cmpi ult, %[[TIDY2]], %[[C1]] : index
+//      CHECK:   memref.load %[[ARGT]][%[[TIDX2]]]
+//      CHECK:   gpu.barrier
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) 
+  {
+    scf.foreach_thread (%i, %j) in (%c7, %c9) {
+        %4 = memref.load %x[%i, %j] : !type        
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  {thread_dim_mapping = [1, 0, 2]}    
+     scf.foreach_thread (%i) in (%c12) {
+        %7 = memref.load %t[%i] : !type1d        
+        %8 = arith.addf %alpha, %7 : f32
+        memref.store %8, %t[%i] : !type1d
+     }  {thread_dim_mapping = [0, 1, 2]}
+    gpu.terminator
+  }
+  return %y : !type
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0
+    transform.structured.map_nested_foreach_thread_to_gpu_threads %funcop { blockDim = [12, 9, 1] }
+  }
+}
+