From: Karol Herbst Date: Mon, 5 Sep 2022 15:22:56 +0000 (+0200) Subject: rusticl/program: enable spirv X-Git-Tag: upstream/23.3.3~13224 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=13a4c49cb182ce4fc199c0aa34fbea7eee6f8304;p=platform%2Fupstream%2Fmesa.git rusticl/program: enable spirv Signed-off-by: Karol Herbst Reviewed-by: Adam Jackson Part-of: --- diff --git a/src/gallium/frontends/rusticl/api/device.rs b/src/gallium/frontends/rusticl/api/device.rs index 21d4ab3..ae094dd 100644 --- a/src/gallium/frontends/rusticl/api/device.rs +++ b/src/gallium/frontends/rusticl/api/device.rs @@ -2,6 +2,7 @@ use crate::api::icd::*; use crate::api::platform::*; use crate::api::util::*; use crate::core::device::*; +use crate::core::version::*; use mesa_rust_gen::*; use mesa_rust_util::ptr::*; @@ -14,16 +15,13 @@ use std::ptr; use std::sync::Arc; use std::sync::Once; -// TODO spec constants need to be implemented -const SPIRV_SUPPORT_STRING: &str = ""; -// "SPIR-V_1.0 SPIR-V_1.1 SPIR-V_1.2 SPIR-V_1.3 SPIR-V_1.4 SPIR-V_1.5"; -const SPIRV_SUPPORT: [cl_name_version; 0] = [ -/* mk_cl_version_ext(1, 0, 0, b"SPIR-V"), - mk_cl_version_ext(1, 1, 0, b"SPIR-V"), - mk_cl_version_ext(1, 2, 0, b"SPIR-V"), - mk_cl_version_ext(1, 3, 0, b"SPIR-V"), - mk_cl_version_ext(1, 4, 0, b"SPIR-V"), - mk_cl_version_ext(1, 5, 0, b"SPIR-V"),*/ +const SPIRV_SUPPORT_STRING: &str = "SPIR-V_1.0 SPIR-V_1.1 SPIR-V_1.2 SPIR-V_1.3 SPIR-V_1.4"; +const SPIRV_SUPPORT: [cl_name_version; 5] = [ + mk_cl_version_ext(1, 0, 0, "SPIR-V"), + mk_cl_version_ext(1, 1, 0, "SPIR-V"), + mk_cl_version_ext(1, 2, 0, "SPIR-V"), + mk_cl_version_ext(1, 3, 0, "SPIR-V"), + mk_cl_version_ext(1, 4, 0, "SPIR-V"), ]; impl CLInfo for cl_device_id { diff --git a/src/gallium/frontends/rusticl/api/platform.rs b/src/gallium/frontends/rusticl/api/platform.rs index 3dd7624..80d6cbf 100644 --- a/src/gallium/frontends/rusticl/api/platform.rs +++ b/src/gallium/frontends/rusticl/api/platform.rs @@ -10,7 +10,7 @@ use rusticl_opencl_gen::*; #[allow(non_camel_case_types)] pub struct _cl_platform_id { dispatch: &'static cl_icd_dispatch, - extensions: [cl_name_version; 1], + extensions: [cl_name_version; 2], } impl CLInfo for cl_platform_id { @@ -18,8 +18,7 @@ impl CLInfo for cl_platform_id { let p = self.get_ref()?; Ok(match q { // TODO spirv - CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd"), - // CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd cl_khr_il_program"), + CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd cl_khr_il_program"), CL_PLATFORM_EXTENSIONS_WITH_VERSION => { cl_prop::>(p.extensions.to_vec()) } @@ -41,8 +40,7 @@ static PLATFORM: _cl_platform_id = _cl_platform_id { dispatch: &DISPATCH, extensions: [ mk_cl_version_ext(1, 0, 0, "cl_khr_icd"), - // TODO spirv - // mk_cl_version_ext(1, 0, 0, "cl_khr_il_program"), + mk_cl_version_ext(1, 0, 0, "cl_khr_il_program"), ], }; diff --git a/src/gallium/frontends/rusticl/api/program.rs b/src/gallium/frontends/rusticl/api/program.rs index 15d9dcb..daa2180 100644 --- a/src/gallium/frontends/rusticl/api/program.rs +++ b/src/gallium/frontends/rusticl/api/program.rs @@ -237,17 +237,16 @@ pub fn create_program_with_il( il: *const ::std::os::raw::c_void, length: usize, ) -> CLResult { - let _c = context.get_arc()?; + let c = context.get_arc()?; // CL_INVALID_VALUE if il is NULL or if length is zero. if il.is_null() || length == 0 { return Err(CL_INVALID_VALUE); } - // let spirv = unsafe { slice::from_raw_parts(il.cast(), length) }; - // TODO SPIR-V - // Ok(cl_program::from_arc(Program::from_spirv(c, spirv))) - Err(CL_INVALID_OPERATION) + // SAFETY: according to API spec + let spirv = unsafe { slice::from_raw_parts(il.cast(), length) }; + Ok(cl_program::from_arc(Program::from_spirv(c, spirv))) } pub fn build_program( @@ -417,29 +416,36 @@ pub fn link_program( pub fn set_program_specialization_constant( program: cl_program, - _spec_id: cl_uint, - _spec_size: usize, + spec_id: cl_uint, + spec_size: usize, spec_value: *const ::std::os::raw::c_void, ) -> CLResult<()> { - let _program = program.get_ref()?; + let program = program.get_ref()?; // CL_INVALID_PROGRAM if program is not a valid program object created from an intermediate // language (e.g. SPIR-V) // TODO: or if the intermediate language does not support specialization constants. - // if program.il.is_empty() { - // Err(CL_INVALID_PROGRAM)? - // } + if program.il.is_empty() { + return Err(CL_INVALID_PROGRAM); + } - // TODO: CL_INVALID_VALUE if spec_size does not match the size of the specialization constant in the module, + if spec_size != program.get_spec_constant_size(spec_id).into() { + // CL_INVALID_VALUE if spec_size does not match the size of the specialization constant in + // the module, + return Err(CL_INVALID_VALUE); + } // or if spec_value is NULL. if spec_value.is_null() { return Err(CL_INVALID_VALUE); } - Err(CL_INVALID_OPERATION) + // SAFETY: according to API spec + program.set_spec_constant(spec_id, unsafe { + slice::from_raw_parts(spec_value.cast(), spec_size) + }); - //• CL_INVALID_SPEC_ID if spec_id is not a valid specialization constant identifier. + Ok(()) } pub fn set_program_release_callback( diff --git a/src/gallium/frontends/rusticl/core/device.rs b/src/gallium/frontends/rusticl/core/device.rs index 5a73f6f..99cdce9 100644 --- a/src/gallium/frontends/rusticl/core/device.rs +++ b/src/gallium/frontends/rusticl/core/device.rs @@ -483,8 +483,7 @@ impl Device { add_ext(1, 0, 0, "cl_khr_byte_addressable_store", ""); add_ext(1, 0, 0, "cl_khr_global_int32_base_atomics", ""); add_ext(1, 0, 0, "cl_khr_global_int32_extended_atomics", ""); - // TODO spirv - // add_ext(1, 0, 0, "cl_khr_il_program", ""); + add_ext(1, 0, 0, "cl_khr_il_program", ""); add_ext(1, 0, 0, "cl_khr_local_int32_base_atomics", ""); add_ext(1, 0, 0, "cl_khr_local_int32_extended_atomics", ""); diff --git a/src/gallium/frontends/rusticl/core/program.rs b/src/gallium/frontends/rusticl/core/program.rs index 42d2f8a..60bb813 100644 --- a/src/gallium/frontends/rusticl/core/program.rs +++ b/src/gallium/frontends/rusticl/core/program.rs @@ -53,7 +53,7 @@ pub struct Program { pub src: CString, pub il: Vec, pub kernel_count: AtomicU32, - spec_constants: Mutex>, + spec_constants: Mutex>, build: Mutex, } @@ -144,7 +144,7 @@ impl Program { src: src, il: Vec::new(), kernel_count: AtomicU32::new(0), - spec_constants: Mutex::new(Vec::new()), + spec_constants: Mutex::new(HashMap::new()), build: Mutex::new(ProgramBuild { builds: builds, kernels: Vec::new(), @@ -217,7 +217,7 @@ impl Program { src: CString::new("").unwrap(), il: Vec::new(), kernel_count: AtomicU32::new(0), - spec_constants: Mutex::new(Vec::new()), + spec_constants: Mutex::new(HashMap::new()), build: Mutex::new(ProgramBuild { builds: builds, kernels: kernels.into_iter().collect(), @@ -250,7 +250,7 @@ impl Program { src: CString::new("").unwrap(), il: spirv.to_vec(), kernel_count: AtomicU32::new(0), - spec_constants: Mutex::new(Vec::new()), + spec_constants: Mutex::new(HashMap::new()), build: Mutex::new(ProgramBuild { builds: builds, kernels: Vec::new(), @@ -518,7 +518,7 @@ impl Program { src: CString::new("").unwrap(), il: Vec::new(), kernel_count: AtomicU32::new(0), - spec_constants: Mutex::new(Vec::new()), + spec_constants: Mutex::new(HashMap::new()), build: Mutex::new(ProgramBuild { builds: builds, kernels: kernels.into_iter().collect(), @@ -535,6 +535,15 @@ impl Program { let spirv = info.spirv.as_ref().unwrap(); let mut bin = spirv.to_bin().to_vec(); bin.extend_from_slice(name.as_bytes()); + + for (k, v) in self.spec_constants.lock().unwrap().iter() { + bin.extend_from_slice(&k.to_ne_bytes()); + unsafe { + // SAFETY: we fully initialize this union + bin.extend_from_slice(&v.u64_.to_ne_bytes()); + } + } + Some(cache.gen_key(&bin)) } else { None @@ -571,7 +580,19 @@ impl Program { } pub fn to_nir(&self, kernel: &str, d: &Arc) -> NirShader { + let constants = self.spec_constants.lock().unwrap(); + let mut spec_constants: Vec<_> = constants + .iter() + .map(|(&id, &value)| nir_spirv_specialization { + id: id, + value: value, + defined_on_module: true, + }) + .collect(); + drop(constants); + let mut lock = self.build_info(); + let info = Self::dev_build_info(&mut lock, d); assert_eq!(info.status, CL_BUILD_SUCCESS as cl_build_status); info.spirv @@ -582,7 +603,7 @@ impl Program { d.screen .nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE), &d.lib_clc, - &mut [], + &mut spec_constants, d.address_bits(), ) .unwrap() @@ -595,4 +616,27 @@ impl Program { pub fn is_src(&self) -> bool { !self.src.to_bytes().is_empty() } + + pub fn get_spec_constant_size(&self, spec_id: u32) -> u8 { + let lock = self.build_info(); + let spirv = lock.builds.values().next().unwrap().spirv.as_ref().unwrap(); + spirv + .spec_constant(spec_id) + .map_or(0, spirv::CLCSpecConstantType::size) + } + + pub fn set_spec_constant(&self, spec_id: u32, data: &[u8]) { + let mut lock = self.spec_constants.lock().unwrap(); + let mut val = nir_const_value::default(); + + match data.len() { + 1 => val.u8_ = u8::from_ne_bytes(data.try_into().unwrap()), + 2 => val.u16_ = u16::from_ne_bytes(data.try_into().unwrap()), + 4 => val.u32_ = u32::from_ne_bytes(data.try_into().unwrap()), + 8 => val.u64_ = u64::from_ne_bytes(data.try_into().unwrap()), + _ => unreachable!("Spec constant with invalid size!"), + }; + + lock.insert(spec_id, val); + } } diff --git a/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs b/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs index fafa250..b957b37 100644 --- a/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs +++ b/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs @@ -368,6 +368,17 @@ impl SPIRVBin { } } + pub fn spec_constant(&self, spec_id: u32) -> Option { + let info = self.info?; + let spec_constants = + unsafe { slice::from_raw_parts(info.spec_constants, info.num_spec_constants as usize) }; + + spec_constants + .iter() + .find(|sc| sc.id == spec_id) + .map(|sc| sc.type_) + } + pub fn print(&self) { unsafe { clc_dump_spirv(&self.spirv, stderr_ptr()); @@ -429,3 +440,25 @@ impl SPIRVKernelArg { }) } } + +pub trait CLCSpecConstantType { + fn size(self) -> u8; +} + +impl CLCSpecConstantType for clc_spec_constant_type { + fn size(self) -> u8 { + match self { + Self::CLC_SPEC_CONSTANT_INT64 + | Self::CLC_SPEC_CONSTANT_UINT64 + | Self::CLC_SPEC_CONSTANT_DOUBLE => 8, + Self::CLC_SPEC_CONSTANT_INT32 + | Self::CLC_SPEC_CONSTANT_UINT32 + | Self::CLC_SPEC_CONSTANT_FLOAT => 4, + Self::CLC_SPEC_CONSTANT_INT16 | Self::CLC_SPEC_CONSTANT_UINT16 => 2, + Self::CLC_SPEC_CONSTANT_INT8 + | Self::CLC_SPEC_CONSTANT_UINT8 + | Self::CLC_SPEC_CONSTANT_BOOL => 1, + Self::CLC_SPEC_CONSTANT_UNKNOWN => 0, + } + } +}