nir/spirv: Rework UBOs and SSBOs
authorJason Ekstrand <jason.ekstrand@intel.com>
Fri, 8 Jan 2016 00:55:56 +0000 (16:55 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Fri, 8 Jan 2016 06:13:46 +0000 (22:13 -0800)
This completely reworks all block load/store operations.  In particular, it
should get row-major matrices working.

src/glsl/nir/spirv/spirv_to_nir.c

index 8acfc4b..9b3d0ce 100644 (file)
@@ -452,6 +452,9 @@ struct_member_decoration_cb(struct vtn_builder *b,
       break;
    case SpvDecorationColMajor:
       break; /* Nothing to do here.  Column-major is the default. */
+   case SpvDecorationRowMajor:
+      ctx->type->members[member]->row_major = true;
+      break;
    default:
       unreachable("Unhandled member decoration");
    }
@@ -565,18 +568,23 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
       break;
 
    case SpvOpTypeVector: {
-      const struct glsl_type *base =
-         vtn_value(b, w[2], vtn_value_type_type)->type->type;
+      struct vtn_type *base = vtn_value(b, w[2], vtn_value_type_type)->type;
       unsigned elems = w[3];
 
-      assert(glsl_type_is_scalar(base));
-      val->type->type = glsl_vector_type(glsl_get_base_type(base), elems);
+      assert(glsl_type_is_scalar(base->type));
+      val->type->type = glsl_vector_type(glsl_get_base_type(base->type), elems);
+
+      /* Vectors implicitly have sizeof(base_type) stride.  For now, this
+       * is always 4 bytes.  This will have to change if we want to start
+       * supporting doubles or half-floats.
+       */
+      val->type->stride = 4;
+      val->type->array_element = base;
       break;
    }
 
    case SpvOpTypeMatrix: {
-      struct vtn_type *base =
-         vtn_value(b, w[2], vtn_value_type_type)->type;
+      struct vtn_type *base = vtn_value(b, w[2], vtn_value_type_type)->type;
       unsigned columns = w[3];
 
       assert(glsl_type_is_vector(base->type));
@@ -1241,153 +1249,251 @@ _vtn_variable_store(struct vtn_builder *b,
 }
 
 static nir_ssa_def *
-nir_vulkan_resource_index(nir_builder *b, unsigned set, unsigned binding,
-                          nir_variable_mode mode, nir_ssa_def *array_index)
+deref_array_offset(struct vtn_builder *b, nir_deref *deref)
 {
-   if (array_index == NULL)
-      array_index = nir_imm_int(b, 0);
+   assert(deref->deref_type == nir_deref_type_array);
+   nir_deref_array *deref_array = nir_deref_as_array(deref);
+   nir_ssa_def *offset = nir_imm_int(&b->nb, deref_array->base_offset);
+
+   if (deref_array->deref_array_type == nir_deref_array_type_indirect)
+      offset = nir_iadd(&b->nb, offset, deref_array->indirect.ssa);
+
+   return offset;
+}
+
+static nir_ssa_def *
+get_vulkan_resource_index(struct vtn_builder *b,
+                          nir_deref **deref, struct vtn_type **type)
+{
+   assert((*deref)->deref_type == nir_deref_type_var);
+   nir_variable *var = nir_deref_as_var(*deref)->var;
+
+   assert(var->interface_type && "variable is a block");
+   assert((*deref)->child);
+
+   nir_ssa_def *array_index;
+   if ((*deref)->child->deref_type == nir_deref_type_array) {
+      *deref = (*deref)->child;
+      *type = (*type)->array_element;
+      array_index = deref_array_offset(b, *deref);
+   } else {
+      array_index = nir_imm_int(&b->nb, 0);
+   }
 
    nir_intrinsic_instr *instr =
-      nir_intrinsic_instr_create(b->shader,
+      nir_intrinsic_instr_create(b->nb.shader,
                                  nir_intrinsic_vulkan_resource_index);
    instr->src[0] = nir_src_for_ssa(array_index);
-   instr->const_index[0] = set;
-   instr->const_index[1] = binding;
-   instr->const_index[2] = mode;
+   instr->const_index[0] = var->data.descriptor_set;
+   instr->const_index[1] = var->data.binding;
+   instr->const_index[2] = var->data.mode;
 
    nir_ssa_dest_init(&instr->instr, &instr->dest, 1, NULL);
-   nir_builder_instr_insert(b, &instr->instr);
+   nir_builder_instr_insert(&b->nb, &instr->instr);
 
    return &instr->dest.ssa;
 }
 
-static struct vtn_ssa_value *
-_vtn_block_load(struct vtn_builder *b, nir_intrinsic_op op,
-                unsigned set, unsigned binding, nir_variable_mode mode,
-                nir_ssa_def *index, nir_ssa_def *offset, struct vtn_type *type)
-{
-   struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value);
-   val->type = type->type;
-   val->transposed = NULL;
-   if (glsl_type_is_vector_or_scalar(type->type)) {
-      nir_intrinsic_instr *load = nir_intrinsic_instr_create(b->shader, op);
-      load->num_components = glsl_get_vector_elements(type->type);
-
-      switch (op) {
-      case nir_intrinsic_load_ubo:
-      case nir_intrinsic_load_ssbo: {
-         nir_ssa_def *res_index = nir_vulkan_resource_index(&b->nb,
-                                                            set, binding,
-                                                            mode, index);
-         load->src[0] = nir_src_for_ssa(res_index);
-         load->src[1] = nir_src_for_ssa(offset);
-         break;
-      }
+static void
+_vtn_load_store_tail(struct vtn_builder *b, nir_intrinsic_op op, bool load,
+                     nir_ssa_def *index, nir_ssa_def *offset,
+                     struct vtn_ssa_value **inout, const struct glsl_type *type)
+{
+   nir_intrinsic_instr *instr = nir_intrinsic_instr_create(b->nb.shader, op);
+   instr->num_components = glsl_get_vector_elements(type);
 
-      case nir_intrinsic_load_push_constant:
-         load->src[0] = nir_src_for_ssa(offset);
-         break;
+   int src = 0;
+   if (!load) {
+      instr->const_index[0] = (1 << instr->num_components) - 1; /* write mask */
+      instr->src[src++] = nir_src_for_ssa((*inout)->def);
+   }
 
-      default:
-         unreachable("Invalid block load intrinsic");
-      }
+   if (index)
+      instr->src[src++] = nir_src_for_ssa(index);
 
-      nir_ssa_dest_init(&load->instr, &load->dest, load->num_components, NULL);
-      nir_builder_instr_insert(&b->nb, &load->instr);
+   instr->src[src++] = nir_src_for_ssa(offset);
 
-      if (glsl_get_base_type(type->type) == GLSL_TYPE_BOOL) {
-         /* Loads of booleans from externally visible memory need to be
-          * fixed up since they're defined to be zero/nonzero rather than
-          * NIR_FALSE/NIR_TRUE.
-          */
-         val->def = nir_ine(&b->nb, &load->dest.ssa, nir_imm_int(&b->nb, 0));
-      } else {
-         val->def = &load->dest.ssa;
-      }
-   } else {
-      unsigned elems = glsl_get_length(type->type);
-      val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
-      if (glsl_type_is_struct(type->type)) {
-         for (unsigned i = 0; i < elems; i++) {
-            nir_ssa_def *child_offset =
-               nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, type->offsets[i]));
-            val->elems[i] = _vtn_block_load(b, op, set, binding, mode, index,
-                                            child_offset, type->members[i]);
-         }
-      } else {
-         for (unsigned i = 0; i < elems; i++) {
-            nir_ssa_def *child_offset =
-               nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, i * type->stride));
-            val->elems[i] = _vtn_block_load(b, op, set, binding, mode, index,
-                                            child_offset,type->array_element);
-         }
-      }
+   if (load) {
+      nir_ssa_dest_init(&instr->instr, &instr->dest,
+                        instr->num_components, NULL);
+      *inout = rzalloc(b, struct vtn_ssa_value);
+      (*inout)->def = &instr->dest.ssa;
+      (*inout)->type = type;
    }
 
-   return val;
+   nir_builder_instr_insert(&b->nb, &instr->instr);
+
+   if (load && glsl_get_base_type(type) == GLSL_TYPE_BOOL)
+      (*inout)->def = nir_ine(&b->nb, (*inout)->def, nir_imm_int(&b->nb, 0));
 }
 
+static struct vtn_ssa_value *
+vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src);
+
 static void
-vtn_block_get_offset(struct vtn_builder *b, nir_deref_var *src,
-                     struct vtn_type **type, nir_deref *src_tail,
-                     nir_ssa_def **index, nir_ssa_def **offset)
+_vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
+                      nir_ssa_def *index, nir_ssa_def *offset, nir_deref *deref,
+                      struct vtn_type *type, struct vtn_ssa_value **inout)
 {
-   nir_deref *deref = &src->deref;
-
-   if (deref->child->deref_type == nir_deref_type_array) {
-      deref = deref->child;
-      *type = (*type)->array_element;
-      nir_deref_array *deref_array = nir_deref_as_array(deref);
-      *index = nir_imm_int(&b->nb, deref_array->base_offset);
-
-      if (deref_array->deref_array_type == nir_deref_array_type_indirect)
-         *index = nir_iadd(&b->nb, *index, deref_array->indirect.ssa);
-   } else {
-      *index = nir_imm_int(&b->nb, 0);
+   if (deref == NULL && load) {
+      assert(*inout == NULL);
+      *inout = rzalloc(b, struct vtn_ssa_value);
+      (*inout)->type = type->type;
    }
 
-   *offset = nir_imm_int(&b->nb, 0);
-   while (deref != src_tail) {
-      deref = deref->child;
-      switch (deref->deref_type) {
-      case nir_deref_type_array: {
-         nir_deref_array *deref_array = nir_deref_as_array(deref);
-         nir_ssa_def *off = nir_imm_int(&b->nb, deref_array->base_offset);
+   enum glsl_base_type base_type = glsl_get_base_type(type->type);
+   switch (base_type) {
+   case GLSL_TYPE_UINT:
+   case GLSL_TYPE_INT:
+   case GLSL_TYPE_FLOAT:
+   case GLSL_TYPE_BOOL:
+      /* This is where things get interesting.  At this point, we've hit
+       * a vector, a scalar, or a matrix.
+       */
+      if (glsl_type_is_matrix(type->type)) {
+         if (deref == NULL) {
+            /* Loading the whole matrix */
+            if (load)
+               (*inout)->elems = ralloc_array(b, struct vtn_ssa_value *, 4);
+
+            struct vtn_ssa_value *transpose;
+            unsigned num_ops, vec_width;
+            if (type->row_major) {
+               num_ops = glsl_get_vector_elements(type->type);
+               vec_width = glsl_get_matrix_columns(type->type);
+               if (load) {
+                  (*inout)->type =
+                     glsl_matrix_type(base_type, vec_width, num_ops);
+               } else {
+                  transpose = vtn_transpose(b, *inout);
+                  inout = &transpose;
+               }
+            } else {
+               num_ops = glsl_get_matrix_columns(type->type);
+               vec_width = glsl_get_vector_elements(type->type);
+            }
 
-         if (deref_array->deref_array_type == nir_deref_array_type_indirect)
-            off = nir_iadd(&b->nb, off, deref_array->indirect.ssa);
+            for (unsigned i = 0; i < num_ops; i++) {
+               nir_ssa_def *elem_offset =
+                  nir_iadd(&b->nb, offset,
+                           nir_imm_int(&b->nb, i * type->stride));
+               _vtn_load_store_tail(b, op, load, index, elem_offset,
+                                    &(*inout)->elems[i],
+                                    glsl_vector_type(base_type, vec_width));
+            }
 
-         off = nir_imul(&b->nb, off, nir_imm_int(&b->nb, (*type)->stride));
-         *offset = nir_iadd(&b->nb, *offset, off);
+            if (load && type->row_major)
+               *inout = vtn_transpose(b, *inout);
+
+            return;
+         } else if (type->row_major) {
+            /* Row-major but with a deref. */
+            nir_ssa_def *col_offset =
+               nir_imul(&b->nb, deref_array_offset(b, deref),
+                        nir_imm_int(&b->nb, type->array_element->stride));
+            offset = nir_iadd(&b->nb, offset, col_offset);
+
+            if (deref->child) {
+               /* Picking off a single element */
+               nir_ssa_def *row_offset =
+                  nir_imul(&b->nb, deref_array_offset(b, deref->child),
+                           nir_imm_int(&b->nb, type->stride));
+               offset = nir_iadd(&b->nb, offset, row_offset);
+               _vtn_load_store_tail(b, op, load, index, offset, inout,
+                                    glsl_scalar_type(base_type));
+               return;
+            } else {
+               unsigned num_comps = glsl_get_vector_elements(type->type);
+               nir_ssa_def *comps[4];
+               for (unsigned i = 0; i < num_comps; i++) {
+                  nir_ssa_def *elem_offset =
+                     nir_iadd(&b->nb, offset,
+                              nir_imm_int(&b->nb, i * type->stride));
+
+                  struct vtn_ssa_value *comp = NULL, temp_val;
+                  if (!load) {
+                     temp_val.def = nir_channel(&b->nb, (*inout)->def, i);
+                     temp_val.type = glsl_scalar_type(base_type);
+                     comp = &temp_val;
+                  }
+                  _vtn_load_store_tail(b, op, load, index, elem_offset,
+                                       &comp, glsl_scalar_type(base_type));
+                  comps[i] = comp->def;
+               }
 
-         *type = (*type)->array_element;
-         break;
+               if (load)
+                  (*inout)->def = nir_vec(&b->nb, comps, num_comps);
+               return;
+            }
+         } else {
+            /* Column-major with a deref. Fall through to array case. */
+         }
+      } else if (deref == NULL) {
+         assert(glsl_type_is_vector_or_scalar(type->type));
+         _vtn_load_store_tail(b, op, load, index, offset, inout, type->type);
+         return;
+      } else {
+         /* Single component of a vector. Fall through to array case. */
       }
+      /* Fall through */
 
-      case nir_deref_type_struct: {
-         nir_deref_struct *deref_struct = nir_deref_as_struct(deref);
+   case GLSL_TYPE_ARRAY:
+      if (deref) {
+         offset = nir_iadd(&b->nb, offset,
+                           nir_imul(&b->nb, deref_array_offset(b, deref),
+                                    nir_imm_int(&b->nb, type->stride)));
 
-         unsigned elem_off = (*type)->offsets[deref_struct->index];
-         *offset = nir_iadd(&b->nb, *offset, nir_imm_int(&b->nb, elem_off));
+         _vtn_block_load_store(b, op, load, index, offset, deref->child,
+                               type->array_element, inout);
+         return;
+      } else {
+         unsigned elems = glsl_get_length(type->type);
+         if (load)
+            (*inout)->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
 
-         *type = (*type)->members[deref_struct->index];
-         break;
+         for (unsigned i = 0; i < elems; i++) {
+            nir_ssa_def *elem_off =
+               nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, i * type->stride));
+            _vtn_block_load_store(b, op, load, index, elem_off, NULL,
+                                  type->array_element, &(*inout)->elems[i]);
+         }
+         return;
       }
+      unreachable("Both branches above return");
 
-      default:
-         unreachable("unknown deref type");
+   case GLSL_TYPE_STRUCT:
+      if (deref) {
+         unsigned member = nir_deref_as_struct(deref)->index;
+         offset = nir_iadd(&b->nb, offset,
+                           nir_imm_int(&b->nb, type->offsets[member]));
+
+         _vtn_block_load_store(b, op, load, index, offset, deref->child,
+                               type->members[member], inout);
+         return;
+      } else {
+         unsigned elems = glsl_get_length(type->type);
+         if (load)
+            (*inout)->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
+
+         for (unsigned i = 0; i < elems; i++) {
+            nir_ssa_def *elem_off =
+               nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, type->offsets[i]));
+            _vtn_block_load_store(b, op, load, index, elem_off, NULL,
+                                  type->members[i], &(*inout)->elems[i]);
+         }
+         return;
       }
+      unreachable("Both branches above return");
+
+   default:
+      unreachable("Invalid block member type");
    }
 }
 
 static struct vtn_ssa_value *
 vtn_block_load(struct vtn_builder *b, nir_deref_var *src,
-               struct vtn_type *type, nir_deref *src_tail)
+               struct vtn_type *type)
 {
-   nir_ssa_def *index;
-   nir_ssa_def *offset;
-   vtn_block_get_offset(b, src, &type, src_tail, &index, &offset);
-
    nir_intrinsic_op op;
    if (src->var->data.mode == nir_var_uniform) {
       if (src->var->data.descriptor_set >= 0) {
@@ -1407,9 +1513,15 @@ vtn_block_load(struct vtn_builder *b, nir_deref_var *src,
       op = nir_intrinsic_load_ssbo;
    }
 
-   return _vtn_block_load(b, op, src->var->data.descriptor_set,
-                          src->var->data.binding, src->var->data.mode,
-                          index, offset, type);
+   nir_deref *block_deref = &src->deref;
+   nir_ssa_def *index = NULL;
+   if (op == nir_intrinsic_load_ubo || op == nir_intrinsic_load_ssbo)
+      index = get_vulkan_resource_index(b, &block_deref, &type);
+
+   struct vtn_ssa_value *value = NULL;
+   _vtn_block_load_store(b, op, true, index, nir_imm_int(&b->nb, 0),
+                         block_deref->child, type, &value);
+   return value;
 }
 
 /*
@@ -1448,13 +1560,11 @@ static struct vtn_ssa_value *
 vtn_variable_load(struct vtn_builder *b, nir_deref_var *src,
                   struct vtn_type *src_type)
 {
-   nir_deref *src_tail = get_deref_tail(src);
-
-   struct vtn_ssa_value *val;
    if (variable_is_external_block(src->var))
-      val = vtn_block_load(b, src, src_type, src_tail);
-   else
-      val = _vtn_variable_load(b, src, src_tail);
+      return vtn_block_load(b, src, src_type);
+
+   nir_deref *src_tail = get_deref_tail(src);
+   struct vtn_ssa_value *val = _vtn_variable_load(b, src, src_tail);
 
    if (src_tail->child) {
       nir_deref_array *vec_deref = nir_deref_as_array(src_tail->child);
@@ -1471,59 +1581,15 @@ vtn_variable_load(struct vtn_builder *b, nir_deref_var *src,
 }
 
 static void
-_vtn_block_store(struct vtn_builder *b, nir_intrinsic_op op,
-                 struct vtn_ssa_value *src, unsigned set, unsigned binding,
-                 nir_variable_mode mode, nir_ssa_def *index,
-                 nir_ssa_def *offset, struct vtn_type *type)
-{
-   assert(src->type == type->type);
-   if (glsl_type_is_vector_or_scalar(type->type)) {
-      nir_intrinsic_instr *store = nir_intrinsic_instr_create(b->shader, op);
-      store->num_components = glsl_get_vector_elements(type->type);
-      store->const_index[0] = (1 << store->num_components) - 1;
-      store->src[0] = nir_src_for_ssa(src->def);
-
-      nir_ssa_def *res_index = nir_vulkan_resource_index(&b->nb,
-                                                         set, binding,
-                                                         mode, index);
-      store->src[1] = nir_src_for_ssa(res_index);
-      store->src[2] = nir_src_for_ssa(offset);
-
-      nir_builder_instr_insert(&b->nb, &store->instr);
-   } else {
-      unsigned elems = glsl_get_length(type->type);
-      if (glsl_type_is_struct(type->type)) {
-         for (unsigned i = 0; i < elems; i++) {
-            nir_ssa_def *child_offset =
-               nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, type->offsets[i]));
-            _vtn_block_store(b, op, src->elems[i], set, binding, mode,
-                             index, child_offset, type->members[i]);
-         }
-      } else {
-         for (unsigned i = 0; i < elems; i++) {
-            nir_ssa_def *child_offset =
-               nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, i * type->stride));
-            _vtn_block_store(b, op, src->elems[i], set, binding, mode,
-                             index, child_offset, type->array_element);
-         }
-      }
-   }
-}
-
-static void
 vtn_block_store(struct vtn_builder *b, struct vtn_ssa_value *src,
-                nir_deref_var *dest, struct vtn_type *type,
-                nir_deref *dest_tail)
+                nir_deref_var *dest, struct vtn_type *type)
 {
-   nir_ssa_def *index;
-   nir_ssa_def *offset;
-   vtn_block_get_offset(b, dest, &type, dest_tail, &index, &offset);
-
-   nir_intrinsic_op op = nir_intrinsic_store_ssbo;
+   nir_deref *block_deref = &dest->deref;
+   nir_ssa_def *index = get_vulkan_resource_index(b, &block_deref, &type);
 
-   return _vtn_block_store(b, op, src, dest->var->data.descriptor_set,
-                           dest->var->data.binding, dest->var->data.mode,
-                           index, offset, type);
+   _vtn_block_load_store(b, nir_intrinsic_store_ssbo, false, index,
+                         nir_imm_int(&b->nb, 0), block_deref->child,
+                         type, &src);
 }
 
 static nir_ssa_def * vtn_vector_insert(struct vtn_builder *b,
@@ -1538,11 +1604,11 @@ void
 vtn_variable_store(struct vtn_builder *b, struct vtn_ssa_value *src,
                    nir_deref_var *dest, struct vtn_type *dest_type)
 {
-   nir_deref *dest_tail = get_deref_tail(dest);
    if (variable_is_external_block(dest->var)) {
       assert(dest->var->data.mode == nir_var_shader_storage);
-      vtn_block_store(b, src, dest, dest_type, dest_tail);
+      vtn_block_store(b, src, dest, dest_type);
    } else {
+      nir_deref *dest_tail = get_deref_tail(dest);
       if (dest_tail->child) {
          struct vtn_ssa_value *val = _vtn_variable_load(b, dest, dest_tail);
          nir_deref_array *deref = nir_deref_as_array(dest_tail->child);