use crate::api::platform::*;
use crate::api::util::*;
use crate::core::device::*;
+use crate::core::version::*;
use mesa_rust_gen::*;
use mesa_rust_util::ptr::*;
use std::sync::Arc;
use std::sync::Once;
-// TODO spec constants need to be implemented
-const SPIRV_SUPPORT_STRING: &str = "";
-// "SPIR-V_1.0 SPIR-V_1.1 SPIR-V_1.2 SPIR-V_1.3 SPIR-V_1.4 SPIR-V_1.5";
-const SPIRV_SUPPORT: [cl_name_version; 0] = [
-/* mk_cl_version_ext(1, 0, 0, b"SPIR-V"),
- mk_cl_version_ext(1, 1, 0, b"SPIR-V"),
- mk_cl_version_ext(1, 2, 0, b"SPIR-V"),
- mk_cl_version_ext(1, 3, 0, b"SPIR-V"),
- mk_cl_version_ext(1, 4, 0, b"SPIR-V"),
- mk_cl_version_ext(1, 5, 0, b"SPIR-V"),*/
+const SPIRV_SUPPORT_STRING: &str = "SPIR-V_1.0 SPIR-V_1.1 SPIR-V_1.2 SPIR-V_1.3 SPIR-V_1.4";
+const SPIRV_SUPPORT: [cl_name_version; 5] = [
+ mk_cl_version_ext(1, 0, 0, "SPIR-V"),
+ mk_cl_version_ext(1, 1, 0, "SPIR-V"),
+ mk_cl_version_ext(1, 2, 0, "SPIR-V"),
+ mk_cl_version_ext(1, 3, 0, "SPIR-V"),
+ mk_cl_version_ext(1, 4, 0, "SPIR-V"),
];
impl CLInfo<cl_device_info> for cl_device_id {
#[allow(non_camel_case_types)]
pub struct _cl_platform_id {
dispatch: &'static cl_icd_dispatch,
- extensions: [cl_name_version; 1],
+ extensions: [cl_name_version; 2],
}
impl CLInfo<cl_platform_info> for cl_platform_id {
let p = self.get_ref()?;
Ok(match q {
// TODO spirv
- CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd"),
- // CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd cl_khr_il_program"),
+ CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd cl_khr_il_program"),
CL_PLATFORM_EXTENSIONS_WITH_VERSION => {
cl_prop::<Vec<cl_name_version>>(p.extensions.to_vec())
}
dispatch: &DISPATCH,
extensions: [
mk_cl_version_ext(1, 0, 0, "cl_khr_icd"),
- // TODO spirv
- // mk_cl_version_ext(1, 0, 0, "cl_khr_il_program"),
+ mk_cl_version_ext(1, 0, 0, "cl_khr_il_program"),
],
};
il: *const ::std::os::raw::c_void,
length: usize,
) -> CLResult<cl_program> {
- let _c = context.get_arc()?;
+ let c = context.get_arc()?;
// CL_INVALID_VALUE if il is NULL or if length is zero.
if il.is_null() || length == 0 {
return Err(CL_INVALID_VALUE);
}
- // let spirv = unsafe { slice::from_raw_parts(il.cast(), length) };
- // TODO SPIR-V
- // Ok(cl_program::from_arc(Program::from_spirv(c, spirv)))
- Err(CL_INVALID_OPERATION)
+ // SAFETY: according to API spec
+ let spirv = unsafe { slice::from_raw_parts(il.cast(), length) };
+ Ok(cl_program::from_arc(Program::from_spirv(c, spirv)))
}
pub fn build_program(
pub fn set_program_specialization_constant(
program: cl_program,
- _spec_id: cl_uint,
- _spec_size: usize,
+ spec_id: cl_uint,
+ spec_size: usize,
spec_value: *const ::std::os::raw::c_void,
) -> CLResult<()> {
- let _program = program.get_ref()?;
+ let program = program.get_ref()?;
// CL_INVALID_PROGRAM if program is not a valid program object created from an intermediate
// language (e.g. SPIR-V)
// TODO: or if the intermediate language does not support specialization constants.
- // if program.il.is_empty() {
- // Err(CL_INVALID_PROGRAM)?
- // }
+ if program.il.is_empty() {
+ return Err(CL_INVALID_PROGRAM);
+ }
- // TODO: CL_INVALID_VALUE if spec_size does not match the size of the specialization constant in the module,
+ if spec_size != program.get_spec_constant_size(spec_id).into() {
+ // CL_INVALID_VALUE if spec_size does not match the size of the specialization constant in
+ // the module,
+ return Err(CL_INVALID_VALUE);
+ }
// or if spec_value is NULL.
if spec_value.is_null() {
return Err(CL_INVALID_VALUE);
}
- Err(CL_INVALID_OPERATION)
+ // SAFETY: according to API spec
+ program.set_spec_constant(spec_id, unsafe {
+ slice::from_raw_parts(spec_value.cast(), spec_size)
+ });
- //• CL_INVALID_SPEC_ID if spec_id is not a valid specialization constant identifier.
+ Ok(())
}
pub fn set_program_release_callback(
add_ext(1, 0, 0, "cl_khr_byte_addressable_store", "");
add_ext(1, 0, 0, "cl_khr_global_int32_base_atomics", "");
add_ext(1, 0, 0, "cl_khr_global_int32_extended_atomics", "");
- // TODO spirv
- // add_ext(1, 0, 0, "cl_khr_il_program", "");
+ add_ext(1, 0, 0, "cl_khr_il_program", "");
add_ext(1, 0, 0, "cl_khr_local_int32_base_atomics", "");
add_ext(1, 0, 0, "cl_khr_local_int32_extended_atomics", "");
pub src: CString,
pub il: Vec<u8>,
pub kernel_count: AtomicU32,
- spec_constants: Mutex<Vec<spirv::SpecConstant>>,
+ spec_constants: Mutex<HashMap<u32, nir_const_value>>,
build: Mutex<ProgramBuild>,
}
src: src,
il: Vec::new(),
kernel_count: AtomicU32::new(0),
- spec_constants: Mutex::new(Vec::new()),
+ spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
builds: builds,
kernels: Vec::new(),
src: CString::new("").unwrap(),
il: Vec::new(),
kernel_count: AtomicU32::new(0),
- spec_constants: Mutex::new(Vec::new()),
+ spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
builds: builds,
kernels: kernels.into_iter().collect(),
src: CString::new("").unwrap(),
il: spirv.to_vec(),
kernel_count: AtomicU32::new(0),
- spec_constants: Mutex::new(Vec::new()),
+ spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
builds: builds,
kernels: Vec::new(),
src: CString::new("").unwrap(),
il: Vec::new(),
kernel_count: AtomicU32::new(0),
- spec_constants: Mutex::new(Vec::new()),
+ spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
builds: builds,
kernels: kernels.into_iter().collect(),
let spirv = info.spirv.as_ref().unwrap();
let mut bin = spirv.to_bin().to_vec();
bin.extend_from_slice(name.as_bytes());
+
+ for (k, v) in self.spec_constants.lock().unwrap().iter() {
+ bin.extend_from_slice(&k.to_ne_bytes());
+ unsafe {
+ // SAFETY: we fully initialize this union
+ bin.extend_from_slice(&v.u64_.to_ne_bytes());
+ }
+ }
+
Some(cache.gen_key(&bin))
} else {
None
}
pub fn to_nir(&self, kernel: &str, d: &Arc<Device>) -> NirShader {
+ let constants = self.spec_constants.lock().unwrap();
+ let mut spec_constants: Vec<_> = constants
+ .iter()
+ .map(|(&id, &value)| nir_spirv_specialization {
+ id: id,
+ value: value,
+ defined_on_module: true,
+ })
+ .collect();
+ drop(constants);
+
let mut lock = self.build_info();
+
let info = Self::dev_build_info(&mut lock, d);
assert_eq!(info.status, CL_BUILD_SUCCESS as cl_build_status);
info.spirv
d.screen
.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
&d.lib_clc,
- &mut [],
+ &mut spec_constants,
d.address_bits(),
)
.unwrap()
pub fn is_src(&self) -> bool {
!self.src.to_bytes().is_empty()
}
+
+ pub fn get_spec_constant_size(&self, spec_id: u32) -> u8 {
+ let lock = self.build_info();
+ let spirv = lock.builds.values().next().unwrap().spirv.as_ref().unwrap();
+ spirv
+ .spec_constant(spec_id)
+ .map_or(0, spirv::CLCSpecConstantType::size)
+ }
+
+ pub fn set_spec_constant(&self, spec_id: u32, data: &[u8]) {
+ let mut lock = self.spec_constants.lock().unwrap();
+ let mut val = nir_const_value::default();
+
+ match data.len() {
+ 1 => val.u8_ = u8::from_ne_bytes(data.try_into().unwrap()),
+ 2 => val.u16_ = u16::from_ne_bytes(data.try_into().unwrap()),
+ 4 => val.u32_ = u32::from_ne_bytes(data.try_into().unwrap()),
+ 8 => val.u64_ = u64::from_ne_bytes(data.try_into().unwrap()),
+ _ => unreachable!("Spec constant with invalid size!"),
+ };
+
+ lock.insert(spec_id, val);
+ }
}
}
}
+ pub fn spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type> {
+ let info = self.info?;
+ let spec_constants =
+ unsafe { slice::from_raw_parts(info.spec_constants, info.num_spec_constants as usize) };
+
+ spec_constants
+ .iter()
+ .find(|sc| sc.id == spec_id)
+ .map(|sc| sc.type_)
+ }
+
pub fn print(&self) {
unsafe {
clc_dump_spirv(&self.spirv, stderr_ptr());
})
}
}
+
+pub trait CLCSpecConstantType {
+ fn size(self) -> u8;
+}
+
+impl CLCSpecConstantType for clc_spec_constant_type {
+ fn size(self) -> u8 {
+ match self {
+ Self::CLC_SPEC_CONSTANT_INT64
+ | Self::CLC_SPEC_CONSTANT_UINT64
+ | Self::CLC_SPEC_CONSTANT_DOUBLE => 8,
+ Self::CLC_SPEC_CONSTANT_INT32
+ | Self::CLC_SPEC_CONSTANT_UINT32
+ | Self::CLC_SPEC_CONSTANT_FLOAT => 4,
+ Self::CLC_SPEC_CONSTANT_INT16 | Self::CLC_SPEC_CONSTANT_UINT16 => 2,
+ Self::CLC_SPEC_CONSTANT_INT8
+ | Self::CLC_SPEC_CONSTANT_UINT8
+ | Self::CLC_SPEC_CONSTANT_BOOL => 1,
+ Self::CLC_SPEC_CONSTANT_UNKNOWN => 0,
+ }
+ }
+}