nir/spirv/alu: Factor out the opcode table
authorJason Ekstrand <jason.ekstrand@intel.com>
Wed, 13 Jan 2016 01:16:48 +0000 (17:16 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Wed, 13 Jan 2016 23:18:36 +0000 (15:18 -0800)
src/glsl/nir/spirv/vtn_alu.c
src/glsl/nir/spirv/vtn_private.h

index 03ed1f0..4c45e23 100644 (file)
@@ -210,6 +210,101 @@ vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
    }
 }
 
+nir_op
+vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap)
+{
+   /* Indicates that the first two arguments should be swapped.  This is
+    * used for implementing greater-than and less-than-or-equal.
+    */
+   *swap = false;
+
+   switch (opcode) {
+   case SpvOpSNegate:            return nir_op_ineg;
+   case SpvOpFNegate:            return nir_op_fneg;
+   case SpvOpNot:                return nir_op_inot;
+   case SpvOpIAdd:               return nir_op_iadd;
+   case SpvOpFAdd:               return nir_op_fadd;
+   case SpvOpISub:               return nir_op_isub;
+   case SpvOpFSub:               return nir_op_fsub;
+   case SpvOpIMul:               return nir_op_imul;
+   case SpvOpFMul:               return nir_op_fmul;
+   case SpvOpUDiv:               return nir_op_udiv;
+   case SpvOpSDiv:               return nir_op_idiv;
+   case SpvOpFDiv:               return nir_op_fdiv;
+   case SpvOpUMod:               return nir_op_umod;
+   case SpvOpSMod:               return nir_op_umod; /* FIXME? */
+   case SpvOpFMod:               return nir_op_fmod;
+
+   case SpvOpShiftRightLogical:     return nir_op_ushr;
+   case SpvOpShiftRightArithmetic:  return nir_op_ishr;
+   case SpvOpShiftLeftLogical:      return nir_op_ishl;
+   case SpvOpLogicalOr:             return nir_op_ior;
+   case SpvOpLogicalEqual:          return nir_op_ieq;
+   case SpvOpLogicalNotEqual:       return nir_op_ine;
+   case SpvOpLogicalAnd:            return nir_op_iand;
+   case SpvOpLogicalNot:            return nir_op_inot;
+   case SpvOpBitwiseOr:             return nir_op_ior;
+   case SpvOpBitwiseXor:            return nir_op_ixor;
+   case SpvOpBitwiseAnd:            return nir_op_iand;
+   case SpvOpSelect:                return nir_op_bcsel;
+   case SpvOpIEqual:                return nir_op_ieq;
+
+   case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
+   case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
+   case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
+   case SpvOpBitReverse:            return nir_op_bitfield_reverse;
+   case SpvOpBitCount:              return nir_op_bit_count;
+
+   /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */
+   case SpvOpFOrdEqual:                            return nir_op_feq;
+   case SpvOpFUnordEqual:                          return nir_op_feq;
+   case SpvOpINotEqual:                            return nir_op_ine;
+   case SpvOpFOrdNotEqual:                         return nir_op_fne;
+   case SpvOpFUnordNotEqual:                       return nir_op_fne;
+   case SpvOpULessThan:                            return nir_op_ult;
+   case SpvOpSLessThan:                            return nir_op_ilt;
+   case SpvOpFOrdLessThan:                         return nir_op_flt;
+   case SpvOpFUnordLessThan:                       return nir_op_flt;
+   case SpvOpUGreaterThan:          *swap = true;  return nir_op_ult;
+   case SpvOpSGreaterThan:          *swap = true;  return nir_op_ilt;
+   case SpvOpFOrdGreaterThan:       *swap = true;  return nir_op_flt;
+   case SpvOpFUnordGreaterThan:     *swap = true;  return nir_op_flt;
+   case SpvOpULessThanEqual:        *swap = true;  return nir_op_uge;
+   case SpvOpSLessThanEqual:        *swap = true;  return nir_op_ige;
+   case SpvOpFOrdLessThanEqual:     *swap = true;  return nir_op_fge;
+   case SpvOpFUnordLessThanEqual:   *swap = true;  return nir_op_fge;
+   case SpvOpUGreaterThanEqual:                    return nir_op_uge;
+   case SpvOpSGreaterThanEqual:                    return nir_op_ige;
+   case SpvOpFOrdGreaterThanEqual:                 return nir_op_fge;
+   case SpvOpFUnordGreaterThanEqual:               return nir_op_fge;
+
+   /* Conversions: */
+   case SpvOpConvertFToU:           return nir_op_f2u;
+   case SpvOpConvertFToS:           return nir_op_f2i;
+   case SpvOpConvertSToF:           return nir_op_i2f;
+   case SpvOpConvertUToF:           return nir_op_u2f;
+   case SpvOpBitcast:               return nir_op_imov;
+   case SpvOpUConvert:
+   case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
+   /* TODO: NIR is 32-bit only; these are no-ops. */
+   case SpvOpSConvert:              return nir_op_imov;
+   case SpvOpFConvert:              return nir_op_fmov;
+
+   /* Derivatives: */
+   case SpvOpDPdx:         return nir_op_fddx;
+   case SpvOpDPdy:         return nir_op_fddy;
+   case SpvOpDPdxFine:     return nir_op_fddx_fine;
+   case SpvOpDPdyFine:     return nir_op_fddy_fine;
+   case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
+   case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
+
+   case SpvOpSRem:
+   case SpvOpFRem:
+   default:
+      unreachable("No NIR equivalent");
+   }
+}
+
 void
 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                const uint32_t *w, unsigned count)
@@ -237,56 +332,38 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       src[i] = vtn_src[i]->def;
    }
 
-   /* Indicates that the first two arguments should be swapped.  This is
-    * used for implementing greater-than and less-than-or-equal.
-    */
-   bool swap = false;
-
-   nir_op op;
    switch (opcode) {
-   /* Basic ALU operations */
-   case SpvOpSNegate:               op = nir_op_ineg;    break;
-   case SpvOpFNegate:               op = nir_op_fneg;    break;
-   case SpvOpNot:                   op = nir_op_inot;    break;
-
    case SpvOpAny:
       if (src[0]->num_components == 1) {
-         op = nir_op_imov;
+         val->ssa->def = nir_imov(&b->nb, src[0]);
       } else {
+         nir_op op;
          switch (src[0]->num_components) {
          case 2:  op = nir_op_bany_inequal2; break;
          case 3:  op = nir_op_bany_inequal3; break;
          case 4:  op = nir_op_bany_inequal4; break;
          }
-         src[1] = nir_imm_int(&b->nb, NIR_FALSE);
+         val->ssa->def = nir_build_alu(&b->nb, op, src[0],
+                                       nir_imm_int(&b->nb, NIR_FALSE),
+                                       NULL, NULL);
       }
-      break;
+      return;
 
    case SpvOpAll:
       if (src[0]->num_components == 1) {
-         op = nir_op_imov;
+         val->ssa->def = nir_imov(&b->nb, src[0]);
       } else {
+         nir_op op;
          switch (src[0]->num_components) {
          case 2:  op = nir_op_ball_iequal2;  break;
          case 3:  op = nir_op_ball_iequal3;  break;
          case 4:  op = nir_op_ball_iequal4;  break;
          }
-         src[1] = nir_imm_int(&b->nb, NIR_TRUE);
+         val->ssa->def = nir_build_alu(&b->nb, op, src[0],
+                                       nir_imm_int(&b->nb, NIR_TRUE),
+                                       NULL, NULL);
       }
-      break;
-
-   case SpvOpIAdd:                  op = nir_op_iadd;    break;
-   case SpvOpFAdd:                  op = nir_op_fadd;    break;
-   case SpvOpISub:                  op = nir_op_isub;    break;
-   case SpvOpFSub:                  op = nir_op_fsub;    break;
-   case SpvOpIMul:                  op = nir_op_imul;    break;
-   case SpvOpFMul:                  op = nir_op_fmul;    break;
-   case SpvOpUDiv:                  op = nir_op_udiv;    break;
-   case SpvOpSDiv:                  op = nir_op_idiv;    break;
-   case SpvOpFDiv:                  op = nir_op_fdiv;    break;
-   case SpvOpUMod:                  op = nir_op_umod;    break;
-   case SpvOpSMod:                  op = nir_op_umod;    break; /* FIXME? */
-   case SpvOpFMod:                  op = nir_op_fmod;    break;
+      return;
 
    case SpvOpOuterProduct: {
       for (unsigned i = 0; i < src[1]->num_components; i++) {
@@ -297,14 +374,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    }
 
    case SpvOpDot:
-      assert(src[0]->num_components == src[1]->num_components);
-      switch (src[0]->num_components) {
-      case 1:  op = nir_op_fmul;    break;
-      case 2:  op = nir_op_fdot2;   break;
-      case 3:  op = nir_op_fdot3;   break;
-      case 4:  op = nir_op_fdot4;   break;
-      }
-      break;
+      val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
+      return;
 
    case SpvOpIAddCarry:
       assert(glsl_type_is_struct(val->ssa->type));
@@ -330,74 +401,6 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
       return;
 
-   case SpvOpShiftRightLogical:     op = nir_op_ushr;    break;
-   case SpvOpShiftRightArithmetic:  op = nir_op_ishr;    break;
-   case SpvOpShiftLeftLogical:      op = nir_op_ishl;    break;
-   case SpvOpLogicalOr:             op = nir_op_ior;     break;
-   case SpvOpLogicalEqual:          op = nir_op_ieq;     break;
-   case SpvOpLogicalNotEqual:       op = nir_op_ine;     break;
-   case SpvOpLogicalAnd:            op = nir_op_iand;    break;
-   case SpvOpLogicalNot:            op = nir_op_inot;    break;
-   case SpvOpBitwiseOr:             op = nir_op_ior;     break;
-   case SpvOpBitwiseXor:            op = nir_op_ixor;    break;
-   case SpvOpBitwiseAnd:            op = nir_op_iand;    break;
-   case SpvOpSelect:                op = nir_op_bcsel;   break;
-   case SpvOpIEqual:                op = nir_op_ieq;     break;
-
-   case SpvOpBitFieldInsert:        op = nir_op_bitfield_insert;     break;
-   case SpvOpBitFieldSExtract:      op = nir_op_ibitfield_extract;   break;
-   case SpvOpBitFieldUExtract:      op = nir_op_ubitfield_extract;   break;
-   case SpvOpBitReverse:            op = nir_op_bitfield_reverse;    break;
-   case SpvOpBitCount:              op = nir_op_bit_count;           break;
-
-   /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */
-   case SpvOpFOrdEqual:             op = nir_op_feq;     break;
-   case SpvOpFUnordEqual:           op = nir_op_feq;     break;
-   case SpvOpINotEqual:             op = nir_op_ine;     break;
-   case SpvOpFOrdNotEqual:          op = nir_op_fne;     break;
-   case SpvOpFUnordNotEqual:        op = nir_op_fne;     break;
-   case SpvOpULessThan:             op = nir_op_ult;     break;
-   case SpvOpSLessThan:             op = nir_op_ilt;     break;
-   case SpvOpFOrdLessThan:          op = nir_op_flt;     break;
-   case SpvOpFUnordLessThan:        op = nir_op_flt;     break;
-   case SpvOpUGreaterThan:          op = nir_op_ult;  swap = true;   break;
-   case SpvOpSGreaterThan:          op = nir_op_ilt;  swap = true;   break;
-   case SpvOpFOrdGreaterThan:       op = nir_op_flt;  swap = true;   break;
-   case SpvOpFUnordGreaterThan:     op = nir_op_flt;  swap = true;   break;
-   case SpvOpULessThanEqual:        op = nir_op_uge;  swap = true;   break;
-   case SpvOpSLessThanEqual:        op = nir_op_ige;  swap = true;   break;
-   case SpvOpFOrdLessThanEqual:     op = nir_op_fge;  swap = true;   break;
-   case SpvOpFUnordLessThanEqual:   op = nir_op_fge;  swap = true;   break;
-   case SpvOpUGreaterThanEqual:     op = nir_op_uge;     break;
-   case SpvOpSGreaterThanEqual:     op = nir_op_ige;     break;
-   case SpvOpFOrdGreaterThanEqual:  op = nir_op_fge;     break;
-   case SpvOpFUnordGreaterThanEqual:op = nir_op_fge;     break;
-
-   /* Conversions: */
-   case SpvOpConvertFToU:           op = nir_op_f2u;     break;
-   case SpvOpConvertFToS:           op = nir_op_f2i;     break;
-   case SpvOpConvertSToF:           op = nir_op_i2f;     break;
-   case SpvOpConvertUToF:           op = nir_op_u2f;     break;
-   case SpvOpBitcast:               op = nir_op_imov;    break;
-   case SpvOpUConvert:
-   case SpvOpSConvert:
-      op = nir_op_imov; /* TODO: NIR is 32-bit only; these are no-ops. */
-      break;
-   case SpvOpFConvert:
-      op = nir_op_fmov;
-      break;
-
-   case SpvOpQuantizeToF16:
-      op = nir_op_fquantize2f16;
-      break;
-
-   /* Derivatives: */
-   case SpvOpDPdx:         op = nir_op_fddx;          break;
-   case SpvOpDPdy:         op = nir_op_fddy;          break;
-   case SpvOpDPdxFine:     op = nir_op_fddx_fine;     break;
-   case SpvOpDPdyFine:     op = nir_op_fddy_fine;     break;
-   case SpvOpDPdxCoarse:   op = nir_op_fddx_coarse;   break;
-   case SpvOpDPdyCoarse:   op = nir_op_fddy_coarse;   break;
    case SpvOpFwidth:
       val->ssa->def = nir_fadd(&b->nb,
                                nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
@@ -419,10 +422,6 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
       return;
 
-   case SpvOpSRem:
-   case SpvOpFRem:
-      unreachable("No NIR equivalent");
-
    case SpvOpIsNan:
       val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
       return;
@@ -432,21 +431,18 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                                       nir_imm_float(&b->nb, INFINITY));
       return;
 
-   case SpvOpIsFinite:
-   case SpvOpIsNormal:
-   case SpvOpSignBitSet:
-   case SpvOpLessOrGreater:
-   case SpvOpOrdered:
-   case SpvOpUnordered:
-   default:
-      unreachable("Unhandled opcode");
-   }
+   default: {
+      bool swap;
+      nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap);
 
-   if (swap) {
-      nir_ssa_def *tmp = src[0];
-      src[0] = src[1];
-      src[1] = tmp;
-   }
+      if (swap) {
+         nir_ssa_def *tmp = src[0];
+         src[0] = src[1];
+         src[1] = tmp;
+      }
 
-   val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
+      val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
+      return;
+   } /* default */
+   }
 }
index 1f88eed..129414a 100644 (file)
@@ -393,6 +393,8 @@ typedef void (*vtn_execution_mode_foreach_cb)(struct vtn_builder *,
 void vtn_foreach_execution_mode(struct vtn_builder *b, struct vtn_value *value,
                                 vtn_execution_mode_foreach_cb cb, void *data);
 
+nir_op vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap);
+
 void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                     const uint32_t *w, unsigned count);