nir: Add shader enqueue data structures and handling
authorKonstantin Seurer <konstantin.seurer@gmail.com>
Tue, 1 Aug 2023 12:35:21 +0000 (14:35 +0200)
committerMarge Bot <emma+marge@anholt.net>
Fri, 18 Aug 2023 16:57:22 +0000 (16:57 +0000)
There are two new variable modes:
- nir_var_mem_node_payload
- nir_var_mem_node_payload_in

Also add a few more intrinsics and some shader info.

Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24512>

src/compiler/nir/nir.c
src/compiler/nir/nir.h
src/compiler/nir/nir_intrinsics.py
src/compiler/nir/nir_lower_io.c
src/compiler/nir/nir_lower_system_values.c
src/compiler/nir/nir_print.c
src/compiler/nir/nir_validate.c
src/compiler/shader_enums.c
src/compiler/shader_enums.h
src/compiler/shader_info.h

index 4a10e2f..4e2b0b2 100644 (file)
@@ -245,6 +245,8 @@ nir_shader_add_variable(nir_shader *shader, nir_variable *var)
    case nir_var_shader_call_data:
    case nir_var_ray_hit_attrib:
    case nir_var_mem_task_payload:
+   case nir_var_mem_node_payload:
+   case nir_var_mem_node_payload_in:
    case nir_var_mem_global:
       break;
 
@@ -2265,6 +2267,10 @@ nir_intrinsic_from_system_value(gl_system_value val)
       return nir_intrinsic_load_frag_size;
    case SYSTEM_VALUE_FRAG_INVOCATION_COUNT:
       return nir_intrinsic_load_frag_invocation_count;
+   case SYSTEM_VALUE_SHADER_INDEX:
+      return nir_intrinsic_load_shader_index;
+   case SYSTEM_VALUE_COALESCED_INPUT_COUNT:
+      return nir_intrinsic_load_coalesced_input_count;
    default:
       unreachable("system value does not directly correspond to intrinsic");
    }
@@ -2421,6 +2427,10 @@ nir_system_value_from_intrinsic(nir_intrinsic_op intrin)
       return SYSTEM_VALUE_FRAG_SIZE;
    case nir_intrinsic_load_frag_invocation_count:
       return SYSTEM_VALUE_FRAG_INVOCATION_COUNT;
+   case nir_intrinsic_load_shader_index:
+      return SYSTEM_VALUE_SHADER_INDEX;
+   case nir_intrinsic_load_coalesced_input_count:
+      return SYSTEM_VALUE_COALESCED_INPUT_COUNT;
    default:
       unreachable("intrinsic doesn't produce a system value");
    }
index 11b1668..6b5414d 100644 (file)
@@ -191,14 +191,16 @@ typedef enum {
    nir_var_mem_ssbo              = (1 << 9),
    nir_var_mem_constant          = (1 << 10),
    nir_var_mem_task_payload      = (1 << 11),
+   nir_var_mem_node_payload      = (1 << 12),
+   nir_var_mem_node_payload_in   = (1 << 13),
 
    /* Generic modes intentionally come last. See encode_dref_modes() in
     * nir_serialize.c for more details.
     */
-   nir_var_shader_temp           = (1 << 12),
-   nir_var_function_temp         = (1 << 13),
-   nir_var_mem_shared            = (1 << 14),
-   nir_var_mem_global            = (1 << 15),
+   nir_var_shader_temp           = (1 << 14),
+   nir_var_function_temp         = (1 << 15),
+   nir_var_mem_shared            = (1 << 16),
+   nir_var_mem_global            = (1 << 17),
 
    nir_var_mem_generic           = (nir_var_shader_temp |
                                     nir_var_function_temp |
@@ -212,11 +214,10 @@ typedef enum {
     * is only for mesh stages.
     */
    nir_var_vec_indexable_modes   = nir_var_mem_ubo | nir_var_mem_ssbo |
-                                   nir_var_mem_shared | nir_var_mem_global |
-                                   nir_var_mem_push_const |
-                                   nir_var_mem_task_payload |
-                                   nir_var_shader_out,
-   nir_num_variable_modes        = 16,
+                                 nir_var_mem_shared | nir_var_mem_global |
+                                 nir_var_mem_push_const | nir_var_mem_task_payload |
+                                 nir_var_shader_out,
+   nir_num_variable_modes        = 18,
    nir_var_all                   = (1 << nir_num_variable_modes) - 1,
 } nir_variable_mode;
 MESA_DEFINE_CPP_ENUM_BITFIELD_OPERATORS(nir_variable_mode)
@@ -476,7 +477,7 @@ typedef struct nir_variable {
        *
        * \sa nir_variable_mode
        */
-      unsigned mode : 16;
+      unsigned mode : 18;
 
       /**
        * Is the variable read-only?
@@ -766,6 +767,9 @@ typedef struct nir_variable {
             uint16_t stride;
          } xfb;
       };
+
+      /** Name of the node this payload will be enqueued to. */
+      const char *node_name;
    } data;
 
    /**
index eccade1..782e087 100644 (file)
@@ -1927,3 +1927,20 @@ intrinsic("load_texture_size_etna", src_comp=[1], dest_comp=3,
 
 # src[] = { field }.
 load("push_constant_zink", [1], [COMPONENT], [CAN_ELIMINATE, CAN_REORDER])
+
+system_value("shader_index", 1, bit_sizes=[32])
+
+system_value("coalesced_input_count", 1, bit_sizes=[32])
+
+# Initialize a payload array per scope
+#
+#   0. Payloads deref
+#   1. Payload count
+#   2. Node index
+intrinsic("initialize_node_payloads", src_comp=[-1, 1, 1], indices=[EXECUTION_SCOPE])
+
+# Optionally enqueue payloads after shader finished writing to them
+intrinsic("enqueue_node_payloads", src_comp=[-1])
+
+# Returns true if it has been called for every payload.
+intrinsic("finalize_incoming_node_payload", src_comp=[-1], dest_comp=1)
index 8623b61..1ca260d 100644 (file)
@@ -2421,6 +2421,10 @@ lower_vars_to_explicit(nir_shader *shader,
    case nir_var_mem_task_payload:
       offset = shader->info.task_payload_size;
       break;
+   case nir_var_mem_node_payload:
+      assert(!shader->info.cs.node_payloads_size);
+      offset = 0;
+      break;
    case nir_var_mem_global:
       offset = shader->global_mem_size;
       break;
@@ -2429,6 +2433,7 @@ lower_vars_to_explicit(nir_shader *shader,
       break;
    case nir_var_shader_call_data:
    case nir_var_ray_hit_attrib:
+   case nir_var_mem_node_payload_in:
       offset = 0;
       break;
    default:
@@ -2470,6 +2475,9 @@ lower_vars_to_explicit(nir_shader *shader,
    case nir_var_mem_task_payload:
       shader->info.task_payload_size = offset;
       break;
+   case nir_var_mem_node_payload:
+      shader->info.cs.node_payloads_size = offset;
+      break;
    case nir_var_mem_global:
       shader->global_mem_size = offset;
       break;
@@ -2478,6 +2486,7 @@ lower_vars_to_explicit(nir_shader *shader,
       break;
    case nir_var_shader_call_data:
    case nir_var_ray_hit_attrib:
+   case nir_var_mem_node_payload_in:
       break;
    default:
       unreachable("Unsupported mode");
@@ -2504,7 +2513,8 @@ nir_lower_vars_to_explicit_types(nir_shader *shader,
       nir_var_mem_shared | nir_var_mem_global | nir_var_mem_constant |
       nir_var_shader_temp | nir_var_function_temp | nir_var_uniform |
       nir_var_shader_call_data | nir_var_ray_hit_attrib |
-      nir_var_mem_task_payload;
+      nir_var_mem_task_payload | nir_var_mem_node_payload |
+      nir_var_mem_node_payload_in;
    assert(!(modes & ~supported) && "unsupported");
 
    bool progress = false;
@@ -2529,6 +2539,10 @@ nir_lower_vars_to_explicit_types(nir_shader *shader,
       progress |= lower_vars_to_explicit(shader, &shader->variables, nir_var_ray_hit_attrib, type_info);
    if (modes & nir_var_mem_task_payload)
       progress |= lower_vars_to_explicit(shader, &shader->variables, nir_var_mem_task_payload, type_info);
+   if (modes & nir_var_mem_node_payload)
+      progress |= lower_vars_to_explicit(shader, &shader->variables, nir_var_mem_node_payload, type_info);
+   if (modes & nir_var_mem_node_payload_in)
+      progress |= lower_vars_to_explicit(shader, &shader->variables, nir_var_mem_node_payload_in, type_info);
 
    nir_foreach_function_impl(impl, shader) {
       if (modes & nir_var_function_temp)
index 25e15ea..c020c24 100644 (file)
@@ -749,6 +749,9 @@ lower_compute_system_value_instr(nir_builder *b,
       return num_wgs;
    }
 
+   case nir_intrinsic_load_shader_index:
+      return nir_imm_int(b, b->shader->info.cs.shader_index);
+
    default:
       return NULL;
    }
index 494d61f..275d615 100644 (file)
@@ -727,6 +727,10 @@ get_variable_mode_str(nir_variable_mode mode, bool want_local_global_mode)
       return "ray_hit_attrib";
    case nir_var_mem_task_payload:
       return "task_payload";
+   case nir_var_mem_node_payload:
+      return "node_payload";
+   case nir_var_mem_node_payload_in:
+      return "node_payload_in";
    default:
       if (mode && (mode & nir_var_mem_generic) == mode)
          return "generic";
index 7ec9b28..1aa9c81 100644 (file)
@@ -1641,6 +1641,10 @@ nir_validate_shader(nir_shader *shader, const char *when)
        shader->info.stage == MESA_SHADER_MESH)
       valid_modes |= nir_var_mem_task_payload;
 
+   if (shader->info.stage == MESA_SHADER_COMPUTE)
+      valid_modes |= nir_var_mem_node_payload |
+                     nir_var_mem_node_payload_in;
+
    exec_list_validate(&shader->variables);
    nir_foreach_variable_in_shader(var, shader)
       validate_var_decl(var, valid_modes, &state);
index b3a3570..ea36dc8 100644 (file)
@@ -389,6 +389,8 @@ gl_system_value_name(gl_system_value sysval)
      ENUM(SYSTEM_VALUE_FULLY_COVERED),
      ENUM(SYSTEM_VALUE_FRAG_SIZE),
      ENUM(SYSTEM_VALUE_FRAG_INVOCATION_COUNT),
+     ENUM(SYSTEM_VALUE_SHADER_INDEX),
+     ENUM(SYSTEM_VALUE_COALESCED_INPUT_COUNT),
    };
    STATIC_ASSERT(ARRAY_SIZE(names) == SYSTEM_VALUE_MAX);
    return NAME(sysval);
index 10358f9..31e8ebb 100644 (file)
@@ -916,6 +916,10 @@ typedef enum
    SYSTEM_VALUE_FRAG_SIZE,
    SYSTEM_VALUE_FRAG_INVOCATION_COUNT,
 
+   /* SPV_AMDX_shader_enqueue */
+   SYSTEM_VALUE_SHADER_INDEX,
+   SYSTEM_VALUE_COALESCED_INPUT_COUNT,
+
    SYSTEM_VALUE_MAX             /**< Number of values */
 } gl_system_value;
 
index a62f616..785473a 100644 (file)
@@ -101,6 +101,7 @@ struct spirv_supported_capabilities {
    bool ray_tracing_position_fetch;
    bool runtime_descriptor_array;
    bool shader_clock;
+   bool shader_enqueue;
    bool shader_viewport_index_layer;
    bool shader_viewport_mask_nv;
    bool sparse_residency;
@@ -529,6 +530,15 @@ typedef struct shader_info {
           *   AddressingModelPhysical64: 64
           */
          unsigned ptr_size;
+
+         /** Index provided by VkPipelineShaderStageNodeCreateInfoAMDX or ShaderIndexAMDX */
+         uint32_t shader_index;
+
+         /** Maximum size required by any output node payload array */
+         uint32_t node_payloads_size;
+
+         /** Static workgroup count for overwriting the enqueued workgroup count. (0 if dynamic) */
+         uint32_t workgroup_count[3];
       } cs;
 
       /* Applies to both TCS and TES. */