spirv: Implement vload[a]_half[n] and vstore[a]_half[n][_r]
authorJesse Natalie <jenatali@microsoft.com>
Thu, 30 Jul 2020 23:45:46 +0000 (16:45 -0700)
committerMarge Bot <eric+marge@anholt.net>
Thu, 1 Oct 2020 18:36:53 +0000 (18:36 +0000)
Note, the aligned versions aren't handled specially yet.

The float16buffer capability is now at least partially supported after
this patch, so move it to be supported when kernels are supported.

v2 (Jason Ekstrand):
 - A few cosmetic cleanups around type/base_type
 - Rebased on top of the big SPIR-V SSA value rework
 - Use the new version of the conversion helpers

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6945>

src/compiler/spirv/spirv_to_nir.c
src/compiler/spirv/vtn_alu.c
src/compiler/spirv/vtn_opencl.c
src/compiler/spirv/vtn_private.h

index c891e75..0c1ac86 100644 (file)
@@ -4133,7 +4133,6 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
          break;
 
       case SpvCapabilityLinkage:
-      case SpvCapabilityFloat16Buffer:
       case SpvCapabilitySparseResidency:
          vtn_warn("Unsupported SPIR-V capability: %s",
                   spirv_capability_to_string(cap));
@@ -4181,6 +4180,7 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
          break;
 
       case SpvCapabilityKernel:
+      case SpvCapabilityFloat16Buffer:
          spv_check_supported(kernel, cap);
          break;
 
index 205eca0..9daffbc 100644 (file)
@@ -382,7 +382,7 @@ handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
    b->nb.exact = true;
 }
 
-static nir_rounding_mode
+nir_rounding_mode
 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
 {
    switch (mode) {
index 13adb84..301cb11 100644 (file)
@@ -614,7 +614,8 @@ handle_core(struct vtn_builder *b, uint32_t opcode,
 
 static void
 _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
-                     const uint32_t *w, unsigned count, bool load)
+                     const uint32_t *w, unsigned count, bool load,
+                     bool vec_aligned, nir_rounding_mode rounding)
 {
    struct vtn_type *type;
    if (load)
@@ -629,12 +630,28 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
    nir_ssa_def *offset = vtn_get_nir_ssa(b, w[5 + a]);
    struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer);
 
+   enum glsl_base_type ptr_base_type =
+      glsl_get_base_type(p->pointer->type->type);
+   if (base_type != ptr_base_type) {
+      vtn_fail_if(ptr_base_type != GLSL_TYPE_FLOAT16 ||
+                  (base_type != GLSL_TYPE_FLOAT &&
+                   base_type != GLSL_TYPE_DOUBLE),
+                  "vload/vstore cannot do type conversion. "
+                  "vload/vstore_half can only convert from half to other "
+                  "floating-point types.");
+   }
+
    struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS];
    nir_ssa_def *ncomps[NIR_MAX_VEC_COMPONENTS];
 
-   nir_ssa_def *moffset = nir_imul_imm(&b->nb, offset, components);
+   nir_ssa_def *moffset = nir_imul_imm(&b->nb, offset,
+      (vec_aligned && components == 3) ? 4 : components);
    nir_deref_instr *deref = vtn_pointer_to_deref(b, p->pointer);
 
+   unsigned alignment = vec_aligned ? glsl_get_cl_alignment(type->type) :
+                                      glsl_get_bit_size(type->type) / 8;
+   deref = nir_alignment_deref_cast(&b->nb, deref, alignment, 0);
+
    for (int i = 0; i < components; i++) {
       nir_ssa_def *coffset = nir_iadd_imm(&b->nb, moffset, i);
       nir_deref_instr *arr_deref = nir_build_deref_ptr_as_array(&b->nb, deref, coffset);
@@ -642,10 +659,30 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
       if (load) {
          comps[i] = vtn_local_load(b, arr_deref, p->type->access);
          ncomps[i] = comps[i]->def;
+         if (base_type != ptr_base_type) {
+            assert(ptr_base_type == GLSL_TYPE_FLOAT16 &&
+                   (base_type == GLSL_TYPE_FLOAT ||
+                    base_type == GLSL_TYPE_DOUBLE));
+            ncomps[i] = nir_f2fN(&b->nb, ncomps[i],
+                                 glsl_base_type_get_bit_size(base_type));
+         }
       } else {
          struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, glsl_scalar_type(base_type));
          struct vtn_ssa_value *val = vtn_ssa_value(b, w[5]);
          ssa->def = nir_channel(&b->nb, val->def, i);
+         if (base_type != ptr_base_type) {
+            assert(ptr_base_type == GLSL_TYPE_FLOAT16 &&
+                   (base_type == GLSL_TYPE_FLOAT ||
+                    base_type == GLSL_TYPE_DOUBLE));
+            if (rounding == nir_rounding_mode_undef) {
+               ssa->def = nir_f2f16(&b->nb, ssa->def);
+            } else {
+               ssa->def = nir_convert_alu_types(&b->nb, ssa->def,
+                                                nir_type_float,
+                                                nir_type_float16,
+                                                rounding, false);
+            }
+         }
          vtn_local_store(b, ssa, arr_deref, p->type->access);
       }
    }
@@ -658,14 +695,27 @@ static void
 vtn_handle_opencl_vload(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
                         const uint32_t *w, unsigned count)
 {
-   _handle_v_load_store(b, opcode, w, count, true);
+   _handle_v_load_store(b, opcode, w, count, true,
+                        opcode == OpenCLstd_Vloada_halfn,
+                        nir_rounding_mode_undef);
 }
 
 static void
 vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
                          const uint32_t *w, unsigned count)
 {
-   _handle_v_load_store(b, opcode, w, count, false);
+   _handle_v_load_store(b, opcode, w, count, false,
+                        opcode == OpenCLstd_Vstorea_halfn,
+                        nir_rounding_mode_undef);
+}
+
+static void
+vtn_handle_opencl_vstore_half_r(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
+                                const uint32_t *w, unsigned count)
+{
+   _handle_v_load_store(b, opcode, w, count, false,
+                        opcode == OpenCLstd_Vstorea_halfn_r,
+                        vtn_rounding_mode_to_nir(b, w[8]));
 }
 
 static nir_ssa_def *
@@ -895,11 +945,22 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_special);
       return true;
    case OpenCLstd_Vloadn:
+   case OpenCLstd_Vload_half:
+   case OpenCLstd_Vload_halfn:
+   case OpenCLstd_Vloada_halfn:
       vtn_handle_opencl_vload(b, cl_opcode, w, count);
       return true;
    case OpenCLstd_Vstoren:
+   case OpenCLstd_Vstore_half:
+   case OpenCLstd_Vstore_halfn:
+   case OpenCLstd_Vstorea_halfn:
       vtn_handle_opencl_vstore(b, cl_opcode, w, count);
       return true;
+   case OpenCLstd_Vstore_half_r:
+   case OpenCLstd_Vstore_halfn_r:
+   case OpenCLstd_Vstorea_halfn_r:
+      vtn_handle_opencl_vstore_half_r(b, cl_opcode, w, count);
+      return true;
    case OpenCLstd_Shuffle:
       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle);
       return true;
index 54c7b11..f5f4ce8 100644 (file)
@@ -900,6 +900,9 @@ enum vtn_variable_mode vtn_storage_class_to_mode(struct vtn_builder *b,
 nir_address_format vtn_mode_to_address_format(struct vtn_builder *b,
                                               enum vtn_variable_mode);
 
+nir_rounding_mode vtn_rounding_mode_to_nir(struct vtn_builder *b,
+                                           SpvFPRoundingMode mode);
+
 static inline uint32_t
 vtn_align_u32(uint32_t v, uint32_t a)
 {