[mlir][GPU] Add known_block_size and known_grid_size to gpu.func
authorKrzysztof Drewniak <Krzysztof.Drewniak@amd.com>
Fri, 2 Dec 2022 20:38:39 +0000 (20:38 +0000)
committerKrzysztof Drewniak <Krzysztof.Drewniak@amd.com>
Thu, 22 Dec 2022 21:41:46 +0000 (21:41 +0000)
In many cases, the the number of workgroups (the grid size) and the
number of workitems within each group (the block size) that a GPU
kernel will be launched with are known. For example, if gpu.launch is
called with constant block and grid sizes, we know that those are the
only possible sizes that will be used to launch that kernel. In other
cases, a custom code-generation pipeline that eventually produces GPU
kernels may know the launch dimensions of those kernels, or at least
may be able to provide an upper bound on them.

Other GPU programming systems, such as OpenCL, allow capturing such
information to enable compiler optimizations - see
reqd_work_group_size, but MLIR currently has no mechanism for doing so.

This set of attributes is the first step in enabling optimizations
based on the known launch dimensions of kernels. It extends the kernel
outline pass to set these bounds on kernels with constant launch
dimensions and extends integer range inference for GPU index
operations to account for the bounds when they are known.

Subsequent revisions will use this data when lowering GPU operations
to the ROCDL dialect.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D139865

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
mlir/test/Dialect/GPU/int-range-interface.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/outlining.mlir

index baf9540..4442307 100644 (file)
@@ -205,6 +205,14 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
     coordinate work items. Declarations of GPU functions, i.e. not having the
     body region, are not supported.
 
+    A function may optionally be annotated with the block and/or grid sizes
+    that will be used when it is launched using the `gpu.known_block_size` and
+    `gpu.known_grid_size` attributes, respectively. If set, these attributes must
+    be arrays of three 32-bit integers giving the x, y, and z launch dimensions.
+    Launching a kernel that has these annotations, or that calls a function with
+    these annotations, using a block size or grid size other than what is specified
+    is undefined behavior.
+
     Syntax:
 
     ```
@@ -311,6 +319,36 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
       return "workgroup_attributions";
     }
 
+    static constexpr StringLiteral getKnownBlockSizeAttrName() {
+      return StringLiteral("gpu.known_block_size");
+    }
+
+    static constexpr StringLiteral getKnownGridSizeAttrName() {
+      return StringLiteral("gpu.known_grid_size");
+    }
+
+    /// Returns the block size this kernel will be launched with along
+    /// dimension `dim` if known. The value of gpu.thread_id dim will be strictly
+    /// less than this size.
+    Optional<uint32_t> getKnownBlockSize(gpu::Dimension dim) {
+      if (auto array =
+        (*this)->getAttrOfType<DenseI32ArrayAttr>(getKnownBlockSizeAttrName())) {
+        return array[static_cast<uint32_t>(dim)];
+      }
+      return std::nullopt;
+    }
+
+    /// Returns the grid size this kernel will be launched with along
+    /// dimension `dim` if known. The value of gpu.block_id dim will be strictly
+    /// less than this size.
+    Optional<uint32_t> getKnownGridSize(gpu::Dimension dim) {
+      if (auto array =
+        (*this)->getAttrOfType<DenseI32ArrayAttr>(getKnownGridSizeAttrName())) {
+        return array[static_cast<uint32_t>(dim)];
+      }
+      return std::nullopt;
+    }
+
     /// Returns the argument types of this function.
     ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
 
@@ -329,6 +367,8 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
     LogicalResult verifyBody();
   }];
   let hasCustomAssemblyFormat = 1;
+
+  let hasVerifier = 1;
 }
 
 def GPU_LaunchFuncOp : GPU_Op<"launch_func",
index e1d92b9..d687043 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -1057,6 +1058,27 @@ LogicalResult GPUFuncOp::verifyBody() {
   return success();
 }
 
+static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op,
+                                               StringRef attrName) {
+  auto maybeAttr = op->getAttr(attrName);
+  if (!maybeAttr)
+    return success();
+  auto array = maybeAttr.dyn_cast<DenseI32ArrayAttr>();
+  if (!array)
+    return op.emitOpError(attrName + " must be a dense i32 array");
+  if (array.size() != 3)
+    return op.emitOpError(attrName + " must contain exactly 3 elements");
+  return success();
+}
+
+LogicalResult GPUFuncOp::verify() {
+  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownBlockSizeAttrName())))
+    return failure();
+  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownGridSizeAttrName())))
+    return failure();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ReturnOp
 //===----------------------------------------------------------------------===//
index 3df44a2..d41823b 100644 (file)
@@ -7,7 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "llvm/ADT/STLForwardCompat.h"
+#include "llvm/Support/MathExtras.h"
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::gpu;
@@ -23,40 +27,107 @@ static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
                                          APInt(width, umax));
 }
 
+namespace {
+enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };
+} // end namespace
+
+/// If the operation `op` is in a context that is annotated with maximum
+/// launch dimensions (a launch op with constant block or grid
+/// sizes or a launch_func op with the appropriate dimensions), return
+/// the bound on the maximum size of the dimension that the op is querying.
+/// IDs will be one less than this bound.
+
+static Value valueByDim(KernelDim3 dims, Dimension dim) {
+  switch (dim) {
+  case Dimension::x:
+    return dims.x;
+  case Dimension::y:
+    return dims.y;
+  case Dimension::z:
+    return dims.z;
+  }
+}
+
+static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
+
+template <typename Op>
+static Optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
+  Dimension dim = op.getDimension();
+  if (auto launch = op->template getParentOfType<LaunchOp>()) {
+    KernelDim3 bounds;
+    switch (type) {
+    case LaunchDims::Block:
+      bounds = launch.getBlockSizeOperandValues();
+      break;
+    case LaunchDims::Grid:
+      bounds = launch.getGridSizeOperandValues();
+      break;
+    }
+    Value maybeBound = valueByDim(bounds, dim);
+    APInt value;
+    if (matchPattern(maybeBound, m_ConstantInt(&value)))
+      return value.getZExtValue();
+  }
+
+  if (auto func = op->template getParentOfType<GPUFuncOp>()) {
+    switch (type) {
+    case LaunchDims::Block:
+      return llvm::transformOptional(func.getKnownBlockSize(dim), zext);
+    case LaunchDims::Grid:
+      return llvm::transformOptional(func.getKnownGridSize(dim), zext);
+    }
+  }
+  return std::nullopt;
+}
+
 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                    SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), getIndexRange(1, kMaxDim));
+  Optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Block);
+  if (knownVal)
+    setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
+  else
+    setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
 void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                   SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
+  uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
+  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
 void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                   SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), getIndexRange(1, kMaxDim));
+  Optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
+  if (knownVal)
+    setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
+  else
+    setResultRange(getResult(), getIndexRange(1, kMaxDim));
 }
 
 void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                    SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
+  uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
+  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
 }
 
 void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                  SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1));
+  setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
 }
 
 void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                      SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
+  setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
 }
 
 void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                    SetIntRangeFn setResultRange) {
+  uint64_t blockDimMax =
+      getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
+  uint64_t gridDimMax =
+      getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
   setResultRange(getResult(),
-                 getIndexRange(0, std::numeric_limits<int64_t>::max()));
+                 getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
 }
 
 void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
index fadae79..e8883ea 100644 (file)
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include <limits>
 
 namespace mlir {
 #define GEN_PASS_DEF_GPULAUNCHSINKINDEXCOMPUTATIONS
@@ -147,8 +149,27 @@ LogicalResult mlir::sinkOperationsIntoLaunchOp(
   return success();
 }
 
+/// Return the provided KernelDim3 as an array of i32 constants if possible.
+static DenseI32ArrayAttr maybeConstantDimsAttr(gpu::KernelDim3 dims) {
+  SmallVector<int32_t, 3> constants;
+  MLIRContext *ctx = dims.x.getContext();
+  for (Value v : {dims.x, dims.y, dims.z}) {
+    APInt constValue;
+    if (!matchPattern(v, m_ConstantInt(&constValue)))
+      return nullptr;
+    // In the event someone called for a too-large block or grid dimension,
+    // don't set bounds as it is likely to cause more confusing behavior.
+    if (constValue.ugt(std::numeric_limits<uint32_t>::max()))
+      return nullptr;
+    constants.push_back(
+        constValue.getLimitedValue(std::numeric_limits<uint32_t>::max()));
+  }
+  return DenseI32ArrayAttr::get(ctx, constants);
+}
+
 /// Outline the `gpu.launch` operation body into a kernel function. Replace
 /// `gpu.terminator` operations by `gpu.return` in the generated function.
+/// Set block and grid size bounds if known.
 static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
                                             StringRef kernelFnName,
                                             SetVector<Value> &operands) {
@@ -173,6 +194,19 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
   auto outlinedFunc = builder.create<gpu::GPUFuncOp>(loc, kernelFnName, type);
   outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
                         builder.getUnitAttr());
+
+  // If we can infer bounds on the grid and/or block sizes from the arguments
+  // to the launch op, propagate them to the generated kernel. This is safe
+  // because multiple launches with the same body are not deduplicated.
+  if (auto blockBounds =
+          maybeConstantDimsAttr(launchOp.getBlockSizeOperandValues()))
+    outlinedFunc->setAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName(),
+                          blockBounds);
+  if (auto gridBounds =
+          maybeConstantDimsAttr(launchOp.getGridSizeOperandValues()))
+    outlinedFunc->setAttr(gpu::GPUFuncOp::getKnownGridSizeAttrName(),
+                          gridBounds);
+
   BlockAndValueMapping map;
 
   // Map the arguments corresponding to the launch parameters like blockIdx,
index 2c5af08..02aec9d 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
+// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: func @launch_func
 func.func @launch_func(%arg0 : index) {
@@ -41,12 +41,18 @@ func.func @launch_func(%arg0 : index) {
     %thread_id_y0 = test.reflect_bounds %thread_id_y
     %thread_id_z0 = test.reflect_bounds %thread_id_z
 
+    // The launch bounds are not constant, and so this can't infer anything
+    // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
+    %thread_id_op = gpu.thread_id y
+    %thread_id_op0 = test.reflect_bounds %thread_id_op
     gpu.terminator
   }
 
   func.return
 }
 
+// -----
+
 // CHECK-LABEL: func @kernel
 module attributes {gpu.container_module} {
   gpu.module @gpu_module {
@@ -100,9 +106,9 @@ module attributes {gpu.container_module} {
       %global_id_y = gpu.global_id y
       %global_id_z = gpu.global_id z
 
-      // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index}
-      // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index}
-      // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index}
+      // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
+      // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
+      // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
       %global_id_x0 = test.reflect_bounds %global_id_x
       %global_id_y0 = test.reflect_bounds %global_id_y
       %global_id_z0 = test.reflect_bounds %global_id_z
@@ -126,3 +132,86 @@ module attributes {gpu.container_module} {
   }
 }
 
+// -----
+
+// CHECK-LABEL: func @annotated_kernel
+module attributes {gpu.container_module} {
+  gpu.module @gpu_module {
+    gpu.func @annotated_kernel() kernel
+      attributes {gpu.known_block_size = array<i32: 8, 12, 16>,
+          gpu.known_grid_size = array<i32: 20, 24, 28>} {
+
+      %grid_dim_x = gpu.grid_dim x
+      %grid_dim_y = gpu.grid_dim y
+      %grid_dim_z = gpu.grid_dim z
+
+      // CHECK: test.reflect_bounds {smax = 20 : index, smin = 20 : index, umax = 20 : index, umin = 20 : index}
+      // CHECK: test.reflect_bounds {smax = 24 : index, smin = 24 : index, umax = 24 : index, umin = 24 : index}
+      // CHECK: test.reflect_bounds {smax = 28 : index, smin = 28 : index, umax = 28 : index, umin = 28 : index}
+      %grid_dim_x0 = test.reflect_bounds %grid_dim_x
+      %grid_dim_y0 = test.reflect_bounds %grid_dim_y
+      %grid_dim_z0 = test.reflect_bounds %grid_dim_z
+
+      %block_id_x = gpu.block_id x
+      %block_id_y = gpu.block_id y
+      %block_id_z = gpu.block_id z
+
+      // CHECK: test.reflect_bounds {smax = 19 : index, smin = 0 : index, umax = 19 : index, umin = 0 : index}
+      // CHECK: test.reflect_bounds {smax = 23 : index, smin = 0 : index, umax = 23 : index, umin = 0 : index}
+      // CHECK: test.reflect_bounds {smax = 27 : index, smin = 0 : index, umax = 27 : index, umin = 0 : index}
+      %block_id_x0 = test.reflect_bounds %block_id_x
+      %block_id_y0 = test.reflect_bounds %block_id_y
+      %block_id_z0 = test.reflect_bounds %block_id_z
+
+      %block_dim_x = gpu.block_dim x
+      %block_dim_y = gpu.block_dim y
+      %block_dim_z = gpu.block_dim z
+
+      // CHECK: test.reflect_bounds {smax = 8 : index, smin = 8 : index, umax = 8 : index, umin = 8 : index}
+      // CHECK: test.reflect_bounds {smax = 12 : index, smin = 12 : index, umax = 12 : index, umin = 12 : index}
+      // CHECK: test.reflect_bounds {smax = 16 : index, smin = 16 : index, umax = 16 : index, umin = 16 : index}
+      %block_dim_x0 = test.reflect_bounds %block_dim_x
+      %block_dim_y0 = test.reflect_bounds %block_dim_y
+      %block_dim_z0 = test.reflect_bounds %block_dim_z
+
+      %thread_id_x = gpu.thread_id x
+      %thread_id_y = gpu.thread_id y
+      %thread_id_z = gpu.thread_id z
+
+      // CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+      // CHECK: test.reflect_bounds {smax = 11 : index, smin = 0 : index, umax = 11 : index, umin = 0 : index}
+      // CHECK: test.reflect_bounds {smax = 15 : index, smin = 0 : index, umax = 15 : index, umin = 0 : index}
+      %thread_id_x0 = test.reflect_bounds %thread_id_x
+      %thread_id_y0 = test.reflect_bounds %thread_id_y
+      %thread_id_z0 = test.reflect_bounds %thread_id_z
+
+      %global_id_x = gpu.global_id x
+      %global_id_y = gpu.global_id y
+      %global_id_z = gpu.global_id z
+
+      // CHECK: test.reflect_bounds {smax = 159 : index, smin = 0 : index, umax = 159 : index, umin = 0 : index}
+      // CHECK: test.reflect_bounds {smax = 287 : index, smin = 0 : index, umax = 287 : index, umin = 0 : index}
+      // CHECK: test.reflect_bounds {smax = 447 : index, smin = 0 : index, umax = 447 : index, umin = 0 : index}
+      %global_id_x0 = test.reflect_bounds %global_id_x
+      %global_id_y0 = test.reflect_bounds %global_id_y
+      %global_id_z0 = test.reflect_bounds %global_id_z
+
+      %subgroup_size = gpu.subgroup_size : index
+      %lane_id = gpu.lane_id
+      %num_subgroups = gpu.num_subgroups : index
+      %subgroup_id = gpu.subgroup_id : index
+
+      // CHECK: test.reflect_bounds {smax = 128 : index, smin = 1 : index, umax = 128 : index, umin = 1 : index}
+      // CHECK: test.reflect_bounds {smax = 127 : index, smin = 0 : index, umax = 127 : index, umin = 0 : index}
+      // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
+      // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
+      %subgroup_size0 = test.reflect_bounds %subgroup_size
+      %lane_id0 = test.reflect_bounds %lane_id
+      %num_subgroups0 = test.reflect_bounds %num_subgroups
+      %subgroup_id0 = test.reflect_bounds %subgroup_id
+
+      gpu.return
+    }
+  }
+}
+
index 7a11acb..76a14d3 100644 (file)
@@ -599,3 +599,25 @@ func.func @alloc() {
    %1 = gpu.alloc(%0) : memref<2x?x?xf32, 1>
    return
 }
+
+// -----
+
+module attributes {gpu.container_module} {
+  gpu.module @kernel {
+    // expected-error@+1 {{'gpu.func' op gpu.known_block_size must be a dense i32 array}}
+    gpu.func @kernel() kernel attributes {gpu.known_block_size = 32 : i32} {
+      gpu.return
+    }
+  }
+}
+
+// -----
+
+module attributes {gpu.container_module} {
+  gpu.module @kernel {
+    // expected-error@+1 {{'gpu.func' op gpu.known_block_size must contain exactly 3 elements}}
+    gpu.func @kernel() kernel attributes {gpu.known_block_size = array<i32: 2, 1>} {
+      gpu.return
+    }
+  }
+}
index 5191dcf..422e0c1 100644 (file)
@@ -41,6 +41,8 @@ func.func @launch() {
 // CHECK-LABEL: gpu.module @launch_kernel
 // CHECK-NEXT: gpu.func @launch_kernel
 // CHECK-SAME: (%[[KERNEL_ARG0:.*]]: f32, %[[KERNEL_ARG1:.*]]: memref<?xf32, 1>)
+// CHECK-SAME: gpu.known_block_size = array<i32: 20, 24, 28>
+// CHECK-SAME: gpu.known_grid_size = array<i32: 8, 12, 16>
 // CHECK-NEXT: %[[BID:.*]] = gpu.block_id x
 // CHECK-NEXT: = gpu.block_id y
 // CHECK-NEXT: = gpu.block_id z
@@ -291,3 +293,20 @@ func.func @recursive_device_function() {
 // CHECK:   func @device_function()
 // CHECK:   func @recursive_device_function()
 // CHECK-NOT:   func @device_function
+
+// -----
+
+// CHECK-LABEL: @non_constant_launches
+func.func @non_constant_launches(%arg0 : index) {
+  // CHECK-NOT: gpu.known_block_size
+  // CHECK-NOT: gpu.known_grid_size
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %arg0, %grid_y = %arg0,
+                                       %grid_z = %arg0)
+             threads(%tx, %ty, %tz) in (%block_x = %arg0, %block_y = %arg0,
+                                        %block_z = %arg0) {
+    gpu.terminator
+  }
+  return
+}
+
+// CHECK-DL-LABEL: gpu.module @non_constant_launches_kernel attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>>}