nir: Add component mask re-interpret helpers
authorJason Ekstrand <jason@jlekstrand.net>
Fri, 25 Sep 2020 21:01:03 +0000 (16:01 -0500)
committerMarge Bot <eric+marge@anholt.net>
Fri, 2 Oct 2020 07:30:49 +0000 (07:30 +0000)
These are based on the ones which already existed in the load/store
vectorization pass but I made some improvements while moving them.  In
particular,

 1. They're both faster if the bit sizes are equal
 2. The check is faster if old_bit_size > new_bit_size
 3. The check now fails if it would use more than NIR_MAX_VEC_COMPONENTS

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6871>

src/compiler/nir/nir.c
src/compiler/nir/nir.h
src/compiler/nir/nir_opt_load_store_vectorize.c

index 160e6e6..34faa34 100644 (file)
 
 #include "main/menums.h" /* BITFIELD64_MASK */
 
+
+/** Return true if the component mask "mask" with bit size "old_bit_size" can
+ * be re-interpreted to be used with "new_bit_size".
+ */
+bool
+nir_component_mask_can_reinterpret(nir_component_mask_t mask,
+                                   unsigned old_bit_size,
+                                   unsigned new_bit_size)
+{
+   assert(util_is_power_of_two_nonzero(old_bit_size));
+   assert(util_is_power_of_two_nonzero(new_bit_size));
+
+   if (old_bit_size == new_bit_size)
+      return true;
+
+   if (old_bit_size == 1 || new_bit_size == 1)
+      return false;
+
+   if (old_bit_size > new_bit_size) {
+      unsigned ratio = old_bit_size / new_bit_size;
+      return util_last_bit(mask) * ratio <= NIR_MAX_VEC_COMPONENTS;
+   }
+
+   unsigned iter = mask;
+   while (iter) {
+      int start, count;
+      u_bit_scan_consecutive_range(&iter, &start, &count);
+      start *= old_bit_size;
+      count *= old_bit_size;
+      if (start % new_bit_size != 0)
+         return false;
+      if (count % new_bit_size != 0)
+         return false;
+   }
+   return true;
+}
+
+/** Re-interprets a component mask "mask" with bit size "old_bit_size" so that
+ * it can be used can be used with "new_bit_size".
+ */
+nir_component_mask_t
+nir_component_mask_reinterpret(nir_component_mask_t mask,
+                               unsigned old_bit_size,
+                               unsigned new_bit_size)
+{
+   assert(nir_component_mask_can_reinterpret(mask, old_bit_size, new_bit_size));
+
+   if (old_bit_size == new_bit_size)
+      return mask;
+
+   nir_component_mask_t new_mask = 0;
+   unsigned iter = mask;
+   while (iter) {
+      int start, count;
+      u_bit_scan_consecutive_range(&iter, &start, &count);
+      start = start * old_bit_size / new_bit_size;
+      count = count * old_bit_size / new_bit_size;
+      new_mask |= BITFIELD_RANGE(start, count);
+   }
+   return new_mask;
+}
+
 nir_shader *
 nir_shader_create(void *mem_ctx,
                   gl_shader_stage stage,
index 9942153..ef4d33a 100644 (file)
@@ -76,6 +76,14 @@ nir_num_components_valid(unsigned num_components)
            num_components == 16;
 }
 
+bool nir_component_mask_can_reinterpret(nir_component_mask_t mask,
+                                        unsigned old_bit_size,
+                                        unsigned new_bit_size);
+nir_component_mask_t
+nir_component_mask_reinterpret(nir_component_mask_t mask,
+                               unsigned old_bit_size,
+                               unsigned new_bit_size);
+
 /** Defines a cast function
  *
  * This macro defines a cast function from in_type to out_type where
index d28419c..593a022 100644 (file)
@@ -625,25 +625,6 @@ cast_deref(nir_builder *b, unsigned num_components, unsigned bit_size, nir_deref
    return nir_build_deref_cast(b, &deref->dest.ssa, deref->mode, type, 0);
 }
 
-/* Return true if the write mask "write_mask" of a store with "old_bit_size"
- * bits per element can be represented for a store with "new_bit_size" bits per
- * element. */
-static bool
-writemask_representable(unsigned write_mask, unsigned old_bit_size, unsigned new_bit_size)
-{
-   while (write_mask) {
-      int start, count;
-      u_bit_scan_consecutive_range(&write_mask, &start, &count);
-      start *= old_bit_size;
-      count *= old_bit_size;
-      if (start % new_bit_size != 0)
-         return false;
-      if (count % new_bit_size != 0)
-         return false;
-   }
-   return true;
-}
-
 /* Return true if "new_bit_size" is a usable bit size for a vectorized load/store
  * of "low" and "high". */
 static bool
@@ -683,33 +664,17 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
          return false;
 
       unsigned write_mask = nir_intrinsic_write_mask(low->intrin);
-      if (!writemask_representable(write_mask, get_bit_size(low), new_bit_size))
+      if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(low), new_bit_size))
          return false;
 
       write_mask = nir_intrinsic_write_mask(high->intrin);
-      if (!writemask_representable(write_mask, get_bit_size(high), new_bit_size))
+      if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(high), new_bit_size))
          return false;
    }
 
    return true;
 }
 
-/* Updates a write mask, "write_mask", so that it can be used with a
- * "new_bit_size"-bit store instead of a "old_bit_size"-bit store. */
-static uint32_t
-update_writemask(unsigned write_mask, unsigned old_bit_size, unsigned new_bit_size)
-{
-   uint32_t res = 0;
-   while (write_mask) {
-      int start, count;
-      u_bit_scan_consecutive_range(&write_mask, &start, &count);
-      start = start * old_bit_size / new_bit_size;
-      count = count * old_bit_size / new_bit_size;
-      res |= ((1 << count) - 1) << start;
-   }
-   return res;
-}
-
 static nir_deref_instr *subtract_deref(nir_builder *b, nir_deref_instr *deref, int64_t offset)
 {
    /* avoid adding another deref to the path */
@@ -847,8 +812,12 @@ vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
    /* get new writemasks */
    uint32_t low_write_mask = nir_intrinsic_write_mask(low->intrin);
    uint32_t high_write_mask = nir_intrinsic_write_mask(high->intrin);
-   low_write_mask = update_writemask(low_write_mask, get_bit_size(low), new_bit_size);
-   high_write_mask = update_writemask(high_write_mask, get_bit_size(high), new_bit_size);
+   low_write_mask = nir_component_mask_reinterpret(low_write_mask,
+                                                   get_bit_size(low),
+                                                   new_bit_size);
+   high_write_mask = nir_component_mask_reinterpret(high_write_mask,
+                                                    get_bit_size(high),
+                                                    new_bit_size);
    high_write_mask <<= high_start / new_bit_size;
 
    uint32_t write_mask = low_write_mask | high_write_mask;