nir/spirv: Add support for a bunch of ALU operations
authorJason Ekstrand <jason.ekstrand@intel.com>
Mon, 4 May 2015 19:04:02 +0000 (12:04 -0700)
committerJason Ekstrand <jason.ekstrand@intel.com>
Sat, 16 May 2015 18:16:33 +0000 (11:16 -0700)
src/glsl/nir/spirv_to_nir.c

index 3f8ce2a..734ffee 100644 (file)
@@ -27,6 +27,7 @@
 
 #include "nir_spirv.h"
 #include "nir_vla.h"
+#include "nir_builder.h"
 #include "spirv.h"
 
 struct vtn_decoration;
@@ -81,6 +82,8 @@ struct vtn_decoration {
 };
 
 struct vtn_builder {
+   nir_builder nb;
+
    nir_shader *shader;
    nir_function_impl *impl;
    struct exec_list *cf_list;
@@ -706,10 +709,191 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
 }
 
 static void
+vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
+                      const uint32_t *w, unsigned count)
+{
+   unreachable("Matrix math not handled");
+}
+
+static void
 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                const uint32_t *w, unsigned count)
 {
-   unreachable("Unhandled opcode");
+   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
+   val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
+
+   /* Collect the various SSA sources */
+   unsigned num_inputs = count - 3;
+   nir_ssa_def *src[4];
+   for (unsigned i = 0; i < num_inputs; i++)
+      src[i] = vtn_ssa_value(b, w[i + 3]);
+
+   /* We use the builder for some of the instructions.  Go ahead and
+    * initialize it with the current cf_list.
+    */
+   nir_builder_insert_after_cf_list(&b->nb, b->cf_list);
+
+   /* Indicates that the first two arguments should be swapped.  This is
+    * used for implementing greater-than and less-than-or-equal.
+    */
+   bool swap = false;
+
+   nir_op op;
+   switch (opcode) {
+   /* Basic ALU operations */
+   case SpvOpSNegate:               op = nir_op_ineg;    break;
+   case SpvOpFNegate:               op = nir_op_fneg;    break;
+   case SpvOpNot:                   op = nir_op_inot;    break;
+
+   case SpvOpAny:
+      switch (src[0]->num_components) {
+      case 1:  op = nir_op_imov;    break;
+      case 2:  op = nir_op_bany2;   break;
+      case 3:  op = nir_op_bany3;   break;
+      case 4:  op = nir_op_bany4;   break;
+      }
+      break;
+
+   case SpvOpAll:
+      switch (src[0]->num_components) {
+      case 1:  op = nir_op_imov;    break;
+      case 2:  op = nir_op_ball2;   break;
+      case 3:  op = nir_op_ball3;   break;
+      case 4:  op = nir_op_ball4;   break;
+      }
+      break;
+
+   case SpvOpIAdd:                  op = nir_op_iadd;    break;
+   case SpvOpFAdd:                  op = nir_op_fadd;    break;
+   case SpvOpISub:                  op = nir_op_isub;    break;
+   case SpvOpFSub:                  op = nir_op_fsub;    break;
+   case SpvOpIMul:                  op = nir_op_imul;    break;
+   case SpvOpFMul:                  op = nir_op_fmul;    break;
+   case SpvOpUDiv:                  op = nir_op_udiv;    break;
+   case SpvOpSDiv:                  op = nir_op_idiv;    break;
+   case SpvOpFDiv:                  op = nir_op_fdiv;    break;
+   case SpvOpUMod:                  op = nir_op_umod;    break;
+   case SpvOpSMod:                  op = nir_op_umod;    break; /* FIXME? */
+   case SpvOpFMod:                  op = nir_op_fmod;    break;
+
+   case SpvOpDot:
+      assert(src[0]->num_components == src[1]->num_components);
+      switch (src[0]->num_components) {
+      case 1:  op = nir_op_fmul;    break;
+      case 2:  op = nir_op_fdot2;   break;
+      case 3:  op = nir_op_fdot3;   break;
+      case 4:  op = nir_op_fdot4;   break;
+      }
+      break;
+
+   case SpvOpShiftRightLogical:     op = nir_op_ushr;    break;
+   case SpvOpShiftRightArithmetic:  op = nir_op_ishr;    break;
+   case SpvOpShiftLeftLogical:      op = nir_op_ishl;    break;
+   case SpvOpLogicalOr:             op = nir_op_ior;     break;
+   case SpvOpLogicalXor:            op = nir_op_ixor;    break;
+   case SpvOpLogicalAnd:            op = nir_op_iand;    break;
+   case SpvOpBitwiseOr:             op = nir_op_ior;     break;
+   case SpvOpBitwiseXor:            op = nir_op_ixor;    break;
+   case SpvOpBitwiseAnd:            op = nir_op_iand;    break;
+   case SpvOpSelect:                op = nir_op_bcsel;   break;
+   case SpvOpIEqual:                op = nir_op_ieq;     break;
+
+   /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */
+   case SpvOpFOrdEqual:             op = nir_op_feq;     break;
+   case SpvOpFUnordEqual:           op = nir_op_feq;     break;
+   case SpvOpINotEqual:             op = nir_op_ine;     break;
+   case SpvOpFOrdNotEqual:          op = nir_op_fne;     break;
+   case SpvOpFUnordNotEqual:        op = nir_op_fne;     break;
+   case SpvOpULessThan:             op = nir_op_ult;     break;
+   case SpvOpSLessThan:             op = nir_op_ilt;     break;
+   case SpvOpFOrdLessThan:          op = nir_op_flt;     break;
+   case SpvOpFUnordLessThan:        op = nir_op_flt;     break;
+   case SpvOpUGreaterThan:          op = nir_op_ult;  swap = true;   break;
+   case SpvOpSGreaterThan:          op = nir_op_ilt;  swap = true;   break;
+   case SpvOpFOrdGreaterThan:       op = nir_op_flt;  swap = true;   break;
+   case SpvOpFUnordGreaterThan:     op = nir_op_flt;  swap = true;   break;
+   case SpvOpULessThanEqual:        op = nir_op_uge;  swap = true;   break;
+   case SpvOpSLessThanEqual:        op = nir_op_ige;  swap = true;   break;
+   case SpvOpFOrdLessThanEqual:     op = nir_op_fge;  swap = true;   break;
+   case SpvOpFUnordLessThanEqual:   op = nir_op_fge;  swap = true;   break;
+   case SpvOpUGreaterThanEqual:     op = nir_op_uge;     break;
+   case SpvOpSGreaterThanEqual:     op = nir_op_ige;     break;
+   case SpvOpFOrdGreaterThanEqual:  op = nir_op_fge;     break;
+   case SpvOpFUnordGreaterThanEqual:op = nir_op_fge;     break;
+
+   /* Conversions: */
+   case SpvOpConvertFToU:           op = nir_op_f2u;     break;
+   case SpvOpConvertFToS:           op = nir_op_f2i;     break;
+   case SpvOpConvertSToF:           op = nir_op_i2f;     break;
+   case SpvOpConvertUToF:           op = nir_op_u2f;     break;
+   case SpvOpBitcast:               op = nir_op_imov;    break;
+   case SpvOpUConvert:
+   case SpvOpSConvert:
+      op = nir_op_imov; /* TODO: NIR is 32-bit only; these are no-ops. */
+      break;
+   case SpvOpFConvert:
+      op = nir_op_fmov;
+      break;
+
+   /* Derivatives: */
+   case SpvOpDPdx:         op = nir_op_fddx;          break;
+   case SpvOpDPdy:         op = nir_op_fddy;          break;
+   case SpvOpDPdxFine:     op = nir_op_fddx_fine;     break;
+   case SpvOpDPdyFine:     op = nir_op_fddy_fine;     break;
+   case SpvOpDPdxCoarse:   op = nir_op_fddx_coarse;   break;
+   case SpvOpDPdyCoarse:   op = nir_op_fddy_coarse;   break;
+   case SpvOpFwidth:
+      val->ssa = nir_fadd(&b->nb,
+                          nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
+                          nir_fabs(&b->nb, nir_fddx(&b->nb, src[1])));
+      return;
+   case SpvOpFwidthFine:
+      val->ssa = nir_fadd(&b->nb,
+                          nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
+                          nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[1])));
+      return;
+   case SpvOpFwidthCoarse:
+      val->ssa = nir_fadd(&b->nb,
+                          nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
+                          nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[1])));
+      return;
+
+   case SpvOpVectorTimesScalar:
+      /* The builder will take care of splatting for us. */
+      val->ssa = nir_fmul(&b->nb, src[0], src[1]);
+      return;
+
+   case SpvOpSRem:
+   case SpvOpFRem:
+      unreachable("No NIR equivalent");
+
+   case SpvOpIsNan:
+   case SpvOpIsInf:
+   case SpvOpIsFinite:
+   case SpvOpIsNormal:
+   case SpvOpSignBitSet:
+   case SpvOpLessOrGreater:
+   case SpvOpOrdered:
+   case SpvOpUnordered:
+   default:
+      unreachable("Unhandled opcode");
+   }
+
+   if (swap) {
+      nir_ssa_def *tmp = src[0];
+      src[0] = src[1];
+      src[1] = tmp;
+   }
+
+   nir_alu_instr *instr = nir_alu_instr_create(b->shader, op);
+   nir_ssa_dest_init(&instr->instr, &instr->dest.dest,
+                     glsl_get_vector_elements(val->type), val->name);
+   val->ssa = &instr->dest.dest.ssa;
+
+   for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++)
+      instr->src[i].src = nir_src_for_ssa(src[i]);
+
+   nir_instr_insert_after_cf_list(b->cf_list, &instr->instr);
 }
 
 static bool
@@ -993,7 +1177,6 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    case SpvOpPtrCastToGeneric:
    case SpvOpGenericCastToPtr:
    case SpvOpBitcast:
-   case SpvOpTranspose:
    case SpvOpIsNan:
    case SpvOpIsInf:
    case SpvOpIsFinite:
@@ -1017,11 +1200,6 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    case SpvOpFRem:
    case SpvOpFMod:
    case SpvOpVectorTimesScalar:
-   case SpvOpMatrixTimesScalar:
-   case SpvOpVectorTimesMatrix:
-   case SpvOpMatrixTimesVector:
-   case SpvOpMatrixTimesMatrix:
-   case SpvOpOuterProduct:
    case SpvOpDot:
    case SpvOpShiftRightLogical:
    case SpvOpShiftRightArithmetic:
@@ -1067,6 +1245,15 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
       vtn_handle_alu(b, opcode, w, count);
       break;
 
+   case SpvOpTranspose:
+   case SpvOpOuterProduct:
+   case SpvOpMatrixTimesScalar:
+   case SpvOpVectorTimesMatrix:
+   case SpvOpMatrixTimesVector:
+   case SpvOpMatrixTimesMatrix:
+      vtn_handle_matrix_alu(b, opcode, w, count);
+      break;
+
    default:
       unreachable("Unhandled opcode");
    }
@@ -1163,6 +1350,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
 
    foreach_list_typed(struct vtn_function, func, node, &b->functions) {
       b->impl = nir_function_impl_create(func->overload);
+      nir_builder_init(&b->nb, b->impl);
       b->cf_list = &b->impl->body;
       vtn_walk_blocks(b, func->start_block, NULL);
    }