d3d12: Run point sprite lowering pass on multi-stream GS when safe
authorJesse Natalie <jenatali@microsoft.com>
Fri, 14 Jan 2022 21:43:15 +0000 (13:43 -0800)
committerMarge Bot <emma+marge@anholt.net>
Fri, 21 Jan 2022 23:08:26 +0000 (23:08 +0000)
In the case of a multi-stream GS that is attempting to output wide
points to stream 0, we can support this by lowering stream 0 to
triangles and then removing the other streams. This is only valid
to do if the other streams are not being written to stream output,
either if they're not present in the SO info or no buffer is bound.

Fixes the arb_gpu_shader5/arb_gpu_shader5-emitstreamvertex_nodraw
piglit test which does this.

Reviewed-by: Sil Vilerino <sivileri@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14624>

src/gallium/drivers/d3d12/d3d12_compiler.cpp
src/gallium/drivers/d3d12/d3d12_lower_point_sprite.c

index 9c6c68b..134d268 100644 (file)
@@ -375,6 +375,19 @@ fill_mode_lowered(struct d3d12_context *ctx, const struct pipe_draw_info *dinfo)
 }
 
 static bool
+has_stream_out_for_streams(struct d3d12_context *ctx)
+{
+   unsigned mask = ctx->gfx_stages[PIPE_SHADER_GEOMETRY]->initial->info.gs.active_stream_mask & ~1;
+   for (unsigned i = 0; i < ctx->gfx_pipeline_state.so_info.num_outputs; ++i) {
+      unsigned stream = ctx->gfx_pipeline_state.so_info.output[i].stream;
+      if (((1 << stream) & mask) &&
+         ctx->so_buffer_views[stream].SizeInBytes)
+         return true;
+   }
+   return false;
+}
+
+static bool
 needs_point_sprite_lowering(struct d3d12_context *ctx, const struct pipe_draw_info *dinfo)
 {
    struct d3d12_shader_selector *vs = ctx->gfx_stages[PIPE_SHADER_VERTEX];
@@ -384,7 +397,9 @@ needs_point_sprite_lowering(struct d3d12_context *ctx, const struct pipe_draw_in
       /* There is an user GS; Check if it outputs points with PSIZE */
       return (gs->initial->info.gs.output_primitive == GL_POINTS &&
               (gs->initial->info.outputs_written & VARYING_BIT_PSIZ ||
-                 ctx->gfx_pipeline_state.rast->base.point_size > 1.0));
+                 ctx->gfx_pipeline_state.rast->base.point_size > 1.0) &&
+              (gs->initial->info.gs.active_stream_mask == 1 ||
+                 !has_stream_out_for_streams(ctx)));
    } else {
       /* No user GS; check if we are drawing wide points */
       return ((dinfo->mode == PIPE_PRIM_POINTS ||
index 35eec6f..1c4160b 100644 (file)
@@ -162,43 +162,44 @@ lower_emit_vertex(nir_intrinsic_instr *instr, nir_builder *b, struct lower_state
    get_scaled_point_size(b, state, &point_width, &point_height);
 
    nir_instr_remove(&instr->instr);
-
-   for (unsigned i = 0; i < 4; i++) {
-      /* All outputs need to be emitted for each vertex */
-      for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
-         if (state->varying[slot] != NULL) {
-            nir_store_var(b, state->varying_out[slot], state->varying[slot],
-                          state->varying_write_mask[slot]);
+   if (stream_id == 0) {
+      for (unsigned i = 0; i < 4; i++) {
+         /* All outputs need to be emitted for each vertex */
+         for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
+            if (state->varying[slot] != NULL) {
+               nir_store_var(b, state->varying_out[slot], state->varying[slot],
+                             state->varying_write_mask[slot]);
+            }
          }
+
+         /* pos = scaled_point_size * point_dir + point_pos */
+         nir_ssa_def *point_dir = get_point_dir(b, state, i);
+         nir_ssa_def *pos = nir_vec4(b,
+                                     nir_ffma(b,
+                                              point_width,
+                                              nir_channel(b, point_dir, 0),
+                                              nir_channel(b, state->point_pos, 0)),
+                                     nir_ffma(b,
+                                              point_height,
+                                              nir_channel(b, point_dir, 1),
+                                              nir_channel(b, state->point_pos, 1)),
+                                     nir_channel(b, state->point_pos, 2),
+                                     nir_channel(b, state->point_pos, 3));
+         nir_store_var(b, state->pos_out, pos, 0xf);
+
+         /* point coord */
+         nir_ssa_def *point_coord = get_point_coord(b, state, i);
+         for (unsigned j = 0; j < state->num_point_coords; ++j)
+            nir_store_var(b, state->point_coord_out[j], point_coord, 0xf);
+
+         /* EmitVertex */
+         nir_emit_vertex(b, .stream_id = stream_id);
       }
 
-      /* pos = scaled_point_size * point_dir + point_pos */
-      nir_ssa_def *point_dir = get_point_dir(b, state, i);
-      nir_ssa_def *pos = nir_vec4(b,
-                                  nir_ffma(b,
-                                           point_width,
-                                           nir_channel(b, point_dir, 0),
-                                           nir_channel(b, state->point_pos, 0)),
-                                  nir_ffma(b,
-                                           point_height,
-                                           nir_channel(b, point_dir, 1),
-                                           nir_channel(b, state->point_pos, 1)),
-                                  nir_channel(b, state->point_pos, 2),
-                                  nir_channel(b, state->point_pos, 3));
-      nir_store_var(b, state->pos_out, pos, 0xf);
-
-      /* point coord */
-      nir_ssa_def *point_coord = get_point_coord(b, state, i);
-      for (unsigned j = 0; j < state->num_point_coords; ++j)
-         nir_store_var(b, state->point_coord_out[j], point_coord, 0xf);
-
-      /* EmitVertex */
-      nir_emit_vertex(b, .stream_id = stream_id);
+      /* EndPrimitive */
+      nir_end_primitive(b, .stream_id = stream_id);
    }
 
-   /* EndPrimitive */
-   nir_end_primitive(b, .stream_id = stream_id);
-
    /* Reset everything */
    state->point_pos = NULL;
    state->point_size = NULL;
@@ -298,7 +299,9 @@ d3d12_lower_point_sprite(nir_shader *shader,
    }
 
    shader->info.gs.output_primitive = GL_TRIANGLE_STRIP;
-   shader->info.gs.vertices_out *= 4;
+   shader->info.gs.vertices_out = shader->info.gs.vertices_out * 4 /
+      util_bitcount(shader->info.gs.active_stream_mask);
+   shader->info.gs.active_stream_mask = 1;
 
    return progress;
 }