microsoft/compiler: Add some more float16 support
authorJesse Natalie <jenatali@microsoft.com>
Tue, 6 Apr 2021 17:48:26 +0000 (10:48 -0700)
committerMarge Bot <eric+marge@anholt.net>
Fri, 9 Apr 2021 01:54:33 +0000 (01:54 +0000)
We can support float16 constants, b2f16, and casts to float16.

Reviewed-by: Enrico Galli <enrico.galli@intel.com>
Reviewed-by: Michael Tang <tangm@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10063>

src/microsoft/compiler/dxil_module.c
src/microsoft/compiler/dxil_module.h
src/microsoft/compiler/nir_to_dxil.c

index 09f2252..7f772c9 100644 (file)
@@ -1624,6 +1624,30 @@ dxil_module_get_int_const(struct dxil_module *m, intmax_t value,
 }
 
 const struct dxil_value *
+dxil_module_get_float16_const(struct dxil_module *m, uint16_t value)
+{
+   const struct dxil_type *type = get_float16_type(m);
+   if (!type)
+      return NULL;
+
+   struct dxil_const *c;
+   LIST_FOR_EACH_ENTRY(c, &m->const_list, head) {
+      if (c->value.type != type || c->undef)
+         continue;
+
+      if (c->int_value == (uintmax_t)value)
+         return &c->value;
+   }
+
+   c = create_const(m, type, false);
+   if (!c)
+      return NULL;
+
+   c->int_value = (uintmax_t)value;
+   return &c->value;
+}
+
+const struct dxil_value *
 dxil_module_get_float_const(struct dxil_module *m, float value)
 {
    const struct dxil_type *type = get_float32_type(m);
@@ -2026,6 +2050,15 @@ emit_int_value(struct dxil_module *m, int64_t value)
 }
 
 static bool
+emit_float16_value(struct dxil_module *m, uint16_t value)
+{
+   if (!value)
+      return emit_null_value(m);
+   uint64_t data = value;
+   return emit_record_no_abbrev(&m->buf, CST_CODE_FLOAT, &data, 1);
+}
+
+static bool
 emit_float_value(struct dxil_module *m, float value)
 {
    uint64_t data = fui(value);
@@ -2087,6 +2120,10 @@ emit_consts(struct dxil_module *m)
 
       case TYPE_FLOAT:
          switch (curr_type->float_bits) {
+         case 16:
+            if (!emit_float16_value(m, (uint16_t)(uintmax_t)c->int_value))
+               return false;
+            break;
          case 32:
             if (!emit_float_value(m, c->float_value))
                return false;
index ca607f5..e569860 100644 (file)
@@ -331,6 +331,9 @@ dxil_module_get_int_const(struct dxil_module *m, intmax_t value,
                           unsigned bit_size);
 
 const struct dxil_value *
+dxil_module_get_float16_const(struct dxil_module *m, uint16_t);
+
+const struct dxil_value *
 dxil_module_get_float_const(struct dxil_module *m, float value);
 
 const struct dxil_value *
index 0d71f3f..e8b4506 100644 (file)
@@ -1555,11 +1555,13 @@ get_cast_op(nir_alu_instr *alu)
       return DXIL_CAST_FPTOUI;
 
    /* int -> float */
+   case nir_op_i2f16:
    case nir_op_i2f32:
    case nir_op_i2f64:
       return DXIL_CAST_SITOFP;
 
    /* uint -> float */
+   case nir_op_u2f16:
    case nir_op_u2f32:
    case nir_op_u2f64:
       return DXIL_CAST_UITOFP;
@@ -1736,6 +1738,22 @@ static bool emit_select(struct ntd_context *ctx, nir_alu_instr *alu,
 }
 
 static bool
+emit_b2f16(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_value *val)
+{
+   assert(val);
+
+   struct dxil_module *m = &ctx->mod;
+
+   const struct dxil_value *c1 = dxil_module_get_float16_const(m, 0x3C00);
+   const struct dxil_value *c0 = dxil_module_get_float16_const(m, 0);
+
+   if (!c0 || !c1)
+      return false;
+
+   return emit_select(ctx, alu, val, c1, c0);
+}
+
+static bool
 emit_b2f32(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_value *val)
 {
    assert(val);
@@ -2056,6 +2074,7 @@ emit_alu(struct ntd_context *ctx, nir_alu_instr *alu)
       return emit_cast(ctx, alu, src[0]);
 
    case nir_op_f2b32: return emit_f2b32(ctx, alu, src[0]);
+   case nir_op_b2f16: return emit_b2f16(ctx, alu, src[0]);
    case nir_op_b2f32: return emit_b2f32(ctx, alu, src[0]);
    default:
       NIR_INSTR_UNSUPPORTED(&alu->instr);