From 3dde5c231e15480466e7aa08ebbd1d7e5a2dd9e4 Mon Sep 17 00:00:00 2001 From: Antonio Gomes Date: Tue, 4 Jul 2023 21:33:41 -0300 Subject: [PATCH] rusticl: Drop some Kernel data and have a NirKernelBuild ref instead Part-of: --- src/gallium/frontends/rusticl/api/kernel.rs | 10 +++--- src/gallium/frontends/rusticl/core/kernel.rs | 51 ++++++++++++--------------- src/gallium/frontends/rusticl/core/program.rs | 8 ++--- 3 files changed, 32 insertions(+), 37 deletions(-) diff --git a/src/gallium/frontends/rusticl/api/kernel.rs b/src/gallium/frontends/rusticl/api/kernel.rs index 8f09c3b..c135f3a 100644 --- a/src/gallium/frontends/rusticl/api/kernel.rs +++ b/src/gallium/frontends/rusticl/api/kernel.rs @@ -21,13 +21,13 @@ impl CLInfo for cl_kernel { fn query(&self, q: cl_kernel_info, _: &[u8]) -> CLResult>> { let kernel = self.get_ref()?; Ok(match q { - CL_KERNEL_ATTRIBUTES => cl_prop::<&str>(&kernel.attributes_string), + CL_KERNEL_ATTRIBUTES => cl_prop::<&str>(&kernel.build.attributes_string), CL_KERNEL_CONTEXT => { let ptr = Arc::as_ptr(&kernel.prog.context); cl_prop::(cl_context::from_ptr(ptr)) } CL_KERNEL_FUNCTION_NAME => cl_prop::<&str>(&kernel.name), - CL_KERNEL_NUM_ARGS => cl_prop::(kernel.args.len() as cl_uint), + CL_KERNEL_NUM_ARGS => cl_prop::(kernel.build.args.len() as cl_uint), CL_KERNEL_PROGRAM => { let ptr = Arc::as_ptr(&kernel.prog); cl_prop::(cl_program::from_ptr(ptr)) @@ -45,7 +45,7 @@ impl CLInfoObj for cl_kernel { let kernel = self.get_ref()?; // CL_INVALID_ARG_INDEX if arg_index is not a valid argument index. - if idx as usize >= kernel.args.len() { + if idx as usize >= kernel.build.args.len() { return Err(CL_INVALID_ARG_INDEX); } @@ -229,7 +229,7 @@ fn set_kernel_arg( let k = kernel.get_arc()?; // CL_INVALID_ARG_INDEX if arg_index is not a valid argument index. - if let Some(arg) = k.args.get(arg_index as usize) { + if let Some(arg) = k.build.args.get(arg_index as usize) { // CL_INVALID_ARG_SIZE if arg_size does not match the size of the data type for an argument // that is not a memory object or if the argument is a memory object and // arg_size != sizeof(cl_mem) or if arg_size is zero and the argument is declared with the @@ -329,7 +329,7 @@ fn set_kernel_arg_svm_pointer( return Err(CL_INVALID_OPERATION); } - if let Some(arg) = kernel.args.get(arg_index) { + if let Some(arg) = kernel.build.args.get(arg_index) { if !matches!( arg.kind, KernelArgType::MemConstant | KernelArgType::MemGlobal diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index ee2e4c2..9713bdf 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -273,15 +273,15 @@ impl Drop for KernelDevState { } impl KernelDevState { - fn new(nirs: HashMap, Arc>) -> Arc { + fn new(nirs: &HashMap, Arc>) -> Arc { let states = nirs - .into_iter() + .iter() .map(|(dev, nir)| { let mut cso = dev .helper_ctx() - .create_compute_state(&nir, nir.shared_size()); + .create_compute_state(nir, nir.shared_size()); let info = dev.helper_ctx().compute_state_info(cso); - let cb = Self::create_nir_constant_buffer(&dev, &nir); + let cb = Self::create_nir_constant_buffer(dev, nir); // if we can't share the cso between threads, destroy it now. if !dev.shareable_shaders() { @@ -290,9 +290,9 @@ impl KernelDevState { }; ( - dev, + dev.clone(), KernelDevStateInner { - nir: nir, + nir: nir.clone(), constant_buffer: cb, cso: cso, info: info, @@ -333,11 +333,9 @@ pub struct Kernel { pub base: CLObjectBase, pub prog: Arc, pub name: String, - pub args: Vec, pub values: Vec>>, pub work_group_size: [usize; 3], - pub attributes_string: String, - internal_args: Vec, + pub build: Arc, dev_state: Arc, } @@ -798,17 +796,18 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] { impl Kernel { pub fn new(name: String, prog: Arc) -> Arc { let nir_kernel_build = prog.get_nir_kernel_build(&name); - let mut nirs = nir_kernel_build.nirs; - let args = nir_kernel_build.args; - let internal_args = nir_kernel_build.internal_args; - let attributes_string = nir_kernel_build.attributes_string; + let nirs = &nir_kernel_build.nirs; - let nir = nirs.values_mut().next().unwrap(); + let nir = nirs.values().next().unwrap(); let wgs = nir.workgroup_size(); let work_group_size = [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize]; // can't use vec!... - let values = args.iter().map(|_| RefCell::new(None)).collect(); + let values = nir_kernel_build + .args + .iter() + .map(|_| RefCell::new(None)) + .collect(); // increase ref prog.kernel_count.fetch_add(1, Ordering::Relaxed); @@ -817,12 +816,10 @@ impl Kernel { base: CLObjectBase::new(), prog: prog, name: name, - args: args, work_group_size: work_group_size, - attributes_string: attributes_string, values: values, - internal_args: internal_args, dev_state: KernelDevState::new(nirs), + build: nir_kernel_build, }) } @@ -899,7 +896,7 @@ impl Kernel { self.optimize_local_size(&q.device, &mut grid, &mut block); - for (arg, val) in self.args.iter().zip(&self.values) { + for (arg, val) in self.build.args.iter().zip(&self.values) { if arg.dead { continue; } @@ -986,7 +983,7 @@ impl Kernel { variable_local_size -= dev_state.nir.shared_size() as u64; let mut printf_buf = None; - for arg in &self.internal_args { + for arg in &self.build.internal_args { if arg.offset > input.len() { input.resize(arg.offset, 0); } @@ -1134,7 +1131,7 @@ impl Kernel { } pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier { - let aq = self.args[idx as usize].spirv.access_qualifier; + let aq = self.build.args[idx as usize].spirv.access_qualifier; if aq == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ @@ -1151,7 +1148,7 @@ impl Kernel { } pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier { - match self.args[idx as usize].spirv.address_qualifier { + match self.build.args[idx as usize].spirv.address_qualifier { clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => { CL_KERNEL_ARG_ADDRESS_PRIVATE } @@ -1168,7 +1165,7 @@ impl Kernel { } pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier { - let tq = self.args[idx as usize].spirv.type_qualifier; + let tq = self.build.args[idx as usize].spirv.type_qualifier; let zero = clc_kernel_arg_type_qualifier(0); let mut res = CL_KERNEL_ARG_TYPE_NONE; @@ -1188,11 +1185,11 @@ impl Kernel { } pub fn arg_name(&self, idx: cl_uint) -> &String { - &self.args[idx as usize].spirv.name + &self.build.args[idx as usize].spirv.name } pub fn arg_type_name(&self, idx: cl_uint) -> &String { - &self.args[idx as usize].spirv.type_name + &self.build.args[idx as usize].spirv.type_name } pub fn priv_mem_size(&self, dev: &Arc) -> cl_ulong { @@ -1223,11 +1220,9 @@ impl Clone for Kernel { base: CLObjectBase::new(), prog: self.prog.clone(), name: self.name.clone(), - args: self.args.clone(), values: self.values.clone(), work_group_size: self.work_group_size, - attributes_string: self.attributes_string.clone(), - internal_args: self.internal_args.clone(), + build: self.build.clone(), dev_state: self.dev_state.clone(), } } diff --git a/src/gallium/frontends/rusticl/core/program.rs b/src/gallium/frontends/rusticl/core/program.rs index 72840bd..9574f0e 100644 --- a/src/gallium/frontends/rusticl/core/program.rs +++ b/src/gallium/frontends/rusticl/core/program.rs @@ -82,7 +82,7 @@ pub(super) struct ProgramBuild { builds: HashMap, ProgramDevBuild>, spec_constants: HashMap, kernels: Vec, - kernel_builds: HashMap, + kernel_builds: HashMap>, } impl ProgramBuild { @@ -148,12 +148,12 @@ impl ProgramBuild { self.kernel_builds.insert( kernel_name.clone(), - NirKernelBuild { + Arc::new(NirKernelBuild { nirs: nirs, args: args, internal_args: internal_args, attributes_string: attributes_string, - }, + }), ); } } @@ -418,7 +418,7 @@ impl Program { self.build.lock().unwrap() } - pub fn get_nir_kernel_build(&self, name: &str) -> NirKernelBuild { + pub fn get_nir_kernel_build(&self, name: &str) -> Arc { let info = self.build_info(); info.kernel_builds.get(name).unwrap().clone() } -- 2.7.4