#include "nir_spirv.h"
#include "nir_vla.h"
+#include "nir_builder.h"
#include "spirv.h"
struct vtn_decoration;
};
struct vtn_builder {
+ nir_builder nb;
+
nir_shader *shader;
nir_function_impl *impl;
struct exec_list *cf_list;
}
static void
+vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
+ const uint32_t *w, unsigned count)
+{
+ unreachable("Matrix math not handled");
+}
+
+static void
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
{
- unreachable("Unhandled opcode");
+ struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
+ val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
+
+ /* Collect the various SSA sources */
+ unsigned num_inputs = count - 3;
+ nir_ssa_def *src[4];
+ for (unsigned i = 0; i < num_inputs; i++)
+ src[i] = vtn_ssa_value(b, w[i + 3]);
+
+ /* We use the builder for some of the instructions. Go ahead and
+ * initialize it with the current cf_list.
+ */
+ nir_builder_insert_after_cf_list(&b->nb, b->cf_list);
+
+ /* Indicates that the first two arguments should be swapped. This is
+ * used for implementing greater-than and less-than-or-equal.
+ */
+ bool swap = false;
+
+ nir_op op;
+ switch (opcode) {
+ /* Basic ALU operations */
+ case SpvOpSNegate: op = nir_op_ineg; break;
+ case SpvOpFNegate: op = nir_op_fneg; break;
+ case SpvOpNot: op = nir_op_inot; break;
+
+ case SpvOpAny:
+ switch (src[0]->num_components) {
+ case 1: op = nir_op_imov; break;
+ case 2: op = nir_op_bany2; break;
+ case 3: op = nir_op_bany3; break;
+ case 4: op = nir_op_bany4; break;
+ }
+ break;
+
+ case SpvOpAll:
+ switch (src[0]->num_components) {
+ case 1: op = nir_op_imov; break;
+ case 2: op = nir_op_ball2; break;
+ case 3: op = nir_op_ball3; break;
+ case 4: op = nir_op_ball4; break;
+ }
+ break;
+
+ case SpvOpIAdd: op = nir_op_iadd; break;
+ case SpvOpFAdd: op = nir_op_fadd; break;
+ case SpvOpISub: op = nir_op_isub; break;
+ case SpvOpFSub: op = nir_op_fsub; break;
+ case SpvOpIMul: op = nir_op_imul; break;
+ case SpvOpFMul: op = nir_op_fmul; break;
+ case SpvOpUDiv: op = nir_op_udiv; break;
+ case SpvOpSDiv: op = nir_op_idiv; break;
+ case SpvOpFDiv: op = nir_op_fdiv; break;
+ case SpvOpUMod: op = nir_op_umod; break;
+ case SpvOpSMod: op = nir_op_umod; break; /* FIXME? */
+ case SpvOpFMod: op = nir_op_fmod; break;
+
+ case SpvOpDot:
+ assert(src[0]->num_components == src[1]->num_components);
+ switch (src[0]->num_components) {
+ case 1: op = nir_op_fmul; break;
+ case 2: op = nir_op_fdot2; break;
+ case 3: op = nir_op_fdot3; break;
+ case 4: op = nir_op_fdot4; break;
+ }
+ break;
+
+ case SpvOpShiftRightLogical: op = nir_op_ushr; break;
+ case SpvOpShiftRightArithmetic: op = nir_op_ishr; break;
+ case SpvOpShiftLeftLogical: op = nir_op_ishl; break;
+ case SpvOpLogicalOr: op = nir_op_ior; break;
+ case SpvOpLogicalXor: op = nir_op_ixor; break;
+ case SpvOpLogicalAnd: op = nir_op_iand; break;
+ case SpvOpBitwiseOr: op = nir_op_ior; break;
+ case SpvOpBitwiseXor: op = nir_op_ixor; break;
+ case SpvOpBitwiseAnd: op = nir_op_iand; break;
+ case SpvOpSelect: op = nir_op_bcsel; break;
+ case SpvOpIEqual: op = nir_op_ieq; break;
+
+ /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */
+ case SpvOpFOrdEqual: op = nir_op_feq; break;
+ case SpvOpFUnordEqual: op = nir_op_feq; break;
+ case SpvOpINotEqual: op = nir_op_ine; break;
+ case SpvOpFOrdNotEqual: op = nir_op_fne; break;
+ case SpvOpFUnordNotEqual: op = nir_op_fne; break;
+ case SpvOpULessThan: op = nir_op_ult; break;
+ case SpvOpSLessThan: op = nir_op_ilt; break;
+ case SpvOpFOrdLessThan: op = nir_op_flt; break;
+ case SpvOpFUnordLessThan: op = nir_op_flt; break;
+ case SpvOpUGreaterThan: op = nir_op_ult; swap = true; break;
+ case SpvOpSGreaterThan: op = nir_op_ilt; swap = true; break;
+ case SpvOpFOrdGreaterThan: op = nir_op_flt; swap = true; break;
+ case SpvOpFUnordGreaterThan: op = nir_op_flt; swap = true; break;
+ case SpvOpULessThanEqual: op = nir_op_uge; swap = true; break;
+ case SpvOpSLessThanEqual: op = nir_op_ige; swap = true; break;
+ case SpvOpFOrdLessThanEqual: op = nir_op_fge; swap = true; break;
+ case SpvOpFUnordLessThanEqual: op = nir_op_fge; swap = true; break;
+ case SpvOpUGreaterThanEqual: op = nir_op_uge; break;
+ case SpvOpSGreaterThanEqual: op = nir_op_ige; break;
+ case SpvOpFOrdGreaterThanEqual: op = nir_op_fge; break;
+ case SpvOpFUnordGreaterThanEqual:op = nir_op_fge; break;
+
+ /* Conversions: */
+ case SpvOpConvertFToU: op = nir_op_f2u; break;
+ case SpvOpConvertFToS: op = nir_op_f2i; break;
+ case SpvOpConvertSToF: op = nir_op_i2f; break;
+ case SpvOpConvertUToF: op = nir_op_u2f; break;
+ case SpvOpBitcast: op = nir_op_imov; break;
+ case SpvOpUConvert:
+ case SpvOpSConvert:
+ op = nir_op_imov; /* TODO: NIR is 32-bit only; these are no-ops. */
+ break;
+ case SpvOpFConvert:
+ op = nir_op_fmov;
+ break;
+
+ /* Derivatives: */
+ case SpvOpDPdx: op = nir_op_fddx; break;
+ case SpvOpDPdy: op = nir_op_fddy; break;
+ case SpvOpDPdxFine: op = nir_op_fddx_fine; break;
+ case SpvOpDPdyFine: op = nir_op_fddy_fine; break;
+ case SpvOpDPdxCoarse: op = nir_op_fddx_coarse; break;
+ case SpvOpDPdyCoarse: op = nir_op_fddy_coarse; break;
+ case SpvOpFwidth:
+ val->ssa = nir_fadd(&b->nb,
+ nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
+ nir_fabs(&b->nb, nir_fddx(&b->nb, src[1])));
+ return;
+ case SpvOpFwidthFine:
+ val->ssa = nir_fadd(&b->nb,
+ nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
+ nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[1])));
+ return;
+ case SpvOpFwidthCoarse:
+ val->ssa = nir_fadd(&b->nb,
+ nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
+ nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[1])));
+ return;
+
+ case SpvOpVectorTimesScalar:
+ /* The builder will take care of splatting for us. */
+ val->ssa = nir_fmul(&b->nb, src[0], src[1]);
+ return;
+
+ case SpvOpSRem:
+ case SpvOpFRem:
+ unreachable("No NIR equivalent");
+
+ case SpvOpIsNan:
+ case SpvOpIsInf:
+ case SpvOpIsFinite:
+ case SpvOpIsNormal:
+ case SpvOpSignBitSet:
+ case SpvOpLessOrGreater:
+ case SpvOpOrdered:
+ case SpvOpUnordered:
+ default:
+ unreachable("Unhandled opcode");
+ }
+
+ if (swap) {
+ nir_ssa_def *tmp = src[0];
+ src[0] = src[1];
+ src[1] = tmp;
+ }
+
+ nir_alu_instr *instr = nir_alu_instr_create(b->shader, op);
+ nir_ssa_dest_init(&instr->instr, &instr->dest.dest,
+ glsl_get_vector_elements(val->type), val->name);
+ val->ssa = &instr->dest.dest.ssa;
+
+ for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++)
+ instr->src[i].src = nir_src_for_ssa(src[i]);
+
+ nir_instr_insert_after_cf_list(b->cf_list, &instr->instr);
}
static bool
case SpvOpPtrCastToGeneric:
case SpvOpGenericCastToPtr:
case SpvOpBitcast:
- case SpvOpTranspose:
case SpvOpIsNan:
case SpvOpIsInf:
case SpvOpIsFinite:
case SpvOpFRem:
case SpvOpFMod:
case SpvOpVectorTimesScalar:
- case SpvOpMatrixTimesScalar:
- case SpvOpVectorTimesMatrix:
- case SpvOpMatrixTimesVector:
- case SpvOpMatrixTimesMatrix:
- case SpvOpOuterProduct:
case SpvOpDot:
case SpvOpShiftRightLogical:
case SpvOpShiftRightArithmetic:
vtn_handle_alu(b, opcode, w, count);
break;
+ case SpvOpTranspose:
+ case SpvOpOuterProduct:
+ case SpvOpMatrixTimesScalar:
+ case SpvOpVectorTimesMatrix:
+ case SpvOpMatrixTimesVector:
+ case SpvOpMatrixTimesMatrix:
+ vtn_handle_matrix_alu(b, opcode, w, count);
+ break;
+
default:
unreachable("Unhandled opcode");
}
foreach_list_typed(struct vtn_function, func, node, &b->functions) {
b->impl = nir_function_impl_create(func->overload);
+ nir_builder_init(&b->nb, b->impl);
b->cf_list = &b->impl->body;
vtn_walk_blocks(b, func->start_block, NULL);
}