nir/lower_int64: Implement lowering of 64-bit integer to 64-bit float conversions.
authorFrancisco Jerez <currojerez@riseup.net>
Mon, 17 Oct 2022 21:05:38 +0000 (14:05 -0700)
committerMarge Bot <emma+marge@anholt.net>
Sat, 29 Oct 2022 19:45:44 +0000 (19:45 +0000)
This involves computing the significand with a 64-bit precision type,
and implementing the normalization and packing manually instead of
relying on u2f32, since the significand can no longer be represented
as a 32-bit integer.  This fixes 64-bit integer to 64-bit float
conversions on devices that support 64-bit float natively but lack
64-bit integer support, like Intel MTL hardware.

Reviewed-by: Boris Brezillon <boris.brezillon@collabora.com> (v1)
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19128>

src/compiler/nir/nir_lower_int64.c

index 10db7cd..005343f 100644 (file)
@@ -701,6 +701,9 @@ lower_2f(nir_builder *b, nir_ssa_def *x, unsigned dest_bit_size,
    unsigned significand_bits;
 
    switch (dest_bit_size) {
+   case 64:
+      significand_bits = 52;
+      break;
    case 32:
       significand_bits = 23;
       break;
@@ -714,8 +717,9 @@ lower_2f(nir_builder *b, nir_ssa_def *x, unsigned dest_bit_size,
    nir_ssa_def *discard =
       nir_imax(b, nir_isub(b, exp, nir_imm_int(b, significand_bits)),
                   nir_imm_int(b, 0));
-   nir_ssa_def *significand =
-      COND_LOWER_CAST(b, u2u32, COND_LOWER_OP(b, ushr, x, discard));
+   nir_ssa_def *significand = COND_LOWER_OP(b, ushr, x, discard);
+   if (significand_bits < 32)
+      significand = COND_LOWER_CAST(b, u2u32, significand);
 
    /* Round-to-nearest-even implementation:
     * - if the non-representable part of the significand is higher than half
@@ -731,19 +735,63 @@ lower_2f(nir_builder *b, nir_ssa_def *x, unsigned dest_bit_size,
    nir_ssa_def *rem = COND_LOWER_OP(b, iand, x, rem_mask);
    nir_ssa_def *halfway = nir_iand(b, COND_LOWER_CMP(b, ieq, rem, half),
                                    nir_ine(b, discard, nir_imm_int(b, 0)));
-   nir_ssa_def *is_odd = nir_i2b(b, nir_iand(b, significand, nir_imm_int(b, 1)));
+   nir_ssa_def *is_odd = COND_LOWER_CMP(b, ine, nir_imm_int64(b, 0),
+                                         COND_LOWER_OP(b, iand, x, lsb_mask));
    nir_ssa_def *round_up = nir_ior(b, COND_LOWER_CMP(b, ilt, half, rem),
                                    nir_iand(b, halfway, is_odd));
-   significand = nir_iadd(b, significand, nir_b2i32(b, round_up));
+   if (significand_bits >= 32)
+      significand = COND_LOWER_OP(b, iadd, significand,
+                                  COND_LOWER_CAST(b, b2i64, round_up));
+   else
+      significand = nir_iadd(b, significand, nir_b2i32(b, round_up));
 
    nir_ssa_def *res;
 
-   if (dest_bit_size == 32)
+   if (dest_bit_size == 64) {
+      /* Compute the left shift required to normalize the original
+       * unrounded input manually.
+       */
+      nir_ssa_def *shift =
+         nir_imax(b, nir_isub(b, nir_imm_int(b, significand_bits), exp),
+                  nir_imm_int(b, 0));
+      significand = COND_LOWER_OP(b, ishl, significand, shift);
+
+      /* Check whether normalization led to overflow of the available
+       * significand bits, which can only happen if round_up was true
+       * above, in which case we need to add carry to the exponent and
+       * discard an extra bit from the significand.  Note that we
+       * don't need to repeat the round-up logic again, since the LSB
+       * of the significand is guaranteed to be zero if there was
+       * overflow.
+       */
+      nir_ssa_def *carry = nir_b2i32(
+         b, nir_uge(b, nir_unpack_64_2x32_split_y(b, significand),
+                    nir_imm_int(b, 1 << (significand_bits - 31))));
+      significand = COND_LOWER_OP(b, ishr, significand, carry);
+      exp = nir_iadd(b, exp, carry);
+
+      /* Compute the biased exponent, taking care to handle a zero
+       * input correctly, which would have caused exp to be negative.
+       */
+      nir_ssa_def *biased_exp = nir_bcsel(b, nir_ilt(b, exp, nir_imm_int(b, 0)),
+                                          nir_imm_int(b, 0),
+                                          nir_iadd(b, exp, nir_imm_int(b, 1023)));
+
+      /* Pack the significand and exponent manually. */
+      nir_ssa_def *lo = nir_unpack_64_2x32_split_x(b, significand);
+      nir_ssa_def *hi = nir_bitfield_insert(
+         b, nir_unpack_64_2x32_split_y(b, significand),
+         biased_exp, nir_imm_int(b, 20), nir_imm_int(b, 11));
+
+      res = nir_pack_64_2x32_split(b, lo, hi);
+
+   } else if (dest_bit_size == 32) {
       res = nir_fmul(b, nir_u2f32(b, significand),
                      nir_fexp2(b, nir_u2f32(b, discard)));
-   else
+   } else {
       res = nir_fmul(b, nir_u2f16(b, significand),
                      nir_fexp2(b, nir_u2f16(b, discard)));
+   }
 
    if (src_is_signed)
       res = nir_fmul(b, res, x_sign);
@@ -818,6 +866,8 @@ nir_lower_int64_op_to_options_mask(nir_op opcode)
    case nir_op_u2u16:
    case nir_op_u2u32:
    case nir_op_u2u64:
+   case nir_op_i2f64:
+   case nir_op_u2f64:
    case nir_op_i2f32:
    case nir_op_u2f32:
    case nir_op_i2f16: