d3d12: Implement launch_grid
authorJesse Natalie <jenatali@microsoft.com>
Fri, 31 Dec 2021 21:50:42 +0000 (13:50 -0800)
committerMarge Bot <emma+marge@anholt.net>
Tue, 11 Jan 2022 01:36:56 +0000 (01:36 +0000)
Some more refactoring in d3d12_draw.cpp to re-use a bunch of state
and descriptor management, and some refactoring of the dirty states.

Reviewed-by: Sil Vilerino <sivileri@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14367>

src/gallium/drivers/d3d12/d3d12_context.cpp
src/gallium/drivers/d3d12/d3d12_context.h
src/gallium/drivers/d3d12/d3d12_draw.cpp

index f3d0052..08fefc9 100644 (file)
@@ -2225,6 +2225,7 @@ d3d12_context_create(struct pipe_screen *pscreen, void *priv, unsigned flags)
    ctx->base.clear_render_target = d3d12_clear_render_target;
    ctx->base.clear_depth_stencil = d3d12_clear_depth_stencil;
    ctx->base.draw_vbo = d3d12_draw_vbo;
+   ctx->base.launch_grid = d3d12_launch_grid;
    ctx->base.flush = d3d12_flush;
    ctx->base.flush_resource = d3d12_flush_resource;
 
index 1fdbf13..b37a742 100644 (file)
@@ -62,6 +62,8 @@ enum d3d12_dirty_flags
    D3D12_DIRTY_ROOT_SIGNATURE   = (1 << 14),
    D3D12_DIRTY_STREAM_OUTPUT    = (1 << 15),
    D3D12_DIRTY_STRIP_CUT_VALUE  = (1 << 16),
+   D3D12_DIRTY_COMPUTE_SHADER   = (1 << 17),
+   D3D12_DIRTY_COMPUTE_ROOT_SIGNATURE = (1 << 18),
 };
 
 enum d3d12_shader_dirty_flags
@@ -73,11 +75,16 @@ enum d3d12_shader_dirty_flags
    D3D12_SHADER_DIRTY_IMAGE         = (1 << 4),
 };
 
-#define D3D12_DIRTY_PSO (D3D12_DIRTY_BLEND | D3D12_DIRTY_RASTERIZER | D3D12_DIRTY_ZSA | \
-                         D3D12_DIRTY_FRAMEBUFFER | D3D12_DIRTY_SAMPLE_MASK | \
-                         D3D12_DIRTY_VERTEX_ELEMENTS | D3D12_DIRTY_PRIM_MODE | \
-                         D3D12_DIRTY_SHADER | D3D12_DIRTY_ROOT_SIGNATURE | \
-                         D3D12_DIRTY_STRIP_CUT_VALUE)
+#define D3D12_DIRTY_GFX_PSO (D3D12_DIRTY_BLEND | D3D12_DIRTY_RASTERIZER | D3D12_DIRTY_ZSA | \
+                             D3D12_DIRTY_FRAMEBUFFER | D3D12_DIRTY_SAMPLE_MASK | \
+                             D3D12_DIRTY_VERTEX_ELEMENTS | D3D12_DIRTY_PRIM_MODE | \
+                             D3D12_DIRTY_SHADER | D3D12_DIRTY_ROOT_SIGNATURE | \
+                             D3D12_DIRTY_STRIP_CUT_VALUE)
+#define D3D12_DIRTY_COMPUTE_PSO (D3D12_DIRTY_COMPUTE_SHADER | D3D12_DIRTY_COMPUTE_ROOT_SIGNATURE)
+
+#define D3D12_DIRTY_COMPUTE_MASK (D3D12_DIRTY_COMPUTE_SHADER | D3D12_DIRTY_COMPUTE_ROOT_SIGNATURE)
+#define D3D12_DIRTY_GFX_MASK ~D3D12_DIRTY_COMPUTE_MASK
+
 
 #define D3D12_SHADER_DIRTY_ALL (D3D12_SHADER_DIRTY_CONSTBUF | D3D12_SHADER_DIRTY_SAMPLER_VIEWS | \
                                 D3D12_SHADER_DIRTY_SAMPLERS | D3D12_SHADER_DIRTY_SSBO | \
@@ -321,6 +328,10 @@ d3d12_draw_vbo(struct pipe_context *pctx,
                unsigned num_draws);
 
 void
+d3d12_launch_grid(struct pipe_context *pctx,
+                  const struct pipe_grid_info *info);
+
+void
 d3d12_blit(struct pipe_context *pctx,
            const struct pipe_blit_info *info);
 
index bbf89a5..dd7bbbb 100644 (file)
@@ -338,11 +338,11 @@ fill_image_descriptors(struct d3d12_context *ctx,
 }
 
 static unsigned
-fill_state_vars(struct d3d12_context *ctx,
-                const struct pipe_draw_info *dinfo,
-                const struct pipe_draw_start_count_bias *draw,
-                struct d3d12_shader *shader,
-                uint32_t *values)
+fill_graphics_state_vars(struct d3d12_context *ctx,
+                         const struct pipe_draw_info *dinfo,
+                         const struct pipe_draw_start_count_bias *draw,
+                         struct d3d12_shader *shader,
+                         uint32_t *values)
 {
    unsigned size = 0;
 
@@ -379,13 +379,14 @@ fill_state_vars(struct d3d12_context *ctx,
 }
 
 static bool
-check_descriptors_left(struct d3d12_context *ctx)
+check_descriptors_left(struct d3d12_context *ctx, bool compute)
 {
    struct d3d12_batch *batch = d3d12_current_batch(ctx);
    unsigned needed_descs = 0;
 
-   for (unsigned i = 0; i < D3D12_GFX_SHADER_STAGES; ++i) {
-      struct d3d12_shader_selector *shader = ctx->gfx_stages[i];
+   unsigned count = compute ? 1 : D3D12_GFX_SHADER_STAGES;
+   for (unsigned i = 0; i < count; ++i) {
+      struct d3d12_shader_selector *shader = compute ? ctx->compute_state : ctx->gfx_stages[i];
 
       if (!shader)
          continue;
@@ -400,8 +401,8 @@ check_descriptors_left(struct d3d12_context *ctx)
       return false;
 
    needed_descs = 0;
-   for (unsigned i = 0; i < D3D12_GFX_SHADER_STAGES; ++i) {
-      struct d3d12_shader_selector *shader = ctx->gfx_stages[i];
+   for (unsigned i = 0; i < count; ++i) {
+      struct d3d12_shader_selector *shader = compute ? ctx->compute_state : ctx->gfx_stages[i];
 
       if (!shader)
          continue;
@@ -417,6 +418,59 @@ check_descriptors_left(struct d3d12_context *ctx)
 
 #define MAX_DESCRIPTOR_TABLES (D3D12_GFX_SHADER_STAGES * 4)
 
+static void
+update_shader_stage_root_parameters(struct d3d12_context *ctx,
+                                    const struct d3d12_shader_selector *shader_sel,
+                                    unsigned &num_params,
+                                    unsigned &num_root_descriptors,
+                                    D3D12_GPU_DESCRIPTOR_HANDLE root_desc_tables[MAX_DESCRIPTOR_TABLES],
+                                    int root_desc_indices[MAX_DESCRIPTOR_TABLES])
+{
+   auto stage = shader_sel->stage;
+   struct d3d12_shader *shader = shader_sel->current;
+   uint64_t dirty = ctx->shader_dirty[stage];
+   assert(shader);
+
+   if (shader->num_cb_bindings > 0) {
+      if (dirty & D3D12_SHADER_DIRTY_CONSTBUF) {
+         assert(num_root_descriptors < MAX_DESCRIPTOR_TABLES);
+         root_desc_tables[num_root_descriptors] = fill_cbv_descriptors(ctx, shader, stage);
+         root_desc_indices[num_root_descriptors++] = num_params;
+      }
+      num_params++;
+   }
+   if (shader->end_srv_binding > 0) {
+      if (dirty & D3D12_SHADER_DIRTY_SAMPLER_VIEWS) {
+         assert(num_root_descriptors < MAX_DESCRIPTOR_TABLES);
+         root_desc_tables[num_root_descriptors] = fill_srv_descriptors(ctx, shader, stage);
+         root_desc_indices[num_root_descriptors++] = num_params;
+      }
+      num_params++;
+      if (dirty & D3D12_SHADER_DIRTY_SAMPLERS) {
+         assert(num_root_descriptors < MAX_DESCRIPTOR_TABLES);
+         root_desc_tables[num_root_descriptors] = fill_sampler_descriptors(ctx, shader_sel, stage);
+         root_desc_indices[num_root_descriptors++] = num_params;
+      }
+      num_params++;
+   }
+   if (shader->nir->info.num_ssbos > 0) {
+      if (dirty & D3D12_SHADER_DIRTY_SSBO) {
+         assert(num_root_descriptors < MAX_DESCRIPTOR_TABLES);
+         root_desc_tables[num_root_descriptors] = fill_ssbo_descriptors(ctx, shader, stage);
+         root_desc_indices[num_root_descriptors++] = num_params;
+      }
+      num_params++;
+   }
+   if (shader->nir->info.num_images > 0) {
+      if (dirty & D3D12_SHADER_DIRTY_IMAGE) {
+         assert(num_root_descriptors < MAX_DESCRIPTOR_TABLES);
+         root_desc_tables[num_root_descriptors] = fill_image_descriptors(ctx, shader, stage);
+         root_desc_indices[num_root_descriptors++] = num_params;
+      }
+      num_params++;
+   }
+}
+
 static unsigned
 update_graphics_root_parameters(struct d3d12_context *ctx,
                                 const struct pipe_draw_info *dinfo,
@@ -425,64 +479,39 @@ update_graphics_root_parameters(struct d3d12_context *ctx,
                                 int root_desc_indices[MAX_DESCRIPTOR_TABLES])
 {
    unsigned num_params = 0;
-   unsigned num_root_desciptors = 0;
+   unsigned num_root_descriptors = 0;
 
    for (unsigned i = 0; i < D3D12_GFX_SHADER_STAGES; ++i) {
-      if (!ctx->gfx_stages[i])
+      struct d3d12_shader_selector *shader_sel = ctx->gfx_stages[i];
+      if (!shader_sel)
          continue;
 
-      struct d3d12_shader_selector *shader_sel = ctx->gfx_stages[i];
-      struct d3d12_shader *shader = shader_sel->current;
-      uint64_t dirty = ctx->shader_dirty[i];
-      assert(shader);
-
-      if (shader->num_cb_bindings > 0) {
-         if (dirty & D3D12_SHADER_DIRTY_CONSTBUF) {
-            assert(num_root_desciptors < MAX_DESCRIPTOR_TABLES);
-            root_desc_tables[num_root_desciptors] = fill_cbv_descriptors(ctx, shader, i);
-            root_desc_indices[num_root_desciptors++] = num_params;
-         }
-         num_params++;
-      }
-      if (shader->end_srv_binding > 0) {
-         if (dirty & D3D12_SHADER_DIRTY_SAMPLER_VIEWS) {
-            assert(num_root_desciptors < MAX_DESCRIPTOR_TABLES);
-            root_desc_tables[num_root_desciptors] = fill_srv_descriptors(ctx, shader, i);
-            root_desc_indices[num_root_desciptors++] = num_params;
-         }
-         num_params++;
-         if (dirty & D3D12_SHADER_DIRTY_SAMPLERS) {
-            assert(num_root_desciptors < MAX_DESCRIPTOR_TABLES);
-            root_desc_tables[num_root_desciptors] = fill_sampler_descriptors(ctx, shader_sel, i);
-            root_desc_indices[num_root_desciptors++] = num_params;
-         }
-         num_params++;
-      }
-      if (shader->nir->info.num_ssbos > 0) {
-         if (dirty & D3D12_SHADER_DIRTY_SSBO) {
-            assert(num_root_desciptors < MAX_DESCRIPTOR_TABLES);
-            root_desc_tables[num_root_desciptors] = fill_ssbo_descriptors(ctx, shader, i);
-            root_desc_indices[num_root_desciptors++] = num_params;
-         }
-         num_params++;
-      }
-      if (shader->nir->info.num_images > 0) {
-         if (dirty & D3D12_SHADER_DIRTY_IMAGE) {
-            assert(num_root_desciptors < MAX_DESCRIPTOR_TABLES);
-            root_desc_tables[num_root_desciptors] = fill_image_descriptors(ctx, shader, i);
-            root_desc_indices[num_root_desciptors++] = num_params;
-         }
-         num_params++;
-      }
+      update_shader_stage_root_parameters(ctx, shader_sel, num_params, num_root_descriptors, root_desc_tables, root_desc_indices);
       /* TODO Don't always update state vars */
-      if (shader->num_state_vars > 0) {
+      if (shader_sel->current->num_state_vars > 0) {
          uint32_t constants[D3D12_MAX_STATE_VARS * 4];
-         unsigned size = fill_state_vars(ctx, dinfo, draw, shader, constants);
+         unsigned size = fill_graphics_state_vars(ctx, dinfo, draw, shader_sel->current, constants);
          ctx->cmdlist->SetGraphicsRoot32BitConstants(num_params, size, constants, 0);
          num_params++;
       }
    }
-   return num_root_desciptors;
+   return num_root_descriptors;
+}
+
+static unsigned
+update_compute_root_parameters(struct d3d12_context *ctx,
+                               const struct pipe_grid_info *info,
+                               D3D12_GPU_DESCRIPTOR_HANDLE root_desc_tables[MAX_DESCRIPTOR_TABLES],
+                               int root_desc_indices[MAX_DESCRIPTOR_TABLES])
+{
+   unsigned num_params = 0;
+   unsigned num_root_descriptors = 0;
+
+   struct d3d12_shader_selector *shader_sel = ctx->compute_state;
+   if (shader_sel) {
+      update_shader_stage_root_parameters(ctx, shader_sel, num_params, num_root_descriptors, root_desc_tables, root_desc_indices);
+   }
+   return num_root_descriptors;
 }
 
 static bool
@@ -761,14 +790,14 @@ d3d12_draw_vbo(struct pipe_context *pctx,
       }
    }
 
-   if (!ctx->current_gfx_pso || ctx->state_dirty & D3D12_DIRTY_PSO) {
+   if (!ctx->current_gfx_pso || ctx->state_dirty & D3D12_DIRTY_GFX_PSO) {
       ctx->current_gfx_pso = d3d12_get_gfx_pipeline_state(ctx);
       assert(ctx->current_gfx_pso);
    }
 
    ctx->cmdlist_dirty |= ctx->state_dirty;
 
-   if (!check_descriptors_left(ctx))
+   if (!check_descriptors_left(ctx, false))
       d3d12_flush_cmdlist(ctx);
    batch = d3d12_current_batch(ctx);
 
@@ -777,7 +806,7 @@ d3d12_draw_vbo(struct pipe_context *pctx,
       ctx->cmdlist->SetGraphicsRootSignature(ctx->gfx_pipeline_state.root_signature);
    }
 
-   if (ctx->cmdlist_dirty & D3D12_DIRTY_PSO) {
+   if (ctx->cmdlist_dirty & D3D12_DIRTY_GFX_PSO) {
       assert(ctx->current_gfx_pso);
       d3d12_batch_reference_object(batch, ctx->current_gfx_pso);
       ctx->cmdlist->SetPipelineState(ctx->current_gfx_pso);
@@ -785,7 +814,7 @@ d3d12_draw_vbo(struct pipe_context *pctx,
 
    D3D12_GPU_DESCRIPTOR_HANDLE root_desc_tables[MAX_DESCRIPTOR_TABLES];
    int root_desc_indices[MAX_DESCRIPTOR_TABLES];
-   unsigned num_root_desciptors = update_graphics_root_parameters(ctx, dinfo, &draws[0], root_desc_tables, root_desc_indices);
+   unsigned num_root_descriptors = update_graphics_root_parameters(ctx, dinfo, &draws[0], root_desc_tables, root_desc_indices);
 
    bool need_zero_one_depth_range = d3d12_need_zero_one_depth_range(ctx);
    if (need_zero_one_depth_range != ctx->need_zero_one_depth_range) {
@@ -923,7 +952,7 @@ d3d12_draw_vbo(struct pipe_context *pctx,
 
    d3d12_apply_resource_states(ctx);
 
-   for (unsigned i = 0; i < num_root_desciptors; ++i)
+   for (unsigned i = 0; i < num_root_descriptors; ++i)
       ctx->cmdlist->SetGraphicsRootDescriptorTable(root_desc_indices[i], root_desc_tables[i]);
 
    if (dinfo->index_size > 0)
@@ -934,13 +963,14 @@ d3d12_draw_vbo(struct pipe_context *pctx,
       ctx->cmdlist->DrawInstanced(draws[0].count, dinfo->instance_count,
                                   draws[0].start, dinfo->start_instance);
 
-   ctx->state_dirty = 0;
+   ctx->state_dirty &= D3D12_DIRTY_COMPUTE_MASK;
    batch->pending_memory_barrier = false;
 
-   if (index_buffer)
-      ctx->cmdlist_dirty = 0;
-   else
-      ctx->cmdlist_dirty &= D3D12_DIRTY_INDEX_BUFFER;
+   ctx->cmdlist_dirty &= D3D12_DIRTY_COMPUTE_MASK |
+      (index_buffer ? 0 : D3D12_DIRTY_INDEX_BUFFER);
+
+   /* The next dispatch needs to reassert the compute PSO */
+   ctx->cmdlist_dirty |= D3D12_DIRTY_COMPUTE_SHADER;
 
    for (unsigned i = 0; i < D3D12_GFX_SHADER_STAGES; ++i)
       ctx->shader_dirty[i] = 0;
@@ -952,3 +982,69 @@ d3d12_draw_vbo(struct pipe_context *pctx,
       }
    }
 }
+
+void
+d3d12_launch_grid(struct pipe_context *pctx, const struct pipe_grid_info *info)
+{
+   struct d3d12_context *ctx = d3d12_context(pctx);
+   struct d3d12_batch *batch;
+
+   d3d12_select_compute_shader_variants(ctx, info);
+   d3d12_validate_queries(ctx);
+   struct d3d12_shader *shader = ctx->compute_state ? ctx->compute_state->current : NULL;
+   if (ctx->compute_pipeline_state.stage != shader) {
+      ctx->compute_pipeline_state.stage = shader;
+      ctx->state_dirty |= D3D12_DIRTY_COMPUTE_SHADER;
+   }
+
+   if (!ctx->compute_pipeline_state.root_signature || ctx->state_dirty & D3D12_DIRTY_COMPUTE_SHADER) {
+      ID3D12RootSignature *root_signature = d3d12_get_root_signature(ctx, true);
+      if (ctx->compute_pipeline_state.root_signature != root_signature) {
+         ctx->compute_pipeline_state.root_signature = root_signature;
+         ctx->state_dirty |= D3D12_DIRTY_COMPUTE_ROOT_SIGNATURE;
+         ctx->shader_dirty[PIPE_SHADER_COMPUTE] |= D3D12_SHADER_DIRTY_ALL;
+      }
+   }
+
+   if (!ctx->current_compute_pso || ctx->state_dirty & D3D12_DIRTY_COMPUTE_PSO) {
+      ctx->current_compute_pso = d3d12_get_compute_pipeline_state(ctx);
+      assert(ctx->current_compute_pso);
+   }
+
+   ctx->cmdlist_dirty |= ctx->state_dirty;
+
+   if (!check_descriptors_left(ctx, true))
+      d3d12_flush_cmdlist(ctx);
+   batch = d3d12_current_batch(ctx);
+
+   if (ctx->cmdlist_dirty & D3D12_DIRTY_COMPUTE_ROOT_SIGNATURE) {
+      d3d12_batch_reference_object(batch, ctx->compute_pipeline_state.root_signature);
+      ctx->cmdlist->SetComputeRootSignature(ctx->compute_pipeline_state.root_signature);
+   }
+
+   if (ctx->cmdlist_dirty & D3D12_DIRTY_COMPUTE_PSO) {
+      assert(ctx->current_compute_pso);
+      d3d12_batch_reference_object(batch, ctx->current_compute_pso);
+      ctx->cmdlist->SetPipelineState(ctx->current_compute_pso);
+   }
+
+   D3D12_GPU_DESCRIPTOR_HANDLE root_desc_tables[MAX_DESCRIPTOR_TABLES];
+   int root_desc_indices[MAX_DESCRIPTOR_TABLES];
+   unsigned num_root_descriptors = update_compute_root_parameters(ctx, info, root_desc_tables, root_desc_indices);
+
+   d3d12_apply_resource_states(ctx);
+
+   for (unsigned i = 0; i < num_root_descriptors; ++i)
+      ctx->cmdlist->SetComputeRootDescriptorTable(root_desc_indices[i], root_desc_tables[i]);
+
+   ctx->cmdlist->Dispatch(info->grid[0], info->grid[1], info->grid[2]);
+
+   ctx->state_dirty &= D3D12_DIRTY_GFX_MASK;
+   ctx->cmdlist_dirty &= D3D12_DIRTY_GFX_MASK;
+
+   /* The next draw needs to reassert the graphics PSO */
+   ctx->cmdlist_dirty |= D3D12_DIRTY_SHADER;
+   batch->pending_memory_barrier = false;
+
+   ctx->shader_dirty[PIPE_SHADER_COMPUTE] = 0;
+}