rusticl: Move nir compilation to Program
authorAntonio Gomes <antoniospg100@gmail.com>
Mon, 10 Apr 2023 23:57:07 +0000 (20:57 -0300)
committerMarge Bot <emma+marge@anholt.net>
Wed, 26 Apr 2023 20:49:42 +0000 (20:49 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22434>

src/gallium/frontends/rusticl/api/kernel.rs
src/gallium/frontends/rusticl/api/program.rs
src/gallium/frontends/rusticl/core/kernel.rs
src/gallium/frontends/rusticl/core/program.rs

index 523a6f2..8a7e180 100644 (file)
@@ -170,11 +170,7 @@ pub fn create_kernel(
         return Err(CL_INVALID_KERNEL_DEFINITION);
     }
 
-    Ok(cl_kernel::from_arc(Kernel::new(
-        name,
-        p,
-        kernel_args.into_iter().next().unwrap(),
-    )))
+    Ok(cl_kernel::from_arc(Kernel::new(name, p)))
 }
 
 pub fn create_kernels_in_program(
@@ -207,11 +203,7 @@ pub fn create_kernels_in_program(
             unsafe {
                 kernels
                     .add(num_kernels as usize)
-                    .write(cl_kernel::from_arc(Kernel::new(
-                        name,
-                        p.clone(),
-                        kernel_args.into_iter().next().unwrap(),
-                    )));
+                    .write(cl_kernel::from_arc(Kernel::new(name, p.clone())));
             }
         }
         num_kernels += 1;
index d987a9e..a4a4b4e 100644 (file)
@@ -235,7 +235,10 @@ pub fn create_program_with_binary(
         return Err(err);
     }
 
-    Ok(cl_program::from_arc(Program::from_bins(c, devs, &bins)))
+    let prog = Program::from_bins(c, devs, &bins);
+    prog.build_nirs();
+
+    Ok(cl_program::from_arc(prog))
     //• CL_INVALID_BINARY if an invalid program binary was encountered for any device. binary_status will return specific status for each device.
 }
 
@@ -289,6 +292,7 @@ pub fn build_program(
     //• CL_INVALID_OPERATION if program was not created with clCreateProgramWithSource, clCreateProgramWithIL or clCreateProgramWithBinary.
 
     if res {
+        p.build_nirs();
         Ok(())
     } else {
         if Platform::dbg().program {
@@ -431,6 +435,9 @@ pub fn link_program(
         CL_LINK_PROGRAM_FAILURE
     };
 
+    // Pre build nir kernels
+    res.build_nirs();
+
     let res = cl_program::from_arc(res);
 
     call_cb(pfn_notify, res, user_data);
index ccc1903..b0fea9e 100644 (file)
@@ -22,7 +22,6 @@ use rusticl_opencl_gen::*;
 use std::cell::RefCell;
 use std::cmp;
 use std::collections::HashMap;
-use std::collections::HashSet;
 use std::convert::TryInto;
 use std::os::raw::c_void;
 use std::ptr;
@@ -255,7 +254,7 @@ impl InternalKernelArg {
 }
 
 struct KernelDevStateInner {
-    nir: NirShader,
+    nir: Arc<NirShader>,
     constant_buffer: Option<Arc<PipeResource>>,
     cso: *mut c_void,
     info: pipe_compute_state_object_info,
@@ -276,7 +275,7 @@ impl Drop for KernelDevState {
 }
 
 impl KernelDevState {
-    fn new(nirs: HashMap<Arc<Device>, NirShader>) -> Arc<Self> {
+    fn new(nirs: HashMap<Arc<Device>, Arc<NirShader>>) -> Arc<Self> {
         let states = nirs
             .into_iter()
             .map(|(dev, nir)| {
@@ -736,94 +735,62 @@ fn deserialize_nir(
     Some((nir, args, internal_args))
 }
 
-fn convert_spirv_to_nir(
+pub fn convert_spirv_to_nir(
     p: &Program,
     name: &str,
-    args: Vec<spirv::SPIRVKernelArg>,
-) -> (
-    HashMap<Arc<Device>, NirShader>,
-    Vec<KernelArg>,
-    Vec<InternalKernelArg>,
-    String,
-) {
-    let mut nirs = HashMap::new();
-    let mut args_set = HashSet::new();
-    let mut internal_args_set = HashSet::new();
-    let mut attributes_string_set = HashSet::new();
-
-    // TODO: we could run this in parallel?
-    for d in p.devs_with_build() {
-        let cache = d.screen().shader_cache();
-        let key = p.hash_key(d, name);
-
-        let res = if let Some(cache) = &cache {
-            cache.get(&mut key.unwrap()).and_then(|entry| {
-                let mut bin: &[u8] = &entry;
-                deserialize_nir(&mut bin, d)
-            })
-        } else {
-            None
-        };
-
-        let (nir, args, internal_args) = if let Some(res) = res {
-            res
-        } else {
-            let mut nir = p.to_nir(name, d);
-
-            /* this is a hack until we support fp16 properly and check for denorms inside
-             * vstore/vload_half
-             */
-            nir.preserve_fp16_denorms();
+    args: &[spirv::SPIRVKernelArg],
+    dev: &Arc<Device>,
+) -> (NirShader, Vec<KernelArg>, Vec<InternalKernelArg>, String) {
+    let cache = dev.screen().shader_cache();
+    let key = p.hash_key(dev, name);
+
+    let res = if let Some(cache) = &cache {
+        cache.get(&mut key.unwrap()).and_then(|entry| {
+            let mut bin: &[u8] = &entry;
+            deserialize_nir(&mut bin, dev)
+        })
+    } else {
+        None
+    };
 
-            lower_and_optimize_nir_pre_inputs(d, &mut nir, &d.lib_clc);
-            let mut args = KernelArg::from_spirv_nir(&args, &mut nir);
-            let internal_args = lower_and_optimize_nir_late(d, &mut nir, &mut args);
+    let (nir, args, internal_args) = if let Some(res) = res {
+        res
+    } else {
+        let mut nir = p.to_nir(name, dev);
 
-            if let Some(cache) = cache {
-                let mut bin = Vec::new();
-                let mut nir = nir.serialize();
+        /* this is a hack until we support fp16 properly and check for denorms inside
+         * vstore/vload_half
+         */
+        nir.preserve_fp16_denorms();
 
-                bin.extend_from_slice(&nir.len().to_ne_bytes());
-                bin.append(&mut nir);
+        lower_and_optimize_nir_pre_inputs(dev, &mut nir, &dev.lib_clc);
+        let mut args = KernelArg::from_spirv_nir(args, &mut nir);
+        let internal_args = lower_and_optimize_nir_late(dev, &mut nir, &mut args);
 
-                bin.extend_from_slice(&args.len().to_ne_bytes());
-                for arg in &args {
-                    bin.append(&mut arg.serialize());
-                }
+        if let Some(cache) = cache {
+            let mut bin = Vec::new();
+            let mut nir = nir.serialize();
 
-                bin.extend_from_slice(&internal_args.len().to_ne_bytes());
-                for arg in &internal_args {
-                    bin.append(&mut arg.serialize());
-                }
+            bin.extend_from_slice(&nir.len().to_ne_bytes());
+            bin.append(&mut nir);
 
-                cache.put(&bin, &mut key.unwrap());
+            bin.extend_from_slice(&args.len().to_ne_bytes());
+            for arg in &args {
+                bin.append(&mut arg.serialize());
             }
 
-            (nir, args, internal_args)
-        };
+            bin.extend_from_slice(&internal_args.len().to_ne_bytes());
+            for arg in &internal_args {
+                bin.append(&mut arg.serialize());
+            }
 
-        args_set.insert(args);
-        internal_args_set.insert(internal_args);
-        nirs.insert(d.clone(), nir);
-        attributes_string_set.insert(p.attribute_str(name, d));
-    }
+            cache.put(&bin, &mut key.unwrap());
+        }
 
-    // we want the same (internal) args for every compiled kernel, for now
-    assert!(args_set.len() == 1);
-    assert!(internal_args_set.len() == 1);
-    assert!(attributes_string_set.len() == 1);
-    let args = args_set.into_iter().next().unwrap();
-    let internal_args = internal_args_set.into_iter().next().unwrap();
-
-    // spec: For kernels not created from OpenCL C source and the clCreateProgramWithSource API call
-    // the string returned from this query [CL_KERNEL_ATTRIBUTES] will be empty.
-    let attributes_string = if p.is_src() {
-        attributes_string_set.into_iter().next().unwrap()
-    } else {
-        String::new()
+        (nir, args, internal_args)
     };
 
-    (nirs, args, internal_args, attributes_string)
+    (nir, args, internal_args, p.attribute_str(name, dev))
 }
 
 fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
@@ -835,9 +802,12 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
 }
 
 impl Kernel {
-    pub fn new(name: String, prog: Arc<Program>, args: Vec<spirv::SPIRVKernelArg>) -> Arc<Kernel> {
-        let (mut nirs, args, internal_args, attributes_string) =
-            convert_spirv_to_nir(&prog, &name, args);
+    pub fn new(name: String, prog: Arc<Program>) -> Arc<Kernel> {
+        let nir_kernel_build = prog.get_nir_kernel_build(&name);
+        let mut nirs = nir_kernel_build.nirs;
+        let args = nir_kernel_build.args;
+        let internal_args = nir_kernel_build.internal_args;
+        let attributes_string = nir_kernel_build.attributes_string;
 
         let nir = nirs.values_mut().next().unwrap();
         let wgs = nir.workgroup_size();
index afed3c8..232b25e 100644 (file)
@@ -1,6 +1,7 @@
 use crate::api::icd::*;
 use crate::core::context::*;
 use crate::core::device::*;
+use crate::core::kernel::*;
 use crate::core::platform::Platform;
 use crate::impl_cl_type_trait;
 
@@ -63,10 +64,19 @@ pub struct Program {
     pub kernel_count: AtomicU32,
     spec_constants: Mutex<HashMap<u32, nir_const_value>>,
     build: Mutex<ProgramBuild>,
+    nir_builds: Mutex<HashMap<String, NirKernelBuild>>,
 }
 
 impl_cl_type_trait!(cl_program, Program, CL_INVALID_PROGRAM);
 
+#[derive(Clone)]
+pub struct NirKernelBuild {
+    pub nirs: HashMap<Arc<Device>, Arc<NirShader>>,
+    pub args: Vec<KernelArg>,
+    pub internal_args: Vec<InternalKernelArg>,
+    pub attributes_string: String,
+}
+
 struct ProgramBuild {
     builds: HashMap<Arc<Device>, ProgramDevBuild>,
     kernels: Vec<String>,
@@ -157,6 +167,7 @@ impl Program {
                 builds: Self::create_default_builds(devs),
                 kernels: Vec::new(),
             }),
+            nir_builds: Mutex::new(HashMap::new()),
         })
     }
 
@@ -229,6 +240,7 @@ impl Program {
                 builds: builds,
                 kernels: kernels.into_iter().collect(),
             }),
+            nir_builds: Mutex::new(HashMap::new()),
         })
     }
 
@@ -245,6 +257,7 @@ impl Program {
                 builds: builds,
                 kernels: Vec::new(),
             }),
+            nir_builds: Mutex::new(HashMap::new()),
         })
     }
 
@@ -259,6 +272,20 @@ impl Program {
         l.builds.get_mut(dev).unwrap()
     }
 
+    fn nir_build_info(&self) -> MutexGuard<HashMap<String, NirKernelBuild>> {
+        self.nir_builds.lock().unwrap()
+    }
+
+    pub fn get_nir_kernel_build(&self, name: &str) -> NirKernelBuild {
+        let info = self.nir_build_info();
+        info.get(name).unwrap().clone()
+    }
+
+    pub fn set_nir_kernel_build(&self, name: &str, nir_build: NirKernelBuild) {
+        let mut info = self.nir_build_info();
+        info.insert(String::from(name), nir_build);
+    }
+
     pub fn status(&self, dev: &Arc<Device>) -> cl_build_status {
         Self::dev_build_info(&mut self.build_info(), dev).status
     }
@@ -496,9 +523,58 @@ impl Program {
                 builds: builds,
                 kernels: kernels.into_iter().collect(),
             }),
+            nir_builds: Mutex::new(HashMap::new()),
         })
     }
 
+    pub fn build_nir_kernel(&self, name: &str, args: Vec<spirv::SPIRVKernelArg>) -> NirKernelBuild {
+        let mut nirs = HashMap::new();
+        let mut args_set = HashSet::new();
+        let mut internal_args_set = HashSet::new();
+        let mut attributes_string_set = HashSet::new();
+
+        // TODO: we could run this in parallel?
+        for d in self.devs_with_build() {
+            let (nir, args, internal_args, attributes_string) =
+                convert_spirv_to_nir(self, name, &args, d);
+            nirs.insert(d.clone(), Arc::new(nir));
+            args_set.insert(args);
+            internal_args_set.insert(internal_args);
+            attributes_string_set.insert(attributes_string);
+        }
+
+        // we want the same (internal) args for every compiled kernel, for now
+        assert!(args_set.len() == 1);
+        assert!(internal_args_set.len() == 1);
+        assert!(attributes_string_set.len() == 1);
+        let args = args_set.into_iter().next().unwrap();
+        let internal_args = internal_args_set.into_iter().next().unwrap();
+
+        // spec: For kernels not created from OpenCL C source and the clCreateProgramWithSource API call
+        // the string returned from this query [CL_KERNEL_ATTRIBUTES] will be empty.
+        let attributes_string = if self.is_src() {
+            attributes_string_set.into_iter().next().unwrap()
+        } else {
+            String::new()
+        };
+
+        NirKernelBuild {
+            nirs: nirs,
+            args: args,
+            internal_args: internal_args,
+            attributes_string: attributes_string,
+        }
+    }
+
+    pub fn build_nirs(&self) {
+        let devs = self.devs_with_build();
+        for k in &self.kernels() {
+            let kernel_args: HashSet<_> = devs.iter().map(|d| self.args(d, k)).collect();
+            let nir_build = self.build_nir_kernel(k, kernel_args.into_iter().next().unwrap());
+            self.set_nir_kernel_build(k, nir_build);
+        }
+    }
+
     pub(super) fn hash_key(&self, dev: &Arc<Device>, name: &str) -> Option<cache_key> {
         if let Some(cache) = dev.screen().shader_cache() {
             let mut lock = self.build_info();