rusticl/kernel: optimize local size
authorKarol Herbst <kherbst@redhat.com>
Wed, 20 Apr 2022 13:27:57 +0000 (15:27 +0200)
committerMarge Bot <emma+marge@anholt.net>
Mon, 12 Sep 2022 05:58:13 +0000 (05:58 +0000)
Signed-off-by: Karol Herbst <kherbst@redhat.com>
Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15439>

src/gallium/frontends/rusticl/core/kernel.rs
src/gallium/frontends/rusticl/util/lib.rs
src/gallium/frontends/rusticl/util/math.rs [new file with mode: 0644]

index 867d562..14d85e9 100644 (file)
@@ -11,10 +11,12 @@ use crate::impl_cl_type_trait;
 use mesa_rust::compiler::clc::*;
 use mesa_rust::compiler::nir::*;
 use mesa_rust_gen::*;
+use mesa_rust_util::math::*;
 use mesa_rust_util::serialize::*;
 use rusticl_opencl_gen::*;
 
 use std::cell::RefCell;
+use std::cmp;
 use std::collections::HashMap;
 use std::collections::HashSet;
 use std::convert::TryInto;
@@ -656,6 +658,44 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
     val.try_into().unwrap()
 }
 
+fn optimize_local_size(d: &Device, grid: &mut [u32; 3], block: &mut [u32; 3]) {
+    let mut threads = d.max_threads_per_block() as u32;
+    let dim_threads = d.max_block_sizes();
+    let subgroups = d.subgroups();
+
+    if !block.contains(&0) {
+        for i in 0..3 {
+            // we already made sure everything is fine
+            grid[i] /= block[i];
+        }
+        return;
+    }
+
+    for i in 0..3 {
+        let t = cmp::min(threads, dim_threads[i] as u32);
+        let gcd = gcd(t, grid[i]);
+
+        block[i] = gcd;
+        grid[i] /= gcd;
+
+        // update limits
+        threads /= block[i];
+    }
+
+    // if we didn't fill the subgroup we can do a bit better if we have threads remaining
+    let total_threads = block[0] * block[1] * block[2];
+    if threads != 1 && total_threads < subgroups {
+        for i in 0..3 {
+            if grid[i] * total_threads < threads {
+                block[i] *= grid[i];
+                grid[i] = 1;
+                // can only do it once as nothing is cleanly divisible
+                break;
+            }
+        }
+    }
+}
+
 impl Kernel {
     pub fn new(name: String, prog: Arc<Program>, args: Vec<spirv::SPIRVKernelArg>) -> Arc<Kernel> {
         let (mut nirs, args, internal_args) = convert_spirv_to_nir(&prog, &name, args);
@@ -706,13 +746,7 @@ impl Kernel {
         let mut img_formats: Vec<u16> = Vec::new();
         let mut img_orders: Vec<u16> = Vec::new();
 
-        for i in 0..3 {
-            if block[i] == 0 {
-                block[i] = 1;
-            } else {
-                grid[i] /= block[i];
-            }
-        }
+        optimize_local_size(&q.device, &mut grid, &mut block);
 
         for (arg, val) in self.args.iter().zip(&self.values) {
             if arg.dead {
index d735982..15ea272 100644 (file)
@@ -1,4 +1,5 @@
 pub mod assert;
+pub mod math;
 pub mod properties;
 pub mod ptr;
 pub mod serialize;
diff --git a/src/gallium/frontends/rusticl/util/math.rs b/src/gallium/frontends/rusticl/util/math.rs
new file mode 100644 (file)
index 0000000..08ff883
--- /dev/null
@@ -0,0 +1,16 @@
+use std::ops::Rem;
+
+pub fn gcd<T>(mut a: T, mut b: T) -> T
+where
+    T: Copy + Default + PartialEq,
+    T: Rem<Output = T>,
+{
+    let mut c = a % b;
+    while c != T::default() {
+        a = b;
+        b = c;
+        c = a % b;
+    }
+
+    b
+}