[mlir][spirv] Respect client API requirements for 64-bit index
authorLei Zhang <antiagainst@google.com>
Mon, 27 Feb 2023 06:15:18 +0000 (06:15 +0000)
committerLei Zhang <antiagainst@google.com>
Mon, 27 Feb 2023 06:16:50 +0000 (06:16 +0000)
Vulkan requires GPU processor ID/count builtin variables to be
32-bit scalar or vector for all the cases. Similarly there
are special requirements for OpenCL. We need to make sure those
rules are respected when converting using 64bit for index.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D144819

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/test/Conversion/GPUToSPIRV/builtins.mlir

index 51b753a..3775189 100644 (file)
@@ -144,14 +144,31 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
     SourceOp op, typename SourceOp::Adaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
-  auto indexType = typeConverter->getIndexType();
-
-  // SPIR-V invocation builtin variables are a vector of type <3xi32>
-  auto spirvBuiltin =
-      spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter);
-  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-      op, indexType, spirvBuiltin,
+  Type indexType = typeConverter->getIndexType();
+
+  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
+  // type <3xi32> by the spec:
+  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
+  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
+  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
+  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
+  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
+  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
+  //
+  // For OpenCL, it depends on the Physical32/Physical64 addressing model:
+  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
+  bool forShader =
+      typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
+  Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
+
+  Value vector =
+      spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
+  Value dim = rewriter.create<spirv::CompositeExtractOp>(
+      op.getLoc(), builtinType, vector,
       rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
+  if (forShader && builtinType != indexType)
+    dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
+  rewriter.replaceOp(op, dim);
   return success();
 }
 
@@ -161,11 +178,23 @@ SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
     SourceOp op, typename SourceOp::Adaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
-  auto indexType = typeConverter->getIndexType();
-
-  auto spirvBuiltin =
-      spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter);
-  rewriter.replaceOp(op, spirvBuiltin);
+  Type indexType = typeConverter->getIndexType();
+  Type i32Type = rewriter.getIntegerType(32);
+
+  // For Vulkan, these SPIR-V builtin variables are required to be a vector of
+  // type i32 by the spec:
+  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
+  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
+  // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
+  //
+  // For OpenCL, they are also required to be i32:
+  // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
+  Value builtinValue =
+      spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
+  if (i32Type != indexType)
+    builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
+                                                      builtinValue);
+  rewriter.replaceOp(op, builtinValue);
   return success();
 }
 
index df2efbe..29ae5f2 100644 (file)
@@ -1,4 +1,5 @@
 // RUN: mlir-opt -split-input-file -convert-gpu-to-spirv="use-64bit-index=false" %s -o - | FileCheck %s --check-prefix=INDEX32
+// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv="use-64bit-index=true" %s -o - | FileCheck %s --check-prefix=INDEX64
 
 module attributes {
   gpu.container_module,
@@ -13,12 +14,15 @@ module attributes {
 
   // INDEX32-LABEL:  spirv.module @{{.*}} Logical GLSL450
   // INDEX32: spirv.GlobalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId") : !spirv.ptr<vector<3xi32>, Input>
+  // INDEX64-LABEL:  spirv.module @{{.*}} Logical GLSL450
+  // INDEX64: spirv.GlobalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId") : !spirv.ptr<vector<3xi32>, Input>
   gpu.module @kernels {
     gpu.func @builtin_workgroup_id_x() kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
       // INDEX32: [[ADDRESS:%.*]] = spirv.mlir.addressof [[WORKGROUPID]]
       // INDEX32-NEXT: [[VEC:%.*]] = spirv.Load "Input" [[ADDRESS]]
       // INDEX32-NEXT: {{%.*}} = spirv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}}
+      // INDEX64: spirv.UConvert %{{.+}} : i32 to i64
       %0 = gpu.block_id x
       gpu.return
     }
@@ -422,11 +426,14 @@ module attributes {
 } {
   // INDEX32-LABEL:  spirv.module @{{.*}} Logical GLSL450
   // INDEX32: spirv.GlobalVariable [[SUBGROUPSIZE:@.*]] built_in("SubgroupSize") : !spirv.ptr<i32, Input>
+  // INDEX64-LABEL:  spirv.module @{{.*}} Logical GLSL450
+  // INDEX64: spirv.GlobalVariable [[SUBGROUPSIZE:@.*]] built_in("SubgroupSize") : !spirv.ptr<i32, Input>
   gpu.module @kernels {
     gpu.func @builtin_subgroup_size() kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
       // INDEX32: [[ADDRESS:%.*]] = spirv.mlir.addressof [[SUBGROUPSIZE]]
       // INDEX32-NEXT: {{%.*}} = spirv.Load "Input" [[ADDRESS]]
+      // INDEX64: spirv.UConvert %{{.+}} : i32 to i64
       %0 = gpu.subgroup_size : index
       gpu.return
     }