radv,aco: implement nir_op_ffma
authorRhys Perry <pendingchaos02@gmail.com>
Wed, 24 Mar 2021 17:17:38 +0000 (17:17 +0000)
committerMarge Bot <emma+marge@anholt.net>
Mon, 13 Dec 2021 11:22:33 +0000 (11:22 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/9805>

src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_instruction_selection_setup.cpp
src/amd/compiler/aco_optimizer.cpp
src/amd/vulkan/radv_pipeline.c

index 7b786c8..7366cbb 100644 (file)
@@ -2092,6 +2092,35 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
       }
       break;
    }
+   case nir_op_ffma: {
+      if (dst.regClass() == v2b) {
+         emit_vop3a_instruction(ctx, instr, aco_opcode::v_fma_f16, dst, false, 3);
+      } else if (dst.regClass() == v1 && instr->dest.dest.ssa.bit_size == 16) {
+         assert(instr->dest.dest.ssa.num_components == 2);
+
+         Temp src0 = as_vgpr(ctx, get_alu_src_vop3p(ctx, instr->src[0]));
+         Temp src1 = as_vgpr(ctx, get_alu_src_vop3p(ctx, instr->src[1]));
+         Temp src2 = as_vgpr(ctx, get_alu_src_vop3p(ctx, instr->src[2]));
+
+         /* swizzle to opsel: all swizzles are either 0 (x) or 1 (y) */
+         unsigned opsel_lo = 0, opsel_hi = 0;
+         for (unsigned i = 0; i < 3; i++) {
+            opsel_lo |= (instr->src[i].swizzle[0] & 1) << i;
+            opsel_hi |= (instr->src[i].swizzle[1] & 1) << i;
+         }
+
+         bld.vop3p(aco_opcode::v_pk_fma_f16, Definition(dst), src0, src1, src2, opsel_lo, opsel_hi);
+         emit_split_vector(ctx, dst, 2);
+      } else if (dst.regClass() == v1) {
+         emit_vop3a_instruction(ctx, instr, aco_opcode::v_fma_f32, dst,
+                                ctx->block->fp_mode.must_flush_denorms32, 3);
+      } else if (dst.regClass() == v2) {
+         emit_vop3a_instruction(ctx, instr, aco_opcode::v_fma_f64, dst, false, 3);
+      } else {
+         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
+      }
+      break;
+   }
    case nir_op_fmax: {
       if (dst.regClass() == v2b) {
          // TODO: check fp_mode.must_flush_denorms16_64
index 8ad3a51..ed72d30 100644 (file)
@@ -467,6 +467,7 @@ init_context(isel_context* ctx, nir_shader* shader)
                case nir_op_fmul:
                case nir_op_fadd:
                case nir_op_fsub:
+               case nir_op_ffma:
                case nir_op_fmax:
                case nir_op_fmin:
                case nir_op_fneg:
index b2569a0..ed90c63 100644 (file)
@@ -3112,9 +3112,7 @@ combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
 
    /* check for fneg modifiers */
    if (instr_info.can_use_input_modifiers[(int)instr->opcode]) {
-      /* at this point, we only have 2-operand instructions */
-      assert(instr->operands.size() == 2);
-      for (unsigned i = 0; i < 2; i++) {
+      for (unsigned i = 0; i < instr->operands.size(); i++) {
          Operand& op = instr->operands[i];
          if (!op.isTemp())
             continue;
index 9c6e5a1..fd8cf55 100644 (file)
@@ -3396,6 +3396,7 @@ opt_vectorize_callback(const nir_instr *instr, void *_)
    case nir_op_fadd:
    case nir_op_fsub:
    case nir_op_fmul:
+   case nir_op_ffma:
    case nir_op_fneg:
    case nir_op_fsat:
    case nir_op_fmin: