[mlir] Add map_nested_foreach_thread_to_gpu_blocks op to transform dialect
authorGuray Ozen <guray.ozen@gmail.com>
Fri, 23 Sep 2022 12:42:51 +0000 (14:42 +0200)
committerGuray Ozen <guray.ozen@gmail.com>
Fri, 23 Sep 2022 14:27:10 +0000 (16:27 +0200)
This revision adds a new op `map_nested_foreach_thread_to_gpu_blocks` to transform dialect.
If `generate_gpu_launch` argument is given, the op first generates `gpu_launch`. Otherwise, `target` must be `gpu_launch`. The op searches top level `scf.foreach_threads` inside the `gpu_launch` and distributes them with gpu.block_id attribute.
Loop mapping is explicit and given by the map_nested_foreach_thread_to_gpu_blocks op. Mapping is done one-to-one, therefore the loops disappear.
It also adds `gpu dialect` as dependent since the new op can create `gpu::LaunchOp` for given `scf::ForeachThreadOp`.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D134190

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-gpu.mlir

index d5fafca..e60131b 100644 (file)
@@ -846,6 +846,63 @@ def MapNestedForeachThreadToGpuThreads :
   }];
 }
 
+def MapNestedForeachThreadToGpuBlocks : Op<Transform_Dialect, 
+    "structured.map_nested_foreach_thread_to_gpu_blocks",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait]> {
+  let description = [{
+    Target the gpu_launch op and rewrite the top level `scf.foreach_thread`
+    to distributed gpu.block_id attribute. If `generate_gpu_launch` attribute 
+    is set, then first generates `gpu_launch` and moves the top level 
+    `scf.foreach_thread` inside.
+
+    The operation searches top level `scf.foreach_thread` ops under 
+    `gpu_launch` and maps each such op to GPU blocks. Mapping is 
+    one-to-one and the induction variables of `scf.foreach_thread` are 
+    rewritten to gpu.block_id according to the `thread_dim_apping` attribute.
+
+    Dynamic, `scf.foreach_thread` trip counts are currently not supported. 
+    Dynamic block dim sizes are currently not supported.
+
+    Only **bufferized** scf.foreach_thread are currently supported.
+    Only scf.foreach_thread distributed to **at most 3 dimensions** are 
+    currently supported.
+
+    The operation alters the block size of the given gpu_launch using 
+    gridDim argument.
+
+    #### Return modes:
+    
+    This operation ignores non-gpu_launch ops and drops them in the return.
+
+    If any scf.foreach_thread with tensors is found, the transform definitely 
+    fails.    
+
+    If all the scf.foreach_thread operations contained within the LaunchOp 
+    referred to by the `target` PDLOperation lower to GPU properly, the 
+    transform succeeds. Otherwise the transform definitely fails.
+
+    The returned handle points to the same LaunchOp operand, consuming it and
+    producing a new SSA value to satisfy chaining and linearity of the IR 
+    properties.
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$gridDim,
+                   UnitAttr:$generate_gpu_launch);
+  let results = (outs PDL_Operation:$result);
+
+  let assemblyFormat = "$target attr-dict";
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::Operation *target, 
+        ::llvm::SmallVectorImpl<::mlir::Operation *> &results, 
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformEachOpTrait, TransformOpInterface]> {
index b658671..11a34cc 100644 (file)
@@ -125,6 +125,21 @@ bool areElementwiseOpsFusable(OpOperand *fusedOperand);
 FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
                                           OpOperand *fusedOperand);
 
+/// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is
+/// one-to-one and the induction variables of `scf.foreach_thread` are rewritten
+/// to gpu.block_id according to the thread_dim_apping attribute. Dynamic,
+/// `scf.foreach_thread` trip counts are currently not supported. Dynamic block
+/// dim sizes are currently not supported.
+LogicalResult rewriteTopLevelForeachThreadToGpuBlocks(
+    RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
+    function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
+                      SmallVector<Value> &)>
+        blockIdGenerator,
+    SmallVector<int64_t> &gridDims);
+
+/// Finds the top level scf::ForeachThreadOp of given target.
+FailureOr<scf::ForeachThreadOp> findTopLevelForeachThreadOp(Operation *target);
+
 /// Searches `scf.foreach_thread` ops nested under `target` and maps each such
 /// op to GPU threads. Mapping is one-to-one and the induction variables of
 /// `scf.foreach_thread` are rewritten to gpu.thread_id according to the
index ca3c932..18df9ed 100644 (file)
@@ -1285,25 +1285,56 @@ mlir::WalkResult mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads(
   return walkResult;
 }
 
-// Alter blockDim of the given kernel
-static LogicalResult alterGpuLaunchBlockDim(SimpleRewriter &rewriter,
-                                            gpu::LaunchOp gpuLaunch,
-                                            SmallVector<int64_t> blockDim) {
-  gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
-  if (blockDim[0] < 1 || blockDim[1] < 1 || blockDim[2] < 1) {
-    gpuLaunch->emitError() << "Given blockDim(" << blockDim[0] << ","
-                           << blockDim[1] << "," << blockDim[2]
-                           << ") is invalid";
+static LogicalResult
+checkGpuLimits(Optional<int64_t> gridDimX, Optional<int64_t> gridDimY,
+               Optional<int64_t> gridDimZ, Optional<int64_t> blockDimX,
+               Optional<int64_t> blockDimY, Optional<int64_t> blockDimZ) {
+  // TODO The limits should live in the gpu dialect, but it's not like that
+  // right now. Read them in the common gpu dialect
+  if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) >
+          1024 ||
+      gridDimY.value_or(1) > 65535 || gridDimZ.value_or(1) > 65535 ||
+      gridDimX.value_or(1) > 2147483647)
+    return failure();
+  return success();
+}
+
+/// Alter grid or block dimensions of the given kernel
+static LogicalResult alterGpuLaunch(SimpleRewriter &rewriter,
+                                    gpu::LaunchOp gpuLaunch,
+                                    Optional<int64_t> gridDimX = llvm::None,
+                                    Optional<int64_t> gridDimY = llvm::None,
+                                    Optional<int64_t> gridDimZ = llvm::None,
+                                    Optional<int64_t> blockDimX = llvm::None,
+                                    Optional<int64_t> blockDimY = llvm::None,
+                                    Optional<int64_t> blockDimZ = llvm::None) {
+  if (failed(checkGpuLimits(gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY,
+                            blockDimZ))) {
+    gpuLaunch->emitError(
+        "Requested kernel thread configuration is larger than the limits");
     return failure();
   }
+
+  gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
+  OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointAfterValue(currentBlockdim.x);
-  auto createBlockDimValue = [&](int64_t dim) {
+  auto createConstValue = [&](int dim) {
     return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(),
                                                    dim);
   };
-  gpuLaunch.blockSizeXMutable().assign(createBlockDimValue(blockDim[0]));
-  gpuLaunch.blockSizeYMutable().assign(createBlockDimValue(blockDim[1]));
-  gpuLaunch.blockSizeZMutable().assign(createBlockDimValue(blockDim[2]));
+
+  if (gridDimX.has_value())
+    gpuLaunch.gridSizeXMutable().assign(createConstValue(gridDimX.value()));
+  if (gridDimY.has_value())
+    gpuLaunch.gridSizeYMutable().assign(createConstValue(gridDimY.value()));
+  if (gridDimZ.has_value())
+    gpuLaunch.gridSizeZMutable().assign(createConstValue(gridDimZ.value()));
+  if (blockDimX.has_value())
+    gpuLaunch.blockSizeXMutable().assign(createConstValue(blockDimX.value()));
+  if (blockDimY.has_value())
+    gpuLaunch.blockSizeYMutable().assign(createConstValue(blockDimY.value()));
+  if (blockDimZ.has_value())
+    gpuLaunch.blockSizeZMutable().assign(createConstValue(blockDimZ.value()));
   return success();
 }
 
@@ -1327,7 +1358,9 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
   if (walkResult.wasInterrupted())
     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
 
-  LogicalResult result = alterGpuLaunchBlockDim(rewriter, gpuLaunch, blockDim);
+  LogicalResult result =
+      alterGpuLaunch(rewriter, gpuLaunch, llvm::None, llvm::None, llvm::None,
+                     blockDim[0], blockDim[1], blockDim[2]);
   if (failed(result))
     return DiagnosedSilenceableFailure::definiteFailure();
 
@@ -1336,6 +1369,184 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
 }
 
 //===----------------------------------------------------------------------===//
+// MapNestedForeachThreadToGpuBlocks
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
+    RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
+    function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
+                      SmallVector<Value> &)>
+        blockIdGenerator,
+    SmallVector<int64_t> &gridDims) {
+  if (foreachThreadOp.getNumResults() > 0)
+    return foreachThreadOp->emitError(
+        "only bufferized scf.foreach_thread lowers to gpu.block_id");
+  if (foreachThreadOp.getNumThreads().size() > 3)
+    return foreachThreadOp->emitError(
+        "scf.foreach_thread with rank > 3 does not lower to gpu.block_id");
+
+  // Step 0. Outline the compute workload region and set up the workload
+  // operands.
+  auto potentialGridDim = foreachThreadOp.getPermutedNumThreads(rewriter);
+  if (failed(potentialGridDim) ||
+      llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) {
+        return !getConstantIntValue(ofr).has_value();
+      }))
+    return foreachThreadOp->emitError("unsupported dynamic gridDim");
+
+  for (OpFoldResult ofr : *potentialGridDim)
+    gridDims.push_back(getConstantIntValue(ofr).value());
+
+  IndexType indexType = rewriter.getIndexType();
+  SmallVector<Value> blockOps;
+  blockIdGenerator(foreachThreadOp, gridDims, indexType, blockOps);
+
+  // Step 1. Move the body of foreachThreadOp.
+  // Erase the terminator first, it will not be used since we are on buffers.
+  rewriter.eraseOp(foreachThreadOp.getTerminator());
+  Block *targetBlock = foreachThreadOp->getBlock();
+  Block::iterator insertionPoint = Block::iterator(foreachThreadOp);
+  Block &sourceBlock = foreachThreadOp.getRegion().front();
+  targetBlock->getOperations().splice(insertionPoint,
+                                      sourceBlock.getOperations());
+
+  // Step 2. RAUW thread indices to thread ops.
+  SmallVector<Value> threadIndices =
+      *foreachThreadOp.getPermutedThreadIndices();
+  assert(blockOps.size() == 3 && "3 block id ops are required");
+  for (auto it : llvm::zip(threadIndices, blockOps)) {
+    Value val = std::get<0>(it);
+    if (!val)
+      continue;
+    for (Operation *user : llvm::make_early_inc_range(val.getUsers())) {
+      rewriter.updateRootInPlace(
+          user, [&]() { user->replaceUsesOfWith(val, std::get<1>(it)); });
+    }
+  }
+
+  // Step 3. Erase old op.
+  rewriter.eraseOp(foreachThreadOp);
+
+  return success();
+}
+
+FailureOr<scf::ForeachThreadOp>
+mlir::linalg::findTopLevelForeachThreadOp(Operation *target) {
+  scf::ForeachThreadOp topLevelForeachThreadOp;
+  auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
+    if (foreachThreadOp->getParentOfType<scf::ForeachThreadOp>())
+      return WalkResult::advance();
+    if (topLevelForeachThreadOp)
+      // TODO Handle multiple foreach if there is no dependences between them
+      return WalkResult::interrupt();
+    topLevelForeachThreadOp = foreachThreadOp;
+    return WalkResult::advance();
+  });
+
+  if (walkResult.wasInterrupted())
+    return target->emitError(
+        "could not find a unique topLevel scf.foreach_thread");
+
+  return topLevelForeachThreadOp;
+}
+
+/// Create gpuLauncOp with given kernel configurations
+static FailureOr<gpu::LaunchOp>
+createGpuLaunch(RewriterBase &rewriter, Location loc,
+                Optional<int64_t> gridDimX = llvm::None,
+                Optional<int64_t> gridDimY = llvm::None,
+                Optional<int64_t> gridDimZ = llvm::None,
+                Optional<int64_t> blockDimX = llvm::None,
+                Optional<int64_t> blockDimY = llvm::None,
+                Optional<int64_t> blockDimZ = llvm::None) {
+  if (failed(checkGpuLimits(gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY,
+                            blockDimZ)))
+    return failure();
+  auto createConstant = [&](int dim) {
+    return rewriter.create<arith::ConstantIndexOp>(loc, dim);
+  };
+  Value one = createConstant(1);
+  Value gridSizeX =
+      gridDimX.has_value() ? createConstant(gridDimX.value()) : one;
+  Value gridSizeY =
+      gridDimY.has_value() ? createConstant(gridDimY.value()) : one;
+  Value gridSizeZ =
+      gridDimZ.has_value() ? createConstant(gridDimZ.value()) : one;
+  Value blockSizeX =
+      blockDimX.has_value() ? createConstant(blockDimX.value()) : one;
+  Value blockSizeY =
+      blockDimY.has_value() ? createConstant(blockDimY.value()) : one;
+  Value blockSizeZ =
+      blockDimZ.has_value() ? createConstant(blockDimZ.value()) : one;
+  auto launchOp = rewriter.create<gpu::LaunchOp>(
+      loc, gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ);
+  rewriter.setInsertionPointToEnd(&launchOp.body().front());
+  rewriter.create<gpu::TerminatorOp>(loc);
+  return launchOp;
+}
+
+DiagnosedSilenceableFailure
+transform::MapNestedForeachThreadToGpuBlocks::applyToOne(
+    Operation *target, SmallVectorImpl<Operation *> &results,
+    transform::TransformState &state) {
+  gpu::LaunchOp gpuLaunch = dyn_cast<gpu::LaunchOp>(target);
+  SimpleRewriter rewriter(getContext());
+
+  if (!getGenerateGpuLaunch() && !gpuLaunch) {
+    target->emitError("Given target is not gpu.launch, set "
+                      "`generate_gpu_launch` attribute");
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+
+  auto res = mlir::linalg::findTopLevelForeachThreadOp(target);
+  if (failed(res))
+    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+
+  scf::ForeachThreadOp topLevelForeachThreadOp = *res;
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(topLevelForeachThreadOp);
+
+  // Generate gpu launch here and move the foreach_thread inside
+  if (getGenerateGpuLaunch()) {
+    FailureOr<gpu::LaunchOp> maybeGpuLaunch =
+        createGpuLaunch(rewriter, target->getLoc());
+    if (failed(maybeGpuLaunch))
+      return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+    gpuLaunch = *maybeGpuLaunch;
+    rewriter.setInsertionPointToStart(&gpuLaunch.body().front());
+    Operation *newForeachThreadOp = rewriter.clone(*topLevelForeachThreadOp);
+    rewriter.eraseOp(topLevelForeachThreadOp);
+    topLevelForeachThreadOp =
+        dyn_cast<scf::ForeachThreadOp>(newForeachThreadOp);
+  }
+
+  auto generateBlocks = [&](Operation *op, const SmallVector<int64_t> &gridDims,
+                            IndexType indexType, SmallVector<Value> &blockOps) {
+    Location loc = op->getLoc();
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(op);
+    SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
+                                        gpu::Dimension::z};
+    for (int64_t idx : llvm::seq<int64_t>(0, gridDims.size())) {
+      blockOps.push_back(
+          rewriter.create<gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
+    }
+  };
+
+  SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
+  if (failed(mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
+          rewriter, topLevelForeachThreadOp, generateBlocks, gridDim)))
+    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+
+  if (failed(alterGpuLaunch(rewriter, gpuLaunch, gridDim[0], gridDim[1],
+                            gridDim[2])))
+    return DiagnosedSilenceableFailure::definiteFailure();
+
+  results.assign({gpuLaunch});
+  return DiagnosedSilenceableFailure(success());
+}
+
+//===----------------------------------------------------------------------===//
 // TileToForeachThreadOp
 //===----------------------------------------------------------------------===//
 
@@ -1562,6 +1773,7 @@ public:
     declareGeneratedDialect<arith::ArithmeticDialect>();
     declareGeneratedDialect<scf::SCFDialect>();
     declareGeneratedDialect<vector::VectorDialect>();
+    declareGeneratedDialect<gpu::GPUDialect>();
 
     registerTransformOps<
 #define GET_OP_LIST
index 00b750e..fbd7bcb 100644 (file)
@@ -1,4 +1,45 @@
-// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file  -canonicalize -cse %s | FileCheck %s
+
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-LABEL: func.func @saxpy2dblock(
+// CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME:    %[[ARGT:[0-9a-z]+]]: memref<32xf32>
+func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+  %c9 = arith.constant 9 : index
+  %c7 = arith.constant 7 : index
+  %one = arith.constant 1 : index
+//      CHECK:   gpu.launch
+//      CHECK:   %[[BLKX:.*]] = gpu.block_id  x
+//      CHECK:   %[[BLKY:.*]] = gpu.block_id  y
+//      CHECK:   memref.load %[[ARGX]][%[[BLKX]], %[[BLKY]]]
+//      CHECK:   memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]]]
+  %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+            threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) 
+  {
+    scf.foreach_thread (%i, %j) in (%c7, %c9) {
+        %4 = memref.load %x[%i, %j] : !type        
+        %5 = memref.load %y[%i, %j] : !type
+        %6 = math.fma %alpha, %4, %5 : f32
+        memref.store %6, %y[%i, %j] : !type
+     }  {thread_dim_mapping = [0, 1, 2]}               
+    gpu.terminator
+  }  
+  return %y : !type
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0
+    transform.structured.map_nested_foreach_thread_to_gpu_blocks %funcop { blockDim = [12, 9, 1]}
+  }
+}
+
+// -----
 
 !type = memref<2 x 32 x f32>
 !type1d = memref<32 x f32>
@@ -12,21 +53,20 @@ func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
   %c12 = arith.constant 12 : index
   %c9 = arith.constant 9 : index
   %c7 = arith.constant 7 : index
-//      CHECK:   gpu.launch
+//      CHECK:   %[[C1:.*]] = arith.constant 1 : index
+//      CHECK:   %[[C12:.*]] = arith.constant 12 : index
+//      CHECK:   %[[C9:.*]] = arith.constant 9 : index
+//      CHECK:   %[[C7:.*]] = arith.constant 7 : index
+//      CHECK:   gpu.launch async [%{{.*}}] blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C1]], %{{.*}} = %[[C1]], %{{.*}} = %[[C1]]) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C12]], %{{.*}} = %[[C9]], %{{.*}} = %[[C1]])
 //      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
 //      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
-//      CHECK:   %[[C9:.*]] = arith.constant 9 : index
 //      CHECK:   arith.cmpi ult, %[[TIDX]], %[[C9]] : index
-//      CHECK:   %[[C7:.*]] = arith.constant 7 : index
 //      CHECK:   arith.cmpi ult, %[[TIDY]], %[[C7]] : index
 //      CHECK:   memref.load %[[ARGX]][%[[TIDY]], %[[TIDX]]]
 //      CHECK:   memref.load %[[ARGY]][%[[TIDY]], %[[TIDX]]]
 //      CHECK:   gpu.barrier
-//      CHECK:   %[[TIDX2:.*]] = gpu.thread_id  x
-//      CHECK:   %[[TIDY2:.*]] = gpu.thread_id  y
-//      CHECK:   %[[C1:.*]] = arith.constant 1 : index
-//      CHECK:   arith.cmpi ult, %[[TIDY2]], %[[C1]] : index
-//      CHECK:   memref.load %[[ARGT]][%[[TIDX2]]]
+//      CHECK:   arith.cmpi ult, %[[TIDY]], %[[C1]] : index
+//      CHECK:   memref.load %[[ARGT]][%[[TIDX]]]
 //      CHECK:   gpu.barrier
   %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
             threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) 
@@ -56,3 +96,45 @@ transform.with_pdl_patterns {
   }
 }
 
+// -----
+
+!type4d = memref<32x64x4x32xf32>
+
+// CHECK-LABEL: func.func @saxpy4d(
+// CHECK-SAME:    %[[ARGX:[0-9a-z]+]]: memref<32x64x4x32xf32>
+// CHECK-SAME:    %[[ARGY:[0-9a-z]+]]: memref<32x64x4x32xf32>
+func.func @saxpy4d(%x: !type4d, %y: !type4d, %alpha : f32) -> !type4d {
+  %c32 = arith.constant 32 : index  
+  %c64 = arith.constant 64 : index  
+  %c4 = arith.constant 4 : index  
+//      CHECK:   %[[C32:.*]] = arith.constant 32 : index
+//      CHECK:   %[[C64:.*]] = arith.constant 64 : index
+//      CHECK:   %[[C4:.*]] = arith.constant 4 : index
+//      CHECK:   %[[C1:.*]] = arith.constant 1 : index
+//      CHECK:   gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C32]], %{{.*}} = %[[C64]], %{{.*}} = %[[C1]]) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C32]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1]])
+//      CHECK:   %[[BLKX:.*]] = gpu.block_id  x
+//      CHECK:   %[[BLKY:.*]] = gpu.block_id  y
+//      CHECK:   %[[TIDX:.*]] = gpu.thread_id  x
+//      CHECK:   %[[TIDY:.*]] = gpu.thread_id  y
+//      CHECK:   memref.load %[[ARGX]][%[[BLKX]], %[[BLKY]], %[[TIDY]], %[[TIDX]]]
+//      CHECK:   memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]], %[[TIDY]], %[[TIDX]]]
+  scf.foreach_thread (%i, %j) in (%c32, %c64) {
+    scf.foreach_thread (%k, %l) in (%c4, %c32) {
+      %4 = memref.load %x[%i, %j, %k, %l] : !type4d        
+      %5 = memref.load %y[%i, %j, %k, %l] : !type4d
+      %6 = math.fma %alpha, %4, %5 : f32
+      memref.store %6, %y[%i, %j, %k, %l] : !type4d
+    }  {thread_dim_mapping = [1, 0, 2]}
+  }  {thread_dim_mapping = [0, 1, 2]}
+  return %y : !type4d
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %funcop = transform.structured.match ops{["func.func"]} in %arg0
+    %gpuLaunch = transform.structured.map_nested_foreach_thread_to_gpu_blocks %funcop { generate_gpu_launch }
+    transform.structured.map_nested_foreach_thread_to_gpu_threads %gpuLaunch { blockDim = [32, 4, 1] }
+  }
+}