rusticl/kernel: fix local buffers
authorKarol Herbst <kherbst@redhat.com>
Wed, 4 May 2022 17:43:26 +0000 (19:43 +0200)
committerMarge Bot <emma+marge@anholt.net>
Mon, 12 Sep 2022 05:58:13 +0000 (05:58 +0000)
Signed-off-by: Karol Herbst <kherbst@redhat.com>
Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15439>

src/gallium/frontends/rusticl/core/kernel.rs
src/gallium/frontends/rusticl/util/math.rs

index 2e693cb..1e36bd7 100644 (file)
@@ -749,7 +749,7 @@ impl Kernel {
         let offsets = create_kernel_arr::<u64>(offsets, 0);
         let mut input: Vec<u8> = Vec::new();
         let mut resource_info = Vec::new();
-        let mut local_size: u32 = nir.shared_size();
+        let mut local_size: u64 = nir.shared_size() as u64;
         let printf_size = q.device.printf_buffer_size() as u32;
         let mut samplers = Vec::new();
         let mut iviews = Vec::new();
@@ -804,8 +804,10 @@ impl Kernel {
                 }
                 KernelArgValue::LocalMem(size) => {
                     // TODO 32 bit
-                    input.extend_from_slice(&[0; 8]);
-                    local_size += *size as u32;
+                    let pot = cmp::min(*size, 0x80);
+                    local_size = align(local_size, pot.next_power_of_two() as u64);
+                    input.extend_from_slice(&local_size.to_ne_bytes());
+                    local_size += *size as u64;
                 }
                 KernelArgValue::Sampler(sampler) => {
                     samplers.push(sampler.pipe());
@@ -900,7 +902,7 @@ impl Kernel {
                     init_data.len() as u32,
                 );
             }
-            let cso = ctx.create_compute_state(nir, input.len() as u32, local_size);
+            let cso = ctx.create_compute_state(nir, input.len() as u32, local_size as u32);
 
             ctx.bind_compute_state(cso);
             ctx.bind_sampler_states(&samplers);
index 08ff883..f6e8d96 100644 (file)
@@ -1,4 +1,6 @@
+use std::ops::Add;
 use std::ops::Rem;
+use std::ops::Sub;
 
 pub fn gcd<T>(mut a: T, mut b: T) -> T
 where
@@ -14,3 +16,20 @@ where
 
     b
 }
+
+pub fn align<T>(val: T, a: T) -> T
+where
+    T: Add<Output = T>,
+    T: Copy,
+    T: Default,
+    T: PartialEq,
+    T: Rem<Output = T>,
+    T: Sub<Output = T>,
+{
+    let tmp = val % a;
+    if tmp == T::default() {
+        val
+    } else {
+        val + (a - tmp)
+    }
+}