nir/loop_analyze: Use try_eval_const_alu and induction variable basis info
authorIan Romanick <ian.d.romanick@intel.com>
Tue, 14 Feb 2023 01:33:29 +0000 (17:33 -0800)
committerMarge Bot <emma+marge@anholt.net>
Thu, 6 Apr 2023 23:50:27 +0000 (23:50 +0000)
This dramatically simplifies will_break_on_first_iteration, and, much
more importantly, makes it significantly more flexible. It is now
possible to handle loops with more complex exit condition and other
kinds of increment operations.

Reviewed-by: Timothy Arceri <tarceri@itsqueeze.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/3445>

src/compiler/nir/nir_loop_analyze.c

index a439bcc..995bf90 100644 (file)
@@ -905,40 +905,20 @@ get_iteration(nir_op cond_op, nir_const_value initial, nir_const_value step,
 }
 
 static bool
-will_break_on_first_iteration(nir_const_value step,
-                              nir_alu_type induction_base_type,
-                              unsigned trip_offset,
-                              nir_op cond_op, unsigned bit_size,
-                              nir_const_value initial,
-                              nir_const_value limit,
-                              bool limit_rhs, bool invert_cond,
-                              unsigned execution_mode)
+will_break_on_first_iteration(nir_alu_instr *cond_alu, nir_ssa_def *basis,
+                              nir_ssa_def *limit_basis,
+                              nir_const_value initial, nir_const_value limit,
+                              bool invert_cond, unsigned execution_mode)
 {
-   if (trip_offset == 1) {
-      nir_op add_op;
-      switch (induction_base_type) {
-      case nir_type_float:
-         add_op = nir_op_fadd;
-         break;
-      case nir_type_int:
-      case nir_type_uint:
-         add_op = nir_op_iadd;
-         break;
-      default:
-         unreachable("Unhandled induction variable base type!");
-      }
+   nir_const_value result;
 
-      initial = eval_const_binop(add_op, bit_size, initial, step,
-                                 execution_mode);
-   }
+   const nir_ssa_def *originals[2] = { basis, limit_basis };
+   const nir_const_value *replacements[2] = { &initial, &limit };
 
-   nir_const_value *src[2];
-   src[limit_rhs ? 0 : 1] = &initial;
-   src[limit_rhs ? 1 : 0] = &limit;
+   ASSERTED bool success = try_eval_const_alu(&result, cond_alu, originals,
+                                              replacements, 2, execution_mode);
 
-   /* Evaluate the loop exit condition */
-   nir_const_value result;
-   nir_eval_const_opcode(cond_op, &result, 1, bit_size, src, execution_mode);
+   assert(success);
 
    return invert_cond ? !result.b : result.b;
 }
@@ -993,7 +973,8 @@ test_iterations(int32_t iter_int, nir_const_value step,
 }
 
 static int
-calculate_iterations(nir_const_value initial, nir_const_value step,
+calculate_iterations(nir_ssa_def *basis, nir_ssa_def *limit_basis,
+                     nir_const_value initial, nir_const_value step,
                      nir_const_value limit, nir_alu_instr *alu,
                      nir_ssa_scalar cond, nir_op alu_op, bool limit_rhs,
                      bool invert_cond, unsigned execution_mode)
@@ -1043,10 +1024,8 @@ calculate_iterations(nir_const_value initial, nir_const_value step,
     * however if the loop condition is false on the first iteration
     * get_iteration's assumption is broken. Handle such loops first.
     */
-   if (will_break_on_first_iteration(step, induction_base_type, trip_offset,
-                                     alu_op, bit_size, initial,
-                                     limit, limit_rhs, invert_cond,
-                                     execution_mode)) {
+   if (will_break_on_first_iteration(cond_alu, basis, limit_basis, initial,
+                                     limit, invert_cond, execution_mode)) {
       return 0;
    }
 
@@ -1329,7 +1308,8 @@ find_trip_count(loop_info_state *state, unsigned execution_mode)
       nir_const_value initial_val = nir_ssa_scalar_as_const_value(initial_s);
       nir_const_value step_val = nir_ssa_scalar_as_const_value(alu_s);
 
-      int iterations = calculate_iterations(initial_val, step_val, limit_val,
+      int iterations = calculate_iterations(lv->basis, limit.def,
+                                            initial_val, step_val, limit_val,
                                             nir_instr_as_alu(lv->update_src->src.parent_instr),
                                             cond,
                                             alu_op, limit_rhs,