From 137974fabbffedbea78cd6b7fc0d30b8b604659a Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Mon, 30 Aug 2021 13:56:17 +0100 Subject: [PATCH] spirv: use sdot_2x16 and udot_2x16 opcodes Signed-off-by: Rhys Perry Reviewed-by: Ian Romanick Part-of: --- src/compiler/spirv/vtn_alu.c | 108 +++++++++++++++++++++++++------------------ 1 file changed, 64 insertions(+), 44 deletions(-) diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index ed73118..443878c 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -892,6 +892,7 @@ vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode, spirv_op_to_string(opcode)); } + unsigned packed_bit_size = 8; if (glsl_type_is_vector(vtn_src[0]->type)) { /* FINISHME: Is this actually as good or better for platforms that don't * have the special instructions (i.e., one or both of has_dot_4x8 or @@ -902,6 +903,14 @@ vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode, glsl_get_bit_size(dest_type) <= 32) { src[0] = nir_pack_32_4x8(&b->nb, src[0]); src[1] = nir_pack_32_4x8(&b->nb, src[1]); + } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 && + glsl_get_bit_size(vtn_src[0]->type) == 16 && + glsl_get_bit_size(dest_type) <= 32 && + opcode != SpvOpSUDotKHR && + opcode != SpvOpSUDotAccSatKHR) { + src[0] = nir_pack_32_2x16(&b->nb, src[0]); + src[1] = nir_pack_32_2x16(&b->nb, src[1]); + packed_bit_size = 16; } } else if (glsl_type_is_scalar(vtn_src[0]->type) && glsl_type_is_32bit(vtn_src[0]->type)) { @@ -1012,53 +1021,64 @@ vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode, assert(src[0]->bit_size == 32 && src[1]->bit_size == 32); nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32); - bool is_signed; - - switch (opcode) { - case SpvOpSDotKHR: - dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero); - is_signed = true; - break; - - case SpvOpUDotKHR: - dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero); - is_signed = false; - break; - - case SpvOpSUDotKHR: - dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero); - is_signed = true; - break; - - case SpvOpSDotAccSatKHR: - if (dest_size == 32) - dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]); - else + bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR || + opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR; + + if (packed_bit_size == 16) { + switch (opcode) { + case SpvOpSDotKHR: + dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero); + break; + case SpvOpUDotKHR: + dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero); + break; + case SpvOpSDotAccSatKHR: + if (dest_size == 32) + dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]); + else + dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero); + break; + case SpvOpUDotAccSatKHR: + if (dest_size == 32) + dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]); + else + dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero); + break; + default: + unreachable("Invalid opcode."); + } + } else { + switch (opcode) { + case SpvOpSDotKHR: dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero); - - is_signed = true; - break; - - case SpvOpUDotAccSatKHR: - if (dest_size == 32) - dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]); - else + break; + case SpvOpUDotKHR: dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero); - - is_signed = false; - break; - - case SpvOpSUDotAccSatKHR: - if (dest_size == 32) - dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]); - else + break; + case SpvOpSUDotKHR: dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero); - - is_signed = true; - break; - - default: - unreachable("Invalid opcode."); + break; + case SpvOpSDotAccSatKHR: + if (dest_size == 32) + dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]); + else + dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero); + break; + case SpvOpUDotAccSatKHR: + if (dest_size == 32) + dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]); + else + dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero); + break; + case SpvOpSUDotAccSatKHR: + if (dest_size == 32) + dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]); + else + dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero); + break; + default: + unreachable("Invalid opcode."); + } } if (dest_size != 32) { -- 2.7.4