nir: Handle wider unaligned loads in lower_mem_access_bit_size
authorFaith Ekstrand <faith.ekstrand@collabora.com>
Mon, 27 Feb 2023 14:45:22 +0000 (08:45 -0600)
committerMarge Bot <emma+marge@anholt.net>
Fri, 3 Mar 2023 02:00:39 +0000 (02:00 +0000)
Reviewed-by: M Henning <drawoc@darkrefraction.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21524>

src/compiler/nir/nir_lower_mem_access_bit_sizes.c

index a44f774..16269f6 100644 (file)
@@ -110,6 +110,8 @@ lower_mem_load(nir_builder *b, nir_intrinsic_instr *intrin,
       const unsigned bytes_left = bytes_read - chunk_start;
       const uint32_t chunk_align_offset =
          (whole_align_offset + chunk_start) % align_mul;
+      const uint32_t chunk_align =
+         nir_combined_align(align_mul, chunk_align_offset);
       requested = mem_access_size_align_cb(intrin->intrinsic, bytes_left,
                                            align_mul, chunk_align_offset,
                                            offset_is_const, cb_data);
@@ -118,9 +120,9 @@ lower_mem_load(nir_builder *b, nir_intrinsic_instr *intrin,
       assert(util_is_power_of_two_nonzero(requested.align));
       if (align_mul < requested.align) {
          /* For this case, we need to be able to shift the value so we assume
-          * there's at most one component.
+          * the alignment is less than the size of a single component.  This
+          * ensures that we don't need to upcast in order to shift.
           */
-         assert(requested.num_components == 1);
          assert(requested.bit_size >= requested.align * 8);
 
          uint64_t align_mask = requested.align - 1;
@@ -133,12 +135,43 @@ lower_mem_load(nir_builder *b, nir_intrinsic_instr *intrin,
                               requested.align, 0, NULL,
                               requested.num_components, requested.bit_size);
 
-         nir_ssa_def *shifted =
-            nir_ushr(b, &load->dest.ssa, nir_imul_imm(b, pad, 8));
+         unsigned max_pad = requested.align - chunk_align;
+         unsigned requested_bytes =
+            requested.num_components * requested.bit_size / 8;
+         chunk_bytes = MIN2(bytes_left, requested_bytes - max_pad);
 
-         chunk_bytes = MIN2(bytes_left, align_mul);
-         assert(num_chunks < ARRAY_SIZE(chunks));
-         chunks[num_chunks++] = nir_u2uN(b, shifted, chunk_bytes * 8);
+         nir_ssa_def *shift = nir_imul_imm(b, pad, 8);
+         nir_ssa_def *shifted = nir_ushr(b, &load->dest.ssa, shift);
+
+         if (load->dest.ssa.num_components > 1) {
+            nir_ssa_def *rev_shift =
+               nir_isub_imm(b, load->dest.ssa.bit_size, shift);
+            nir_ssa_def *rev_shifted = nir_ishl(b, &load->dest.ssa, rev_shift);
+
+            nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
+            for (unsigned i = 1; i < load->dest.ssa.num_components; i++)
+               comps[i - 1] = nir_channel(b, rev_shifted, i);
+
+            comps[load->dest.ssa.num_components - 1] =
+               nir_imm_zero(b, 1, load->dest.ssa.bit_size);
+
+            rev_shifted = nir_vec(b, comps, load->dest.ssa.num_components);
+            shifted = nir_bcsel(b, nir_ieq_imm(b, shift, 0), &load->dest.ssa,
+                                   nir_ior(b, shifted, rev_shifted));
+         }
+
+         unsigned chunk_bit_size = MIN2(8 << (ffs(chunk_bytes) - 1), bit_size);
+         unsigned chunk_num_components = chunk_bytes / (chunk_bit_size / 8);
+
+         /* There's no guarantee that chunk_num_components is a valid NIR
+          * vector size, so just loop one chunk component at a time
+          */
+         for (unsigned i = 0; i < chunk_num_components; i++) {
+            assert(num_chunks < ARRAY_SIZE(chunks));
+            chunks[num_chunks++] =
+               nir_extract_bits(b, &shifted, 1, i * chunk_bit_size,
+                                1, chunk_bit_size);
+         }
       } else if (chunk_align_offset % requested.align) {
          /* In this case, we know how much to adjust the offset */
          uint32_t delta = chunk_align_offset % requested.align;