nir/loop_analyze: Add a function to evaluate an ALU as constant
authorIan Romanick <ian.d.romanick@intel.com>
Tue, 14 Feb 2023 01:23:14 +0000 (17:23 -0800)
committerMarge Bot <emma+marge@anholt.net>
Thu, 6 Apr 2023 23:50:27 +0000 (23:50 +0000)
...with a substitution. This function is largely a copy-and-paste of
try_fold_alu (nir_opt_constant_folding.c), and an argument could be made
that this function belongs in that file.

v2: Some changes were mistakenly squashed in to "nir/loop_analyze: Use
try_eval_const_alu and induction variable basis info" that should have
been here.

Reviewed-by: Timothy Arceri <tarceri@itsqueeze.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/3445>

src/compiler/nir/nir_loop_analyze.c

index a9be89c..068078f 100644 (file)
@@ -705,6 +705,122 @@ eval_const_binop(nir_op op, unsigned bit_size,
    return dest;
 }
 
+static int
+find_replacement(const nir_ssa_def **originals, const nir_ssa_def *key,
+                 unsigned num_replacements)
+{
+   for (int i = 0; i < num_replacements; i++) {
+      if (originals[i] == key)
+         return i;
+   }
+
+   return -1;
+}
+
+/**
+ * Try to evaluate an ALU instruction as a constant with a replacement
+ *
+ * Much like \c nir_opt_constant_folding.c:try_fold_alu, this method attempts
+ * to evaluate an ALU instruction as a constant. There are two significant
+ * differences.
+ *
+ * First, this method performs the evaluation recursively. If any source of
+ * the ALU instruction is not itself a constant, it is first evaluated.
+ *
+ * Second, if the SSA value \c original is encountered as a source of the ALU
+ * instruction, the value \c replacement is substituted.
+ *
+ * The intended purpose of this function is to evaluate an arbitrary
+ * expression involving a loop induction variable. In this case, \c original
+ * would be the phi node associated with the induction variable, and
+ * \c replacement is the initial value of the induction variable.
+ *
+ * \returns true if the ALU instruction can be evaluated as constant (after
+ * applying the previously described substitution) or false otherwise.
+ */
+static bool
+try_eval_const_alu(nir_const_value *dest, nir_alu_instr *alu,
+                   const nir_ssa_def **originals,
+                   const nir_const_value **replacements,
+                   unsigned num_replacements, unsigned execution_mode)
+{
+   nir_const_value src[NIR_MAX_VEC_COMPONENTS][NIR_MAX_VEC_COMPONENTS];
+
+   if (!alu->dest.dest.is_ssa)
+      return false;
+
+   /* In the case that any outputs/inputs have unsized types, then we need to
+    * guess the bit-size. In this case, the validator ensures that all
+    * bit-sizes match so we can just take the bit-size from first
+    * output/input with an unsized type. If all the outputs/inputs are sized
+    * then we don't need to guess the bit-size at all because the code we
+    * generate for constant opcodes in this case already knows the sizes of
+    * the types involved and does not need the provided bit-size for anything
+    * (although it still requires to receive a valid bit-size).
+    */
+   unsigned bit_size = 0;
+   if (!nir_alu_type_get_type_size(nir_op_infos[alu->op].output_type))
+      bit_size = alu->dest.dest.ssa.bit_size;
+
+   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
+      if (!alu->src[i].src.is_ssa)
+         return false;
+
+      if (bit_size == 0 &&
+          !nir_alu_type_get_type_size(nir_op_infos[alu->op].input_types[i]))
+         bit_size = alu->src[i].src.ssa->bit_size;
+
+      nir_instr *src_instr = alu->src[i].src.ssa->parent_instr;
+
+      if (src_instr->type == nir_instr_type_load_const) {
+         nir_load_const_instr *load_const = nir_instr_as_load_const(src_instr);
+
+         for (unsigned j = 0; j < nir_ssa_alu_instr_src_components(alu, i);
+              j++) {
+            src[i][j] = load_const->value[alu->src[i].swizzle[j]];
+         }
+      } else {
+         int r = find_replacement(originals, alu->src[i].src.ssa,
+                                  num_replacements);
+
+         if (r >= 0) {
+            for (unsigned j = 0; j < nir_ssa_alu_instr_src_components(alu, i);
+                 j++) {
+               src[i][j] = replacements[r][alu->src[i].swizzle[j]];
+            }
+         } else if (src_instr->type == nir_instr_type_alu) {
+            memset(src[i], 0, sizeof(src[i]));
+
+            if (!try_eval_const_alu(src[i], nir_instr_as_alu(src_instr),
+                                    originals, replacements, num_replacements,
+                                    execution_mode))
+               return false;
+         } else {
+            return false;
+         }
+      }
+
+      /* We shouldn't have any source modifiers in the optimization loop. */
+      assert(!alu->src[i].abs && !alu->src[i].negate);
+   }
+
+   if (bit_size == 0)
+      bit_size = 32;
+
+   /* We shouldn't have any saturate modifiers in the optimization loop. */
+   assert(!alu->dest.saturate);
+
+   nir_const_value *srcs[NIR_MAX_VEC_COMPONENTS];
+
+   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i)
+      srcs[i] = src[i];
+
+   nir_eval_const_opcode(alu->op, dest, alu->dest.dest.ssa.num_components,
+                         bit_size, srcs, execution_mode);
+
+   return true;
+}
+
 static int32_t
 get_iteration(nir_op cond_op, nir_const_value initial, nir_const_value step,
               nir_const_value limit, unsigned bit_size,