spirv: Handle arbitrary bit sizes for deref array indices
authorJason Ekstrand <jason.ekstrand@intel.com>
Fri, 14 Dec 2018 17:06:07 +0000 (11:06 -0600)
committerJason Ekstrand <jason@jlekstrand.net>
Tue, 8 Jan 2019 00:38:29 +0000 (00:38 +0000)
We already had code in link_as_ssa to handle bit sizes; we just need to
use it.  While we're at it we clean up link_as_ssa a bit and add an
explicit bit_size parameter in preparation for a day when we have derefs
that aren't 32 bit.

Cc: mesa-stable@lists.freedesktop.org
Reviewed-by: Alejandro PiƱeiro <apinheiro@igalia.com>
Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
src/compiler/spirv/vtn_private.h
src/compiler/spirv/vtn_variables.c

index defcbb8..3573925 100644 (file)
@@ -390,7 +390,7 @@ enum vtn_access_mode {
 
 struct vtn_access_link {
    enum vtn_access_mode mode;
-   uint32_t id;
+   int64_t id;
 };
 
 struct vtn_access_chain {
index d50e445..70bec69 100644 (file)
@@ -65,6 +65,23 @@ vtn_pointer_is_external_block(struct vtn_builder *b,
            b->options->lower_workgroup_access_to_offsets);
 }
 
+static nir_ssa_def *
+vtn_access_link_as_ssa(struct vtn_builder *b, struct vtn_access_link link,
+                       unsigned stride, unsigned bit_size)
+{
+   vtn_assert(stride > 0);
+   if (link.mode == vtn_access_mode_literal) {
+      return nir_imm_intN_t(&b->nb, link.id * stride, bit_size);
+   } else {
+      nir_ssa_def *ssa = vtn_ssa_value(b, link.id)->def;
+      if (ssa->bit_size != bit_size)
+         ssa = nir_i2i(&b->nb, ssa, bit_size);
+      if (stride != 1)
+         ssa = nir_imul_imm(&b->nb, ssa, stride);
+      return ssa;
+   }
+}
+
 /* Dereference the given base pointer by the access chain */
 static struct vtn_pointer *
 vtn_nir_deref_pointer_dereference(struct vtn_builder *b,
@@ -95,13 +112,8 @@ vtn_nir_deref_pointer_dereference(struct vtn_builder *b,
          tail = nir_build_deref_struct(&b->nb, tail, idx);
          type = type->members[idx];
       } else {
-         nir_ssa_def *index;
-         if (deref_chain->link[i].mode == vtn_access_mode_literal) {
-            index = nir_imm_int(&b->nb, deref_chain->link[i].id);
-         } else {
-            vtn_assert(deref_chain->link[i].mode == vtn_access_mode_id);
-            index = vtn_ssa_value(b, deref_chain->link[i].id)->def;
-         }
+         nir_ssa_def *index = vtn_access_link_as_ssa(b, deref_chain->link[i], 1,
+                                                     tail->dest.ssa.bit_size);
          tail = nir_build_deref_array(&b->nb, tail, index);
          type = type->array_element;
       }
@@ -120,26 +132,6 @@ vtn_nir_deref_pointer_dereference(struct vtn_builder *b,
 }
 
 static nir_ssa_def *
-vtn_access_link_as_ssa(struct vtn_builder *b, struct vtn_access_link link,
-                       unsigned stride)
-{
-   vtn_assert(stride > 0);
-   if (link.mode == vtn_access_mode_literal) {
-      return nir_imm_int(&b->nb, link.id * stride);
-   } else if (stride == 1) {
-       nir_ssa_def *ssa = vtn_ssa_value(b, link.id)->def;
-       if (ssa->bit_size != 32)
-          ssa = nir_i2i32(&b->nb, ssa);
-      return ssa;
-   } else {
-      nir_ssa_def *src0 = vtn_ssa_value(b, link.id)->def;
-      if (src0->bit_size != 32)
-         src0 = nir_i2i32(&b->nb, src0);
-      return nir_imul_imm(&b->nb, src0, stride);
-   }
-}
-
-static nir_ssa_def *
 vtn_variable_resource_index(struct vtn_builder *b, struct vtn_variable *var,
                             nir_ssa_def *desc_array_index)
 {
@@ -196,7 +188,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b,
          if (glsl_type_is_array(type->type)) {
             if (deref_chain->length >= 1) {
                desc_arr_idx =
-                  vtn_access_link_as_ssa(b, deref_chain->link[0], 1);
+                  vtn_access_link_as_ssa(b, deref_chain->link[0], 1, 32);
                idx++;
                /* This consumes a level of type */
                type = type->array_element;
@@ -212,7 +204,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b,
          } else if (deref_chain->ptr_as_array) {
             /* You can't have a zero-length OpPtrAccessChain */
             vtn_assert(deref_chain->length >= 1);
-            desc_arr_idx = vtn_access_link_as_ssa(b, deref_chain->link[0], 1);
+            desc_arr_idx = vtn_access_link_as_ssa(b, deref_chain->link[0], 1, 32);
          } else {
             /* We have a regular non-array SSBO. */
             desc_arr_idx = NULL;
@@ -244,7 +236,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b,
           */
          vtn_assert(deref_chain->length >= 1);
          nir_ssa_def *offset_index =
-            vtn_access_link_as_ssa(b, deref_chain->link[0], 1);
+            vtn_access_link_as_ssa(b, deref_chain->link[0], 1, 32);
          idx++;
 
          block_index = vtn_resource_reindex(b, block_index, offset_index);
@@ -298,7 +290,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b,
 
       nir_ssa_def *elem_offset =
          vtn_access_link_as_ssa(b, deref_chain->link[idx],
-                                base->ptr_type->stride);
+                                base->ptr_type->stride, offset->bit_size);
       offset = nir_iadd(&b->nb, offset, elem_offset);
       idx++;
    }
@@ -319,7 +311,8 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b,
       case GLSL_TYPE_BOOL:
       case GLSL_TYPE_ARRAY: {
          nir_ssa_def *elem_offset =
-            vtn_access_link_as_ssa(b, deref_chain->link[idx], type->stride);
+            vtn_access_link_as_ssa(b, deref_chain->link[idx],
+                                   type->stride, offset->bit_size);
          offset = nir_iadd(&b->nb, offset, elem_offset);
          type = type->array_element;
          access |= type->access;
@@ -1911,7 +1904,22 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
          struct vtn_value *link_val = vtn_untyped_value(b, w[i]);
          if (link_val->value_type == vtn_value_type_constant) {
             chain->link[idx].mode = vtn_access_mode_literal;
-            chain->link[idx].id = link_val->constant->values[0].u32[0];
+            switch (glsl_get_bit_size(link_val->type->type)) {
+            case 8:
+               chain->link[idx].id = link_val->constant->values[0].i8[0];
+               break;
+            case 16:
+               chain->link[idx].id = link_val->constant->values[0].i16[0];
+               break;
+            case 32:
+               chain->link[idx].id = link_val->constant->values[0].i32[0];
+               break;
+            case 64:
+               chain->link[idx].id = link_val->constant->values[0].i64[0];
+               break;
+            default:
+               vtn_fail("Invalid bit size");
+            }
          } else {
             chain->link[idx].mode = vtn_access_mode_id;
             chain->link[idx].id = w[i];