intel/rt: Add support for hit attributes
authorJason Ekstrand <jason@jlekstrand.net>
Thu, 6 Aug 2020 21:42:14 +0000 (16:42 -0500)
committerMarge Bot <eric+marge@anholt.net>
Wed, 25 Nov 2020 05:37:10 +0000 (05:37 +0000)
For triangle geometry, the hit attributes are always two floats which
contain the barycentric coordinates of the hit.  For procedural
geometry, they're an arbitrary blob of data passed from the intersection
shader to the hit shaders.  In our implementation, we stash that data
right after the HW RayQuery in the ray stack.

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7356>

src/intel/compiler/brw_nir_rt.c

index 9c7f7be..c7ca1cd 100644 (file)
@@ -66,6 +66,10 @@ lower_rt_io_derefs(nir_shader *shader)
    nir_foreach_variable_with_modes(var, shader, nir_var_shader_call_data)
       num_shader_call_vars++;
 
+   unsigned num_ray_hit_attrib_vars = 0;
+   nir_foreach_variable_with_modes(var, shader, nir_var_ray_hit_attrib)
+      num_ray_hit_attrib_vars++;
+
    /* At most one payload is allowed because it's an input.  Technically, this
     * is also true for hit attribute variables.  However, after we inline an
     * any-hit shader into an intersection shader, we can end up with multiple
@@ -87,6 +91,22 @@ lower_rt_io_derefs(nir_shader *shader)
       progress = true;
    }
 
+   gl_shader_stage stage = shader->info.stage;
+   nir_ssa_def *hit_attrib_addr = NULL;
+   if (num_ray_hit_attrib_vars > 0) {
+      assert(stage == MESA_SHADER_ANY_HIT ||
+             stage == MESA_SHADER_CLOSEST_HIT ||
+             stage == MESA_SHADER_INTERSECTION);
+      nir_ssa_def *hit_addr =
+         brw_nir_rt_mem_hit_addr(&b, stage == MESA_SHADER_CLOSEST_HIT);
+      /* The vec2 barycentrics are in 2nd and 3rd dwords of MemHit */
+      nir_ssa_def *bary_addr = nir_iadd_imm(&b, hit_addr, 4);
+      hit_attrib_addr = nir_bcsel(&b, nir_load_leaf_procedural_intel(&b),
+                                      brw_nir_rt_hit_attrib_data_addr(&b),
+                                      bary_addr);
+      progress = true;
+   }
+
    nir_foreach_block(block, impl) {
       nir_foreach_instr_safe(instr, block) {
          if (instr->type != nir_instr_type_deref)
@@ -106,6 +126,19 @@ lower_rt_io_derefs(nir_shader *shader)
                nir_instr_remove(&deref->instr);
                progress = true;
             }
+         } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
+            deref->modes = nir_var_function_temp;
+            if (deref->deref_type == nir_deref_type_var) {
+               b.cursor = nir_before_instr(&deref->instr);
+               nir_deref_instr *cast =
+                  nir_build_deref_cast(&b, hit_attrib_addr,
+                                       nir_var_function_temp,
+                                       deref->type, 0);
+               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
+                                        nir_src_for_ssa(&cast->dest.ssa));
+               nir_instr_remove(&deref->instr);
+               progress = true;
+            }
          }
 
          /* We're going to lower all function_temp memory to scratch using
@@ -172,18 +205,20 @@ lower_rt_io_and_scratch(nir_shader *nir)
     */
    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
               nir_var_function_temp |
-              nir_var_shader_call_data,
+              nir_var_shader_call_data |
+              nir_var_ray_hit_attrib,
               glsl_get_natural_size_align_bytes);
 
    /* Now patch any derefs to I/O vars */
    NIR_PASS_V(nir, lower_rt_io_derefs);
 
-   /* Finally, lower any remaining function_temp and mem_constant access to
-    * 64-bit global memory access.
+   /* Finally, lower any remaining function_temp, mem_constant, or
+    * ray_hit_attrib access to 64-bit global memory access.
     */
    NIR_PASS_V(nir, nir_lower_explicit_io,
               nir_var_function_temp |
-              nir_var_mem_constant,
+              nir_var_mem_constant |
+              nir_var_ray_hit_attrib,
               nir_address_format_64bit_global);
 }