intel/rt: Add support for scratch in ray-tracing shaders
authorJason Ekstrand <jason@jlekstrand.net>
Thu, 6 Aug 2020 18:16:53 +0000 (13:16 -0500)
committerMarge Bot <eric+marge@anholt.net>
Wed, 25 Nov 2020 05:37:10 +0000 (05:37 +0000)
In ray-tracing shader stages, we have a real call stack and so we can't
use the normal scratch mechanism.  Instead, the invocation's stack lives
in a memory region of the RT scratch buffer that sits after the HW ray
stacks.  We handle this by asking nir_lower_io to lower local variables
to 64-bit global memory access.  Unlike nir_lower_io for 32-bit offset
scratch, when 64-bit global access is requested, nir_lower_io generates
an address calculation which starts from a load_scratch_base_ptr.  We
then lower this intrinsic to the appropriate address calculation in
brw_nir_lower_rt_intrinsics.

When a COMPUTE_WALKER command is sent to the hardware with the BTD Mode
bit set to true, the hardware generates a set of stack IDs, one for each
invocation.  These then get passed along from one shader invocation to
the next as we trace the ray.  We can use those stack IDs to figure out
which stack our invocation needs to access.  Because we may not be the
first shader in the stack, there's a per-stack offset that gets stored
in the "hotzone".

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7356>

src/intel/compiler/brw_nir_lower_rt_intrinsics.c
src/intel/compiler/brw_nir_rt.c
src/intel/compiler/brw_nir_rt.h
src/intel/compiler/brw_nir_rt_builder.h

index 0b06f9e..0f197d9 100644 (file)
@@ -37,6 +37,14 @@ lower_rt_intrinsics_impl(nir_function_impl *impl,
    struct brw_nir_rt_globals_defs globals;
    brw_nir_rt_load_globals(b, &globals);
 
+   nir_ssa_def *hotzone_addr = brw_nir_rt_sw_hotzone_addr(b, devinfo);
+   nir_ssa_def *hotzone = nir_load_global(b, hotzone_addr, 16, 4, 32);
+
+   nir_ssa_def *thread_stack_base_addr = brw_nir_rt_sw_stack_addr(b, devinfo);
+   nir_ssa_def *stack_base_offset = nir_channel(b, hotzone, 0);
+   nir_ssa_def *stack_base_addr =
+      nir_iadd(b, thread_stack_base_addr, nir_u2u64(b, stack_base_offset));
+
    nir_foreach_block(block, impl) {
       nir_foreach_instr_safe(instr, block) {
          if (instr->type != nir_instr_type_intrinsic)
@@ -48,6 +56,11 @@ lower_rt_intrinsics_impl(nir_function_impl *impl,
 
          nir_ssa_def *sysval = NULL;
          switch (intrin->intrinsic) {
+         case nir_intrinsic_load_scratch_base_ptr:
+            assert(nir_intrinsic_base(intrin) == 1);
+            sysval = stack_base_addr;
+            break;
+
          case nir_intrinsic_load_ray_base_mem_addr_intel:
             sysval = globals.base_mem_addr;
             break;
index edbe572..e22cf37 100644 (file)
  */
 
 #include "brw_nir_rt.h"
+#include "nir_builder.h"
+
+static bool
+resize_deref(nir_builder *b, nir_deref_instr *deref,
+             unsigned num_components, unsigned bit_size)
+{
+   assert(deref->dest.is_ssa);
+   if (deref->dest.ssa.num_components == num_components &&
+       deref->dest.ssa.bit_size == bit_size)
+      return false;
+
+   /* NIR requires array indices have to match the deref bit size */
+   if (deref->dest.ssa.bit_size != bit_size &&
+       (deref->deref_type == nir_deref_type_array ||
+        deref->deref_type == nir_deref_type_ptr_as_array)) {
+      b->cursor = nir_before_instr(&deref->instr);
+      assert(deref->arr.index.is_ssa);
+      nir_ssa_def *idx;
+      if (nir_src_is_const(deref->arr.index)) {
+         idx = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index), bit_size);
+      } else {
+         idx = nir_i2i(b, deref->arr.index.ssa, bit_size);
+      }
+      nir_instr_rewrite_src(&deref->instr, &deref->arr.index,
+                            nir_src_for_ssa(idx));
+   }
+
+   deref->dest.ssa.num_components = num_components;
+   deref->dest.ssa.bit_size = bit_size;
+
+   return true;
+}
+
+static bool
+resize_function_temp_derefs(nir_shader *shader)
+{
+   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
+
+   bool progress = false;
+
+   nir_builder b;
+   nir_builder_init(&b, impl);
+
+   nir_foreach_block(block, impl) {
+      nir_foreach_instr_safe(instr, block) {
+         if (instr->type != nir_instr_type_deref)
+            continue;
+
+         nir_deref_instr *deref = nir_instr_as_deref(instr);
+
+         /* We're going to lower all function_temp memory to scratch using
+          * 64-bit addresses.  We need to resize all our derefs first or else
+          * nir_lower_explicit_io will have a fit.
+          */
+         if (nir_deref_mode_is(deref, nir_var_function_temp) &&
+             resize_deref(&b, deref, 1, 64))
+            progress = true;
+      }
+   }
+
+   if (progress) {
+      nir_metadata_preserve(impl, nir_metadata_block_index |
+                                  nir_metadata_dominance);
+   } else {
+      nir_metadata_preserve(impl, nir_metadata_all);
+   }
+
+   return progress;
+}
+
+static void
+lower_rt_scratch(nir_shader *nir)
+{
+   /* First, we to ensure all the local variables have explicit types. */
+   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
+              nir_var_function_temp,
+              glsl_get_natural_size_align_bytes);
+
+   NIR_PASS_V(nir, resize_function_temp_derefs);
+
+   /* Now, lower those variables to 64-bit global memory access */
+   NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_function_temp,
+              nir_address_format_64bit_global);
+}
 
 void
 brw_nir_lower_raygen(nir_shader *nir)
 {
    assert(nir->info.stage == MESA_SHADER_RAYGEN);
+   lower_rt_scratch(nir);
 }
 
 void
 brw_nir_lower_any_hit(nir_shader *nir, const struct gen_device_info *devinfo)
 {
    assert(nir->info.stage == MESA_SHADER_ANY_HIT);
+   lower_rt_scratch(nir);
 }
 
 void
 brw_nir_lower_closest_hit(nir_shader *nir)
 {
    assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT);
+   lower_rt_scratch(nir);
 }
 
 void
 brw_nir_lower_miss(nir_shader *nir)
 {
    assert(nir->info.stage == MESA_SHADER_MISS);
+   lower_rt_scratch(nir);
 }
 
 void
 brw_nir_lower_callable(nir_shader *nir)
 {
    assert(nir->info.stage == MESA_SHADER_CALLABLE);
+   lower_rt_scratch(nir);
 }
 
 void
@@ -60,4 +149,5 @@ brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
 {
    assert(intersection->info.stage == MESA_SHADER_INTERSECTION);
    assert(any_hit == NULL || any_hit->info.stage == MESA_SHADER_ANY_HIT);
+   lower_rt_scratch(intersection);
 }
index 8bb1b60..2423ee1 100644 (file)
@@ -41,6 +41,9 @@ void brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
                                                  const nir_shader *any_hit,
                                                  const struct gen_device_info *devinfo);
 
+/* We require the stack to be 8B aligned at the start of a shader */
+#define BRW_BTD_STACK_ALIGN 8
+
 void brw_nir_lower_rt_intrinsics(nir_shader *shader,
                                  const struct gen_device_info *devinfo);
 
index 14c1cd2..f5c2c42 100644 (file)
@@ -41,6 +41,30 @@ nir_load_global_const_block_intel(nir_builder *b, nir_ssa_def *addr,
    return &load->dest.ssa;
 }
 
+/* We have our own load/store scratch helpers because they emit a global
+ * memory read or write based on the scratch_base_ptr system value rather
+ * than a load/store_scratch intrinsic.
+ */
+static inline nir_ssa_def *
+brw_nir_rt_load_scratch(nir_builder *b, uint32_t offset, unsigned align,
+                        unsigned num_components, unsigned bit_size)
+{
+   nir_ssa_def *addr =
+      nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 1, 64), offset);
+   return nir_load_global(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
+                          num_components, bit_size);
+}
+
+static inline void
+brw_nir_rt_store_scratch(nir_builder *b, uint32_t offset, unsigned align,
+                         nir_ssa_def *value, nir_component_mask_t write_mask)
+{
+   nir_ssa_def *addr =
+      nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 1, 64), offset);
+   nir_store_global(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
+                    value, write_mask);
+}
+
 static inline void
 assert_def_size(nir_ssa_def *def, unsigned num_components, unsigned bit_size)
 {