throw build_error();
}
- nir->info.cs.local_size_variable = true;
+ nir->info.cs.local_size_variable = sym.reqd_work_group_size[0] == 0;
+ nir->info.cs.local_size[0] = sym.reqd_work_group_size[0];
+ nir->info.cs.local_size[1] = sym.reqd_work_group_size[1];
+ nir->info.cs.local_size[2] = sym.reqd_work_group_size[2];
nir_validate_shader(nir, "clover");
// Inline all functions first.
text.data.insert(text.data.end(), blob.data, blob.data + blob.size);
m.syms.emplace_back(sym.name, std::string(),
- std::vector<size_t>(), section_id, 0, args);
+ sym.reqd_work_group_size, section_id, 0, args);
m.secs.push_back(text);
section_id++;
}
std::string kernel_name;
size_t kernel_nb = 0u;
std::vector<module::argument> args;
+ std::vector<size_t> req_local_size;
module m;
+ std::unordered_map<SpvId, std::vector<size_t> > req_local_sizes;
std::unordered_map<SpvId, std::string> kernels;
std::unordered_map<SpvId, module::argument> types;
std::unordered_map<SpvId, SpvId> pointer_types;
source.data() + (i + 3u) * sizeof(uint32_t));
break;
+ case SpvOpExecutionMode:
+ switch (get<SpvExecutionMode>(inst, 2)) {
+ case SpvExecutionModeLocalSize:
+ req_local_sizes[get<SpvId>(inst, 1)] = {
+ get<uint32_t>(inst, 3),
+ get<uint32_t>(inst, 4),
+ get<uint32_t>(inst, 5)
+ };
+ break;
+ default:
+ break;
+ }
+
case SpvOpDecorate: {
const auto id = get<SpvId>(inst, 1);
const auto decoration = get<SpvDecoration>(inst, 2);
}
case SpvOpFunction: {
- const auto kernels_iter = kernels.find(get<SpvId>(inst, 2));
+ auto id = get<SpvId>(inst, 2);
+ const auto kernels_iter = kernels.find(id);
if (kernels_iter != kernels.end())
kernel_name = kernels_iter->second;
+
+ const auto req_local_size_iter = req_local_sizes.find(id);
+ if (req_local_size_iter != req_local_sizes.end())
+ req_local_size = (*req_local_size_iter).second;
+ else
+ req_local_size = { 0, 0, 0 };
+
break;
}
args[i].info.type_name = param_type_names[kernel_name][i];
m.syms.emplace_back(kernel_name, std::string(),
- std::vector<size_t>(), 0, kernel_nb, args);
+ req_local_size, 0, kernel_nb, args);
++kernel_nb;
kernel_name.clear();
args.clear();