intel/rt: Implement support for shader call payloads
authorJason Ekstrand <jason@jlekstrand.net>
Fri, 4 Sep 2020 01:20:22 +0000 (20:20 -0500)
committerMarge Bot <eric+marge@anholt.net>
Wed, 25 Nov 2020 05:37:10 +0000 (05:37 +0000)
Both traceRay() and executeCallable() take a payload parameter which
gets passed from the caller to the callee and which the callee can write
to pass data back to the caller.  We implement these by passing a
pointer to the data structure in the callee to the caller as the second
QWord on its stack.  Coming out of spirv_to_nir, the incoming call
payloads get the nir_var_shader_call_data variable mode allowing us to
easily identify them.  Outgoing call payloads get assigned the
nir_var_shader_temp mode and will have been turned into function_temp by
nir_lower_global_vars_to_local.  All we have to do is crawl the shader
looking for references to the nir_var_shader_call_data variable and
rewrite those to use the passed in pointer.  nir_lower_explicit_io will
do the rest for us.

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7356>

src/intel/compiler/brw_nir_lower_shader_calls.c
src/intel/compiler/brw_nir_rt.c
src/intel/compiler/brw_nir_rt.h

index c4f67a6..8c6fd84 100644 (file)
@@ -507,14 +507,22 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
          offset = ALIGN(offset, BRW_BTD_STACK_ALIGN);
          max_scratch_size = MAX2(max_scratch_size, offset);
 
-         /* First thing on the called shader's stack is the resume address */
+         /* First thing on the called shader's stack is the resume address
+          * followed by a pointer to the payload.
+          */
          nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
          nir_ssa_def *resume_record_addr =
             nir_iadd_imm(b, nir_load_btd_resume_sbt_addr_intel(b),
                          (first_resume_sbt_idx + call_idx) *
                          BRW_BTD_RESUME_SBT_STRIDE);
+         /* By the time we get here, any remaining shader/function memory
+          * pointers have been lowered to SSA values.
+          */
+         assert(nir_get_shader_call_payload_src(call)->is_ssa);
+         nir_ssa_def *payload_addr =
+            nir_get_shader_call_payload_src(call)->ssa;
          brw_nir_rt_store_scratch(b, offset, BRW_BTD_STACK_ALIGN,
-                                  resume_record_addr,
+                                  nir_vec2(b, resume_record_addr, payload_addr),
                                   0xf /* write_mask */);
 
          nir_btd_stack_push_intel(b, offset);
index 12a2d09..6c9fd3a 100644 (file)
@@ -22,7 +22,7 @@
  */
 
 #include "brw_nir_rt.h"
-#include "nir_builder.h"
+#include "brw_nir_rt_builder.h"
 
 static bool
 resize_deref(nir_builder *b, nir_deref_instr *deref,
@@ -56,21 +56,57 @@ resize_deref(nir_builder *b, nir_deref_instr *deref,
 }
 
 static bool
-resize_function_temp_derefs(nir_shader *shader)
+lower_rt_io_derefs(nir_shader *shader)
 {
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
 
    bool progress = false;
 
+   unsigned num_shader_call_vars = 0;
+   nir_foreach_variable_with_modes(var, shader, nir_var_shader_call_data)
+      num_shader_call_vars++;
+
+   /* At most one payload is allowed because it's an input.  Technically, this
+    * is also true for hit attribute variables.  However, after we inline an
+    * any-hit shader into an intersection shader, we can end up with multiple
+    * hit attribute variables.  They'll end up mapping to a cast from the same
+    * base pointer so this is fine.
+    */
+   assert(num_shader_call_vars <= 1);
+
    nir_builder b;
    nir_builder_init(&b, impl);
 
+   b.cursor = nir_before_cf_list(&impl->body);
+   nir_ssa_def *call_data_addr = NULL;
+   if (num_shader_call_vars > 0) {
+      assert(shader->scratch_size >= BRW_BTD_STACK_CALLEE_DATA_SIZE);
+      call_data_addr =
+         brw_nir_rt_load_scratch(&b, BRW_BTD_STACK_CALL_DATA_PTR_OFFSET, 8,
+                                 1, 64);
+      progress = true;
+   }
+
    nir_foreach_block(block, impl) {
       nir_foreach_instr_safe(instr, block) {
          if (instr->type != nir_instr_type_deref)
             continue;
 
          nir_deref_instr *deref = nir_instr_as_deref(instr);
+         if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
+            deref->modes = nir_var_function_temp;
+            if (deref->deref_type == nir_deref_type_var) {
+               b.cursor = nir_before_instr(&deref->instr);
+               nir_deref_instr *cast =
+                  nir_build_deref_cast(&b, call_data_addr,
+                                       nir_var_function_temp,
+                                       deref->var->type, 0);
+               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
+                                        nir_src_for_ssa(&cast->dest.ssa));
+               nir_instr_remove(&deref->instr);
+               progress = true;
+            }
+         }
 
          /* We're going to lower all function_temp memory to scratch using
           * 64-bit addresses.  We need to resize all our derefs first or else
@@ -92,17 +128,59 @@ resize_function_temp_derefs(nir_shader *shader)
    return progress;
 }
 
+/** Lowers ray-tracing shader I/O and scratch access
+ *
+ * SPV_KHR_ray_tracing adds three new types of I/O, each of which need their
+ * own bit of special care:
+ *
+ *  - Shader payload data:  This is represented by the IncomingCallableData
+ *    and IncomingRayPayload storage classes which are both represented by
+ *    nir_var_call_data in NIR.  There is at most one of these per-shader and
+ *    they contain payload data passed down the stack from the parent shader
+ *    when it calls executeCallable() or traceRay().  In our implementation,
+ *    the actual storage lives in the calling shader's scratch space and we're
+ *    passed a pointer to it.
+ *
+ *  - Hit attribute data:  This is represented by the HitAttribute storage
+ *    class in SPIR-V and nir_var_ray_hit_attrib in NIR.  For triangle
+ *    geometry, it's supposed to contain two floats which are the barycentric
+ *    coordinates.  For AABS/procedural geometry, it contains the hit data
+ *    written out by the intersection shader.  In our implementation, it's a
+ *    64-bit pointer which points either to the u/v area of the relevant
+ *    MemHit data structure or the space right after the HW ray stack entry.
+ *
+ *  - Shader record buffer data:  This allows read-only access to the data
+ *    stored in the SBT right after the bindless shader handles.  It's
+ *    effectively a UBO with a magic address.  Coming out of spirv_to_nir,
+ *    we get a nir_intrinsic_load_shader_record_ptr which is cast to a
+ *    nir_var_mem_global deref and all access happens through that.  The
+ *    shader_record_ptr system value is handled in brw_nir_lower_rt_intrinsics
+ *    and we assume nir_lower_explicit_io is called elsewhere thanks to
+ *    VK_KHR_buffer_device_address so there's really nothing to do here.
+ *
+ * We also handle lowering any remaining function_temp variables to scratch at
+ * this point.  This gets rid of any remaining arrays and also takes care of
+ * the sending side of ray payloads where we pass pointers to a function_temp
+ * variable down the call stack.
+ */
 static void
-lower_rt_scratch(nir_shader *nir)
+lower_rt_io_and_scratch(nir_shader *nir)
 {
-   /* First, we to ensure all the local variables have explicit types. */
+   /* First, we to ensure all the I/O variables have explicit types.  Because
+    * these are shader-internal and don't come in from outside, they don't
+    * have an explicit memory layout and we have to assign them one.
+    */
    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
-              nir_var_function_temp,
+              nir_var_function_temp |
+              nir_var_shader_call_data,
               glsl_get_natural_size_align_bytes);
 
-   NIR_PASS_V(nir, resize_function_temp_derefs);
+   /* Now patch any derefs to I/O vars */
+   NIR_PASS_V(nir, lower_rt_io_derefs);
 
-   /* Now, lower those variables to 64-bit global memory access */
+   /* Finally, lower any remaining function_temp access to 64-bit global
+    * memory access.
+    */
    NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_function_temp,
               nir_address_format_64bit_global);
 }
@@ -112,7 +190,7 @@ brw_nir_lower_raygen(nir_shader *nir)
 {
    assert(nir->info.stage == MESA_SHADER_RAYGEN);
    NIR_PASS_V(nir, brw_nir_lower_shader_returns);
-   lower_rt_scratch(nir);
+   lower_rt_io_and_scratch(nir);
 }
 
 void
@@ -120,7 +198,7 @@ brw_nir_lower_any_hit(nir_shader *nir, const struct gen_device_info *devinfo)
 {
    assert(nir->info.stage == MESA_SHADER_ANY_HIT);
    NIR_PASS_V(nir, brw_nir_lower_shader_returns);
-   lower_rt_scratch(nir);
+   lower_rt_io_and_scratch(nir);
 }
 
 void
@@ -128,7 +206,7 @@ brw_nir_lower_closest_hit(nir_shader *nir)
 {
    assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT);
    NIR_PASS_V(nir, brw_nir_lower_shader_returns);
-   lower_rt_scratch(nir);
+   lower_rt_io_and_scratch(nir);
 }
 
 void
@@ -136,7 +214,7 @@ brw_nir_lower_miss(nir_shader *nir)
 {
    assert(nir->info.stage == MESA_SHADER_MISS);
    NIR_PASS_V(nir, brw_nir_lower_shader_returns);
-   lower_rt_scratch(nir);
+   lower_rt_io_and_scratch(nir);
 }
 
 void
@@ -144,7 +222,7 @@ brw_nir_lower_callable(nir_shader *nir)
 {
    assert(nir->info.stage == MESA_SHADER_CALLABLE);
    NIR_PASS_V(nir, brw_nir_lower_shader_returns);
-   lower_rt_scratch(nir);
+   lower_rt_io_and_scratch(nir);
 }
 
 void
@@ -155,5 +233,5 @@ brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
    assert(intersection->info.stage == MESA_SHADER_INTERSECTION);
    assert(any_hit == NULL || any_hit->info.stage == MESA_SHADER_ANY_HIT);
    NIR_PASS_V(intersection, brw_nir_lower_shader_returns);
-   lower_rt_scratch(intersection);
+   lower_rt_io_and_scratch(intersection);
 }
index 8876318..c391301 100644 (file)
@@ -41,9 +41,10 @@ void brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
                                                  const nir_shader *any_hit,
                                                  const struct gen_device_info *devinfo);
 
-/* We reserve the first 8B of the stack for callee data pointers */
+/* We reserve the first 16B of the stack for callee data pointers */
 #define BRW_BTD_STACK_RESUME_BSR_ADDR_OFFSET 0
-#define BRW_BTD_STACK_CALLEE_DATA_SIZE 8
+#define BRW_BTD_STACK_CALL_DATA_PTR_OFFSET 8
+#define BRW_BTD_STACK_CALLEE_DATA_SIZE 16
 
 /* We require the stack to be 8B aligned at the start of a shader */
 #define BRW_BTD_STACK_ALIGN 8