nir/shrink_vectors: shrink ALU properly
authorDaniel Schürmann <daniel@schuermann.dev>
Wed, 7 Jul 2021 14:43:13 +0000 (16:43 +0200)
committerMarge Bot <eric+marge@anholt.net>
Mon, 26 Jul 2021 09:24:37 +0000 (09:24 +0000)
ALU instructions of which not all components are read,
can be shrunk to the number of read components.
Previously, this would only remove trailing components.

This patch enables to remove components from any position.

Stat changes for softpipe:
total instructions in shared programs: 3001291 -> 2984698 (-0.55%)
instructions in affected programs: 225585 -> 208992 (-7.36%)
total loops in shared programs: 1389 -> 1358 (-2.23%)
loops in affected programs: 36 -> 5 (-86.11%)

Reviewed-by: Emma Anholt <emma@anholt.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/11411>

src/compiler/nir/nir_opt_shrink_vectors.c

index 45284dc..c1283d3 100644 (file)
  *
  * Due to various optimization passes (or frontend implementations,
  * particularly prog_to_nir), we may have instructions generating vectors
- * whose components don't get read by any instruction.  While it can be tricky
- * to eliminate either unused low components of a writemask (you might need to
- * increment some offset from a load_uniform, for example) or channels in the
- * middle of a partially set writemask (you might need to reswizzle ALU ops
- * using the value), it is trivial to just drop the trailing components.
+ * whose components don't get read by any instruction. As it can be tricky
+ * to eliminate unused low components or channels in the middle of a writemask
+ * (you might need to increment some offset from a load_uniform, for example),
+ * it is trivial to just drop the trailing components. For vector ALU only used
+ * by ALU, this pass eliminates arbitrary channels and reswizzles the uses.
  *
  * This pass is probably only of use to vector backends -- scalar backends
  * typically get unused def channel trimming by scalarizing and dead code
@@ -75,13 +75,64 @@ opt_shrink_vectors_alu(nir_builder *b, nir_alu_instr *instr)
 {
    nir_ssa_def *def = &instr->dest.dest.ssa;
 
+   /* Nothing to shrink */
+   if (def->num_components == 1)
+      return false;
+
    if (nir_op_infos[instr->op].output_size == 0) {
-      if (shrink_dest_to_read_mask(def)) {
-         instr->dest.write_mask &=
-            BITFIELD_MASK(def->num_components);
+      /* don't remove any channels if used by an intrinsic */
+      nir_foreach_use_safe(use_src, def) {
+         if (use_src->parent_instr->type == nir_instr_type_intrinsic)
+            return false;
+      }
+
+      unsigned mask = nir_ssa_def_components_read(def);
+      unsigned last_bit = util_last_bit(mask);
+      unsigned num_components = util_bitcount(mask);
+      if (mask == 0 || num_components == def->num_components)
+         return false;
+
+      const bool is_bitfield_mask = last_bit == num_components;
 
+      if (is_bitfield_mask) {
+         /* just reduce the number of components and return */
+         def->num_components = num_components;
+         instr->dest.write_mask = mask;
          return true;
       }
+
+      /* update sources */
+      for (int i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
+         unsigned index = 0;
+         for (int j = 0; j < last_bit; j++) {
+            if ((mask >> j) & 0x1)
+               instr->src[i].swizzle[index++] = instr->src[i].swizzle[j];
+         }
+         assert(index == num_components);
+      }
+
+      /* compute new dest swizzles */
+      uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
+      unsigned index = 0;
+      for (int i = 0; i < def->num_components; i++) {
+         if ((mask >> i) & 0x1)
+            reswizzle[i] = index++;
+      }
+      assert(index == num_components);
+
+      /* update dest */
+      def->num_components = num_components;
+      instr->dest.write_mask = BITFIELD_MASK(num_components);
+
+      /* update uses */
+      nir_foreach_use(use_src, def) {
+         assert(use_src->parent_instr->type == nir_instr_type_alu);
+         nir_alu_src *alu_src = (nir_alu_src*)use_src;
+         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+            alu_src->swizzle[i] = reswizzle[alu_src->swizzle[i]];
+      }
+
+      return true;
    } else {
 
       switch (instr->op) {