nir/range_analysis: use perform_analysis() in nir_unsigned_upper_bound()
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 10 Feb 2023 16:24:39 +0000 (16:24 +0000)
committerMarge Bot <emma+marge@anholt.net>
Wed, 22 Mar 2023 09:24:18 +0000 (09:24 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21381>

src/compiler/nir/nir_range_analysis.c

index 0de5955..aecd601 100644 (file)
@@ -1398,424 +1398,465 @@ static const nir_unsigned_upper_bound_config default_ub_config = {
    },
 };
 
-static uint32_t
-nir_unsigned_upper_bound_impl(nir_shader *shader, struct hash_table *range_ht,
-                              nir_ssa_scalar scalar,
-                              const nir_unsigned_upper_bound_config *config,
-                              unsigned stack_depth)
-{
-   assert(scalar.def->bit_size <= 32);
+struct uub_query {
+   struct analysis_query head;
+   nir_ssa_scalar scalar;
+};
 
-   if (!config)
-      config = &default_ub_config;
-   if (nir_ssa_scalar_is_const(scalar))
-      return nir_ssa_scalar_as_uint(scalar);
+static void
+push_uub_query(struct analysis_state *state, nir_ssa_scalar scalar)
+{
+   struct uub_query *pushed_q = push_analysis_query(state, sizeof(struct uub_query));
+   pushed_q->scalar = scalar;
+}
 
+static uintptr_t
+get_uub_key(struct analysis_query *q)
+{
+   nir_ssa_scalar scalar = ((struct uub_query *)q)->scalar;
    /* keys can't be 0, so we have to add 1 to the index */
-   void *key = (void*)(((uintptr_t)(scalar.def->index + 1) << 4) | scalar.comp);
-   struct hash_entry *he = _mesa_hash_table_search(range_ht, key);
-   if (he != NULL)
-      return (uintptr_t)he->data;
-
-   uint32_t max = bitmask(scalar.def->bit_size);
-
-   /* Avoid stack overflows. 200 is just a random setting, that happened to work with wine stacks
-    * which tend to be smaller than normal Linux ones. */
-   if (stack_depth >= 200)
-      return max;
-
-   if (scalar.def->parent_instr->type == nir_instr_type_intrinsic) {
-      uint32_t res = max;
-      nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(scalar.def->parent_instr);
-      switch (intrin->intrinsic) {
-      case nir_intrinsic_load_local_invocation_index:
-         /* The local invocation index is used under the hood by RADV for
-          * some non-compute-like shaders (eg. LS and NGG). These technically
-          * run in workgroups on the HW, even though this fact is not exposed
-          * by the API.
-          * They can safely use the same code path here as variable sized
-          * compute-like shader stages.
-          */
-         if (!gl_shader_stage_uses_workgroup(shader->info.stage) ||
-             shader->info.workgroup_size_variable) {
-            res = config->max_workgroup_invocations - 1;
-         } else {
-            res = (shader->info.workgroup_size[0] *
-                   shader->info.workgroup_size[1] *
-                   shader->info.workgroup_size[2]) - 1u;
-         }
-         break;
-      case nir_intrinsic_load_local_invocation_id:
-         if (shader->info.workgroup_size_variable)
-            res = config->max_workgroup_size[scalar.comp] - 1u;
-         else
-            res = shader->info.workgroup_size[scalar.comp] - 1u;
-         break;
-      case nir_intrinsic_load_workgroup_id:
-         res = config->max_workgroup_count[scalar.comp] - 1u;
-         break;
-      case nir_intrinsic_load_num_workgroups:
-         res = config->max_workgroup_count[scalar.comp];
-         break;
-      case nir_intrinsic_load_global_invocation_id:
-         if (shader->info.workgroup_size_variable) {
-            res = mul_clamp(config->max_workgroup_size[scalar.comp],
-                            config->max_workgroup_count[scalar.comp]) - 1u;
-         } else {
-            res = (shader->info.workgroup_size[scalar.comp] *
-                   config->max_workgroup_count[scalar.comp]) - 1u;
-         }
-         break;
-      case nir_intrinsic_load_invocation_id:
-         if (shader->info.stage == MESA_SHADER_TESS_CTRL)
-            res = shader->info.tess.tcs_vertices_out
-                  ? (shader->info.tess.tcs_vertices_out - 1)
-                  : 511; /* Generous maximum output patch size of 512 */
-         break;
-      case nir_intrinsic_load_subgroup_invocation:
-      case nir_intrinsic_first_invocation:
-         res = config->max_subgroup_size - 1;
-         break;
-      case nir_intrinsic_mbcnt_amd: {
-         uint32_t src0 = config->max_subgroup_size - 1;
-         uint32_t src1 = nir_unsigned_upper_bound_impl(shader, range_ht, nir_get_ssa_scalar(intrin->src[1].ssa, 0),
-                                                       config, stack_depth + 1);
+   unsigned shift_amount = ffs(NIR_MAX_VEC_COMPONENTS) - 1;
+   return nir_ssa_scalar_is_const(scalar)
+          ? 0
+          : ((uintptr_t)(scalar.def->index + 1) << shift_amount) | scalar.comp;
+}
 
-         if (src0 + src1 < src0)
-            res = max; /* overflow */
-         else
-            res = src0 + src1;
-         break;
+static void
+get_intrinsic_uub(struct analysis_state *state, struct uub_query q, uint32_t *result,
+                  const uint32_t *src)
+{
+   nir_shader *shader = state->shader;
+   const nir_unsigned_upper_bound_config *config = state->config;
+
+   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(q.scalar.def->parent_instr);
+   switch (intrin->intrinsic) {
+   case nir_intrinsic_load_local_invocation_index:
+      /* The local invocation index is used under the hood by RADV for
+       * some non-compute-like shaders (eg. LS and NGG). These technically
+       * run in workgroups on the HW, even though this fact is not exposed
+       * by the API.
+       * They can safely use the same code path here as variable sized
+       * compute-like shader stages.
+       */
+      if (!gl_shader_stage_uses_workgroup(shader->info.stage) ||
+          shader->info.workgroup_size_variable) {
+         *result = config->max_workgroup_invocations - 1;
+      } else {
+         *result = (shader->info.workgroup_size[0] *
+                    shader->info.workgroup_size[1] *
+                    shader->info.workgroup_size[2]) - 1u;
       }
-      case nir_intrinsic_load_subgroup_size:
-         res = config->max_subgroup_size;
-         break;
-      case nir_intrinsic_load_subgroup_id:
-      case nir_intrinsic_load_num_subgroups: {
-         uint32_t workgroup_size = config->max_workgroup_invocations;
-         if (gl_shader_stage_uses_workgroup(shader->info.stage) &&
-             !shader->info.workgroup_size_variable) {
-            workgroup_size = shader->info.workgroup_size[0] *
-                             shader->info.workgroup_size[1] *
-                             shader->info.workgroup_size[2];
+      break;
+   case nir_intrinsic_load_local_invocation_id:
+      if (shader->info.workgroup_size_variable)
+         *result = config->max_workgroup_size[q.scalar.comp] - 1u;
+      else
+         *result = shader->info.workgroup_size[q.scalar.comp] - 1u;
+      break;
+   case nir_intrinsic_load_workgroup_id:
+      *result = config->max_workgroup_count[q.scalar.comp] - 1u;
+      break;
+   case nir_intrinsic_load_num_workgroups:
+      *result = config->max_workgroup_count[q.scalar.comp];
+      break;
+   case nir_intrinsic_load_global_invocation_id:
+      if (shader->info.workgroup_size_variable) {
+         *result = mul_clamp(config->max_workgroup_size[q.scalar.comp],
+                             config->max_workgroup_count[q.scalar.comp]) - 1u;
+      } else {
+         *result = (shader->info.workgroup_size[q.scalar.comp] *
+                    config->max_workgroup_count[q.scalar.comp]) - 1u;
+      }
+      break;
+   case nir_intrinsic_load_invocation_id:
+      if (shader->info.stage == MESA_SHADER_TESS_CTRL)
+         *result = shader->info.tess.tcs_vertices_out
+                   ? (shader->info.tess.tcs_vertices_out - 1)
+                   : 511; /* Generous maximum output patch size of 512 */
+      break;
+   case nir_intrinsic_load_subgroup_invocation:
+   case nir_intrinsic_first_invocation:
+      *result = config->max_subgroup_size - 1;
+      break;
+   case nir_intrinsic_mbcnt_amd: {
+      if (!q.head.pushed_queries) {
+         push_uub_query(state, nir_get_ssa_scalar(intrin->src[1].ssa, 0));
+         return;
+      } else {
+         uint32_t src0 = config->max_subgroup_size - 1;
+         uint32_t src1 = src[0];
+         if (src0 + src1 >= src0) /* check overflow */
+            *result = src0 + src1;
+      }
+      break;
+   }
+   case nir_intrinsic_load_subgroup_size:
+      *result = config->max_subgroup_size;
+      break;
+   case nir_intrinsic_load_subgroup_id:
+   case nir_intrinsic_load_num_subgroups: {
+      uint32_t workgroup_size = config->max_workgroup_invocations;
+      if (gl_shader_stage_uses_workgroup(shader->info.stage) &&
+          !shader->info.workgroup_size_variable) {
+         workgroup_size = shader->info.workgroup_size[0] *
+                          shader->info.workgroup_size[1] *
+                          shader->info.workgroup_size[2];
+      }
+      *result = DIV_ROUND_UP(workgroup_size, config->min_subgroup_size);
+      if (intrin->intrinsic == nir_intrinsic_load_subgroup_id)
+         (*result)--;
+      break;
+   }
+   case nir_intrinsic_load_input: {
+      if (shader->info.stage == MESA_SHADER_VERTEX && nir_src_is_const(intrin->src[0])) {
+         nir_variable *var = lookup_input(shader, nir_intrinsic_base(intrin));
+         if (var) {
+            int loc = var->data.location - VERT_ATTRIB_GENERIC0;
+            if (loc >= 0)
+               *result = config->vertex_attrib_max[loc];
          }
-         res = DIV_ROUND_UP(workgroup_size, config->min_subgroup_size);
-         if (intrin->intrinsic == nir_intrinsic_load_subgroup_id)
-            res--;
-         break;
       }
-      case nir_intrinsic_load_input: {
-         if (shader->info.stage == MESA_SHADER_VERTEX && nir_src_is_const(intrin->src[0])) {
-            nir_variable *var = lookup_input(shader, nir_intrinsic_base(intrin));
-            if (var) {
-               int loc = var->data.location - VERT_ATTRIB_GENERIC0;
-               if (loc >= 0)
-                  res = config->vertex_attrib_max[loc];
-            }
+      break;
+   }
+   case nir_intrinsic_reduce:
+   case nir_intrinsic_inclusive_scan:
+   case nir_intrinsic_exclusive_scan: {
+      nir_op op = nir_intrinsic_reduction_op(intrin);
+      if (op == nir_op_umin || op == nir_op_umax || op == nir_op_imin || op == nir_op_imax) {
+         if (!q.head.pushed_queries) {
+            push_uub_query(state, nir_get_ssa_scalar(intrin->src[0].ssa, q.scalar.comp));
+            return;
+         } else {
+            *result = src[0];
          }
-         break;
       }
-      case nir_intrinsic_reduce:
-      case nir_intrinsic_inclusive_scan:
-      case nir_intrinsic_exclusive_scan: {
-         nir_op op = nir_intrinsic_reduction_op(intrin);
-         if (op == nir_op_umin || op == nir_op_umax || op == nir_op_imin || op == nir_op_imax)
-            res = nir_unsigned_upper_bound_impl(shader, range_ht, nir_get_ssa_scalar(intrin->src[0].ssa, scalar.comp),
-                                                config, stack_depth + 1);
-         break;
+      break;
+   }
+   case nir_intrinsic_read_first_invocation:
+   case nir_intrinsic_read_invocation:
+   case nir_intrinsic_shuffle:
+   case nir_intrinsic_shuffle_xor:
+   case nir_intrinsic_shuffle_up:
+   case nir_intrinsic_shuffle_down:
+   case nir_intrinsic_quad_broadcast:
+   case nir_intrinsic_quad_swap_horizontal:
+   case nir_intrinsic_quad_swap_vertical:
+   case nir_intrinsic_quad_swap_diagonal:
+   case nir_intrinsic_quad_swizzle_amd:
+   case nir_intrinsic_masked_swizzle_amd:
+      if (!q.head.pushed_queries) {
+         push_uub_query(state, nir_get_ssa_scalar(intrin->src[0].ssa, q.scalar.comp));
+         return;
+      } else {
+         *result = src[0];
       }
-      case nir_intrinsic_read_first_invocation:
-      case nir_intrinsic_read_invocation:
-      case nir_intrinsic_shuffle:
-      case nir_intrinsic_shuffle_xor:
-      case nir_intrinsic_shuffle_up:
-      case nir_intrinsic_shuffle_down:
-      case nir_intrinsic_quad_broadcast:
-      case nir_intrinsic_quad_swap_horizontal:
-      case nir_intrinsic_quad_swap_vertical:
-      case nir_intrinsic_quad_swap_diagonal:
-      case nir_intrinsic_quad_swizzle_amd:
-      case nir_intrinsic_masked_swizzle_amd:
-         res = nir_unsigned_upper_bound_impl(shader, range_ht, nir_get_ssa_scalar(intrin->src[0].ssa, scalar.comp),
-                                             config, stack_depth + 1);
-         break;
-      case nir_intrinsic_write_invocation_amd: {
-         uint32_t src0 = nir_unsigned_upper_bound_impl(shader, range_ht, nir_get_ssa_scalar(intrin->src[0].ssa, scalar.comp),
-                                                       config, stack_depth + 1);
-         uint32_t src1 = nir_unsigned_upper_bound_impl(shader, range_ht, nir_get_ssa_scalar(intrin->src[1].ssa, scalar.comp),
-                                                       config, stack_depth + 1);
-         res = MAX2(src0, src1);
-         break;
+      break;
+   case nir_intrinsic_write_invocation_amd:
+      if (!q.head.pushed_queries) {
+         push_uub_query(state, nir_get_ssa_scalar(intrin->src[0].ssa, q.scalar.comp));
+         push_uub_query(state, nir_get_ssa_scalar(intrin->src[1].ssa, q.scalar.comp));
+         return;
+      } else {
+         *result = MAX2(src[0], src[1]);
       }
-      case nir_intrinsic_load_tess_rel_patch_id_amd:
-      case nir_intrinsic_load_tcs_num_patches_amd:
-         /* Very generous maximum: TCS/TES executed by largest possible workgroup */
-         res = config->max_workgroup_invocations / MAX2(shader->info.tess.tcs_vertices_out, 1u);
+      break;
+   case nir_intrinsic_load_tess_rel_patch_id_amd:
+   case nir_intrinsic_load_tcs_num_patches_amd:
+      /* Very generous maximum: TCS/TES executed by largest possible workgroup */
+      *result = config->max_workgroup_invocations / MAX2(shader->info.tess.tcs_vertices_out, 1u);
+      break;
+   case nir_intrinsic_load_typed_buffer_amd: {
+      const enum pipe_format format = nir_intrinsic_format(intrin);
+      if (format == PIPE_FORMAT_NONE)
          break;
-      case nir_intrinsic_load_typed_buffer_amd: {
-         const enum pipe_format format = nir_intrinsic_format(intrin);
-         if (format == PIPE_FORMAT_NONE)
-            break;
-
-         const struct util_format_description* desc = util_format_description(format);
-         if (desc->channel[scalar.comp].type != UTIL_FORMAT_TYPE_UNSIGNED)
-            break;
-
-         if (desc->channel[scalar.comp].normalized) {
-            res = fui(1.0);
-            break;
-         }
 
-         const uint32_t chan_max = u_uintN_max(desc->channel[scalar.comp].size);
-         res = desc->channel[scalar.comp].pure_integer ? chan_max : fui(chan_max);
-         break;
-      }
-      case nir_intrinsic_load_scalar_arg_amd:
-      case nir_intrinsic_load_vector_arg_amd: {
-         uint32_t upper_bound = nir_intrinsic_arg_upper_bound_u32_amd(intrin);
-         if (upper_bound)
-            res = upper_bound;
+      const struct util_format_description* desc = util_format_description(format);
+      if (desc->channel[q.scalar.comp].type != UTIL_FORMAT_TYPE_UNSIGNED)
          break;
-      }
-      default:
+
+      if (desc->channel[q.scalar.comp].normalized) {
+         *result = fui(1.0);
          break;
       }
-      if (res != max)
-         _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res);
-      return res;
-   }
 
-   if (scalar.def->parent_instr->type == nir_instr_type_phi) {
-      nir_cf_node *prev = nir_cf_node_prev(&scalar.def->parent_instr->block->cf_node);
+      const uint32_t chan_max = u_uintN_max(desc->channel[q.scalar.comp].size);
+      *result = desc->channel[q.scalar.comp].pure_integer ? chan_max : fui(chan_max);
+      break;
+   }
+   case nir_intrinsic_load_scalar_arg_amd:
+   case nir_intrinsic_load_vector_arg_amd: {
+      uint32_t upper_bound = nir_intrinsic_arg_upper_bound_u32_amd(intrin);
+      if (upper_bound)
+         *result = upper_bound;
+      break;
+   }
+   default:
+      break;
+   }
+}
 
-      uint32_t res = 0;
-      if (!prev || prev->type == nir_cf_node_block) {
-         _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)max);
+static void
+get_alu_uub(struct analysis_state *state, struct uub_query q, uint32_t *result, const uint32_t *src)
+{
+   nir_op op = nir_ssa_scalar_alu_op(q.scalar);
+
+   /* Early exit for unsupported ALU opcodes. */
+   switch (op) {
+   case nir_op_umin:
+   case nir_op_imin:
+   case nir_op_imax:
+   case nir_op_umax:
+   case nir_op_iand:
+   case nir_op_ior:
+   case nir_op_ixor:
+   case nir_op_ishl:
+   case nir_op_imul:
+   case nir_op_ushr:
+   case nir_op_ishr:
+   case nir_op_iadd:
+   case nir_op_umod:
+   case nir_op_udiv:
+   case nir_op_bcsel:
+   case nir_op_b32csel:
+   case nir_op_ubfe:
+   case nir_op_bfm:
+   case nir_op_fmul:
+   case nir_op_fmulz:
+   case nir_op_extract_u8:
+   case nir_op_extract_i8:
+   case nir_op_extract_u16:
+   case nir_op_extract_i16:
+   case nir_op_b2i8:
+   case nir_op_b2i16:
+   case nir_op_b2i32:
+      break;
+   case nir_op_u2u1:
+   case nir_op_u2u8:
+   case nir_op_u2u16:
+   case nir_op_u2u32:
+   case nir_op_f2u32:
+      if (nir_ssa_scalar_chase_alu_src(q.scalar, 0).def->bit_size > 32) {
+         /* If src is >32 bits, return max */
+         return;
+      }
+      break;
+   default:
+      return;
+   }
 
-         struct set *visited = _mesa_pointer_set_create(NULL);
-         nir_ssa_scalar defs[64];
-         unsigned def_count = search_phi_bcsel(scalar, defs, 64, visited);
-         _mesa_set_destroy(visited, NULL);
+   if (!q.head.pushed_queries) {
+      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++)
+         push_uub_query(state, nir_ssa_scalar_chase_alu_src(q.scalar, i));
+      return;
+   }
 
-         for (unsigned i = 0; i < def_count; i++)
-            res = MAX2(res, nir_unsigned_upper_bound_impl(shader, range_ht, defs[i], config, stack_depth + 1));
+   uint32_t max = bitmask(q.scalar.def->bit_size);
+   switch (op) {
+   case nir_op_umin:
+      *result = src[0] < src[1] ? src[0] : src[1];
+      break;
+   case nir_op_imin:
+   case nir_op_imax:
+   case nir_op_umax:
+      *result = src[0] > src[1] ? src[0] : src[1];
+      break;
+   case nir_op_iand:
+      *result = bitmask(util_last_bit64(src[0])) & bitmask(util_last_bit64(src[1]));
+      break;
+   case nir_op_ior:
+   case nir_op_ixor:
+      *result = bitmask(util_last_bit64(src[0])) | bitmask(util_last_bit64(src[1]));
+      break;
+   case nir_op_ishl: {
+      uint32_t src1 = MIN2(src[1], q.scalar.def->bit_size - 1u);
+      if (util_last_bit64(src[0]) + src1 <= q.scalar.def->bit_size) /* check overflow */
+         *result = src[0] << src1;
+      break;
+   }
+   case nir_op_imul:
+      if (src[0] == 0 || (src[0] * src[1]) / src[0] == src[1]) /* check overflow */
+         *result = src[0] * src[1];
+      break;
+   case nir_op_ushr: {
+      nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(q.scalar, 1);
+      uint32_t mask = q.scalar.def->bit_size - 1u;
+      if (nir_ssa_scalar_is_const(src1_scalar))
+         *result = src[0] >> (nir_ssa_scalar_as_uint(src1_scalar) & mask);
+      else
+         *result = src[0];
+      break;
+   }
+   case nir_op_ishr: {
+      nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(q.scalar, 1);
+      uint32_t mask = q.scalar.def->bit_size - 1u;
+      if (src[0] <= 2147483647 && nir_ssa_scalar_is_const(src1_scalar))
+         *result = src[0] >> (nir_ssa_scalar_as_uint(src1_scalar) & mask);
+      else
+         *result = src[0];
+      break;
+   }
+   case nir_op_iadd:
+      if (src[0] + src[1] >= src[0]) /* check overflow */
+         *result = src[0] + src[1];
+      break;
+   case nir_op_umod:
+      *result = src[1] ? src[1] - 1 : 0;
+      break;
+   case nir_op_udiv: {
+      nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(q.scalar, 1);
+      if (nir_ssa_scalar_is_const(src1_scalar))
+         *result = nir_ssa_scalar_as_uint(src1_scalar)
+                   ? src[0] / nir_ssa_scalar_as_uint(src1_scalar) : 0;
+      else
+         *result = src[0];
+      break;
+   }
+   case nir_op_bcsel:
+   case nir_op_b32csel:
+      *result = src[1] > src[2] ? src[1] : src[2];
+      break;
+   case nir_op_ubfe:
+      *result = bitmask(MIN2(src[2], q.scalar.def->bit_size));
+      break;
+   case nir_op_bfm: {
+      nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(q.scalar, 1);
+      if (nir_ssa_scalar_is_const(src1_scalar)) {
+         uint32_t src0 = MIN2(src[0], 31);
+         uint32_t src1 = nir_ssa_scalar_as_uint(src1_scalar) & 0x1fu;
+         *result = bitmask(src0) << src1;
       } else {
-         nir_foreach_phi_src(src, nir_instr_as_phi(scalar.def->parent_instr)) {
-            res = MAX2(res, nir_unsigned_upper_bound_impl(
-               shader, range_ht, nir_get_ssa_scalar(src->src.ssa, scalar.comp), config, stack_depth + 1));
-         }
+         uint32_t src0 = MIN2(src[0], 31);
+         uint32_t src1 = MIN2(src[1], 31);
+         *result = bitmask(MIN2(src0 + src1, 32));
       }
-
-      _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res);
-      return res;
+      break;
    }
+   /* limited floating-point support for f2u32(fmul(load_input(), <constant>)) */
+   case nir_op_f2u32:
+      /* infinity/NaN starts at 0x7f800000u, negative numbers at 0x80000000 */
+      if (src[0] < 0x7f800000u) {
+         float val;
+         memcpy(&val, &src[0], 4);
+         *result = (uint32_t)val;
+      }
+      break;
+   case nir_op_fmul:
+   case nir_op_fmulz:
+      /* infinity/NaN starts at 0x7f800000u, negative numbers at 0x80000000 */
+      if (src[0] < 0x7f800000u && src[1] < 0x7f800000u) {
+         float src0_f, src1_f;
+         memcpy(&src0_f, &src[0], 4);
+         memcpy(&src1_f, &src[1], 4);
+         /* not a proper rounding-up multiplication, but should be good enough */
+         float max_f = ceilf(src0_f) * ceilf(src1_f);
+         memcpy(result, &max_f, 4);
+      }
+      break;
+   case nir_op_u2u1:
+   case nir_op_u2u8:
+   case nir_op_u2u16:
+   case nir_op_u2u32:
+      *result = MIN2(src[0], max);
+      break;
+   case nir_op_b2i8:
+   case nir_op_b2i16:
+   case nir_op_b2i32:
+      *result = 1;
+      break;
+   case nir_op_sad_u8x4:
+      *result = src[2] + 4 * 255;
+      break;
+   case nir_op_extract_u8:
+      *result = MIN2(src[0], UINT8_MAX);
+      break;
+   case nir_op_extract_i8:
+      *result = (src[0] >= 0x80) ? max : MIN2(src[0], INT8_MAX);
+      break;
+   case nir_op_extract_u16:
+      *result = MIN2(src[0], UINT16_MAX);
+      break;
+   case nir_op_extract_i16:
+      *result = (src[0] >= 0x8000) ? max : MIN2(src[0], INT16_MAX);
+      break;
+   default:
+      break;
+   }
+}
 
-   if (nir_ssa_scalar_is_alu(scalar)) {
-      nir_op op = nir_ssa_scalar_alu_op(scalar);
+static void
+get_phi_uub(struct analysis_state *state, struct uub_query q, uint32_t *result, const uint32_t *src)
+{
+   nir_phi_instr *phi = nir_instr_as_phi(q.scalar.def->parent_instr);
 
-      switch (op) {
-      case nir_op_umin:
-      case nir_op_imin:
-      case nir_op_imax:
-      case nir_op_umax:
-      case nir_op_iand:
-      case nir_op_ior:
-      case nir_op_ixor:
-      case nir_op_ishl:
-      case nir_op_imul:
-      case nir_op_ushr:
-      case nir_op_ishr:
-      case nir_op_iadd:
-      case nir_op_umod:
-      case nir_op_udiv:
-      case nir_op_bcsel:
-      case nir_op_b32csel:
-      case nir_op_ubfe:
-      case nir_op_bfm:
-      case nir_op_fmul:
-      case nir_op_fmulz:
-      case nir_op_extract_u8:
-      case nir_op_extract_i8:
-      case nir_op_extract_u16:
-      case nir_op_extract_i16:
-      case nir_op_b2i8:
-      case nir_op_b2i16:
-      case nir_op_b2i32:
-         break;
-      case nir_op_u2u1:
-      case nir_op_u2u8:
-      case nir_op_u2u16:
-      case nir_op_u2u32:
-      case nir_op_f2u32:
-         if (nir_ssa_scalar_chase_alu_src(scalar, 0).def->bit_size > 32) {
-            /* If src is >32 bits, return max */
-            return max;
-         }
-         break;
-      default:
-         return max;
-      }
+   if (exec_list_is_empty(&phi->srcs))
+      return;
 
-      uint32_t src0 = nir_unsigned_upper_bound_impl(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 0),
-                                                    config, stack_depth + 1);
-      uint32_t src1 = max, src2 = max;
-      if (nir_op_infos[op].num_inputs > 1)
-         src1 = nir_unsigned_upper_bound_impl(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 1),
-                                              config, stack_depth + 1);
-      if (nir_op_infos[op].num_inputs > 2)
-         src2 = nir_unsigned_upper_bound_impl(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 2),
-                                              config, stack_depth + 1);
-
-      uint32_t res = max;
-      switch (op) {
-      case nir_op_umin:
-         res = src0 < src1 ? src0 : src1;
-         break;
-      case nir_op_imin:
-      case nir_op_imax:
-      case nir_op_umax:
-         res = src0 > src1 ? src0 : src1;
-         break;
-      case nir_op_iand:
-         res = bitmask(util_last_bit64(src0)) & bitmask(util_last_bit64(src1));
-         break;
-      case nir_op_ior:
-      case nir_op_ixor:
-         res = bitmask(util_last_bit64(src0)) | bitmask(util_last_bit64(src1));
-         break;
-      case nir_op_ishl: {
-         src1 = MIN2(src1, q.scalar.def->bit_size - 1u);
-         if (util_last_bit64(src0) + src1 > scalar.def->bit_size)
-            res = max; /* overflow */
-         else
-            res = src0 << src1;
-         break;
-      }
-      case nir_op_imul:
-         if (src0 != 0 && (src0 * src1) / src0 != src1)
-            res = max;
-         else
-            res = src0 * src1;
-         break;
-      case nir_op_ushr: {
-         nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
-         uint32_t mask = q.scalar.def->bit_size - 1u;
-         if (nir_ssa_scalar_is_const(src1_scalar))
-            res = src0 >> (nir_ssa_scalar_as_uint(src1_scalar) & mask);
-         else
-            res = src0;
-         break;
-      }
-      case nir_op_ishr: {
-         nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
-         uint32_t mask = q.scalar.def->bit_size - 1u;
-         if (src0 <= 2147483647 && nir_ssa_scalar_is_const(src1_scalar))
-            res = src0 >> (nir_ssa_scalar_as_uint(src1_scalar) & mask);
-         else
-            res = src0;
-         break;
-      }
-      case nir_op_iadd:
-         if (src0 + src1 < src0)
-            res = max; /* overflow */
-         else
-            res = src0 + src1;
-         break;
-      case nir_op_umod:
-         res = src1 ? src1 - 1 : 0;
-         break;
-      case nir_op_udiv: {
-         nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
-         if (nir_ssa_scalar_is_const(src1_scalar))
-            res = nir_ssa_scalar_as_uint(src1_scalar) ? src0 / nir_ssa_scalar_as_uint(src1_scalar) : 0;
-         else
-            res = src0;
-         break;
-      }
-      case nir_op_bcsel:
-      case nir_op_b32csel:
-         res = src1 > src2 ? src1 : src2;
-         break;
-      case nir_op_ubfe:
-         res = bitmask(MIN2(src2, scalar.def->bit_size));
-         break;
-      case nir_op_bfm: {
-         nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
-         if (nir_ssa_scalar_is_const(src1_scalar)) {
-            src0 = MIN2(src0, 31);
-            src1 = nir_ssa_scalar_as_uint(src1_scalar) & 0x1fu;
-            res = bitmask(src0) << src1;
-         } else {
-            src0 = MIN2(src0, 31);
-            src1 = MIN2(src1, 31);
-            res = bitmask(MIN2(src0 + src1, 32));
-         }
-         break;
-      }
-      /* limited floating-point support for f2u32(fmul(load_input(), <constant>)) */
-      case nir_op_f2u32:
-         /* infinity/NaN starts at 0x7f800000u, negative numbers at 0x80000000 */
-         if (src0 < 0x7f800000u) {
-            float val;
-            memcpy(&val, &src0, 4);
-            res = (uint32_t)val;
-         }
-         break;
-      case nir_op_fmul:
-      case nir_op_fmulz:
-         /* infinity/NaN starts at 0x7f800000u, negative numbers at 0x80000000 */
-         if (src0 < 0x7f800000u && src1 < 0x7f800000u) {
-            float src0_f, src1_f;
-            memcpy(&src0_f, &src0, 4);
-            memcpy(&src1_f, &src1, 4);
-            /* not a proper rounding-up multiplication, but should be good enough */
-            float max_f = ceilf(src0_f) * ceilf(src1_f);
-            memcpy(&res, &max_f, 4);
-         }
-         break;
-      case nir_op_u2u1:
-      case nir_op_u2u8:
-      case nir_op_u2u16:
-      case nir_op_u2u32:
-         res = MIN2(src0, max);
-         break;
-      case nir_op_b2i8:
-      case nir_op_b2i16:
-      case nir_op_b2i32:
-         res = 1;
-         break;
-      case nir_op_sad_u8x4:
-         res = src2 + 4 * 255;
-         break;
-      case nir_op_extract_u8:
-         res = MIN2(src0, UINT8_MAX);
-         break;
-      case nir_op_extract_i8:
-         res = (src0 >= 0x80) ? max : MIN2(src0, INT8_MAX);
-         break;
-      case nir_op_extract_u16:
-         res = MIN2(src0, UINT16_MAX);
-         break;
-      case nir_op_extract_i16:
-         res = (src0 >= 0x8000) ? max : MIN2(src0, INT16_MAX);
-         break;
-      default:
-         res = max;
-         break;
-      }
-      _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res);
-      return res;
+   if (q.head.pushed_queries) {
+      *result = src[0];
+      for (unsigned i = 1; i < q.head.pushed_queries; i++)
+         *result = MAX2(*result, src[i]);
+      return;
    }
 
-   return max;
+   nir_cf_node *prev = nir_cf_node_prev(&phi->instr.block->cf_node);
+   if (!prev || prev->type == nir_cf_node_block) {
+      /* Resolve cycles by inserting max into range_ht. */
+      uint32_t max = bitmask(q.scalar.def->bit_size);
+      _mesa_hash_table_insert(state->range_ht, (void*)get_uub_key(&q.head), (void*)(uintptr_t)max);
+
+      struct set *visited = _mesa_pointer_set_create(NULL);
+      nir_ssa_scalar *defs = alloca(sizeof(nir_ssa_scalar) * 64);
+      unsigned def_count = search_phi_bcsel(q.scalar, defs, 64, visited);
+      _mesa_set_destroy(visited, NULL);
+
+      for (unsigned i = 0; i < def_count; i++)
+         push_uub_query(state, defs[i]);
+   } else {
+      nir_foreach_phi_src(src, phi)
+         push_uub_query(state, nir_get_ssa_scalar(src->src.ssa, q.scalar.comp));
+   }
+}
+
+static void
+process_uub_query(struct analysis_state *state, struct analysis_query *aq, uint32_t *result,
+                  const uint32_t *src)
+{
+   struct uub_query q = *(struct uub_query *)aq;
+
+   *result = bitmask(q.scalar.def->bit_size);
+   if (nir_ssa_scalar_is_const(q.scalar))
+      *result = nir_ssa_scalar_as_uint(q.scalar);
+   else if (q.scalar.def->parent_instr->type == nir_instr_type_intrinsic)
+      get_intrinsic_uub(state, q, result, src);
+   else if (nir_ssa_scalar_is_alu(q.scalar))
+      get_alu_uub(state, q, result, src);
+   else if (q.scalar.def->parent_instr->type == nir_instr_type_phi)
+      get_phi_uub(state, q, result, src);
 }
 
 uint32_t
 nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht,
                          nir_ssa_scalar scalar,
-                         const nir_unsigned_upper_bound_config *config)
-{
-   return nir_unsigned_upper_bound_impl(shader, range_ht, scalar, config, 0);
+                         const nir_unsigned_upper_bound_config *config) {
+   if (!config)
+      config = &default_ub_config;
+
+   struct uub_query query_alloc[16];
+   uint32_t result_alloc[16];
+
+   struct analysis_state state;
+   state.shader = shader;
+   state.config = config;
+   state.range_ht = range_ht;
+   util_dynarray_init_from_stack(&state.query_stack, query_alloc, sizeof(query_alloc));
+   util_dynarray_init_from_stack(&state.result_stack, result_alloc, sizeof(result_alloc));
+   state.query_size = sizeof(struct uub_query);
+   state.get_key = &get_uub_key;
+   state.process_query = &process_uub_query;
+
+   push_uub_query(&state, scalar);
+
+   return perform_analysis(&state);
 }
 
 bool