ac/llvm: implement udot_4x8/sdot_4x8/udot_2x16/sdot_2x16 opcodes
authorRhys Perry <pendingchaos02@gmail.com>
Tue, 31 Aug 2021 10:53:33 +0000 (11:53 +0100)
committerMarge Bot <eric+marge@anholt.net>
Fri, 3 Sep 2021 13:21:27 +0000 (13:21 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12617>

src/amd/llvm/ac_nir_to_llvm.c

index 804a1dc..ef24e29 100644 (file)
@@ -1247,6 +1247,34 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
       break;
    }
 
+   case nir_op_sdot_4x8_iadd:
+   case nir_op_udot_4x8_uadd:
+   case nir_op_sdot_4x8_iadd_sat:
+   case nir_op_udot_4x8_uadd_sat: {
+      const char *name = instr->op == nir_op_sdot_4x8_iadd ||
+                         instr->op == nir_op_sdot_4x8_iadd_sat
+                         ? "llvm.amdgcn.sdot4" : "llvm.amdgcn.udot4";
+      src[3] = LLVMConstInt(ctx->ac.i1, instr->op == nir_op_sdot_4x8_iadd_sat ||
+                                        instr->op == nir_op_udot_4x8_uadd_sat, false);
+      result = ac_build_intrinsic(&ctx->ac, name, def_type, src, 4, AC_FUNC_ATTR_READNONE);
+      break;
+   }
+
+   case nir_op_sdot_2x16_iadd:
+   case nir_op_udot_2x16_uadd:
+   case nir_op_sdot_2x16_iadd_sat:
+   case nir_op_udot_2x16_uadd_sat: {
+      const char *name = instr->op == nir_op_sdot_2x16_iadd ||
+                         instr->op == nir_op_sdot_2x16_iadd_sat
+                         ? "llvm.amdgcn.sdot2" : "llvm.amdgcn.udot2";
+      src[0] = LLVMBuildBitCast(ctx->ac.builder, src[0], ctx->ac.v2i16, "");
+      src[1] = LLVMBuildBitCast(ctx->ac.builder, src[1], ctx->ac.v2i16, "");
+      src[3] = LLVMConstInt(ctx->ac.i1, instr->op == nir_op_sdot_2x16_iadd_sat ||
+                                        instr->op == nir_op_udot_2x16_uadd_sat, false);
+      result = ac_build_intrinsic(&ctx->ac, name, def_type, src, 4, AC_FUNC_ATTR_READNONE);
+      break;
+   }
+
    default:
       fprintf(stderr, "Unknown NIR alu instr: ");
       nir_print_instr(&instr->instr, stderr);