nir/print: Use src_type when printing consts in SSA uses
authorCaio Oliveira <caio.oliveira@intel.com>
Tue, 13 Jun 2023 04:59:34 +0000 (21:59 -0700)
committerMarge Bot <emma+marge@anholt.net>
Wed, 28 Jun 2023 20:17:18 +0000 (20:17 +0000)
If the src_type is not available, untie by looking at the results from
nir_gather_ssa_types(). If that is ambiguous, just pick uint.

Now in print_const_from_load() when the type is invalid, print the full
constant form (with both padded hex and float); when the passed type
is valid, print the terse form based on it.

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Acked-by: Emma Anholt <emma@anholt.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23562>

src/compiler/nir/nir_print.c

index db40197..c911294 100644 (file)
@@ -130,85 +130,102 @@ print_ssa_def(nir_ssa_def *def, print_state *state)
 }
 
 static void
-print_const_from_load(nir_load_const_instr *instr, print_state *state, bool terse)
+print_hex_padded_const_value(const nir_const_value *value, unsigned bit_size, FILE *fp)
 {
-   FILE *fp = state->fp;
+   switch (bit_size) {
+   case 64: fprintf(fp, "0x%016" PRIx64, value->u64); break;
+   case 32: fprintf(fp, "0x%08x", value->u32); break;
+   case 16: fprintf(fp, "0x%04x", value->u16); break;
+   case 8:  fprintf(fp, "0x%02x", value->u8); break;
+   default:
+      unreachable("unhandled bit size");
+   }
+}
 
-   /*
-    * we don't really know the type of the constant (if it will be used as a
-    * float or an int), so just print the raw constant in hex for fidelity
-    * and then print in float again for readability.
-    */
+static void
+print_hex_terse_const_value(const nir_const_value *value, unsigned bit_size, FILE *fp)
+{
+   switch (bit_size) {
+   case 64: fprintf(fp, "0x%" PRIx64, value->u64); break;
+   case 32: fprintf(fp, "0x%x", value->u32); break;
+   case 16: fprintf(fp, "0x%x", value->u16); break;
+   case 8:  fprintf(fp, "0x%x", value->u8); break;
+   default:
+      unreachable("unhandled bit size");
+   }
+}
 
-   bool first_part = true;
-   bool second_part = instr->def.bit_size > 8;
+static void
+print_float_const_value(const nir_const_value *value, unsigned bit_size, FILE *fp)
+{
+   switch (bit_size) {
+   case 64: fprintf(fp, "%f", value->f64); break;
+   case 32: fprintf(fp, "%f", value->f32); break;
+   case 16: fprintf(fp, "%f", _mesa_half_to_float(value->u16)); break;
+   default:
+      unreachable("unhandled bit size");
+   }
+}
 
-   /* For a terse representation, we pick one.  The load_const instruction itself
-    * will have a complete representation with both parts (when applicable).
-    */
-   if (terse && first_part && second_part) {
-      const unsigned index = instr->def.index;
-      /* If we don't have information or have conflicting information,
-       * use the first part (integer).
-       */
-      if (state->int_types) {
-         first_part = BITSET_TEST(state->int_types, index) ||
-                      !BITSET_TEST(state->float_types, index);
+static void
+print_const_from_load(nir_load_const_instr *instr, print_state *state, nir_alu_type type)
+{
+   FILE *fp = state->fp;
+
+   const unsigned bit_size = instr->def.bit_size;
+   const unsigned num_components = instr->def.num_components;
+
+   /* There's only one way to print booleans. */
+   if (bit_size == 1) {
+      fprintf(fp, "(");
+      for (unsigned i = 0; i < num_components; i++) {
+         if (i != 0)
+            fprintf(fp, ", ");
+         fprintf(fp, "%s", instr->value[i].b ? "true" : "false");
       }
-      second_part = !first_part;
+      fprintf(fp, ")");
+      return;
    }
 
    fprintf(fp, "(");
 
-   if (first_part) {
-      for (unsigned i = 0; i < instr->def.num_components; i++) {
+   type = nir_alu_type_get_base_type(type);
+
+   if (type != nir_type_invalid) {
+      for (unsigned i = 0; i < num_components; i++) {
+         const nir_const_value *v = &instr->value[i];
          if (i != 0)
             fprintf(fp, ", ");
-
-         switch (instr->def.bit_size) {
-         case 64:
-            fprintf(fp, terse ? "0x%" PRIx64 : "0x%016" PRIx64, instr->value[i].u64);
-            break;
-         case 32:
-            fprintf(fp, terse ? "0x%x" : "0x%08x", instr->value[i].u32);
-            break;
-         case 16:
-            fprintf(fp, terse ? "0x%x" : "0x%04x", instr->value[i].u16);
+         switch (type) {
+         case nir_type_float:
+            print_float_const_value(v, bit_size, fp);
             break;
-         case 8:
-            fprintf(fp, terse ? "0x%x" : "0x%02x", instr->value[i].u8);
-            break;
-         case 1:
-            fprintf(fp, "%s", instr->value[i].b ? "true" : "false");
+         case nir_type_int:
+         case nir_type_uint:
+            print_hex_terse_const_value(v, bit_size, fp);
             break;
+
+         default:
+            unreachable("invalid nir alu base type");
          }
       }
-   }
-
-   if (first_part && second_part) {
-      if (instr->def.num_components > 1)
-         fprintf(fp, ") = (");
-      else
-         fprintf(fp, " = ");
-   }
-
-   if (second_part) {
-      for (unsigned i = 0; i < instr->def.num_components; i++) {
+   } else {
+      for (unsigned i = 0; i < num_components; i++) {
          if (i != 0)
             fprintf(fp, ", ");
+         print_hex_padded_const_value(&instr->value[i], bit_size, fp);
+      }
 
-         switch (instr->def.bit_size) {
-         case 64:
-            fprintf(fp, "%f", instr->value[i].f64);
-            break;
-         case 32:
-            fprintf(fp, "%f", instr->value[i].f32);
-            break;
-         case 16:
-            fprintf(fp, "%f", _mesa_half_to_float(instr->value[i].u16));
-            break;
-         default:
-            unreachable("unhandled bit size");
+      if (bit_size > 8) {
+         if (num_components > 1)
+            fprintf(fp, ") = (");
+         else
+            fprintf(fp, " = ");
+
+         for (unsigned i = 0; i < num_components; i++) {
+            if (i != 0)
+               fprintf(fp, ", ");
+            print_float_const_value(&instr->value[i], bit_size, fp);
          }
       }
    }
@@ -225,22 +242,42 @@ print_load_const_instr(nir_load_const_instr *instr, print_state *state)
 
    fprintf(fp, " = load_const ");
 
-   print_const_from_load(instr, state, false);
+   /* In the definition, print all interpretations of the value. */
+   print_const_from_load(instr, state, nir_type_invalid);
 }
 
 static void
-print_ssa_use(nir_ssa_def *def, print_state *state)
+print_ssa_use(nir_ssa_def *def, print_state *state, nir_alu_type src_type)
 {
    FILE *fp = state->fp;
    fprintf(fp, "ssa_%u", def->index);
    nir_instr *instr = def->parent_instr;
+
    if (instr->type == nir_instr_type_load_const && !NIR_DEBUG(PRINT_NO_INLINE_CONSTS)) {
+      nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
       fprintf(fp, " ");
-      print_const_from_load(nir_instr_as_load_const(instr), state, true);
+
+      nir_alu_type type = nir_alu_type_get_base_type(src_type);
+
+      if (type == nir_type_invalid && state->int_types) {
+         const unsigned index = load_const->def.index;
+         const bool inferred_int = BITSET_TEST(state->int_types, index);
+         const bool inferred_float = BITSET_TEST(state->float_types, index);
+
+         if (inferred_float && !inferred_int)
+            type = nir_type_float;
+      }
+
+      if (type == nir_type_invalid)
+         type = nir_type_uint;
+
+      /* For a constant in a source, always pick one interpretation. */
+      assert(type != nir_type_invalid);
+      print_const_from_load(load_const, state, type);
    }
 }
 
-static void print_src(const nir_src *src, print_state *state);
+static void print_src(const nir_src *src, print_state *state, nir_alu_type src_type);
 
 static void
 print_reg_src(const nir_reg_src *src, print_state *state)
@@ -251,7 +288,7 @@ print_reg_src(const nir_reg_src *src, print_state *state)
       fprintf(fp, "[%u", src->base_offset);
       if (src->indirect != NULL) {
          fprintf(fp, " + ");
-         print_src(src->indirect, state);
+         print_src(src->indirect, state, nir_type_invalid);
       }
       fprintf(fp, "]");
    }
@@ -267,17 +304,17 @@ print_reg_dest(nir_reg_dest *dest, print_state *state)
       fprintf(fp, "[%u", dest->base_offset);
       if (dest->indirect != NULL) {
          fprintf(fp, " + ");
-         print_src(dest->indirect, state);
+         print_src(dest->indirect, state, nir_type_invalid);
       }
       fprintf(fp, "]");
    }
 }
 
 static void
-print_src(const nir_src *src, print_state *state)
+print_src(const nir_src *src, print_state *state, nir_alu_type src_type)
 {
    if (src->is_ssa)
-      print_ssa_use(src->ssa, state);
+      print_ssa_use(src->ssa, state, src_type);
    else
       print_reg_src(&src->reg, state);
 }
@@ -307,7 +344,8 @@ print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
    if (instr->src[src].abs)
       fprintf(fp, "abs(");
 
-   print_src(&instr->src[src].src, state);
+   const nir_op_info *info = &nir_op_infos[instr->op];
+   print_src(&instr->src[src].src, state, info->input_types[src]);
 
    bool print_swizzle = false;
    nir_component_mask_t used_channels = 0;
@@ -773,7 +811,7 @@ print_deref_link(const nir_deref_instr *instr, bool whole_chain, print_state *st
       return;
    } else if (instr->deref_type == nir_deref_type_cast) {
       fprintf(fp, "(%s *)", glsl_get_type_name(instr->type));
-      print_src(&instr->parent, state);
+      print_src(&instr->parent, state, nir_type_invalid);
       return;
    }
 
@@ -808,7 +846,7 @@ print_deref_link(const nir_deref_instr *instr, bool whole_chain, print_state *st
    if (whole_chain) {
       print_deref_link(parent, whole_chain, state);
    } else {
-      print_src(&instr->parent, state);
+      print_src(&instr->parent, state, nir_type_invalid);
    }
 
    if (is_parent_cast || need_deref)
@@ -826,7 +864,7 @@ print_deref_link(const nir_deref_instr *instr, bool whole_chain, print_state *st
          fprintf(fp, "[%"PRId64"]", nir_src_as_int(instr->arr.index));
       } else {
          fprintf(fp, "[");
-         print_src(&instr->arr.index, state);
+         print_src(&instr->arr.index, state, nir_type_invalid);
          fprintf(fp, "]");
       }
       break;
@@ -958,7 +996,7 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
       if (i != 0)
          fprintf(fp, ", ");
 
-      print_src(&instr->src[i], state);
+      print_src(&instr->src[i], state, nir_intrinsic_instr_src_type(instr, i));
    }
 
    fprintf(fp, ") (");
@@ -1400,7 +1438,7 @@ print_tex_instr(nir_tex_instr *instr, print_state *state)
          fprintf(fp, ", ");
       }
 
-      print_src(&instr->src[i].src, state);
+      print_src(&instr->src[i].src, state, nir_tex_instr_src_type(instr, i));
       fprintf(fp, " ");
 
       switch(instr->src[i].src_type) {
@@ -1520,7 +1558,7 @@ print_call_instr(nir_call_instr *instr, print_state *state)
       if (i != 0)
          fprintf(fp, ", ");
 
-      print_src(&instr->params[i], state);
+      print_src(&instr->params[i], state, nir_type_invalid);
    }
 }
 
@@ -1554,7 +1592,7 @@ print_jump_instr(nir_jump_instr *instr, print_state *state)
    case nir_jump_goto_if:
       fprintf(fp, "goto block_%u if ",
               instr->target ? instr->target->index : -1);
-      print_src(&instr->condition, state);
+      print_src(&instr->condition, state, nir_type_invalid);
       fprintf(fp, " else block_%u",
               instr->else_target ? instr->else_target->index : -1);
       break;
@@ -1580,7 +1618,7 @@ print_phi_instr(nir_phi_instr *instr, print_state *state)
          fprintf(fp, ", ");
 
       fprintf(fp, "block_%u: ", src->pred->index);
-      print_src(&src->src, state);
+      print_src(&src->src, state, nir_type_invalid);
    }
 }
 
@@ -1594,7 +1632,7 @@ print_parallel_copy_instr(nir_parallel_copy_instr *instr, print_state *state)
 
       print_dest(&entry->dest, state);
       fprintf(fp, " = ");
-      print_src(&entry->src, state);
+      print_src(&entry->src, state, nir_type_invalid);
    }
 }
 
@@ -1695,7 +1733,7 @@ print_if(nir_if *if_stmt, print_state *state, unsigned tabs)
 
    print_tabs(tabs, fp);
    fprintf(fp, "if ");
-   print_src(&if_stmt->condition, state);
+   print_src(&if_stmt->condition, state, nir_type_invalid);
    switch (if_stmt->control) {
    case nir_selection_control_flatten:
       fprintf(fp, " /* flatten */");