nir: Add options to nir_lower_compute_system_values to control compute ID base lowering
authorJesse Natalie <jenatali@microsoft.com>
Fri, 21 Aug 2020 17:40:45 +0000 (10:40 -0700)
committerMarge Bot <eric+marge@anholt.net>
Fri, 21 Aug 2020 22:07:05 +0000 (22:07 +0000)
If no options are provided, existing intrinsics are used.
If the lowering pass indicates there should be offsets used for global
invocation ID or work group ID, then those instructions are lowered to
include the offset.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5891>

13 files changed:
src/amd/vulkan/radv_shader.c
src/broadcom/compiler/vir.c
src/compiler/nir/nir.h
src/compiler/nir/nir_lower_system_values.c
src/freedreno/vulkan/tu_shader.c
src/gallium/auxiliary/nir/tgsi_to_nir.c
src/gallium/drivers/freedreno/ir3/ir3_cmdline.c
src/gallium/frontends/clover/nir/invocation.cpp
src/gallium/frontends/vallium/val_pipeline.c
src/intel/compiler/brw_nir.c
src/mesa/state_tracker/st_glsl_to_nir.cpp
src/mesa/state_tracker/st_nir_builtins.c
src/mesa/state_tracker/st_program.c

index 01a22b7..1d227ef 100644 (file)
@@ -540,7 +540,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
                NIR_PASS_V(nir, nir_propagate_invariant);
 
                NIR_PASS_V(nir, nir_lower_system_values);
-               NIR_PASS_V(nir, nir_lower_compute_system_values);
+               NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
                NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);
 
index 83eb4d1..a0b0891 100644 (file)
@@ -586,7 +586,7 @@ v3d_lower_nir(struct v3d_compile *c)
 
         NIR_PASS_V(c->s, nir_lower_tex, &tex_options);
         NIR_PASS_V(c->s, nir_lower_system_values);
-        NIR_PASS_V(c->s, nir_lower_compute_system_values);
+        NIR_PASS_V(c->s, nir_lower_compute_system_values, NULL);
 
         NIR_PASS_V(c->s, nir_lower_vars_to_scratch,
                    nir_var_function_temp,
index 7432afd..005f762 100644 (file)
@@ -4276,7 +4276,13 @@ bool nir_lower_subgroups(nir_shader *shader,
 
 bool nir_lower_system_values(nir_shader *shader);
 
-bool nir_lower_compute_system_values(nir_shader *shader);
+typedef struct nir_lower_compute_system_values_options {
+   bool has_base_global_invocation_id:1;
+   bool has_base_work_group_id:1;
+} nir_lower_compute_system_values_options;
+
+bool nir_lower_compute_system_values(nir_shader *shader,
+                                     const nir_lower_compute_system_values_options *options);
 
 enum PACKED nir_lower_tex_packing {
    nir_lower_tex_packing_none = 0,
index bc80d18..b99f655 100644 (file)
@@ -227,6 +227,7 @@ lower_compute_system_value_instr(nir_builder *b,
                                  nir_instr *instr, void *_options)
 {
    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+   const nir_lower_compute_system_values_options *options = _options;
 
    /* All the intrinsics we care about are loads */
    if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
@@ -276,7 +277,7 @@ lower_compute_system_value_instr(nir_builder *b,
                                         nir_channel(b, local_size, 1)));
          return nir_u2u(b, nir_vec3(b, id_x, id_y, id_z), bit_size);
       } else {
-         return sanitize_32bit_sysval(b, intrin);
+         return NULL;
       }
 
    case nir_intrinsic_load_local_invocation_index:
@@ -310,7 +311,7 @@ lower_compute_system_value_instr(nir_builder *b,
          index = nir_iadd(b, index, nir_channel(b, local_id, 0));
          return nir_u2u(b, index, bit_size);
       } else {
-         return sanitize_32bit_sysval(b, intrin);
+         return NULL;
       }
 
    case nir_intrinsic_load_local_group_size:
@@ -319,7 +320,7 @@ lower_compute_system_value_instr(nir_builder *b,
           * this point.  We do, however, have to make sure that the intrinsic
           * is only 32-bit.
           */
-         return sanitize_32bit_sysval(b, intrin);
+         return NULL;
       } else {
          /* using a 32 bit constant is safe here as no device/driver needs more
           * than 32 bits for the local size */
@@ -331,8 +332,9 @@ lower_compute_system_value_instr(nir_builder *b,
          return nir_u2u(b, nir_build_imm(b, 3, 32, local_size_const), bit_size);
       }
 
-   case nir_intrinsic_load_global_invocation_id: {
-      if (!b->shader->options->has_cs_global_id) {
+   case nir_intrinsic_load_global_invocation_id_zero_base: {
+      if ((options && options->has_base_work_group_id) ||
+          !b->shader->options->has_cs_global_id) {
          nir_ssa_def *group_size = nir_load_local_group_size(b);
          nir_ssa_def *group_id = nir_load_work_group_id(b, bit_size);
          nir_ssa_def *local_id = nir_load_local_invocation_id(b);
@@ -345,8 +347,21 @@ lower_compute_system_value_instr(nir_builder *b,
       }
    }
 
+   case nir_intrinsic_load_global_invocation_id: {
+      if (options && options->has_base_global_invocation_id)
+         return nir_iadd(b, nir_load_global_invocation_id_zero_base(b, bit_size),
+                            nir_load_base_global_invocation_id(b, bit_size));
+      else if (!b->shader->options->has_cs_global_id)
+         return nir_load_global_invocation_id_zero_base(b, bit_size);
+      else
+         return NULL;
+   }
+
    case nir_intrinsic_load_global_invocation_index: {
-      nir_ssa_def *global_id = nir_load_global_invocation_id(b, bit_size);
+      /* OpenCL's global_linear_id explicitly removes the global offset before computing this */
+      assert(b->shader->info.stage == MESA_SHADER_KERNEL);
+      nir_ssa_def *global_base_id = nir_load_base_global_invocation_id(b, bit_size);
+      nir_ssa_def *global_id = nir_isub(b, nir_load_global_invocation_id(b, bit_size), global_base_id);
       nir_ssa_def *global_size = build_global_group_size(b, bit_size);
 
       /* index = id.x + ((id.y + (id.z * size.y)) * size.x) */
@@ -359,13 +374,22 @@ lower_compute_system_value_instr(nir_builder *b,
       return index;
    }
 
+   case nir_intrinsic_load_work_group_id: {
+      if (options && options->has_base_work_group_id)
+         return nir_iadd(b, nir_u2u(b, nir_load_work_group_id_zero_base(b), bit_size),
+                            nir_load_base_work_group_id(b, bit_size));
+      else
+         return NULL;
+   }
+
    default:
       return NULL;
    }
 }
 
 bool
-nir_lower_compute_system_values(nir_shader *shader)
+nir_lower_compute_system_values(nir_shader *shader,
+                                const nir_lower_compute_system_values_options *options)
 {
    if (shader->info.stage != MESA_SHADER_COMPUTE &&
        shader->info.stage != MESA_SHADER_KERNEL)
@@ -374,5 +398,5 @@ nir_lower_compute_system_values(nir_shader *shader)
    return nir_shader_lower_instructions(shader,
                                         lower_compute_system_value_filter,
                                         lower_compute_system_value_instr,
-                                        NULL);
+                                        (void*)options);
 }
index aca5ea0..afdb7c7 100644 (file)
@@ -765,7 +765,7 @@ tu_shader_create(struct tu_device *dev,
    nir_assign_io_var_locations(nir, nir_var_shader_out, &nir->num_outputs, stage);
 
    NIR_PASS_V(nir, nir_lower_system_values);
-   NIR_PASS_V(nir, nir_lower_compute_system_values);
+   NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
    NIR_PASS_V(nir, nir_lower_frexp);
 
index 7b1b055..1070f22 100644 (file)
@@ -2559,7 +2559,7 @@ ttn_finalize_nir(struct ttn_compile *c, struct pipe_screen *screen)
    NIR_PASS_V(nir, nir_split_var_copies);
    NIR_PASS_V(nir, nir_lower_var_copies);
    NIR_PASS_V(nir, nir_lower_system_values);
-   NIR_PASS_V(nir, nir_lower_compute_system_values);
+   NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
    if (c->cap_packed_uniforms)
       NIR_PASS_V(nir, nir_lower_uniforms_to_ubo, 16);
index 54227e4..e80312d 100644 (file)
@@ -185,7 +185,7 @@ load_glsl(unsigned num_files, char* const* files, gl_shader_stage stage)
                        ir3_glsl_type_size);
 
        NIR_PASS_V(nir, nir_lower_system_values);
-       NIR_PASS_V(nir, nir_lower_compute_system_values);
+       NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
        NIR_PASS_V(nir, nir_lower_frexp);
        NIR_PASS_V(nir, nir_lower_io,
@@ -403,7 +403,7 @@ main(int argc, char **argv)
                /* TODO do this somewhere else */
                nir_lower_int64(nir);
                nir_lower_system_values(nir);
-               nir_lower_compute_system_values(nir);
+               nir_lower_compute_system_values(nir, NULL);
        } else if (num_files > 0) {
                nir = load_glsl(num_files, filenames, stage);
        } else {
index f574720..3656a3c 100644 (file)
@@ -169,7 +169,7 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
                  spirv_options.global_addr_format);
 
       NIR_PASS_V(nir, nir_lower_system_values);
-      NIR_PASS_V(nir, nir_lower_compute_system_values);
+      NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
       if (compiler_options->lower_int64_options)
          NIR_PASS_V(nir, nir_lower_int64);
index 779e2ec..fb0a88a 100644 (file)
@@ -562,7 +562,7 @@ val_shader_compile_to_ir(struct val_pipeline *pipeline,
    if (stage == MESA_SHADER_FRAGMENT)
       val_lower_input_attachments(nir, false);
    NIR_PASS_V(nir, nir_lower_system_values);
-   NIR_PASS_V(nir, nir_lower_compute_system_values);
+   NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
    NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);
    nir_remove_dead_variables(nir, nir_var_uniform, NULL);
index 024cea7..4f1b562 100644 (file)
@@ -709,7 +709,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir,
    }
 
    OPT(nir_lower_system_values);
-   OPT(nir_lower_compute_system_values);
+   OPT(nir_lower_compute_system_values, NULL);
 
    const nir_lower_subgroups_options subgroups_options = {
       .ballot_bit_size = 32,
index b5b85ae..b4d78c8 100644 (file)
@@ -771,7 +771,7 @@ st_link_nir(struct gl_context *ctx,
                  st->pipe->screen);
 
       NIR_PASS_V(nir, nir_lower_system_values);
-      NIR_PASS_V(nir, nir_lower_compute_system_values);
+      NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
       NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);
 
index cd078f3..1e29526 100644 (file)
@@ -43,7 +43,7 @@ st_nir_finish_builtin_shader(struct st_context *st,
    NIR_PASS_V(nir, nir_split_var_copies);
    NIR_PASS_V(nir, nir_lower_var_copies);
    NIR_PASS_V(nir, nir_lower_system_values);
-   NIR_PASS_V(nir, nir_lower_compute_system_values);
+   NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
    if (nir->options->lower_to_scalar) {
       nir_variable_mode mask =
index 7ba6344..bfc6d90 100644 (file)
@@ -388,7 +388,7 @@ st_translate_prog_to_nir(struct st_context *st, struct gl_program *prog,
 
    NIR_PASS_V(nir, st_nir_lower_wpos_ytransform, prog, screen);
    NIR_PASS_V(nir, nir_lower_system_values);
-   NIR_PASS_V(nir, nir_lower_compute_system_values);
+   NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
    /* Optimise NIR */
    NIR_PASS_V(nir, nir_opt_constant_folding);