nir/range_analysis: add missing masking of shift amounts
authorRhys Perry <pendingchaos02@gmail.com>
Wed, 15 Mar 2023 16:11:12 +0000 (16:11 +0000)
committerMarge Bot <emma+marge@anholt.net>
Wed, 22 Mar 2023 09:24:18 +0000 (09:24 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Fixes: 72ac3f60261 ("nir: add nir_unsigned_upper_bound and nir_addition_might_overflow")
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21381>

src/compiler/nir/nir_range_analysis.c

index f4f6c0a..0f44aba 100644 (file)
@@ -1601,12 +1601,14 @@ nir_unsigned_upper_bound_impl(nir_shader *shader, struct hash_table *range_ht,
       case nir_op_ixor:
          res = bitmask(util_last_bit64(src0)) | bitmask(util_last_bit64(src1));
          break;
-      case nir_op_ishl:
+      case nir_op_ishl: {
+         src1 = MIN2(src1, q.scalar.def->bit_size - 1u);
          if (util_last_bit64(src0) + src1 > scalar.def->bit_size)
             res = max; /* overflow */
          else
-            res = src0 << MIN2(src1, scalar.def->bit_size - 1u);
+            res = src0 << src1;
          break;
+      }
       case nir_op_imul:
          if (src0 != 0 && (src0 * src1) / src0 != src1)
             res = max;
@@ -1615,16 +1617,18 @@ nir_unsigned_upper_bound_impl(nir_shader *shader, struct hash_table *range_ht,
          break;
       case nir_op_ushr: {
          nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
+         uint32_t mask = q.scalar.def->bit_size - 1u;
          if (nir_ssa_scalar_is_const(src1_scalar))
-            res = src0 >> nir_ssa_scalar_as_uint(src1_scalar);
+            res = src0 >> (nir_ssa_scalar_as_uint(src1_scalar) & mask);
          else
             res = src0;
          break;
       }
       case nir_op_ishr: {
          nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
+         uint32_t mask = q.scalar.def->bit_size - 1u;
          if (src0 <= 2147483647 && nir_ssa_scalar_is_const(src1_scalar))
-            res = src0 >> nir_ssa_scalar_as_uint(src1_scalar);
+            res = src0 >> (nir_ssa_scalar_as_uint(src1_scalar) & mask);
          else
             res = src0;
          break;