clover/spirv: support CL_KERNEL_COMPILE_WORK_GROUP_SIZE
authorKarol Herbst <kherbst@redhat.com>
Sun, 23 Aug 2020 14:46:05 +0000 (16:46 +0200)
committerMarge Bot <eric+marge@anholt.net>
Wed, 7 Oct 2020 13:18:22 +0000 (13:18 +0000)
Reviewed-by: Serge Martin <edb@sigluy.net>
Reviewed-by: Francisco Jerez <currojerez@riseup.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4974>

src/gallium/frontends/clover/nir/invocation.cpp
src/gallium/frontends/clover/spirv/invocation.cpp

index bb27895..f3dd26f 100644 (file)
@@ -288,7 +288,10 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
          throw build_error();
       }
 
-      nir->info.cs.local_size_variable = true;
+      nir->info.cs.local_size_variable = sym.reqd_work_group_size[0] == 0;
+      nir->info.cs.local_size[0] = sym.reqd_work_group_size[0];
+      nir->info.cs.local_size[1] = sym.reqd_work_group_size[1];
+      nir->info.cs.local_size[2] = sym.reqd_work_group_size[2];
       nir_validate_shader(nir, "clover");
 
       // Inline all functions first.
@@ -391,7 +394,7 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
       text.data.insert(text.data.end(), blob.data, blob.data + blob.size);
 
       m.syms.emplace_back(sym.name, std::string(),
-                          std::vector<size_t>(), section_id, 0, args);
+                          sym.reqd_work_group_size, section_id, 0, args);
       m.secs.push_back(text);
       section_id++;
    }
index 0ac6a99..33e7229 100644 (file)
@@ -135,9 +135,11 @@ namespace {
       std::string kernel_name;
       size_t kernel_nb = 0u;
       std::vector<module::argument> args;
+      std::vector<size_t> req_local_size;
 
       module m;
 
+      std::unordered_map<SpvId, std::vector<size_t> > req_local_sizes;
       std::unordered_map<SpvId, std::string> kernels;
       std::unordered_map<SpvId, module::argument> types;
       std::unordered_map<SpvId, SpvId> pointer_types;
@@ -185,6 +187,19 @@ namespace {
                                source.data() + (i + 3u) * sizeof(uint32_t));
             break;
 
+         case SpvOpExecutionMode:
+            switch (get<SpvExecutionMode>(inst, 2)) {
+            case SpvExecutionModeLocalSize:
+               req_local_sizes[get<SpvId>(inst, 1)] = {
+                  get<uint32_t>(inst, 3),
+                  get<uint32_t>(inst, 4),
+                  get<uint32_t>(inst, 5)
+               };
+               break;
+            default:
+               break;
+            }
+
          case SpvOpDecorate: {
             const auto id = get<SpvId>(inst, 1);
             const auto decoration = get<SpvDecoration>(inst, 2);
@@ -367,9 +382,17 @@ namespace {
          }
 
          case SpvOpFunction: {
-            const auto kernels_iter = kernels.find(get<SpvId>(inst, 2));
+            auto id = get<SpvId>(inst, 2);
+            const auto kernels_iter = kernels.find(id);
             if (kernels_iter != kernels.end())
                kernel_name = kernels_iter->second;
+
+            const auto req_local_size_iter = req_local_sizes.find(id);
+            if (req_local_size_iter != req_local_sizes.end())
+               req_local_size =  (*req_local_size_iter).second;
+            else
+               req_local_size = { 0, 0, 0 };
+
             break;
          }
 
@@ -428,7 +451,7 @@ namespace {
                args[i].info.type_name = param_type_names[kernel_name][i];
 
             m.syms.emplace_back(kernel_name, std::string(),
-                                std::vector<size_t>(), 0, kernel_nb, args);
+                                req_local_size, 0, kernel_nb, args);
             ++kernel_nb;
             kernel_name.clear();
             args.clear();