nir/print: Improve NIR_PRINT=print_consts by using nir_gather_ssa_types()
authorCaio Oliveira <caio.oliveira@intel.com>
Sat, 10 Jun 2023 00:40:58 +0000 (17:40 -0700)
committerMarge Bot <emma+marge@anholt.net>
Wed, 28 Jun 2023 20:17:18 +0000 (20:17 +0000)
The two representations are *always* used for `load_const`, but when
inlining the value as SSA source, use just a single terse
representation.

The choice between integer or float is based on the result of
nir_gather_ssa_types(), with a bias for integer when in doubt.

Also remove extra comment `/* */` syntax since the value is already
enclosed by parenthesis.

---

For illustration, here's some instructions from crucible test
func.shader.averageRounded.uint64_t with NIR_DEBUG=print_consts:

BEFORE:

```
vec1 32 con ssa_23 = load_const (0xfffffffc = -nan)
vec1 32 div ssa_24 = iand ssa_13, ssa_23 /*(0xfffffffc = -nan)*/
vec1 32 con ssa_25 = load_const (0x00000024 = 0.000000)
vec1 32 con ssa_26 = intrinsic load_ubo (ssa_1 /*(0x00000002 = 0.000000)*/, ssa_25 /*(0x00000024 = 0.000000)*/) (access=0, align_mul=1073741824, align_offset=36, range_base=0, range=-1)
vec1 32 con ssa_27 = load_const (0x00000008 = 0.000000)
vec1 32 con ssa_28 = load_const (0x00000007 = 0.000000)
vec1 32 con ssa_29 = iand ssa_4.y, ssa_1 /*(0x00000002 = 0.000000)*/
vec1 32 con ssa_30 = ishl ssa_29, ssa_28 /*(0x00000007 = 0.000000)*/
vec1 32 con ssa_31 = load_const (0x7b000808 = 664776890994587263929995856502063104.000000)
vec1 32 con ssa_32 = ior ssa_31 /*(0x7b000808 = 664776890994587263929995856502063104.000000)*/, ssa_30
```

AFTER:

```
vec1 32 con ssa_23 = load_const (0xfffffffc = -nan)
vec1 32 div ssa_24 = iand ssa_13, ssa_23 (0xfffffffc)
vec1 32 con ssa_25 = load_const (0x00000024 = 0.000000)
vec1 32 con ssa_26 = intrinsic load_ubo (ssa_1 (0x2), ssa_25 (0x24)) (access=0, align_mul=1073741824, align_offset=36, range_base=0, range=-1)
vec1 32 con ssa_27 = load_const (0x00000008 = 0.000000)
vec1 32 con ssa_28 = load_const (0x00000007 = 0.000000)
vec1 32 con ssa_29 = iand ssa_4.y, ssa_1 (0x2)
vec1 32 con ssa_30 = ishl ssa_29, ssa_28 (0x7)
vec1 32 con ssa_31 = load_const (0x7b000808 = 664776890994587263929995856502063104.000000)
vec1 32 con ssa_32 = ior ssa_31 (0x7b000808), ssa_30
```

and some instructions from crucible test func.gs.basic with NIR_DEBUG=print_consts,
now showing float representation being selected:

BEFORE:

```
vec4 32 ssa_10 = load_const (0x3e4ccccd, 0x3e4ccccd, 0x00000000, 0x00000000) = (0.200000, 0.200000, 0.000000, 0.000000)
vec4 32 ssa_9 = intrinsic load_deref (ssa_42) (access=0)
vec4 32 ssa_11 = fadd ssa_9, ssa_10 /*(0x3e4ccccd, 0x3e4ccccd, 0x00000000, 0x00000000) = (0.200000, 0.200000, 0.000000, 0.000000)*/
```

AFTER:

```
vec4 32 ssa_10 = load_const (0x3e4ccccd, 0x3e4ccccd, 0x00000000, 0x00000000) = (0.200000, 0.200000, 0.000000, 0.000000)
vec4 32 ssa_9 = intrinsic load_deref (ssa_42) (access=0)
vec4 32 ssa_11 = fadd ssa_9, ssa_10 (0.200000, 0.200000, 0.000000, 0.000000)
```

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Acked-by: Emma Anholt <emma@anholt.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23562>

src/compiler/nir/nir_print.c

index c0a4f5f..c4b6029 100644 (file)
@@ -54,6 +54,13 @@ typedef struct {
    /* an index used to make new non-conflicting names */
    unsigned index;
 
+   /* Used with nir_gather_ssa_types() to identify best representation
+    * to print terse inline constant values together with SSA sources.
+    * Updated per nir_function_impl being printed.
+    */
+   BITSET_WORD *float_types;
+   BITSET_WORD *int_types;
+
    /**
     * Optional table of annotations mapping nir object
     * (such as instr or var) to message to print.
@@ -123,7 +130,7 @@ print_ssa_def(nir_ssa_def *def, print_state *state)
 }
 
 static void
-print_const_from_load(nir_load_const_instr *instr, print_state *state)
+print_const_from_load(nir_load_const_instr *instr, print_state *state, bool terse)
 {
    FILE *fp = state->fp;
 
@@ -133,37 +140,59 @@ print_const_from_load(nir_load_const_instr *instr, print_state *state)
     * and then print in float again for readability.
     */
 
+   bool first_part = true;
+   bool second_part = instr->def.bit_size > 8;
+
+   /* For a terse representation, we pick one.  The load_const instruction itself
+    * will have a complete representation with both parts (when applicable).
+    */
+   if (terse && first_part && second_part) {
+      const unsigned index = instr->def.index;
+      /* If we don't have information or have conflicting information,
+       * use the first part (integer).
+       */
+      if (state->int_types) {
+         first_part = BITSET_TEST(state->int_types, index) ||
+                      !BITSET_TEST(state->float_types, index);
+      }
+      second_part = !first_part;
+   }
+
    fprintf(fp, "(");
 
-   for (unsigned i = 0; i < instr->def.num_components; i++) {
-      if (i != 0)
-         fprintf(fp, ", ");
+   if (first_part) {
+      for (unsigned i = 0; i < instr->def.num_components; i++) {
+         if (i != 0)
+            fprintf(fp, ", ");
 
-      switch (instr->def.bit_size) {
-      case 64:
-         fprintf(fp, "0x%016" PRIx64, instr->value[i].u64);
-         break;
-      case 32:
-         fprintf(fp, "0x%08x", instr->value[i].u32);
-         break;
-      case 16:
-         fprintf(fp, "0x%04x", instr->value[i].u16);
-         break;
-      case 8:
-         fprintf(fp, "0x%02x", instr->value[i].u8);
-         break;
-      case 1:
-         fprintf(fp, "%s", instr->value[i].b ? "true" : "false");
-         break;
+         switch (instr->def.bit_size) {
+         case 64:
+            fprintf(fp, terse ? "0x%" PRIx64 : "0x%016" PRIx64, instr->value[i].u64);
+            break;
+         case 32:
+            fprintf(fp, terse ? "0x%x" : "0x%08x", instr->value[i].u32);
+            break;
+         case 16:
+            fprintf(fp, terse ? "0x%x" : "0x%04x", instr->value[i].u16);
+            break;
+         case 8:
+            fprintf(fp, terse ? "0x%x" : "0x%02x", instr->value[i].u8);
+            break;
+         case 1:
+            fprintf(fp, "%s", instr->value[i].b ? "true" : "false");
+            break;
+         }
       }
    }
 
-   if (instr->def.bit_size > 8) {
+   if (first_part && second_part) {
       if (instr->def.num_components > 1)
          fprintf(fp, ") = (");
       else
          fprintf(fp, " = ");
+   }
 
+   if (second_part) {
       for (unsigned i = 0; i < instr->def.num_components; i++) {
          if (i != 0)
             fprintf(fp, ", ");
@@ -196,7 +225,7 @@ print_load_const_instr(nir_load_const_instr *instr, print_state *state)
 
    fprintf(fp, " = load_const ");
 
-   print_const_from_load(instr, state);
+   print_const_from_load(instr, state, false);
 }
 
 static void
@@ -206,9 +235,8 @@ print_ssa_use(nir_ssa_def *def, print_state *state)
    fprintf(fp, "ssa_%u", def->index);
    nir_instr *instr = def->parent_instr;
    if (instr->type == nir_instr_type_load_const && NIR_DEBUG(PRINT_CONSTS)) {
-      fprintf(fp, " /*");
-      print_const_from_load(nir_instr_as_load_const(instr), state);
-      fprintf(fp, "*/");
+      fprintf(fp, " ");
+      print_const_from_load(nir_instr_as_load_const(instr), state, true);
    }
 }
 
@@ -1752,6 +1780,16 @@ print_function_impl(nir_function_impl *impl, print_state *state)
       fprintf(fp, "\tpreamble %s\n", impl->preamble->name);
    }
 
+   if (NIR_DEBUG(PRINT_CONSTS)) {
+      /* Don't reindex the SSA as suggested by nir_gather_ssa_types() because
+       * nir_print don't modify the shader.  If needed, a limit for ssa_alloc
+       * can be added.
+       */
+      state->float_types = calloc(BITSET_WORDS(impl->ssa_alloc), sizeof(BITSET_WORD));
+      state->int_types = calloc(BITSET_WORDS(impl->ssa_alloc), sizeof(BITSET_WORD));
+      nir_gather_ssa_types(impl, state->float_types, state->int_types);
+   }
+
    nir_foreach_function_temp_variable(var, impl) {
       fprintf(fp, "\t");
       print_var_decl(var, state);
@@ -1769,6 +1807,9 @@ print_function_impl(nir_function_impl *impl, print_state *state)
    }
 
    fprintf(fp, "\tblock block_%u:\n}\n\n", impl->end_block->index);
+
+   free(state->float_types);
+   free(state->int_types);
 }
 
 static void
@@ -1796,6 +1837,8 @@ init_print_state(print_state *state, nir_shader *shader, FILE *fp)
    state->syms = _mesa_set_create(NULL, _mesa_hash_string,
                                   _mesa_key_string_equal);
    state->index = 0;
+   state->int_types = NULL;
+   state->float_types = NULL;
 }
 
 static void