rusticl/kernel: prepare for nir caching
authorKarol Herbst <kherbst@redhat.com>
Sun, 17 Apr 2022 12:52:06 +0000 (14:52 +0200)
committerMarge Bot <emma+marge@anholt.net>
Mon, 12 Sep 2022 05:58:13 +0000 (05:58 +0000)
Signed-off-by: Karol Herbst <kherbst@redhat.com>
Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15439>

src/gallium/frontends/rusticl/api/kernel.rs
src/gallium/frontends/rusticl/core/kernel.rs
src/gallium/frontends/rusticl/core/program.rs

index 5240d0b..7258601 100644 (file)
@@ -164,7 +164,6 @@ pub fn create_kernel(
     // kernel_name such as the number of arguments, the argument types are not the same for all
     // devices for which the program executable has been built.
     let devs = get_devices_with_valid_build(&p)?;
-    let nirs = p.nirs(&name);
     let kernel_args: HashSet<_> = devs.iter().map(|d| p.args(d, &name)).collect();
     if kernel_args.len() != 1 {
         return Err(CL_INVALID_KERNEL_DEFINITION);
@@ -173,7 +172,6 @@ pub fn create_kernel(
     Ok(cl_kernel::from_arc(Kernel::new(
         name,
         p,
-        nirs,
         kernel_args.into_iter().next().unwrap(),
     )))
 }
@@ -206,13 +204,11 @@ pub fn create_kernels_in_program(
         if !kernels.is_null() {
             // we just assume the client isn't stupid
             unsafe {
-                let nirs = p.nirs(&name);
                 kernels
                     .add(num_kernels as usize)
                     .write(cl_kernel::from_arc(Kernel::new(
                         name,
                         p.clone(),
-                        nirs,
                         kernel_args.into_iter().next().unwrap(),
                     )));
             }
index 8959b54..cecc0fd 100644 (file)
@@ -32,7 +32,7 @@ pub enum KernelArgValue {
     LocalMem(usize),
 }
 
-#[derive(PartialEq, Eq, Clone)]
+#[derive(Hash, PartialEq, Eq, Clone)]
 pub enum KernelArgType {
     Constant, // for anything passed by value
     Image,
@@ -53,7 +53,7 @@ pub enum InternalKernelArgType {
     OrderArray,
 }
 
-#[derive(Clone)]
+#[derive(Hash, PartialEq, Eq, Clone)]
 pub struct KernelArg {
     spirv: spirv::SPIRVKernelArg,
     pub kind: KernelArgType,
@@ -70,7 +70,7 @@ pub struct InternalKernelArg {
 }
 
 impl KernelArg {
-    fn from_spirv_nir(spirv: Vec<spirv::SPIRVKernelArg>, nir: &mut NirShader) -> Vec<Self> {
+    fn from_spirv_nir(spirv: &[spirv::SPIRVKernelArg], nir: &mut NirShader) -> Vec<Self> {
         let nir_arg_map: HashMap<_, _> = nir
             .variables_with_mode(
                 nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
@@ -79,7 +79,7 @@ impl KernelArg {
             .collect();
         let mut res = Vec::new();
 
-        for (i, s) in spirv.into_iter().enumerate() {
+        for (i, s) in spirv.iter().enumerate() {
             let nir = nir_arg_map.get(&(i as i32)).unwrap();
             let kind = match s.address_qualifier {
                 clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
@@ -109,7 +109,7 @@ impl KernelArg {
             };
 
             res.push(Self {
-                spirv: s,
+                spirv: s.clone(),
                 size: unsafe { glsl_get_cl_size(nir.type_) } as usize,
                 // we'll update it later in the 2nd pass
                 kind: kind,
@@ -454,6 +454,43 @@ fn lower_and_optimize_nir_late(
     res
 }
 
+fn convert_spirv_to_nir(
+    p: &Program,
+    name: &str,
+    args: Vec<spirv::SPIRVKernelArg>,
+) -> (
+    HashMap<Arc<Device>, NirShader>,
+    Vec<KernelArg>,
+    Vec<InternalKernelArg>,
+) {
+    let mut nirs = HashMap::new();
+    let mut args_set = HashSet::new();
+    let mut internal_args_set = HashSet::new();
+
+    // TODO: we could run this in parallel?
+    for d in p.devs_with_build() {
+        let mut nir = p.to_nir(name, d);
+
+        lower_and_optimize_nir_pre_inputs(d, &mut nir, &d.lib_clc);
+
+        let mut args = KernelArg::from_spirv_nir(&args, &mut nir);
+        let mut internal_args = lower_and_optimize_nir_late(d, &mut nir, args.len());
+        KernelArg::assign_locations(&mut args, &mut internal_args, &mut nir);
+
+        args_set.insert(args);
+        internal_args_set.insert(internal_args);
+        nirs.insert(d.clone(), nir);
+    }
+
+    // we want the same (internal) args for every compiled kernel, for now
+    assert!(args_set.len() == 1);
+    assert!(internal_args_set.len() == 1);
+    let args = args_set.into_iter().next().unwrap();
+    let internal_args = internal_args_set.into_iter().next().unwrap();
+
+    (nirs, args, internal_args)
+}
+
 fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
     let val;
     (val, *buf) = (*buf).split_at(S);
@@ -463,30 +500,15 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
 }
 
 impl Kernel {
-    pub fn new(
-        name: String,
-        prog: Arc<Program>,
-        mut nirs: HashMap<Arc<Device>, NirShader>,
-        args: Vec<spirv::SPIRVKernelArg>,
-    ) -> Arc<Kernel> {
-        nirs.iter_mut()
-            .for_each(|(d, n)| lower_and_optimize_nir_pre_inputs(d, n, &d.lib_clc));
+    pub fn new(name: String, prog: Arc<Program>, args: Vec<spirv::SPIRVKernelArg>) -> Arc<Kernel> {
+        let (mut nirs, args, internal_args) = convert_spirv_to_nir(&prog, &name, args);
+
         let nir = nirs.values_mut().next().unwrap();
         let wgs = nir.workgroup_size();
         let work_group_size = [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize];
-        let mut args = KernelArg::from_spirv_nir(args, nir);
+
         // can't use vec!...
         let values = args.iter().map(|_| RefCell::new(None)).collect();
-        let internal_args: HashSet<_> = nirs
-            .iter_mut()
-            .map(|(d, n)| lower_and_optimize_nir_late(d, n, args.len()))
-            .collect();
-        // we want the same internal args for every compiled kernel, for now
-        assert!(internal_args.len() == 1);
-        let mut internal_args = internal_args.into_iter().next().unwrap();
-
-        nirs.values_mut()
-            .for_each(|n| KernelArg::assign_locations(&mut args, &mut internal_args, n));
 
         Arc::new(Self {
             base: CLObjectBase::new(),
index 1cb1d9a..6c36b75 100644 (file)
@@ -436,27 +436,30 @@ impl Program {
         })
     }
 
-    pub fn nirs(&self, kernel: &str) -> HashMap<Arc<Device>, NirShader> {
+    pub fn devs_with_build(&self) -> Vec<&Arc<Device>> {
         let mut lock = self.build_info();
-        let mut res = HashMap::new();
-        for d in &self.devs {
-            let info = Self::dev_build_info(&mut lock, d);
-            if info.status != CL_BUILD_SUCCESS as cl_build_status {
-                continue;
-            }
-            let nir = info
-                .spirv
-                .as_ref()
-                .unwrap()
-                .to_nir(
-                    kernel,
-                    d.screen
-                        .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
-                    &d.lib_clc,
-                )
-                .unwrap();
-            res.insert(d.clone(), nir);
-        }
-        res
+        self.devs
+            .iter()
+            .filter(|d| {
+                let info = Self::dev_build_info(&mut lock, d);
+                info.status == CL_BUILD_SUCCESS as cl_build_status
+            })
+            .collect()
+    }
+
+    pub fn to_nir(&self, kernel: &str, d: &Arc<Device>) -> NirShader {
+        let mut lock = self.build_info();
+        let info = Self::dev_build_info(&mut lock, d);
+        assert_eq!(info.status, CL_BUILD_SUCCESS as cl_build_status);
+        info.spirv
+            .as_ref()
+            .unwrap()
+            .to_nir(
+                kernel,
+                d.screen
+                    .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
+                &d.lib_clc,
+            )
+            .unwrap()
     }
 }