From 23ac24f5b1fdde73cf8ec1ef6cbe08d73d6776f5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Daniel=20Sch=C3=BCrmann?= Date: Fri, 28 Feb 2020 20:17:44 +0100 Subject: [PATCH] aco: add missing conversion operations for small bitsizes MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Reviewed-by: Rhys Perry Reviewed-By: Timur Kristóf Part-of: --- src/amd/compiler/aco_instruction_selection.cpp | 195 +++++++++++++++++++-- .../compiler/aco_instruction_selection_setup.cpp | 9 +- 2 files changed, 190 insertions(+), 14 deletions(-) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 8b8f3fa..4589b40 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -1900,8 +1900,27 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_f2f16: + case nir_op_f2f16_rtne: { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 64) + src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src); + src = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src); + break; + } + case nir_op_f2f16_rtz: { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 64) + src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src); + src = bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32, bld.def(v1), src, Operand(0u)); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src); + break; + } case nir_op_f2f32: { - if (instr->src[0].src.ssa->bit_size == 64) { + if (instr->src[0].src.ssa->bit_size == 16) { + emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f16, dst); + } else if (instr->src[0].src.ssa->bit_size == 64) { emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f64, dst); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1911,13 +1930,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_f2f64: { - if (instr->src[0].src.ssa->bit_size == 32) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f64_f32, dst); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 16) + src = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src); + bld.vop1(aco_opcode::v_cvt_f64_f32, Definition(dst), src); break; } case nir_op_i2f32: { @@ -1969,6 +1985,36 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_f2i16: { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 16) + src = bld.vop1(aco_opcode::v_cvt_i16_f16, bld.def(v1), src); + else if (instr->src[0].src.ssa->bit_size == 32) + src = bld.vop1(aco_opcode::v_cvt_i32_f32, bld.def(v1), src); + else + src = bld.vop1(aco_opcode::v_cvt_i32_f64, bld.def(v1), src); + + if (dst.type() == RegType::vgpr) + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src); + else + bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), src); + break; + } + case nir_op_f2u16: { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 16) + src = bld.vop1(aco_opcode::v_cvt_u16_f16, bld.def(v1), src); + else if (instr->src[0].src.ssa->bit_size == 32) + src = bld.vop1(aco_opcode::v_cvt_u32_f32, bld.def(v1), src); + else + src = bld.vop1(aco_opcode::v_cvt_u32_f64, bld.def(v1), src); + + if (dst.type() == RegType::vgpr) + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src); + else + bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), src); + break; + } case nir_op_f2i32: { Temp src = get_alu_src(ctx, instr->src[0]); if (instr->src[0].src.ssa->bit_size == 32) { @@ -2190,9 +2236,91 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_i2i8: + case nir_op_u2u8: { + Temp src = get_alu_src(ctx, instr->src[0]); + /* we can actually just say dst = src */ + if (src.regClass() == s1) + bld.copy(Definition(dst), src); + else + emit_extract_vector(ctx, src, 0, dst); + break; + } + case nir_op_i2i16: { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) { + if (dst.regClass() == s1) { + bld.sop1(aco_opcode::s_sext_i32_i8, Definition(dst), Operand(src)); + } else { + assert(src.regClass() == v1b); + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; + sdwa->operands[0] = Operand(src); + sdwa->definitions[0] = Definition(dst); + sdwa->sel[0] = sdwa_sbyte; + sdwa->dst_sel = sdwa_sword; + ctx->block->instructions.emplace_back(std::move(sdwa)); + } + } else { + Temp src = get_alu_src(ctx, instr->src[0]); + /* we can actually just say dst = src */ + if (src.regClass() == s1) + bld.copy(Definition(dst), src); + else + emit_extract_vector(ctx, src, 0, dst); + } + break; + } + case nir_op_u2u16: { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) { + if (dst.regClass() == s1) + bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFu), src); + else { + assert(src.regClass() == v1b); + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; + sdwa->operands[0] = Operand(src); + sdwa->definitions[0] = Definition(dst); + sdwa->sel[0] = sdwa_ubyte; + sdwa->dst_sel = sdwa_uword; + ctx->block->instructions.emplace_back(std::move(sdwa)); + } + } else { + Temp src = get_alu_src(ctx, instr->src[0]); + /* we can actually just say dst = src */ + if (src.regClass() == s1) + bld.copy(Definition(dst), src); + else + emit_extract_vector(ctx, src, 0, dst); + } + break; + } case nir_op_i2i32: { Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 64) { + if (instr->src[0].src.ssa->bit_size == 8) { + if (dst.regClass() == s1) { + bld.sop1(aco_opcode::s_sext_i32_i8, Definition(dst), Operand(src)); + } else { + assert(src.regClass() == v1b); + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; + sdwa->operands[0] = Operand(src); + sdwa->definitions[0] = Definition(dst); + sdwa->sel[0] = sdwa_sbyte; + sdwa->dst_sel = sdwa_sdword; + ctx->block->instructions.emplace_back(std::move(sdwa)); + } + } else if (instr->src[0].src.ssa->bit_size == 16) { + if (dst.regClass() == s1) { + bld.sop1(aco_opcode::s_sext_i32_i16, Definition(dst), Operand(src)); + } else { + assert(src.regClass() == v2b); + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; + sdwa->operands[0] = Operand(src); + sdwa->definitions[0] = Definition(dst); + sdwa->sel[0] = sdwa_sword; + sdwa->dst_sel = sdwa_udword; + ctx->block->instructions.emplace_back(std::move(sdwa)); + } + } else if (instr->src[0].src.ssa->bit_size == 64) { /* we can actually just say dst = src, as it would map the lower register */ emit_extract_vector(ctx, src, 0, dst); } else { @@ -2204,12 +2332,29 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_u2u32: { Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 16) { + if (instr->src[0].src.ssa->bit_size == 8) { + if (dst.regClass() == s1) + bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFu), src); + else { + assert(src.regClass() == v1b); + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; + sdwa->operands[0] = Operand(src); + sdwa->definitions[0] = Definition(dst); + sdwa->sel[0] = sdwa_ubyte; + sdwa->dst_sel = sdwa_udword; + ctx->block->instructions.emplace_back(std::move(sdwa)); + } + } else if (instr->src[0].src.ssa->bit_size == 16) { if (dst.regClass() == s1) { bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFFFu), src); } else { - // TODO: do better with SDWA - bld.vop2(aco_opcode::v_and_b32, Definition(dst), Operand(0xFFFFu), src); + assert(src.regClass() == v2b); + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; + sdwa->operands[0] = Operand(src); + sdwa->definitions[0] = Definition(dst); + sdwa->sel[0] = sdwa_uword; + sdwa->dst_sel = sdwa_udword; + ctx->block->instructions.emplace_back(std::move(sdwa)); } } else if (instr->src[0].src.ssa->bit_size == 64) { /* we can actually just say dst = src, as it would map the lower register */ @@ -2298,6 +2443,32 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_unpack_64_2x32_split_y: bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0])); break; + case nir_op_unpack_32_2x16_split_x: + if (dst.type() == RegType::vgpr) { + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(dst.regClass()), get_alu_src(ctx, instr->src[0])); + } else { + bld.copy(Definition(dst), get_alu_src(ctx, instr->src[0])); + } + break; + case nir_op_unpack_32_2x16_split_y: + if (dst.type() == RegType::vgpr) { + bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0])); + } else { + bld.sop2(aco_opcode::s_bfe_u32, Definition(dst), get_alu_src(ctx, instr->src[0]), Operand(uint32_t(16 << 16 | 16))); + } + break; + case nir_op_pack_32_2x16_split: { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = get_alu_src(ctx, instr->src[1]); + if (dst.regClass() == v1) { + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src0, src1); + } else { + src0 = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), src0, Operand(0xFFFFu)); + src1 = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), src1, Operand(16u)); + bld.sop2(aco_opcode::s_or_b32, Definition(dst), bld.def(s1, scc), src0, src1); + } + break; + } case nir_op_pack_half_2x16: { Temp src = get_alu_src(ctx, instr->src[0], 2); diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp index d7d294e..ac7df92 100644 --- a/src/amd/compiler/aco_instruction_selection_setup.cpp +++ b/src/amd/compiler/aco_instruction_selection_setup.cpp @@ -305,6 +305,9 @@ void init_context(isel_context *ctx, nir_shader *shader) case nir_op_fround_even: case nir_op_fsin: case nir_op_fcos: + case nir_op_f2f16: + case nir_op_f2f16_rtz: + case nir_op_f2f16_rtne: case nir_op_f2f32: case nir_op_f2f64: case nir_op_u2f32: @@ -328,13 +331,15 @@ void init_context(isel_context *ctx, nir_shader *shader) case nir_op_cube_face_coord: type = RegType::vgpr; break; + case nir_op_f2i16: + case nir_op_f2u16: + case nir_op_f2i32: + case nir_op_f2u32: case nir_op_f2i64: case nir_op_f2u64: case nir_op_b2i32: case nir_op_b2b32: case nir_op_b2f32: - case nir_op_f2i32: - case nir_op_f2u32: case nir_op_mov: type = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? RegType::vgpr : RegType::sgpr; break; -- 2.7.4