rusticl/program: enable spirv
authorKarol Herbst <kherbst@redhat.com>
Mon, 5 Sep 2022 15:22:56 +0000 (17:22 +0200)
committerMarge Bot <emma+marge@anholt.net>
Mon, 13 Feb 2023 12:45:07 +0000 (12:45 +0000)
Signed-off-by: Karol Herbst <kherbst@redhat.com>
Reviewed-by: Adam Jackson <ajax@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19008>

src/gallium/frontends/rusticl/api/device.rs
src/gallium/frontends/rusticl/api/platform.rs
src/gallium/frontends/rusticl/api/program.rs
src/gallium/frontends/rusticl/core/device.rs
src/gallium/frontends/rusticl/core/program.rs
src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs

index 21d4ab3..ae094dd 100644 (file)
@@ -2,6 +2,7 @@ use crate::api::icd::*;
 use crate::api::platform::*;
 use crate::api::util::*;
 use crate::core::device::*;
+use crate::core::version::*;
 
 use mesa_rust_gen::*;
 use mesa_rust_util::ptr::*;
@@ -14,16 +15,13 @@ use std::ptr;
 use std::sync::Arc;
 use std::sync::Once;
 
-// TODO spec constants need to be implemented
-const SPIRV_SUPPORT_STRING: &str = "";
-//    "SPIR-V_1.0 SPIR-V_1.1 SPIR-V_1.2 SPIR-V_1.3 SPIR-V_1.4 SPIR-V_1.5";
-const SPIRV_SUPPORT: [cl_name_version; 0] = [
-/*    mk_cl_version_ext(1, 0, 0, b"SPIR-V"),
-    mk_cl_version_ext(1, 1, 0, b"SPIR-V"),
-    mk_cl_version_ext(1, 2, 0, b"SPIR-V"),
-    mk_cl_version_ext(1, 3, 0, b"SPIR-V"),
-    mk_cl_version_ext(1, 4, 0, b"SPIR-V"),
-    mk_cl_version_ext(1, 5, 0, b"SPIR-V"),*/
+const SPIRV_SUPPORT_STRING: &str = "SPIR-V_1.0 SPIR-V_1.1 SPIR-V_1.2 SPIR-V_1.3 SPIR-V_1.4";
+const SPIRV_SUPPORT: [cl_name_version; 5] = [
+    mk_cl_version_ext(1, 0, 0, "SPIR-V"),
+    mk_cl_version_ext(1, 1, 0, "SPIR-V"),
+    mk_cl_version_ext(1, 2, 0, "SPIR-V"),
+    mk_cl_version_ext(1, 3, 0, "SPIR-V"),
+    mk_cl_version_ext(1, 4, 0, "SPIR-V"),
 ];
 
 impl CLInfo<cl_device_info> for cl_device_id {
index 3dd7624..80d6cbf 100644 (file)
@@ -10,7 +10,7 @@ use rusticl_opencl_gen::*;
 #[allow(non_camel_case_types)]
 pub struct _cl_platform_id {
     dispatch: &'static cl_icd_dispatch,
-    extensions: [cl_name_version; 1],
+    extensions: [cl_name_version; 2],
 }
 
 impl CLInfo<cl_platform_info> for cl_platform_id {
@@ -18,8 +18,7 @@ impl CLInfo<cl_platform_info> for cl_platform_id {
         let p = self.get_ref()?;
         Ok(match q {
             // TODO spirv
-            CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd"),
-            //            CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd cl_khr_il_program"),
+            CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd cl_khr_il_program"),
             CL_PLATFORM_EXTENSIONS_WITH_VERSION => {
                 cl_prop::<Vec<cl_name_version>>(p.extensions.to_vec())
             }
@@ -41,8 +40,7 @@ static PLATFORM: _cl_platform_id = _cl_platform_id {
     dispatch: &DISPATCH,
     extensions: [
         mk_cl_version_ext(1, 0, 0, "cl_khr_icd"),
-        // TODO spirv
-        //        mk_cl_version_ext(1, 0, 0, "cl_khr_il_program"),
+        mk_cl_version_ext(1, 0, 0, "cl_khr_il_program"),
     ],
 };
 
index 15d9dcb..daa2180 100644 (file)
@@ -237,17 +237,16 @@ pub fn create_program_with_il(
     il: *const ::std::os::raw::c_void,
     length: usize,
 ) -> CLResult<cl_program> {
-    let _c = context.get_arc()?;
+    let c = context.get_arc()?;
 
     // CL_INVALID_VALUE if il is NULL or if length is zero.
     if il.is_null() || length == 0 {
         return Err(CL_INVALID_VALUE);
     }
 
-    //    let spirv = unsafe { slice::from_raw_parts(il.cast(), length) };
-    // TODO SPIR-V
-    //    Ok(cl_program::from_arc(Program::from_spirv(c, spirv)))
-    Err(CL_INVALID_OPERATION)
+    // SAFETY: according to API spec
+    let spirv = unsafe { slice::from_raw_parts(il.cast(), length) };
+    Ok(cl_program::from_arc(Program::from_spirv(c, spirv)))
 }
 
 pub fn build_program(
@@ -417,29 +416,36 @@ pub fn link_program(
 
 pub fn set_program_specialization_constant(
     program: cl_program,
-    _spec_id: cl_uint,
-    _spec_size: usize,
+    spec_id: cl_uint,
+    spec_size: usize,
     spec_value: *const ::std::os::raw::c_void,
 ) -> CLResult<()> {
-    let _program = program.get_ref()?;
+    let program = program.get_ref()?;
 
     // CL_INVALID_PROGRAM if program is not a valid program object created from an intermediate
     // language (e.g. SPIR-V)
     // TODO: or if the intermediate language does not support specialization constants.
-    //    if program.il.is_empty() {
-    //        Err(CL_INVALID_PROGRAM)?
-    //    }
+    if program.il.is_empty() {
+        return Err(CL_INVALID_PROGRAM);
+    }
 
-    // TODO: CL_INVALID_VALUE if spec_size does not match the size of the specialization constant in the module,
+    if spec_size != program.get_spec_constant_size(spec_id).into() {
+        // CL_INVALID_VALUE if spec_size does not match the size of the specialization constant in
+        // the module,
+        return Err(CL_INVALID_VALUE);
+    }
 
     // or if spec_value is NULL.
     if spec_value.is_null() {
         return Err(CL_INVALID_VALUE);
     }
 
-    Err(CL_INVALID_OPERATION)
+    // SAFETY: according to API spec
+    program.set_spec_constant(spec_id, unsafe {
+        slice::from_raw_parts(spec_value.cast(), spec_size)
+    });
 
-    //• CL_INVALID_SPEC_ID if spec_id is not a valid specialization constant identifier.
+    Ok(())
 }
 
 pub fn set_program_release_callback(
index 5a73f6f..99cdce9 100644 (file)
@@ -483,8 +483,7 @@ impl Device {
         add_ext(1, 0, 0, "cl_khr_byte_addressable_store", "");
         add_ext(1, 0, 0, "cl_khr_global_int32_base_atomics", "");
         add_ext(1, 0, 0, "cl_khr_global_int32_extended_atomics", "");
-        // TODO spirv
-        // add_ext(1, 0, 0, "cl_khr_il_program", "");
+        add_ext(1, 0, 0, "cl_khr_il_program", "");
         add_ext(1, 0, 0, "cl_khr_local_int32_base_atomics", "");
         add_ext(1, 0, 0, "cl_khr_local_int32_extended_atomics", "");
 
index 42d2f8a..60bb813 100644 (file)
@@ -53,7 +53,7 @@ pub struct Program {
     pub src: CString,
     pub il: Vec<u8>,
     pub kernel_count: AtomicU32,
-    spec_constants: Mutex<Vec<spirv::SpecConstant>>,
+    spec_constants: Mutex<HashMap<u32, nir_const_value>>,
     build: Mutex<ProgramBuild>,
 }
 
@@ -144,7 +144,7 @@ impl Program {
             src: src,
             il: Vec::new(),
             kernel_count: AtomicU32::new(0),
-            spec_constants: Mutex::new(Vec::new()),
+            spec_constants: Mutex::new(HashMap::new()),
             build: Mutex::new(ProgramBuild {
                 builds: builds,
                 kernels: Vec::new(),
@@ -217,7 +217,7 @@ impl Program {
             src: CString::new("").unwrap(),
             il: Vec::new(),
             kernel_count: AtomicU32::new(0),
-            spec_constants: Mutex::new(Vec::new()),
+            spec_constants: Mutex::new(HashMap::new()),
             build: Mutex::new(ProgramBuild {
                 builds: builds,
                 kernels: kernels.into_iter().collect(),
@@ -250,7 +250,7 @@ impl Program {
             src: CString::new("").unwrap(),
             il: spirv.to_vec(),
             kernel_count: AtomicU32::new(0),
-            spec_constants: Mutex::new(Vec::new()),
+            spec_constants: Mutex::new(HashMap::new()),
             build: Mutex::new(ProgramBuild {
                 builds: builds,
                 kernels: Vec::new(),
@@ -518,7 +518,7 @@ impl Program {
             src: CString::new("").unwrap(),
             il: Vec::new(),
             kernel_count: AtomicU32::new(0),
-            spec_constants: Mutex::new(Vec::new()),
+            spec_constants: Mutex::new(HashMap::new()),
             build: Mutex::new(ProgramBuild {
                 builds: builds,
                 kernels: kernels.into_iter().collect(),
@@ -535,6 +535,15 @@ impl Program {
             let spirv = info.spirv.as_ref().unwrap();
             let mut bin = spirv.to_bin().to_vec();
             bin.extend_from_slice(name.as_bytes());
+
+            for (k, v) in self.spec_constants.lock().unwrap().iter() {
+                bin.extend_from_slice(&k.to_ne_bytes());
+                unsafe {
+                    // SAFETY: we fully initialize this union
+                    bin.extend_from_slice(&v.u64_.to_ne_bytes());
+                }
+            }
+
             Some(cache.gen_key(&bin))
         } else {
             None
@@ -571,7 +580,19 @@ impl Program {
     }
 
     pub fn to_nir(&self, kernel: &str, d: &Arc<Device>) -> NirShader {
+        let constants = self.spec_constants.lock().unwrap();
+        let mut spec_constants: Vec<_> = constants
+            .iter()
+            .map(|(&id, &value)| nir_spirv_specialization {
+                id: id,
+                value: value,
+                defined_on_module: true,
+            })
+            .collect();
+        drop(constants);
+
         let mut lock = self.build_info();
+
         let info = Self::dev_build_info(&mut lock, d);
         assert_eq!(info.status, CL_BUILD_SUCCESS as cl_build_status);
         info.spirv
@@ -582,7 +603,7 @@ impl Program {
                 d.screen
                     .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
                 &d.lib_clc,
-                &mut [],
+                &mut spec_constants,
                 d.address_bits(),
             )
             .unwrap()
@@ -595,4 +616,27 @@ impl Program {
     pub fn is_src(&self) -> bool {
         !self.src.to_bytes().is_empty()
     }
+
+    pub fn get_spec_constant_size(&self, spec_id: u32) -> u8 {
+        let lock = self.build_info();
+        let spirv = lock.builds.values().next().unwrap().spirv.as_ref().unwrap();
+        spirv
+            .spec_constant(spec_id)
+            .map_or(0, spirv::CLCSpecConstantType::size)
+    }
+
+    pub fn set_spec_constant(&self, spec_id: u32, data: &[u8]) {
+        let mut lock = self.spec_constants.lock().unwrap();
+        let mut val = nir_const_value::default();
+
+        match data.len() {
+            1 => val.u8_ = u8::from_ne_bytes(data.try_into().unwrap()),
+            2 => val.u16_ = u16::from_ne_bytes(data.try_into().unwrap()),
+            4 => val.u32_ = u32::from_ne_bytes(data.try_into().unwrap()),
+            8 => val.u64_ = u64::from_ne_bytes(data.try_into().unwrap()),
+            _ => unreachable!("Spec constant with invalid size!"),
+        };
+
+        lock.insert(spec_id, val);
+    }
 }
index fafa250..b957b37 100644 (file)
@@ -368,6 +368,17 @@ impl SPIRVBin {
         }
     }
 
+    pub fn spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type> {
+        let info = self.info?;
+        let spec_constants =
+            unsafe { slice::from_raw_parts(info.spec_constants, info.num_spec_constants as usize) };
+
+        spec_constants
+            .iter()
+            .find(|sc| sc.id == spec_id)
+            .map(|sc| sc.type_)
+    }
+
     pub fn print(&self) {
         unsafe {
             clc_dump_spirv(&self.spirv, stderr_ptr());
@@ -429,3 +440,25 @@ impl SPIRVKernelArg {
         })
     }
 }
+
+pub trait CLCSpecConstantType {
+    fn size(self) -> u8;
+}
+
+impl CLCSpecConstantType for clc_spec_constant_type {
+    fn size(self) -> u8 {
+        match self {
+            Self::CLC_SPEC_CONSTANT_INT64
+            | Self::CLC_SPEC_CONSTANT_UINT64
+            | Self::CLC_SPEC_CONSTANT_DOUBLE => 8,
+            Self::CLC_SPEC_CONSTANT_INT32
+            | Self::CLC_SPEC_CONSTANT_UINT32
+            | Self::CLC_SPEC_CONSTANT_FLOAT => 4,
+            Self::CLC_SPEC_CONSTANT_INT16 | Self::CLC_SPEC_CONSTANT_UINT16 => 2,
+            Self::CLC_SPEC_CONSTANT_INT8
+            | Self::CLC_SPEC_CONSTANT_UINT8
+            | Self::CLC_SPEC_CONSTANT_BOOL => 1,
+            Self::CLC_SPEC_CONSTANT_UNKNOWN => 0,
+        }
+    }
+}