rusticl/kernel: fix more 32 bit problems
authorKarol Herbst <kherbst@redhat.com>
Wed, 26 Oct 2022 22:58:47 +0000 (00:58 +0200)
committerMarge Bot <emma+marge@anholt.net>
Fri, 28 Oct 2022 18:46:33 +0000 (18:46 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19353>

src/gallium/frontends/rusticl/core/kernel.rs

index 5da5ac6..11199bf 100644 (file)
@@ -386,6 +386,17 @@ fn lower_and_optimize_nir_late(
     nir: &mut NirShader,
     args: &mut [KernelArg],
 ) -> Vec<InternalKernelArg> {
+    let address_bits_base_type;
+    let address_bits_ptr_type;
+
+    if dev.address_bits() == 64 {
+        address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64;
+        address_bits_ptr_type = unsafe { glsl_uint64_t_type() };
+    } else {
+        address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT;
+        address_bits_ptr_type = unsafe { glsl_uint_type() };
+    };
+
     let mut res = Vec::new();
     let nir_options = unsafe {
         &*dev
@@ -454,11 +465,12 @@ fn lower_and_optimize_nir_late(
     res.push(InternalKernelArg {
         kind: InternalKernelArgType::GlobalWorkOffsets,
         offset: 0,
-        size: 24,
+        size: (3 * dev.address_bits() / 8) as usize,
     });
+
     lower_state.base_global_invoc_id = nir.add_var(
         nir_variable_mode::nir_var_uniform,
-        unsafe { glsl_vector_type(glsl_base_type::GLSL_TYPE_UINT64, 3) },
+        unsafe { glsl_vector_type(address_bits_base_type, 3) },
         args.len() + res.len() - 1,
         "base_global_invocation_id",
     );
@@ -470,7 +482,7 @@ fn lower_and_optimize_nir_late(
         });
         lower_state.const_buf = nir.add_var(
             nir_variable_mode::nir_var_uniform,
-            unsafe { glsl_uint64_t_type() },
+            address_bits_ptr_type,
             args.len() + res.len() - 1,
             "constant_buffer_addr",
         );
@@ -483,7 +495,7 @@ fn lower_and_optimize_nir_late(
         });
         lower_state.printf_buf = nir.add_var(
             nir_variable_mode::nir_var_uniform,
-            unsafe { glsl_uint64_t_type() },
+            address_bits_ptr_type,
             args.len() + res.len() - 1,
             "printf_buffer_addr",
         );
@@ -804,6 +816,11 @@ impl Kernel {
         let mut tex_orders: Vec<u16> = Vec::new();
         let mut img_formats: Vec<u16> = Vec::new();
         let mut img_orders: Vec<u16> = Vec::new();
+        let null_ptr: &[u8] = if q.device.address_bits() == 64 {
+            &[0; 8]
+        } else {
+            &[0; 4]
+        };
 
         optimize_local_size(&q.device, &mut grid, &mut block);
 
@@ -824,7 +841,11 @@ impl Kernel {
                 KernelArgValue::MemObject(mem) => {
                     let res = mem.get_res_of_dev(&q.device)?;
                     if mem.is_buffer() {
-                        input.extend_from_slice(&mem.offset.to_ne_bytes());
+                        if q.device.address_bits() == 64 {
+                            input.extend_from_slice(&mem.offset.to_ne_bytes());
+                        } else {
+                            input.extend_from_slice(&(mem.offset as u32).to_ne_bytes());
+                        }
                         resource_info.push((Some(res.clone()), arg.offset));
                     } else {
                         let format = mem.image_format.to_pipe_format().unwrap();
@@ -852,7 +873,11 @@ impl Kernel {
                     // TODO 32 bit
                     let pot = cmp::min(*size, 0x80);
                     local_size = align(local_size, pot.next_power_of_two() as u64);
-                    input.extend_from_slice(&local_size.to_ne_bytes());
+                    if q.device.address_bits() == 64 {
+                        input.extend_from_slice(&local_size.to_ne_bytes());
+                    } else {
+                        input.extend_from_slice(&(local_size as u32).to_ne_bytes());
+                    }
                     local_size += *size as u64;
                 }
                 KernelArgValue::Sampler(sampler) => {
@@ -863,7 +888,7 @@ impl Kernel {
                         arg.kind == KernelArgType::MemGlobal
                             || arg.kind == KernelArgType::MemConstant
                     );
-                    input.extend_from_slice(&[0; 8]);
+                    input.extend_from_slice(null_ptr);
                 }
             }
         }
@@ -875,7 +900,7 @@ impl Kernel {
             }
             match arg.kind {
                 InternalKernelArgType::ConstantBuffer => {
-                    input.extend_from_slice(&[0; 8]);
+                    input.extend_from_slice(null_ptr);
                     let buf = nir.get_constant_buffer();
                     let res = Arc::new(
                         q.device
@@ -892,7 +917,15 @@ impl Kernel {
                     resource_info.push((Some(res), arg.offset));
                 }
                 InternalKernelArgType::GlobalWorkOffsets => {
-                    input.extend_from_slice(&cl_prop::<[u64; 3]>(offsets));
+                    if q.device.address_bits() == 64 {
+                        input.extend_from_slice(&cl_prop::<[u64; 3]>(offsets));
+                    } else {
+                        input.extend_from_slice(&cl_prop::<[u32; 3]>([
+                            offsets[0] as u32,
+                            offsets[1] as u32,
+                            offsets[2] as u32,
+                        ]));
+                    }
                 }
                 InternalKernelArgType::PrintfBuffer => {
                     let buf = Arc::new(
@@ -902,7 +935,7 @@ impl Kernel {
                             .unwrap(),
                     );
 
-                    input.extend_from_slice(&[0; 8]);
+                    input.extend_from_slice(null_ptr);
                     resource_info.push((Some(buf.clone()), arg.offset));
 
                     printf_buf = Some(buf);