dzn: Hook up subgroup size to compute shader compilation
authorJesse Natalie <jenatali@microsoft.com>
Tue, 2 May 2023 23:18:14 +0000 (16:18 -0700)
committerJesse Natalie <jenatali@microsoft.com>
Tue, 2 May 2023 23:39:10 +0000 (16:39 -0700)
Previously this was only in the graphics path... where it does nothing,
since D3D only supports wave size control for compute. Whoops.

Fixes: db083070 ("dzn: Implement subgroup size control extension")
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22810>

src/microsoft/vulkan/dzn_pipeline.c

index f0ac0b6..a03c0d2 100644 (file)
@@ -2417,7 +2417,7 @@ dzn_compute_pipeline_compile_shader(struct dzn_device *device,
                                     D3D12_SHADER_BYTECODE *shader,
                                     const VkComputePipelineCreateInfo *info)
 {
-   uint8_t spirv_hash[SHA1_DIGEST_LENGTH], pipeline_hash[SHA1_DIGEST_LENGTH];
+   uint8_t spirv_hash[SHA1_DIGEST_LENGTH], pipeline_hash[SHA1_DIGEST_LENGTH], nir_hash[SHA1_DIGEST_LENGTH];
    VkResult ret = VK_SUCCESS;
    nir_shader *nir = NULL;
 
@@ -2439,8 +2439,24 @@ dzn_compute_pipeline_compile_shader(struct dzn_device *device,
          goto out;
    }
 
-   struct dzn_nir_options options = { .nir_opts = dxil_get_nir_compiler_options() };
-   ret = dzn_pipeline_get_nir_shader(device, layout, cache, spirv_hash,
+   const VkPipelineShaderStageRequiredSubgroupSizeCreateInfo *subgroup_size =
+      (const VkPipelineShaderStageRequiredSubgroupSizeCreateInfo *)
+      vk_find_struct_const(info->stage.pNext, PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO);
+   enum gl_subgroup_size subgroup_enum = subgroup_size && subgroup_size->requiredSubgroupSize >= 8 ?
+      subgroup_size->requiredSubgroupSize : SUBGROUP_SIZE_FULL_SUBGROUPS;
+
+   if (cache) {
+      struct mesa_sha1 nir_hash_ctx;
+      _mesa_sha1_init(&nir_hash_ctx);
+      _mesa_sha1_update(&nir_hash_ctx, &subgroup_enum, sizeof(subgroup_enum));
+      _mesa_sha1_update(&nir_hash_ctx, spirv_hash, sizeof(spirv_hash));
+      _mesa_sha1_final(&nir_hash_ctx, nir_hash);
+   }
+   struct dzn_nir_options options = {
+      .nir_opts = dxil_get_nir_compiler_options(),
+      .subgroup_size = subgroup_enum,
+   };
+   ret = dzn_pipeline_get_nir_shader(device, layout, cache, nir_hash,
                                      &info->stage, MESA_SHADER_COMPUTE,
                                      &options, &nir);
    if (ret != VK_SUCCESS)