nir/opt_vectorize: hash whether a swizzle accesses elements beyond the maximum vector...
authorDaniel Schürmann <daniel@schuermann.dev>
Fri, 11 Sep 2020 10:05:17 +0000 (11:05 +0100)
committerMarge Bot <eric+marge@anholt.net>
Thu, 31 Dec 2020 16:44:58 +0000 (16:44 +0000)
Swizzles that access components outside of the maximum
vector size cannot be vectorized with each other.
This patch creates different hash bins for this case.

For example accesses to .x and .y are considered different variables
compared to accesses to .z and .w for 16-bit vec2.

This prevents the vectorization of things like
   vec2 16 ssa_3 = iadd ssa_1.xz, ssa_2.xz

Reviewed-by: Connor Abbott <cwabbott0@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6666>

src/compiler/nir/nir_opt_vectorize.c

index 3f06a19..79f8008 100644 (file)
@@ -39,11 +39,18 @@ hash_src(uint32_t hash, const nir_src *src)
 }
 
 static uint32_t
-hash_alu_src(uint32_t hash, const nir_alu_src *src)
+hash_alu_src(uint32_t hash, const nir_alu_src *src,
+             uint32_t num_components, uint32_t max_vec)
 {
    assert(!src->abs && !src->negate);
 
-   /* intentionally don't hash swizzle */
+   /* hash whether a swizzle accesses elements beyond the maximum
+    * vectorization factor:
+    * For example accesses to .x and .y are considered different variables
+    * compared to accesses to .z and .w for 16-bit vec2.
+    */
+   uint32_t swizzle = (src->swizzle[0] & ~(max_vec - 1));
+   hash = HASH(hash, swizzle);
 
    return hash_src(hash, &src->src);
 }
@@ -59,7 +66,9 @@ hash_instr(const void *data)
    hash = HASH(hash, alu->dest.dest.ssa.bit_size);
 
    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
-      hash = hash_alu_src(hash, &alu->src[i]);
+      hash = hash_alu_src(hash, &alu->src[i],
+                          alu->dest.dest.ssa.num_components,
+                          instr->pass_flags);
 
    return hash;
 }
@@ -75,13 +84,18 @@ srcs_equal(const nir_src *src1, const nir_src *src2)
 }
 
 static bool
-alu_srcs_equal(const nir_alu_src *src1, const nir_alu_src *src2)
+alu_srcs_equal(const nir_alu_src *src1, const nir_alu_src *src2,
+               uint32_t max_vec)
 {
    assert(!src1->abs);
    assert(!src1->negate);
    assert(!src2->abs);
    assert(!src2->negate);
 
+   uint32_t mask = ~(max_vec - 1);
+   if ((src1->swizzle[0] & mask) != (src2->swizzle[0] & mask))
+      return false;
+
    return srcs_equal(&src1->src, &src2->src);
 }
 
@@ -103,7 +117,7 @@ instrs_equal(const void *data1, const void *data2)
       return false;
 
    for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
-      if (!alu_srcs_equal(&alu1->src[i], &alu2->src[i]))
+      if (!alu_srcs_equal(&alu1->src[i], &alu2->src[i], instr1->pass_flags))
          return false;
    }
 
@@ -139,6 +153,14 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
       for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
          if (nir_op_infos[alu->op].input_sizes[i] != 0)
             return false;
+
+         /* don't hash instructions which are already swizzled
+          * outside of max_components: these should better be scalarized */
+         uint32_t mask = vectorize_16bit ? ~1 : ~3;
+         for (unsigned j = 0; j < alu->dest.dest.ssa.num_components; j++) {
+            if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[i] & mask))
+               return false;
+         }
       }
 
       return true;
@@ -188,6 +210,7 @@ instr_try_combine(struct nir_shader *nir, struct set *instr_set,
    nir_ssa_dest_init(&new_alu->instr, &new_alu->dest.dest,
                      total_components, alu1->dest.dest.ssa.bit_size, NULL);
    new_alu->dest.write_mask = (1 << total_components) - 1;
+   new_alu->instr.pass_flags = alu1->instr.pass_flags;
 
    /* If either channel is exact, we have to preserve it even if it's
     * not optimal for other channels.
@@ -341,6 +364,9 @@ vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set,
    if (filter && !filter(instr, data))
       return false;
 
+   /* set max vector to instr pass flags: this is used to hash swizzles */
+   instr->pass_flags = nir->options->vectorize_vec2_16bit ? 2 : 4;
+
    struct set_entry *entry = _mesa_set_search(instr_set, instr);
    if (entry) {
       nir_instr *old_instr = (nir_instr *) entry->key;
@@ -377,7 +403,8 @@ vectorize_block(struct nir_shader *nir, nir_block *block,
    }
 
    nir_foreach_instr_reverse(instr, block) {
-      if (instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit))
+      if (instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit) &&
+          (!filter || filter(instr, data)))
          _mesa_set_remove_key(instr_set, instr);
    }