nir/spirv: Split ALU operations out into their own file
authorJason Ekstrand <jason.ekstrand@intel.com>
Fri, 8 Jan 2016 19:02:17 +0000 (11:02 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Fri, 8 Jan 2016 19:26:43 +0000 (11:26 -0800)
src/glsl/Makefile.sources
src/glsl/nir/spirv/spirv_to_nir.c
src/glsl/nir/spirv/vtn_alu.c [new file with mode: 0644]
src/glsl/nir/spirv/vtn_private.h

index 97fac86..89113bc 100644 (file)
@@ -95,8 +95,10 @@ NIR_FILES = \
 SPIRV_FILES = \
        nir/spirv/nir_spirv.h \
        nir/spirv/spirv_to_nir.c \
+       nir/spirv/vtn_alu.c \
        nir/spirv/vtn_cfg.c \
-       nir/spirv/vtn_glsl450.c
+       nir/spirv/vtn_glsl450.c \
+       nir/spirv/vtn_private.h
 
 # libglsl
 
index 919be09..191d35d 100644 (file)
@@ -1327,9 +1327,6 @@ _vtn_load_store_tail(struct vtn_builder *b, nir_intrinsic_op op, bool load,
       (*inout)->def = nir_ine(&b->nb, (*inout)->def, nir_imm_int(&b->nb, 0));
 }
 
-static struct vtn_ssa_value *
-vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src);
-
 static void
 _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
                       nir_ssa_def *index, nir_ssa_def *offset, nir_deref *deref,
@@ -1365,7 +1362,7 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
                   (*inout)->type =
                      glsl_matrix_type(base_type, vec_width, num_ops);
                } else {
-                  transpose = vtn_transpose(b, *inout);
+                  transpose = vtn_ssa_transpose(b, *inout);
                   inout = &transpose;
                }
             } else {
@@ -1383,7 +1380,7 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
             }
 
             if (load && type->row_major)
-               *inout = vtn_transpose(b, *inout);
+               *inout = vtn_ssa_transpose(b, *inout);
 
             return;
          } else if (type->row_major) {
@@ -2074,7 +2071,7 @@ vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
    }
 }
 
-static struct vtn_ssa_value *
+struct vtn_ssa_value *
 vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
 {
    struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
@@ -2598,8 +2595,8 @@ create_vec(nir_shader *shader, unsigned num_components)
    return vec;
 }
 
-static struct vtn_ssa_value *
-vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src)
+struct vtn_ssa_value *
+vtn_ssa_transpose(struct vtn_builder *b, struct vtn_ssa_value *src)
 {
    if (src->transposed)
       return src->transposed;
@@ -2628,411 +2625,6 @@ vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src)
    return dest;
 }
 
-/*
- * Normally, column vectors in SPIR-V correspond to a single NIR SSA
- * definition. But for matrix multiplies, we want to do one routine for
- * multiplying a matrix by a matrix and then pretend that vectors are matrices
- * with one column. So we "wrap" these things, and unwrap the result before we
- * send it off.
- */
-
-static struct vtn_ssa_value *
-vtn_wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
-{
-   if (val == NULL)
-      return NULL;
-
-   if (glsl_type_is_matrix(val->type))
-      return val;
-
-   struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
-   dest->type = val->type;
-   dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
-   dest->elems[0] = val;
-
-   return dest;
-}
-
-static struct vtn_ssa_value *
-vtn_unwrap_matrix(struct vtn_ssa_value *val)
-{
-   if (glsl_type_is_matrix(val->type))
-         return val;
-
-   return val->elems[0];
-}
-
-static struct vtn_ssa_value *
-vtn_matrix_multiply(struct vtn_builder *b,
-                    struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
-{
-
-   struct vtn_ssa_value *src0 = vtn_wrap_matrix(b, _src0);
-   struct vtn_ssa_value *src1 = vtn_wrap_matrix(b, _src1);
-   struct vtn_ssa_value *src0_transpose = vtn_wrap_matrix(b, _src0->transposed);
-   struct vtn_ssa_value *src1_transpose = vtn_wrap_matrix(b, _src1->transposed);
-
-   unsigned src0_rows = glsl_get_vector_elements(src0->type);
-   unsigned src0_columns = glsl_get_matrix_columns(src0->type);
-   unsigned src1_columns = glsl_get_matrix_columns(src1->type);
-
-   const struct glsl_type *dest_type;
-   if (src1_columns > 1) {
-      dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
-                                   src0_rows, src1_columns);
-   } else {
-      dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
-   }
-   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
-
-   dest = vtn_wrap_matrix(b, dest);
-
-   bool transpose_result = false;
-   if (src0_transpose && src1_transpose) {
-      /* transpose(A) * transpose(B) = transpose(B * A) */
-      src1 = src0_transpose;
-      src0 = src1_transpose;
-      src0_transpose = NULL;
-      src1_transpose = NULL;
-      transpose_result = true;
-   }
-
-   if (src0_transpose && !src1_transpose &&
-       glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
-      /* We already have the rows of src0 and the columns of src1 available,
-       * so we can just take the dot product of each row with each column to
-       * get the result.
-       */
-
-      for (unsigned i = 0; i < src1_columns; i++) {
-         nir_alu_instr *vec = create_vec(b->shader, src0_rows);
-         for (unsigned j = 0; j < src0_rows; j++) {
-            vec->src[j].src =
-               nir_src_for_ssa(nir_fdot(&b->nb, src0_transpose->elems[j]->def,
-                                        src1->elems[i]->def));
-         }
-
-         nir_builder_instr_insert(&b->nb, &vec->instr);
-         dest->elems[i]->def = &vec->dest.dest.ssa;
-      }
-   } else {
-      /* We don't handle the case where src1 is transposed but not src0, since
-       * the general case only uses individual components of src1 so the
-       * optimizer should chew through the transpose we emitted for src1.
-       */
-
-      for (unsigned i = 0; i < src1_columns; i++) {
-         /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
-         dest->elems[i]->def =
-            nir_fmul(&b->nb, src0->elems[0]->def,
-                     vtn_vector_extract(b, src1->elems[i]->def, 0));
-         for (unsigned j = 1; j < src0_columns; j++) {
-            dest->elems[i]->def =
-               nir_fadd(&b->nb, dest->elems[i]->def,
-                        nir_fmul(&b->nb, src0->elems[j]->def,
-                                 vtn_vector_extract(b,
-                                                    src1->elems[i]->def, j)));
-         }
-      }
-   }
-
-   dest = vtn_unwrap_matrix(dest);
-
-   if (transpose_result)
-      dest = vtn_transpose(b, dest);
-
-   return dest;
-}
-
-static struct vtn_ssa_value *
-vtn_mat_times_scalar(struct vtn_builder *b,
-                     struct vtn_ssa_value *mat,
-                     nir_ssa_def *scalar)
-{
-   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
-   for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
-      if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
-         dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
-      else
-         dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
-   }
-
-   return dest;
-}
-
-static void
-vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
-                      const uint32_t *w, unsigned count)
-{
-   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-
-   switch (opcode) {
-   case SpvOpTranspose: {
-      struct vtn_ssa_value *src = vtn_ssa_value(b, w[3]);
-      val->ssa = vtn_transpose(b, src);
-      break;
-   }
-
-   case SpvOpOuterProduct: {
-      struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
-      struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
-
-      val->ssa = vtn_matrix_multiply(b, src0, vtn_transpose(b, src1));
-      break;
-   }
-
-   case SpvOpMatrixTimesScalar: {
-      struct vtn_ssa_value *mat = vtn_ssa_value(b, w[3]);
-      struct vtn_ssa_value *scalar = vtn_ssa_value(b, w[4]);
-
-      if (mat->transposed) {
-         val->ssa = vtn_transpose(b, vtn_mat_times_scalar(b, mat->transposed,
-                                                          scalar->def));
-      } else {
-         val->ssa = vtn_mat_times_scalar(b, mat, scalar->def);
-      }
-      break;
-   }
-
-   case SpvOpVectorTimesMatrix:
-   case SpvOpMatrixTimesVector:
-   case SpvOpMatrixTimesMatrix: {
-      struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
-      struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
-
-      if (opcode == SpvOpVectorTimesMatrix) {
-         val->ssa = vtn_matrix_multiply(b, vtn_transpose(b, src1), src0);
-      } else {
-         val->ssa = vtn_matrix_multiply(b, src0, src1);
-      }
-      break;
-   }
-
-   default: unreachable("unknown matrix opcode");
-   }
-}
-
-static void
-vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
-               const uint32_t *w, unsigned count)
-{
-   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-   const struct glsl_type *type =
-      vtn_value(b, w[1], vtn_value_type_type)->type->type;
-   val->ssa = vtn_create_ssa_value(b, type);
-
-   /* Collect the various SSA sources */
-   const unsigned num_inputs = count - 3;
-   nir_ssa_def *src[4];
-   for (unsigned i = 0; i < num_inputs; i++)
-      src[i] = vtn_ssa_value(b, w[i + 3])->def;
-   for (unsigned i = num_inputs; i < 4; i++)
-      src[i] = NULL;
-
-   /* Indicates that the first two arguments should be swapped.  This is
-    * used for implementing greater-than and less-than-or-equal.
-    */
-   bool swap = false;
-
-   nir_op op;
-   switch (opcode) {
-   /* Basic ALU operations */
-   case SpvOpSNegate:               op = nir_op_ineg;    break;
-   case SpvOpFNegate:               op = nir_op_fneg;    break;
-   case SpvOpNot:                   op = nir_op_inot;    break;
-
-   case SpvOpAny:
-      if (src[0]->num_components == 1) {
-         op = nir_op_imov;
-      } else {
-         switch (src[0]->num_components) {
-         case 2:  op = nir_op_bany_inequal2; break;
-         case 3:  op = nir_op_bany_inequal3; break;
-         case 4:  op = nir_op_bany_inequal4; break;
-         }
-         src[1] = nir_imm_int(&b->nb, NIR_FALSE);
-      }
-      break;
-
-   case SpvOpAll:
-      if (src[0]->num_components == 1) {
-         op = nir_op_imov;
-      } else {
-         switch (src[0]->num_components) {
-         case 2:  op = nir_op_ball_iequal2;  break;
-         case 3:  op = nir_op_ball_iequal3;  break;
-         case 4:  op = nir_op_ball_iequal4;  break;
-         }
-         src[1] = nir_imm_int(&b->nb, NIR_TRUE);
-      }
-      break;
-
-   case SpvOpIAdd:                  op = nir_op_iadd;    break;
-   case SpvOpFAdd:                  op = nir_op_fadd;    break;
-   case SpvOpISub:                  op = nir_op_isub;    break;
-   case SpvOpFSub:                  op = nir_op_fsub;    break;
-   case SpvOpIMul:                  op = nir_op_imul;    break;
-   case SpvOpFMul:                  op = nir_op_fmul;    break;
-   case SpvOpUDiv:                  op = nir_op_udiv;    break;
-   case SpvOpSDiv:                  op = nir_op_idiv;    break;
-   case SpvOpFDiv:                  op = nir_op_fdiv;    break;
-   case SpvOpUMod:                  op = nir_op_umod;    break;
-   case SpvOpSMod:                  op = nir_op_umod;    break; /* FIXME? */
-   case SpvOpFMod:                  op = nir_op_fmod;    break;
-
-   case SpvOpDot:
-      assert(src[0]->num_components == src[1]->num_components);
-      switch (src[0]->num_components) {
-      case 1:  op = nir_op_fmul;    break;
-      case 2:  op = nir_op_fdot2;   break;
-      case 3:  op = nir_op_fdot3;   break;
-      case 4:  op = nir_op_fdot4;   break;
-      }
-      break;
-
-   case SpvOpIAddCarry:
-      assert(glsl_type_is_struct(val->ssa->type));
-      val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
-      val->ssa->elems[1]->def =
-         nir_b2i(&b->nb, nir_uadd_carry(&b->nb, src[0], src[1]));
-      return;
-
-   case SpvOpISubBorrow:
-      assert(glsl_type_is_struct(val->ssa->type));
-      val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
-      val->ssa->elems[1]->def =
-         nir_b2i(&b->nb, nir_usub_borrow(&b->nb, src[0], src[1]));
-      return;
-
-   case SpvOpUMulExtended:
-      assert(glsl_type_is_struct(val->ssa->type));
-      val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
-      val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
-      return;
-
-   case SpvOpSMulExtended:
-      assert(glsl_type_is_struct(val->ssa->type));
-      val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
-      val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
-      return;
-
-   case SpvOpShiftRightLogical:     op = nir_op_ushr;    break;
-   case SpvOpShiftRightArithmetic:  op = nir_op_ishr;    break;
-   case SpvOpShiftLeftLogical:      op = nir_op_ishl;    break;
-   case SpvOpLogicalOr:             op = nir_op_ior;     break;
-   case SpvOpLogicalEqual:          op = nir_op_ieq;     break;
-   case SpvOpLogicalNotEqual:       op = nir_op_ine;     break;
-   case SpvOpLogicalAnd:            op = nir_op_iand;    break;
-   case SpvOpLogicalNot:            op = nir_op_inot;    break;
-   case SpvOpBitwiseOr:             op = nir_op_ior;     break;
-   case SpvOpBitwiseXor:            op = nir_op_ixor;    break;
-   case SpvOpBitwiseAnd:            op = nir_op_iand;    break;
-   case SpvOpSelect:                op = nir_op_bcsel;   break;
-   case SpvOpIEqual:                op = nir_op_ieq;     break;
-
-   case SpvOpBitFieldInsert:        op = nir_op_bitfield_insert;     break;
-   case SpvOpBitFieldSExtract:      op = nir_op_ibitfield_extract;   break;
-   case SpvOpBitFieldUExtract:      op = nir_op_ubitfield_extract;   break;
-   case SpvOpBitReverse:            op = nir_op_bitfield_reverse;    break;
-   case SpvOpBitCount:              op = nir_op_bit_count;           break;
-
-   /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */
-   case SpvOpFOrdEqual:             op = nir_op_feq;     break;
-   case SpvOpFUnordEqual:           op = nir_op_feq;     break;
-   case SpvOpINotEqual:             op = nir_op_ine;     break;
-   case SpvOpFOrdNotEqual:          op = nir_op_fne;     break;
-   case SpvOpFUnordNotEqual:        op = nir_op_fne;     break;
-   case SpvOpULessThan:             op = nir_op_ult;     break;
-   case SpvOpSLessThan:             op = nir_op_ilt;     break;
-   case SpvOpFOrdLessThan:          op = nir_op_flt;     break;
-   case SpvOpFUnordLessThan:        op = nir_op_flt;     break;
-   case SpvOpUGreaterThan:          op = nir_op_ult;  swap = true;   break;
-   case SpvOpSGreaterThan:          op = nir_op_ilt;  swap = true;   break;
-   case SpvOpFOrdGreaterThan:       op = nir_op_flt;  swap = true;   break;
-   case SpvOpFUnordGreaterThan:     op = nir_op_flt;  swap = true;   break;
-   case SpvOpULessThanEqual:        op = nir_op_uge;  swap = true;   break;
-   case SpvOpSLessThanEqual:        op = nir_op_ige;  swap = true;   break;
-   case SpvOpFOrdLessThanEqual:     op = nir_op_fge;  swap = true;   break;
-   case SpvOpFUnordLessThanEqual:   op = nir_op_fge;  swap = true;   break;
-   case SpvOpUGreaterThanEqual:     op = nir_op_uge;     break;
-   case SpvOpSGreaterThanEqual:     op = nir_op_ige;     break;
-   case SpvOpFOrdGreaterThanEqual:  op = nir_op_fge;     break;
-   case SpvOpFUnordGreaterThanEqual:op = nir_op_fge;     break;
-
-   /* Conversions: */
-   case SpvOpConvertFToU:           op = nir_op_f2u;     break;
-   case SpvOpConvertFToS:           op = nir_op_f2i;     break;
-   case SpvOpConvertSToF:           op = nir_op_i2f;     break;
-   case SpvOpConvertUToF:           op = nir_op_u2f;     break;
-   case SpvOpBitcast:               op = nir_op_imov;    break;
-   case SpvOpUConvert:
-   case SpvOpSConvert:
-      op = nir_op_imov; /* TODO: NIR is 32-bit only; these are no-ops. */
-      break;
-   case SpvOpFConvert:
-      op = nir_op_fmov;
-      break;
-
-   /* Derivatives: */
-   case SpvOpDPdx:         op = nir_op_fddx;          break;
-   case SpvOpDPdy:         op = nir_op_fddy;          break;
-   case SpvOpDPdxFine:     op = nir_op_fddx_fine;     break;
-   case SpvOpDPdyFine:     op = nir_op_fddy_fine;     break;
-   case SpvOpDPdxCoarse:   op = nir_op_fddx_coarse;   break;
-   case SpvOpDPdyCoarse:   op = nir_op_fddy_coarse;   break;
-   case SpvOpFwidth:
-      val->ssa->def = nir_fadd(&b->nb,
-                               nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
-                               nir_fabs(&b->nb, nir_fddx(&b->nb, src[1])));
-      return;
-   case SpvOpFwidthFine:
-      val->ssa->def = nir_fadd(&b->nb,
-                               nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
-                               nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[1])));
-      return;
-   case SpvOpFwidthCoarse:
-      val->ssa->def = nir_fadd(&b->nb,
-                               nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
-                               nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[1])));
-      return;
-
-   case SpvOpVectorTimesScalar:
-      /* The builder will take care of splatting for us. */
-      val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
-      return;
-
-   case SpvOpSRem:
-   case SpvOpFRem:
-      unreachable("No NIR equivalent");
-
-   case SpvOpIsNan:
-      val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
-      return;
-
-   case SpvOpIsInf:
-      val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]),
-                                      nir_imm_float(&b->nb, INFINITY));
-      return;
-
-   case SpvOpIsFinite:
-   case SpvOpIsNormal:
-   case SpvOpSignBitSet:
-   case SpvOpLessOrGreater:
-   case SpvOpOrdered:
-   case SpvOpUnordered:
-   default:
-      unreachable("Unhandled opcode");
-   }
-
-   if (swap) {
-      nir_ssa_def *tmp = src[0];
-      src[0] = src[1];
-      src[1] = tmp;
-   }
-
-   val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
-}
-
 static nir_ssa_def *
 vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index)
 {
@@ -3835,16 +3427,13 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    case SpvOpBitFieldUExtract:
    case SpvOpBitReverse:
    case SpvOpBitCount:
-      vtn_handle_alu(b, opcode, w, count);
-      break;
-
    case SpvOpTranspose:
    case SpvOpOuterProduct:
    case SpvOpMatrixTimesScalar:
    case SpvOpVectorTimesMatrix:
    case SpvOpMatrixTimesVector:
    case SpvOpMatrixTimesMatrix:
-      vtn_handle_matrix_alu(b, opcode, w, count);
+      vtn_handle_alu(b, opcode, w, count);
       break;
 
    case SpvOpVectorExtractDynamic:
diff --git a/src/glsl/nir/spirv/vtn_alu.c b/src/glsl/nir/spirv/vtn_alu.c
new file mode 100644 (file)
index 0000000..a8c6e5c
--- /dev/null
@@ -0,0 +1,420 @@
+/*
+ * Copyright © 2016 Intel Corporation
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the "Software"),
+ * to deal in the Software without restriction, including without limitation
+ * the rights to use, copy, modify, merge, publish, distribute, sublicense,
+ * and/or sell copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice (including the next
+ * paragraph) shall be included in all copies or substantial portions of the
+ * Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+ * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "vtn_private.h"
+
+/*
+ * Normally, column vectors in SPIR-V correspond to a single NIR SSA
+ * definition. But for matrix multiplies, we want to do one routine for
+ * multiplying a matrix by a matrix and then pretend that vectors are matrices
+ * with one column. So we "wrap" these things, and unwrap the result before we
+ * send it off.
+ */
+
+static struct vtn_ssa_value *
+wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
+{
+   if (val == NULL)
+      return NULL;
+
+   if (glsl_type_is_matrix(val->type))
+      return val;
+
+   struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
+   dest->type = val->type;
+   dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
+   dest->elems[0] = val;
+
+   return dest;
+}
+
+static struct vtn_ssa_value *
+unwrap_matrix(struct vtn_ssa_value *val)
+{
+   if (glsl_type_is_matrix(val->type))
+         return val;
+
+   return val->elems[0];
+}
+
+static struct vtn_ssa_value *
+matrix_multiply(struct vtn_builder *b,
+                struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
+{
+
+   struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
+   struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
+   struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
+   struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
+
+   unsigned src0_rows = glsl_get_vector_elements(src0->type);
+   unsigned src0_columns = glsl_get_matrix_columns(src0->type);
+   unsigned src1_columns = glsl_get_matrix_columns(src1->type);
+
+   const struct glsl_type *dest_type;
+   if (src1_columns > 1) {
+      dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
+                                   src0_rows, src1_columns);
+   } else {
+      dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
+   }
+   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
+
+   dest = wrap_matrix(b, dest);
+
+   bool transpose_result = false;
+   if (src0_transpose && src1_transpose) {
+      /* transpose(A) * transpose(B) = transpose(B * A) */
+      src1 = src0_transpose;
+      src0 = src1_transpose;
+      src0_transpose = NULL;
+      src1_transpose = NULL;
+      transpose_result = true;
+   }
+
+   if (src0_transpose && !src1_transpose &&
+       glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
+      /* We already have the rows of src0 and the columns of src1 available,
+       * so we can just take the dot product of each row with each column to
+       * get the result.
+       */
+
+      for (unsigned i = 0; i < src1_columns; i++) {
+         nir_ssa_def *vec_src[4];
+         for (unsigned j = 0; j < src0_rows; j++) {
+            vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
+                                          src1->elems[i]->def);
+         }
+         dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
+      }
+   } else {
+      /* We don't handle the case where src1 is transposed but not src0, since
+       * the general case only uses individual components of src1 so the
+       * optimizer should chew through the transpose we emitted for src1.
+       */
+
+      for (unsigned i = 0; i < src1_columns; i++) {
+         /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
+         dest->elems[i]->def =
+            nir_fmul(&b->nb, src0->elems[0]->def,
+                     nir_channel(&b->nb, src1->elems[i]->def, 0));
+         for (unsigned j = 1; j < src0_columns; j++) {
+            dest->elems[i]->def =
+               nir_fadd(&b->nb, dest->elems[i]->def,
+                        nir_fmul(&b->nb, src0->elems[j]->def,
+                                 nir_channel(&b->nb, src1->elems[i]->def, j)));
+         }
+      }
+   }
+
+   dest = unwrap_matrix(dest);
+
+   if (transpose_result)
+      dest = vtn_ssa_transpose(b, dest);
+
+   return dest;
+}
+
+static struct vtn_ssa_value *
+mat_times_scalar(struct vtn_builder *b,
+                 struct vtn_ssa_value *mat,
+                 nir_ssa_def *scalar)
+{
+   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
+   for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
+      if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
+         dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
+      else
+         dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
+   }
+
+   return dest;
+}
+
+static void
+vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
+                      struct vtn_value *dest,
+                      struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
+{
+   switch (opcode) {
+   case SpvOpTranspose:
+      dest->ssa = vtn_ssa_transpose(b, src0);
+      break;
+
+   case SpvOpOuterProduct:
+      dest->ssa = matrix_multiply(b, src0, vtn_ssa_transpose(b, src1));
+      break;
+
+   case SpvOpMatrixTimesScalar:
+      if (src0->transposed) {
+         dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
+                                                           src1->def));
+      } else {
+         dest->ssa = mat_times_scalar(b, src0, src1->def);
+      }
+      break;
+
+   case SpvOpVectorTimesMatrix:
+   case SpvOpMatrixTimesVector:
+   case SpvOpMatrixTimesMatrix:
+      if (opcode == SpvOpVectorTimesMatrix) {
+         dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
+      } else {
+         dest->ssa = matrix_multiply(b, src0, src1);
+      }
+      break;
+
+   default: unreachable("unknown matrix opcode");
+   }
+}
+
+void
+vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
+               const uint32_t *w, unsigned count)
+{
+   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
+   const struct glsl_type *type =
+      vtn_value(b, w[1], vtn_value_type_type)->type->type;
+
+   /* Collect the various SSA sources */
+   const unsigned num_inputs = count - 3;
+   struct vtn_ssa_value *vtn_src[4] = { NULL, };
+   for (unsigned i = 0; i < num_inputs; i++)
+      vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
+
+   if (glsl_type_is_matrix(vtn_src[0]->type) ||
+       (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
+      vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
+      return;
+   }
+
+   val->ssa = vtn_create_ssa_value(b, type);
+   nir_ssa_def *src[4] = { NULL, };
+   for (unsigned i = 0; i < num_inputs; i++) {
+      assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
+      src[i] = vtn_src[i]->def;
+   }
+
+   /* Indicates that the first two arguments should be swapped.  This is
+    * used for implementing greater-than and less-than-or-equal.
+    */
+   bool swap = false;
+
+   nir_op op;
+   switch (opcode) {
+   /* Basic ALU operations */
+   case SpvOpSNegate:               op = nir_op_ineg;    break;
+   case SpvOpFNegate:               op = nir_op_fneg;    break;
+   case SpvOpNot:                   op = nir_op_inot;    break;
+
+   case SpvOpAny:
+      if (src[0]->num_components == 1) {
+         op = nir_op_imov;
+      } else {
+         switch (src[0]->num_components) {
+         case 2:  op = nir_op_bany_inequal2; break;
+         case 3:  op = nir_op_bany_inequal3; break;
+         case 4:  op = nir_op_bany_inequal4; break;
+         }
+         src[1] = nir_imm_int(&b->nb, NIR_FALSE);
+      }
+      break;
+
+   case SpvOpAll:
+      if (src[0]->num_components == 1) {
+         op = nir_op_imov;
+      } else {
+         switch (src[0]->num_components) {
+         case 2:  op = nir_op_ball_iequal2;  break;
+         case 3:  op = nir_op_ball_iequal3;  break;
+         case 4:  op = nir_op_ball_iequal4;  break;
+         }
+         src[1] = nir_imm_int(&b->nb, NIR_TRUE);
+      }
+      break;
+
+   case SpvOpIAdd:                  op = nir_op_iadd;    break;
+   case SpvOpFAdd:                  op = nir_op_fadd;    break;
+   case SpvOpISub:                  op = nir_op_isub;    break;
+   case SpvOpFSub:                  op = nir_op_fsub;    break;
+   case SpvOpIMul:                  op = nir_op_imul;    break;
+   case SpvOpFMul:                  op = nir_op_fmul;    break;
+   case SpvOpUDiv:                  op = nir_op_udiv;    break;
+   case SpvOpSDiv:                  op = nir_op_idiv;    break;
+   case SpvOpFDiv:                  op = nir_op_fdiv;    break;
+   case SpvOpUMod:                  op = nir_op_umod;    break;
+   case SpvOpSMod:                  op = nir_op_umod;    break; /* FIXME? */
+   case SpvOpFMod:                  op = nir_op_fmod;    break;
+
+   case SpvOpDot:
+      assert(src[0]->num_components == src[1]->num_components);
+      switch (src[0]->num_components) {
+      case 1:  op = nir_op_fmul;    break;
+      case 2:  op = nir_op_fdot2;   break;
+      case 3:  op = nir_op_fdot3;   break;
+      case 4:  op = nir_op_fdot4;   break;
+      }
+      break;
+
+   case SpvOpIAddCarry:
+      assert(glsl_type_is_struct(val->ssa->type));
+      val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
+      val->ssa->elems[1]->def =
+         nir_b2i(&b->nb, nir_uadd_carry(&b->nb, src[0], src[1]));
+      return;
+
+   case SpvOpISubBorrow:
+      assert(glsl_type_is_struct(val->ssa->type));
+      val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
+      val->ssa->elems[1]->def =
+         nir_b2i(&b->nb, nir_usub_borrow(&b->nb, src[0], src[1]));
+      return;
+
+   case SpvOpUMulExtended:
+      assert(glsl_type_is_struct(val->ssa->type));
+      val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
+      val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
+      return;
+
+   case SpvOpSMulExtended:
+      assert(glsl_type_is_struct(val->ssa->type));
+      val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
+      val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
+      return;
+
+   case SpvOpShiftRightLogical:     op = nir_op_ushr;    break;
+   case SpvOpShiftRightArithmetic:  op = nir_op_ishr;    break;
+   case SpvOpShiftLeftLogical:      op = nir_op_ishl;    break;
+   case SpvOpLogicalOr:             op = nir_op_ior;     break;
+   case SpvOpLogicalEqual:          op = nir_op_ieq;     break;
+   case SpvOpLogicalNotEqual:       op = nir_op_ine;     break;
+   case SpvOpLogicalAnd:            op = nir_op_iand;    break;
+   case SpvOpLogicalNot:            op = nir_op_inot;    break;
+   case SpvOpBitwiseOr:             op = nir_op_ior;     break;
+   case SpvOpBitwiseXor:            op = nir_op_ixor;    break;
+   case SpvOpBitwiseAnd:            op = nir_op_iand;    break;
+   case SpvOpSelect:                op = nir_op_bcsel;   break;
+   case SpvOpIEqual:                op = nir_op_ieq;     break;
+
+   case SpvOpBitFieldInsert:        op = nir_op_bitfield_insert;     break;
+   case SpvOpBitFieldSExtract:      op = nir_op_ibitfield_extract;   break;
+   case SpvOpBitFieldUExtract:      op = nir_op_ubitfield_extract;   break;
+   case SpvOpBitReverse:            op = nir_op_bitfield_reverse;    break;
+   case SpvOpBitCount:              op = nir_op_bit_count;           break;
+
+   /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */
+   case SpvOpFOrdEqual:             op = nir_op_feq;     break;
+   case SpvOpFUnordEqual:           op = nir_op_feq;     break;
+   case SpvOpINotEqual:             op = nir_op_ine;     break;
+   case SpvOpFOrdNotEqual:          op = nir_op_fne;     break;
+   case SpvOpFUnordNotEqual:        op = nir_op_fne;     break;
+   case SpvOpULessThan:             op = nir_op_ult;     break;
+   case SpvOpSLessThan:             op = nir_op_ilt;     break;
+   case SpvOpFOrdLessThan:          op = nir_op_flt;     break;
+   case SpvOpFUnordLessThan:        op = nir_op_flt;     break;
+   case SpvOpUGreaterThan:          op = nir_op_ult;  swap = true;   break;
+   case SpvOpSGreaterThan:          op = nir_op_ilt;  swap = true;   break;
+   case SpvOpFOrdGreaterThan:       op = nir_op_flt;  swap = true;   break;
+   case SpvOpFUnordGreaterThan:     op = nir_op_flt;  swap = true;   break;
+   case SpvOpULessThanEqual:        op = nir_op_uge;  swap = true;   break;
+   case SpvOpSLessThanEqual:        op = nir_op_ige;  swap = true;   break;
+   case SpvOpFOrdLessThanEqual:     op = nir_op_fge;  swap = true;   break;
+   case SpvOpFUnordLessThanEqual:   op = nir_op_fge;  swap = true;   break;
+   case SpvOpUGreaterThanEqual:     op = nir_op_uge;     break;
+   case SpvOpSGreaterThanEqual:     op = nir_op_ige;     break;
+   case SpvOpFOrdGreaterThanEqual:  op = nir_op_fge;     break;
+   case SpvOpFUnordGreaterThanEqual:op = nir_op_fge;     break;
+
+   /* Conversions: */
+   case SpvOpConvertFToU:           op = nir_op_f2u;     break;
+   case SpvOpConvertFToS:           op = nir_op_f2i;     break;
+   case SpvOpConvertSToF:           op = nir_op_i2f;     break;
+   case SpvOpConvertUToF:           op = nir_op_u2f;     break;
+   case SpvOpBitcast:               op = nir_op_imov;    break;
+   case SpvOpUConvert:
+   case SpvOpSConvert:
+      op = nir_op_imov; /* TODO: NIR is 32-bit only; these are no-ops. */
+      break;
+   case SpvOpFConvert:
+      op = nir_op_fmov;
+      break;
+
+   /* Derivatives: */
+   case SpvOpDPdx:         op = nir_op_fddx;          break;
+   case SpvOpDPdy:         op = nir_op_fddy;          break;
+   case SpvOpDPdxFine:     op = nir_op_fddx_fine;     break;
+   case SpvOpDPdyFine:     op = nir_op_fddy_fine;     break;
+   case SpvOpDPdxCoarse:   op = nir_op_fddx_coarse;   break;
+   case SpvOpDPdyCoarse:   op = nir_op_fddy_coarse;   break;
+   case SpvOpFwidth:
+      val->ssa->def = nir_fadd(&b->nb,
+                               nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
+                               nir_fabs(&b->nb, nir_fddx(&b->nb, src[1])));
+      return;
+   case SpvOpFwidthFine:
+      val->ssa->def = nir_fadd(&b->nb,
+                               nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
+                               nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[1])));
+      return;
+   case SpvOpFwidthCoarse:
+      val->ssa->def = nir_fadd(&b->nb,
+                               nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
+                               nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[1])));
+      return;
+
+   case SpvOpVectorTimesScalar:
+      /* The builder will take care of splatting for us. */
+      val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
+      return;
+
+   case SpvOpSRem:
+   case SpvOpFRem:
+      unreachable("No NIR equivalent");
+
+   case SpvOpIsNan:
+      val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
+      return;
+
+   case SpvOpIsInf:
+      val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]),
+                                      nir_imm_float(&b->nb, INFINITY));
+      return;
+
+   case SpvOpIsFinite:
+   case SpvOpIsNormal:
+   case SpvOpSignBitSet:
+   case SpvOpLessOrGreater:
+   case SpvOpOrdered:
+   case SpvOpUnordered:
+   default:
+      unreachable("Unhandled opcode");
+   }
+
+   if (swap) {
+      nir_ssa_def *tmp = src[0];
+      src[0] = src[1];
+      src[1] = tmp;
+   }
+
+   val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
+}
index 0fa7dd4..14355c9 100644 (file)
@@ -363,6 +363,12 @@ vtn_value(struct vtn_builder *b, uint32_t value_id,
 
 struct vtn_ssa_value *vtn_ssa_value(struct vtn_builder *b, uint32_t value_id);
 
+struct vtn_ssa_value *vtn_create_ssa_value(struct vtn_builder *b,
+                                           const struct glsl_type *type);
+
+struct vtn_ssa_value *vtn_ssa_transpose(struct vtn_builder *b,
+                                        struct vtn_ssa_value *src);
+
 void vtn_variable_store(struct vtn_builder *b, struct vtn_ssa_value *src,
                         nir_deref_var *dest, struct vtn_type *dest_type);
 
@@ -384,5 +390,8 @@ typedef void (*vtn_execution_mode_foreach_cb)(struct vtn_builder *,
 void vtn_foreach_execution_mode(struct vtn_builder *b, struct vtn_value *value,
                                 vtn_execution_mode_foreach_cb cb, void *data);
 
+void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
+                    const uint32_t *w, unsigned count);
+
 bool vtn_handle_glsl450_instruction(struct vtn_builder *b, uint32_t ext_opcode,
                                     const uint32_t *words, unsigned count);