nir: Print divergence information for registers as well as SSA defs
authorKenneth Graunke <kenneth@whitecape.org>
Thu, 9 Feb 2023 23:18:05 +0000 (15:18 -0800)
committerMarge Bot <emma+marge@anholt.net>
Tue, 14 Feb 2023 17:54:11 +0000 (17:54 +0000)
This patch causes us to print "con" and "div" for registers as well as
SSA defs.  We print it on both register declarations, and destinations.
The latter isn't strictly necessary, but it is handy to be able to see
e.g. a convergent value being assigned to a divergent register without
having to constantly refer back to definitions that might be much
earlier in the program.  I originally printed it for sources as well,
but that got to be a bit wordy, so I dropped that.

Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21299>

src/compiler/nir/nir_print.c

index 294fca7..25868c8 100644 (file)
@@ -91,11 +91,22 @@ static const char *sizes[] = { "error", "vec1", "vec2", "vec3", "vec4",
                                "error", "error", "error", "error",
                                "error", "error", "error", "vec16"};
 
+static const char *
+divergence_status(print_state *state, bool divergent)
+{
+   if (state->shader->info.divergence_analysis_run)
+      return divergent ? "div " : "con ";
+
+   return "";
+}
+
 static void
 print_register_decl(nir_register *reg, print_state *state)
 {
    FILE *fp = state->fp;
-   fprintf(fp, "decl_reg %s %u ", sizes[reg->num_components], reg->bit_size);
+   fprintf(fp, "decl_reg %s %u %s", sizes[reg->num_components],
+           reg->bit_size, divergence_status(state, reg->divergent));
+
    print_register(reg, state);
    if (reg->num_array_elems != 0)
       fprintf(fp, "[%u]", reg->num_array_elems);
@@ -107,12 +118,8 @@ print_ssa_def(nir_ssa_def *def, print_state *state)
 {
    FILE *fp = state->fp;
 
-   const char *divergence = "";
-   if (state->shader->info.divergence_analysis_run)
-      divergence = def->divergent ? "div " : "con ";
-
    fprintf(fp, "%s %2u %sssa_%u", sizes[def->num_components], def->bit_size,
-           divergence, def->index);
+           divergence_status(state, def->divergent), def->index);
 }
 
 static void
@@ -226,6 +233,7 @@ static void
 print_reg_dest(nir_reg_dest *dest, print_state *state)
 {
    FILE *fp = state->fp;
+   fprintf(fp, "%s", divergence_status(state, dest->reg->divergent));
    print_register(dest->reg, state);
    if (dest->reg->num_array_elems != 0) {
       fprintf(fp, "[%u", dest->base_offset);