From e5df6ee6054b2f6a47e09b3cb613b48fc6f3307e Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Wed, 7 Jun 2023 20:36:56 +0200 Subject: [PATCH] aco: make validation work without SSA temps MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_validate.cpp | 137 ++++++++++++++++---------------------- 1 file changed, 59 insertions(+), 78 deletions(-) diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp index e32747b..2d0db3c 100644 --- a/src/amd/compiler/aco_validate.cpp +++ b/src/amd/compiler/aco_validate.cpp @@ -338,10 +338,10 @@ validate_ir(Program* program) if (instr->isVOPC() || instr->opcode == aco_opcode::v_readfirstlane_b32 || instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_readlane_b32_e64) { - check(instr->definitions[0].getTemp().type() == RegType::sgpr, + check(instr->definitions[0].regClass().type() == RegType::sgpr, "Wrong Definition type for VALU instruction", instr.get()); } else { - check(instr->definitions[0].getTemp().type() == RegType::vgpr, + check(instr->definitions[0].regClass().type() == RegType::vgpr, "Wrong Definition type for VALU instruction", instr.get()); } @@ -352,35 +352,30 @@ validate_ir(Program* program) if (instr->opcode == aco_opcode::v_readfirstlane_b32 || instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_readlane_b32_e64) { - check(i != 1 || (op.isTemp() && op.regClass().type() == RegType::sgpr) || - op.isConstant(), + check(i != 1 || op.isOfType(RegType::sgpr) || op.isConstant(), "Must be a SGPR or a constant", instr.get()); - check(i == 1 || (op.isTemp() && op.regClass().type() == RegType::vgpr && - op.bytes() <= 4), + check(i == 1 || (op.isOfType(RegType::vgpr) && op.bytes() <= 4), "Wrong Operand type for VALU instruction", instr.get()); continue; } if (instr->opcode == aco_opcode::v_permlane16_b32 || instr->opcode == aco_opcode::v_permlanex16_b32) { - check(i != 0 || (op.isTemp() && op.regClass().type() == RegType::vgpr), + check(i != 0 || op.isOfType(RegType::vgpr), "Operand 0 of v_permlane must be VGPR", instr.get()); - check(i == 0 || (op.isTemp() && op.regClass().type() == RegType::sgpr) || - op.isConstant(), + check(i == 0 || op.isOfType(RegType::sgpr) || op.isConstant(), "Lane select operands of v_permlane must be SGPR or constant", instr.get()); } if (instr->opcode == aco_opcode::v_writelane_b32 || instr->opcode == aco_opcode::v_writelane_b32_e64) { - check(i != 2 || (op.isTemp() && op.regClass().type() == RegType::vgpr && - op.bytes() <= 4), + check(i != 2 || (op.isOfType(RegType::vgpr) && op.bytes() <= 4), "Wrong Operand type for VALU instruction", instr.get()); - check(i == 2 || (op.isTemp() && op.regClass().type() == RegType::sgpr) || - op.isConstant(), + check(i == 2 || op.isOfType(RegType::sgpr) || op.isConstant(), "Must be a SGPR or a constant", instr.get()); continue; } - if (op.isTemp() && instr->operands[i].regClass().type() == RegType::sgpr) { + if (op.isOfType(RegType::sgpr)) { check(scalar_mask & (1 << i), "Wrong source position for SGPR argument", instr.get()); @@ -425,10 +420,10 @@ validate_ir(Program* program) if (instr->isSOP1() || instr->isSOP2()) { if (!instr->definitions.empty()) - check(instr->definitions[0].getTemp().type() == RegType::sgpr, + check(instr->definitions[0].regClass().type() == RegType::sgpr, "Wrong Definition type for SALU instruction", instr.get()); for (const Operand& op : instr->operands) { - check(op.isConstant() || op.regClass().type() <= RegType::sgpr, + check(op.isConstant() || op.isOfType(RegType::sgpr), "Wrong Operand type for SALU instruction", instr.get()); } } @@ -444,19 +439,19 @@ validate_ir(Program* program) } check(size == instr->definitions[0].bytes(), "Definition size does not match operand sizes", instr.get()); - if (instr->definitions[0].getTemp().type() == RegType::sgpr) { + if (instr->definitions[0].regClass().type() == RegType::sgpr) { for (const Operand& op : instr->operands) { check(op.isConstant() || op.regClass().type() == RegType::sgpr, "Wrong Operand type for scalar vector", instr.get()); } } } else if (instr->opcode == aco_opcode::p_extract_vector) { - check((instr->operands[0].isTemp()) && instr->operands[1].isConstant(), + check(!instr->operands[0].isConstant() && instr->operands[1].isConstant(), "Wrong Operand types", instr.get()); check((instr->operands[1].constantValue() + 1) * instr->definitions[0].bytes() <= instr->operands[0].bytes(), "Index out of range", instr.get()); - check(instr->definitions[0].getTemp().type() == RegType::vgpr || + check(instr->definitions[0].regClass().type() == RegType::vgpr || instr->operands[0].regClass().type() == RegType::sgpr, "Cannot extract SGPR value from VGPR vector", instr.get()); check(program->gfx_level >= GFX9 || @@ -464,14 +459,14 @@ validate_ir(Program* program) instr->operands[0].regClass().type() == RegType::vgpr, "Cannot extract subdword from SGPR before GFX9+", instr.get()); } else if (instr->opcode == aco_opcode::p_split_vector) { - check(instr->operands[0].isTemp(), "Operand must be a temporary", instr.get()); + check(!instr->operands[0].isConstant(), "Operand must not be constant", instr.get()); unsigned size = 0; for (const Definition& def : instr->definitions) { size += def.bytes(); } check(size == instr->operands[0].bytes(), "Operand size does not match definition sizes", instr.get()); - if (instr->operands[0].getTemp().type() == RegType::vgpr) { + if (instr->operands[0].isOfType(RegType::vgpr)) { for (const Definition& def : instr->definitions) check(def.regClass().type() == RegType::vgpr, "Wrong Definition type for VGPR split_vector", instr.get()); @@ -486,10 +481,10 @@ validate_ir(Program* program) for (unsigned i = 0; i < instr->operands.size(); i++) { check(instr->definitions[i].bytes() == instr->operands[i].bytes(), "Operand and Definition size must match", instr.get()); - if (instr->operands[i].isTemp()) { - check((instr->definitions[i].getTemp().type() == + if (instr->operands[i].hasRegClass()) { + check((instr->definitions[i].regClass().type() == instr->operands[i].regClass().type()) || - (instr->definitions[i].getTemp().type() == RegType::vgpr && + (instr->definitions[i].regClass().type() == RegType::vgpr && instr->operands[i].regClass().type() == RegType::sgpr), "Operand and Definition types do not match", instr.get()); check(instr->definitions[i].regClass().is_linear_vgpr() == @@ -504,7 +499,7 @@ validate_ir(Program* program) } else if (instr->opcode == aco_opcode::p_phi) { check(instr->operands.size() == block.logical_preds.size(), "Number of Operands does not match number of predecessors", instr.get()); - check(instr->definitions[0].getTemp().type() == RegType::vgpr, + check(instr->definitions[0].regClass().type() == RegType::vgpr, "Logical Phi Definition must be vgpr", instr.get()); for (const Operand& op : instr->operands) check(instr->definitions[0].size() == op.size(), @@ -520,26 +515,27 @@ validate_ir(Program* program) "Number of Operands does not match number of predecessors", instr.get()); } else if (instr->opcode == aco_opcode::p_extract || instr->opcode == aco_opcode::p_insert) { - check(instr->operands[0].isTemp(), "Data operand must be temporary", instr.get()); + check(!instr->operands[0].isConstant(), "Data operand must not be constant", + instr.get()); check(instr->operands[1].isConstant(), "Index must be constant", instr.get()); if (instr->opcode == aco_opcode::p_extract) check(instr->operands[3].isConstant(), "Sign-extend flag must be constant", instr.get()); - check(instr->definitions[0].getTemp().type() != RegType::sgpr || - instr->operands[0].getTemp().type() == RegType::sgpr, + check(instr->definitions[0].regClass().type() != RegType::sgpr || + instr->operands[0].regClass().type() == RegType::sgpr, "Can't extract/insert VGPR to SGPR", instr.get()); if (instr->opcode == aco_opcode::p_insert) check(instr->operands[0].bytes() == instr->definitions[0].bytes(), "Sizes of p_insert data operand and definition must match", instr.get()); - if (instr->definitions[0].getTemp().type() == RegType::sgpr) + if (instr->definitions[0].regClass().type() == RegType::sgpr) check(instr->definitions.size() >= 2 && instr->definitions[1].isFixed() && instr->definitions[1].physReg() == scc, "SGPR extract/insert needs an SCC definition", instr.get()); - unsigned data_bits = instr->operands[0].getTemp().bytes() * 8u; + unsigned data_bits = instr->operands[0].bytes() * 8u; unsigned op_bits = instr->operands[2].constantValue(); if (instr->opcode == aco_opcode::p_insert) { @@ -558,20 +554,19 @@ validate_ir(Program* program) } else if (instr->opcode == aco_opcode::p_jump_to_epilog) { check(instr->definitions.size() == 0, "p_jump_to_epilog must have 0 definitions", instr.get()); - check(instr->operands.size() > 0 && - instr->operands[0].getTemp().type() == RegType::sgpr && - instr->operands[0].getTemp().size() == 2, + check(instr->operands.size() > 0 && instr->operands[0].isOfType(RegType::sgpr) && + instr->operands[0].size() == 2, "First operand of p_jump_to_epilog must be a SGPR", instr.get()); for (unsigned i = 1; i < instr->operands.size(); i++) { - check(instr->operands[i].getTemp().type() == RegType::vgpr || - instr->operands[i].isUndefined(), - "Other operands of p_jump_to_epilog must be VGPRs or undef", instr.get()); + check( + instr->operands[i].isOfType(RegType::vgpr) || instr->operands[i].isUndefined(), + "Other operands of p_jump_to_epilog must be VGPRs or undef", instr.get()); } } else if (instr->opcode == aco_opcode::p_dual_src_export_gfx11) { check(instr->definitions.size() == 6, "p_dual_src_export_gfx11 must have 6 definitions", instr.get()); - check(instr->definitions[2].getTemp().type() == RegType::vgpr && - instr->definitions[2].getTemp().size() == 1, + check(instr->definitions[2].regClass().type() == RegType::vgpr && + instr->definitions[2].regClass().size() == 1, "Third definition of p_dual_src_export_gfx11 must be a v1", instr.get()); check(instr->definitions[3].regClass() == program->lane_mask, "Fourth definition of p_dual_src_export_gfx11 must be a lane mask", instr.get()); @@ -582,9 +577,9 @@ validate_ir(Program* program) check(instr->operands.size() == 8, "p_dual_src_export_gfx11 must have 8 operands", instr.get()); for (unsigned i = 0; i < instr->operands.size(); i++) { - check(instr->operands[i].getTemp().type() == RegType::vgpr || - instr->operands[i].isUndefined(), - "Operands of p_dual_src_export_gfx11 must be VGPRs or undef", instr.get()); + check( + instr->operands[i].isOfType(RegType::vgpr) || instr->operands[i].isUndefined(), + "Operands of p_dual_src_export_gfx11 must be VGPRs or undef", instr.get()); } } else if (instr->opcode == aco_opcode::p_start_linear_vgpr) { check(instr->definitions.size() == 1, "Must have one definition", instr.get()); @@ -618,17 +613,13 @@ validate_ir(Program* program) } case Format::SMEM: { if (instr->operands.size() >= 1) - check((instr->operands[0].isFixed() && !instr->operands[0].isConstant()) || - (instr->operands[0].isTemp() && - instr->operands[0].regClass().type() == RegType::sgpr), - "SMEM operands must be sgpr", instr.get()); + check(instr->operands[0].isOfType(RegType::sgpr), "SMEM operands must be sgpr", + instr.get()); if (instr->operands.size() >= 2) - check(instr->operands[1].isConstant() || - (instr->operands[1].isTemp() && - instr->operands[1].regClass().type() == RegType::sgpr), + check(instr->operands[1].isConstant() || instr->operands[1].isOfType(RegType::sgpr), "SMEM offset must be constant or sgpr", instr.get()); if (!instr->definitions.empty()) - check(instr->definitions[0].getTemp().type() == RegType::sgpr, + check(instr->definitions[0].regClass().type() == RegType::sgpr, "SMEM result must be sgpr", instr.get()); break; } @@ -636,15 +627,11 @@ validate_ir(Program* program) case Format::MUBUF: { check(instr->operands.size() > 1, "VMEM instructions must have at least one operand", instr.get()); - check(instr->operands[1].hasRegClass() && - instr->operands[1].regClass().type() == RegType::vgpr, + check(instr->operands[1].isOfType(RegType::vgpr), "VADDR must be in vgpr for VMEM instructions", instr.get()); - check( - instr->operands[0].isTemp() && instr->operands[0].regClass().type() == RegType::sgpr, - "VMEM resource constant must be sgpr", instr.get()); - check(instr->operands.size() < 4 || - (instr->operands[3].isTemp() && - instr->operands[3].regClass().type() == RegType::vgpr), + check(instr->operands[0].isOfType(RegType::sgpr), "VMEM resource constant must be sgpr", + instr.get()); + check(instr->operands.size() < 4 || instr->operands[3].isOfType(RegType::vgpr), "VMEM write data must be vgpr", instr.get()); const bool d16 = instr->opcode == aco_opcode::buffer_load_dword || // FIXME: used to spill subdword variables @@ -668,8 +655,7 @@ validate_ir(Program* program) instr->opcode == aco_opcode::tbuffer_load_format_d16_xyz || instr->opcode == aco_opcode::tbuffer_load_format_d16_xyzw; if (instr->definitions.size()) { - check(instr->definitions[0].isTemp() && - instr->definitions[0].regClass().type() == RegType::vgpr, + check(instr->definitions[0].regClass().type() == RegType::vgpr, "VMEM definitions[0] (VDATA) must be VGPR", instr.get()); check(d16 || !instr->definitions[0].regClass().is_subdword(), "Only D16 opcodes can load subdword values.", instr.get()); @@ -699,12 +685,13 @@ validate_ir(Program* program) } if (instr->mimg().strict_wqm) { - check(instr->operands[3].isTemp() && instr->operands[3].regClass().is_linear_vgpr(), + check(instr->operands[3].hasRegClass() && + instr->operands[3].regClass().is_linear_vgpr(), "MIMG operands[3] must be temp linear VGPR.", instr.get()); unsigned total_size = 0; for (unsigned i = 4; i < instr->operands.size(); i++) { - check(instr->operands[i].isTemp() && instr->operands[i].regClass() == v1, + check(instr->operands[i].hasRegClass() && instr->operands[i].regClass() == v1, "MIMG operands[4+] (VADDR) must be v1", instr.get()); total_size += instr->operands[i].bytes(); } @@ -733,8 +720,7 @@ validate_ir(Program* program) } if (instr->definitions.size()) { - check(instr->definitions[0].isTemp() && - instr->definitions[0].regClass().type() == RegType::vgpr, + check(instr->definitions[0].regClass().type() == RegType::vgpr, "MIMG definitions[0] (VDATA) must be VGPR", instr.get()); check(instr->mimg().d16 || !instr->definitions[0].regClass().is_subdword(), "Only D16 MIMG instructions can load subdword values.", instr.get()); @@ -745,19 +731,17 @@ validate_ir(Program* program) } case Format::DS: { for (const Operand& op : instr->operands) { - check((op.isTemp() && op.regClass().type() == RegType::vgpr) || op.physReg() == m0 || - op.isUndefined(), + check(op.isOfType(RegType::vgpr) || op.physReg() == m0 || op.isUndefined(), "Only VGPRs are valid DS instruction operands", instr.get()); } if (!instr->definitions.empty()) - check(instr->definitions[0].getTemp().type() == RegType::vgpr, + check(instr->definitions[0].regClass().type() == RegType::vgpr, "DS instruction must return VGPR", instr.get()); break; } case Format::EXP: { for (unsigned i = 0; i < 4; i++) - check(instr->operands[i].hasRegClass() && - instr->operands[i].regClass().type() == RegType::vgpr, + check(instr->operands[i].isOfType(RegType::vgpr), "Only VGPRs are valid Export arguments", instr.get()); break; } @@ -766,25 +750,22 @@ validate_ir(Program* program) instr.get()); FALLTHROUGH; case Format::GLOBAL: - check( - instr->operands[0].isTemp() && instr->operands[0].regClass().type() == RegType::vgpr, - "FLAT/GLOBAL address must be vgpr", instr.get()); + check(instr->operands[0].isOfType(RegType::vgpr), "FLAT/GLOBAL address must be vgpr", + instr.get()); FALLTHROUGH; case Format::SCRATCH: { - check(instr->operands[0].hasRegClass() && - instr->operands[0].regClass().type() == RegType::vgpr, + check(instr->operands[0].isOfType(RegType::vgpr), "FLAT/GLOBAL/SCRATCH address must be undefined or vgpr", instr.get()); - check(instr->operands[1].hasRegClass() && - instr->operands[1].regClass().type() == RegType::sgpr, + check(instr->operands[1].isOfType(RegType::sgpr), "FLAT/GLOBAL/SCRATCH sgpr address must be undefined or sgpr", instr.get()); if (instr->format == Format::SCRATCH && program->gfx_level < GFX10_3) - check(instr->operands[0].isTemp() || instr->operands[1].isTemp(), + check(!instr->operands[0].isUndefined() || !instr->operands[1].isUndefined(), "SCRATCH must have either SADDR or ADDR operand", instr.get()); if (!instr->definitions.empty()) - check(instr->definitions[0].getTemp().type() == RegType::vgpr, + check(instr->definitions[0].regClass().type() == RegType::vgpr, "FLAT/GLOBAL/SCRATCH result must be vgpr", instr.get()); else - check(instr->operands[2].regClass().type() == RegType::vgpr, + check(instr->operands[2].isOfType(RegType::vgpr), "FLAT/GLOBAL/SCRATCH data must be vgpr", instr.get()); break; } -- 2.7.4