nir/range_analysis: use perform_analysis() in nir_analyze_range()
authorRhys Perry <pendingchaos02@gmail.com>
Tue, 14 Feb 2023 21:38:41 +0000 (21:38 +0000)
committerMarge Bot <emma+marge@anholt.net>
Wed, 22 Mar 2023 09:24:18 +0000 (09:24 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21381>

src/compiler/nir/nir_range_analysis.c

index aecd601..532335c 100644 (file)
@@ -126,18 +126,15 @@ is_not_zero(enum ssa_ranges r)
    return r == gt_zero || r == lt_zero || r == ne_zero;
 }
 
-static void *
+static uint32_t
 pack_data(const struct ssa_result_range r)
 {
-   return (void *)(uintptr_t)(r.range | r.is_integral << 8 | r.is_finite << 9 |
-                              r.is_a_number << 10);
+   return r.range | r.is_integral << 8 | r.is_finite << 9 | r.is_a_number << 10;
 }
 
 static struct ssa_result_range
-unpack_data(const void *p)
+unpack_data(uint32_t v)
 {
-   const uintptr_t v = (uintptr_t) p;
-
    return (struct ssa_result_range){
       .range       = v & 0xff,
       .is_integral = (v & 0x00100) != 0,
@@ -146,31 +143,6 @@ unpack_data(const void *p)
    };
 }
 
-static void *
-pack_key(const struct nir_alu_instr *instr, nir_alu_type type)
-{
-   uintptr_t type_encoding;
-   uintptr_t ptr = (uintptr_t) instr;
-
-   /* The low 2 bits have to be zero or this whole scheme falls apart. */
-   assert((ptr & 0x3) == 0);
-
-   /* NIR is typeless in the sense that sequences of bits have whatever
-    * meaning is attached to them by the instruction that consumes them.
-    * However, the number of bits must match between producer and consumer.
-    * As a result, the number of bits does not need to be encoded here.
-    */
-   switch (nir_alu_type_get_base_type(type)) {
-   case nir_type_int:   type_encoding = 0; break;
-   case nir_type_uint:  type_encoding = 1; break;
-   case nir_type_bool:  type_encoding = 2; break;
-   case nir_type_float: type_encoding = 3; break;
-   default: unreachable("Invalid base type.");
-   }
-
-   return (void *)(ptr | type_encoding);
-}
-
 static nir_alu_type
 nir_alu_src_type(const nir_alu_instr *instr, unsigned src)
 {
@@ -319,7 +291,7 @@ analyze_constant(const struct nir_alu_instr *instr, unsigned src,
 }
 
 /**
- * Short-hand name for use in the tables in analyze_expression.  If this name
+ * Short-hand name for use in the tables in process_fp_query.  If this name
  * becomes a problem on some compiler, we can change it to _.
  */
 #define _______ unknown
@@ -502,6 +474,53 @@ union_ranges(enum ssa_ranges a, enum ssa_ranges b)
 #define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(t)
 #endif /* !defined(NDEBUG) */
 
+struct fp_query {
+   struct analysis_query head;
+   const nir_alu_instr *instr;
+   unsigned src;
+   nir_alu_type use_type;
+};
+
+static void
+push_fp_query(struct analysis_state *state, const nir_alu_instr *alu, unsigned src, nir_alu_type type)
+{
+   struct fp_query *pushed_q = push_analysis_query(state, sizeof(struct fp_query));
+   pushed_q->instr = alu;
+   pushed_q->src = src;
+   pushed_q->use_type = type == nir_type_invalid ? nir_alu_src_type(alu, src) : type;
+}
+
+static uintptr_t
+get_fp_key(struct analysis_query *q)
+{
+   struct fp_query *fp_q = (struct fp_query *)q;
+   const nir_src *src = &fp_q->instr->src[fp_q->src].src;
+
+   if (!src->is_ssa || src->ssa->parent_instr->type != nir_instr_type_alu)
+      return 0;
+
+   uintptr_t type_encoding;
+   uintptr_t ptr = (uintptr_t)nir_instr_as_alu(src->ssa->parent_instr);
+
+   /* The low 2 bits have to be zero or this whole scheme falls apart. */
+   assert((ptr & 0x3) == 0);
+
+   /* NIR is typeless in the sense that sequences of bits have whatever
+    * meaning is attached to them by the instruction that consumes them.
+    * However, the number of bits must match between producer and consumer.
+    * As a result, the number of bits does not need to be encoded here.
+    */
+   switch (nir_alu_type_get_base_type(fp_q->use_type)) {
+   case nir_type_int:   type_encoding = 0; break;
+   case nir_type_uint:  type_encoding = 1; break;
+   case nir_type_bool:  type_encoding = 2; break;
+   case nir_type_float: type_encoding = 3; break;
+   default: unreachable("Invalid base type.");
+   }
+
+   return ptr | type_encoding;
+}
+
 /**
  * Analyze an expression to determine the range of its result
  *
@@ -511,21 +530,32 @@ union_ranges(enum ssa_ranges a, enum ssa_ranges b)
  * This function implements this grammar as a recursive-descent parser.  Some
  * (but not all) of the grammar is listed in-line in the function.
  */
-static struct ssa_result_range
-analyze_expression(const nir_alu_instr *instr, unsigned src,
-                   struct hash_table *ht, nir_alu_type use_type)
+static void
+process_fp_query(struct analysis_state *state, struct analysis_query *aq, uint32_t *result,
+                 const uint32_t *src_res)
 {
    /* Ensure that the _Pragma("GCC unroll 7") above are correct. */
    STATIC_ASSERT(last_range + 1 == 7);
 
-   if (!instr->src[src].src.is_ssa)
-      return (struct ssa_result_range){unknown, false, false, false};
+   struct fp_query q = *(struct fp_query *)aq;
+   const nir_alu_instr *instr = q.instr;
+   unsigned src = q.src;
+   nir_alu_type use_type = q.use_type;
 
-   if (nir_src_is_const(instr->src[src].src))
-      return analyze_constant(instr, src, use_type);
+   if (!instr->src[src].src.is_ssa) {
+      *result = pack_data((struct ssa_result_range){unknown, false, false, false});
+      return;
+   }
 
-   if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
-      return (struct ssa_result_range){unknown, false, false, false};
+   if (nir_src_is_const(instr->src[src].src)) {
+      *result = pack_data(analyze_constant(instr, src, use_type));
+      return;
+   }
+
+   if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) {
+      *result = pack_data((struct ssa_result_range){unknown, false, false, false});
+      return;
+   }
 
    const struct nir_alu_instr *const alu =
        nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
@@ -544,13 +574,62 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       if (use_base_type != src_base_type &&
           (use_base_type == nir_type_float ||
            src_base_type == nir_type_float)) {
-         return (struct ssa_result_range){unknown, false, false, false};
+         *result = pack_data((struct ssa_result_range){unknown, false, false, false});
+         return;
       }
    }
 
-   struct hash_entry *he = _mesa_hash_table_search(ht, pack_key(alu, use_type));
-   if (he != NULL)
-      return unpack_data(he->data);
+   if (!aq->pushed_queries) {
+      switch (alu->op) {
+      case nir_op_bcsel:
+         push_fp_query(state, alu, 1, use_type);
+         push_fp_query(state, alu, 2, use_type);
+         return;
+      case nir_op_mov:
+         push_fp_query(state, alu, 0, use_type);
+         return;
+      case nir_op_i2f32:
+      case nir_op_u2f32:
+      case nir_op_fabs:
+      case nir_op_fexp2:
+      case nir_op_frcp:
+      case nir_op_fneg:
+      case nir_op_fsat:
+      case nir_op_fsign:
+      case nir_op_ffloor:
+      case nir_op_fceil:
+      case nir_op_ftrunc:
+      case nir_op_fdot2:
+      case nir_op_fdot3:
+      case nir_op_fdot4:
+      case nir_op_fdot8:
+      case nir_op_fdot16:
+      case nir_op_fdot2_replicated:
+      case nir_op_fdot3_replicated:
+      case nir_op_fdot4_replicated:
+      case nir_op_fdot8_replicated:
+      case nir_op_fdot16_replicated:
+         push_fp_query(state, alu, 0, nir_type_invalid);
+         return;
+      case nir_op_fadd:
+      case nir_op_fmax:
+      case nir_op_fmin:
+      case nir_op_fmul:
+      case nir_op_fmulz:
+      case nir_op_fpow:
+         push_fp_query(state, alu, 0, nir_type_invalid);
+         push_fp_query(state, alu, 1, nir_type_invalid);
+         return;
+      case nir_op_ffma:
+      case nir_op_flrp:
+         push_fp_query(state, alu, 0, nir_type_invalid);
+         push_fp_query(state, alu, 1, nir_type_invalid);
+         push_fp_query(state, alu, 2, nir_type_invalid);
+         return;
+      default:
+         break;
+      }
+   }
 
    struct ssa_result_range r = {unknown, false, false, false};
 
@@ -666,10 +745,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       break;
 
    case nir_op_bcsel: {
-      const struct ssa_result_range left =
-         analyze_expression(alu, 1, ht, use_type);
-      const struct ssa_result_range right =
-         analyze_expression(alu, 2, ht, use_type);
+      const struct ssa_result_range left = unpack_data(src_res[0]);
+      const struct ssa_result_range right = unpack_data(src_res[1]);
 
       r.is_integral = left.is_integral && right.is_integral;
 
@@ -694,7 +771,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
 
    case nir_op_i2f32:
    case nir_op_u2f32:
-      r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      r = unpack_data(src_res[0]);
 
       r.is_integral = true;
       r.is_a_number = true;
@@ -706,7 +783,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       break;
 
    case nir_op_fabs:
-      r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      r = unpack_data(src_res[0]);
 
       switch (r.range) {
       case unknown:
@@ -728,10 +805,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       break;
 
    case nir_op_fadd: {
-      const struct ssa_result_range left =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
-      const struct ssa_result_range right =
-         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
+      const struct ssa_result_range left = unpack_data(src_res[0]);
+      const struct ssa_result_range right = unpack_data(src_res[1]);
 
       r.is_integral = left.is_integral && right.is_integral;
       r.range = fadd_table[left.range][right.range];
@@ -755,7 +830,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
          ge_zero, ge_zero, ge_zero, gt_zero, gt_zero, ge_zero, gt_zero
       };
 
-      r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      r = unpack_data(src_res[0]);
 
       ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(table);
       ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(table);
@@ -770,10 +845,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_fmax: {
-      const struct ssa_result_range left =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
-      const struct ssa_result_range right =
-         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
+      const struct ssa_result_range left = unpack_data(src_res[0]);
+      const struct ssa_result_range right = unpack_data(src_res[1]);
 
       r.is_integral = left.is_integral && right.is_integral;
 
@@ -856,10 +929,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_fmin: {
-      const struct ssa_result_range left =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
-      const struct ssa_result_range right =
-         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
+      const struct ssa_result_range left = unpack_data(src_res[0]);
+      const struct ssa_result_range right = unpack_data(src_res[1]);
 
       r.is_integral = left.is_integral && right.is_integral;
 
@@ -943,10 +1014,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
 
    case nir_op_fmul:
    case nir_op_fmulz: {
-      const struct ssa_result_range left =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
-      const struct ssa_result_range right =
-         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
+      const struct ssa_result_range left = unpack_data(src_res[0]);
+      const struct ssa_result_range right = unpack_data(src_res[1]);
 
       r.is_integral = left.is_integral && right.is_integral;
 
@@ -981,7 +1050,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
 
    case nir_op_frcp:
       r = (struct ssa_result_range){
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range,
+         unpack_data(src_res[0]).range,
          false,
          false, /* Various cases can result in NaN, so assume the worst. */
          false  /*    "      "    "     "    "  "    "    "    "    "    */
@@ -989,18 +1058,16 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       break;
 
    case nir_op_mov:
-      r = analyze_expression(alu, 0, ht, use_type);
+      r = unpack_data(src_res[0]);
       break;
 
    case nir_op_fneg:
-      r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
-
+      r = unpack_data(src_res[0]);
       r.range = fneg_table[r.range];
       break;
 
    case nir_op_fsat: {
-      const struct ssa_result_range left =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range left = unpack_data(src_res[0]);
 
       /* fsat(NaN) = 0. */
       r.is_a_number = true;
@@ -1035,7 +1102,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
 
    case nir_op_fsign:
       r = (struct ssa_result_range){
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range,
+         unpack_data(src_res[0]).range,
          true,
          true, /* fsign is -1, 0, or 1, even for NaN, so it must be a number. */
          true  /* fsign is -1, 0, or 1, even for NaN, so it must be finite. */
@@ -1048,8 +1115,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
       break;
 
    case nir_op_ffloor: {
-      const struct ssa_result_range left =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range left = unpack_data(src_res[0]);
 
       r.is_integral = true;
 
@@ -1070,8 +1136,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_fceil: {
-      const struct ssa_result_range left =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range left = unpack_data(src_res[0]);
 
       r.is_integral = true;
 
@@ -1092,8 +1157,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_ftrunc: {
-      const struct ssa_result_range left =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range left = unpack_data(src_res[0]);
 
       r.is_integral = true;
 
@@ -1139,8 +1203,7 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    case nir_op_fdot4_replicated:
    case nir_op_fdot8_replicated:
    case nir_op_fdot16_replicated: {
-      const struct ssa_result_range left =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
+      const struct ssa_result_range left = unpack_data(src_res[0]);
 
       /* If the two sources are the same SSA value, then the result is either
        * NaN or some number >= 0.  If one source is the negation of the other,
@@ -1211,10 +1274,8 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
          /* eq_zero */ { ge_zero, gt_zero, gt_zero, eq_zero, ge_zero, ge_zero, gt_zero },
       };
 
-      const struct ssa_result_range left =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
-      const struct ssa_result_range right =
-         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
+      const struct ssa_result_range left = unpack_data(src_res[0]);
+      const struct ssa_result_range right = unpack_data(src_res[1]);
 
       ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(table);
       ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(table);
@@ -1230,12 +1291,9 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_ffma: {
-      const struct ssa_result_range first =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
-      const struct ssa_result_range second =
-         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
-      const struct ssa_result_range third =
-         analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
+      const struct ssa_result_range first = unpack_data(src_res[0]);
+      const struct ssa_result_range second = unpack_data(src_res[1]);
+      const struct ssa_result_range third = unpack_data(src_res[2]);
 
       r.is_integral = first.is_integral && second.is_integral &&
                       third.is_integral;
@@ -1261,12 +1319,9 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    }
 
    case nir_op_flrp: {
-      const struct ssa_result_range first =
-         analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
-      const struct ssa_result_range second =
-         analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
-      const struct ssa_result_range third =
-         analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
+      const struct ssa_result_range first = unpack_data(src_res[0]);
+      const struct ssa_result_range second = unpack_data(src_res[1]);
+      const struct ssa_result_range third = unpack_data(src_res[2]);
 
       r.is_integral = first.is_integral && second.is_integral &&
                       third.is_integral;
@@ -1296,18 +1351,29 @@ analyze_expression(const nir_alu_instr *instr, unsigned src,
    /* Just like isfinite(), the is_finite flag implies the value is a number. */
    assert((int) r.is_finite <= (int) r.is_a_number);
 
-   _mesa_hash_table_insert(ht, pack_key(alu, use_type), pack_data(r));
-   return r;
+   *result = pack_data(r);
 }
 
 #undef _______
 
 struct ssa_result_range
 nir_analyze_range(struct hash_table *range_ht,
-                  const nir_alu_instr *instr, unsigned src)
+                  const nir_alu_instr *alu, unsigned src)
 {
-   return analyze_expression(instr, src, range_ht,
-                             nir_alu_src_type(instr, src));
+   struct fp_query query_alloc[64];
+   uint32_t result_alloc[64];
+
+   struct analysis_state state;
+   state.range_ht = range_ht;
+   util_dynarray_init_from_stack(&state.query_stack, query_alloc, sizeof(query_alloc));
+   util_dynarray_init_from_stack(&state.result_stack, result_alloc, sizeof(result_alloc));
+   state.query_size = sizeof(struct fp_query);
+   state.get_key = &get_fp_key;
+   state.process_query = &process_fp_query;
+
+   push_fp_query(&state, alu, src, nir_type_invalid);
+
+   return unpack_data(perform_analysis(&state));
 }
 
 static uint32_t bitmask(uint32_t size) {