aco: make validation work without SSA temps
authorGeorg Lehmann <dadschoorse@gmail.com>
Wed, 7 Jun 2023 18:36:56 +0000 (20:36 +0200)
committerMarge Bot <emma+marge@anholt.net>
Mon, 12 Jun 2023 19:43:17 +0000 (19:43 +0000)
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23507>

src/amd/compiler/aco_validate.cpp

index e32747b..2d0db3c 100644 (file)
@@ -338,10 +338,10 @@ validate_ir(Program* program)
                if (instr->isVOPC() || instr->opcode == aco_opcode::v_readfirstlane_b32 ||
                    instr->opcode == aco_opcode::v_readlane_b32 ||
                    instr->opcode == aco_opcode::v_readlane_b32_e64) {
-                  check(instr->definitions[0].getTemp().type() == RegType::sgpr,
+                  check(instr->definitions[0].regClass().type() == RegType::sgpr,
                         "Wrong Definition type for VALU instruction", instr.get());
                } else {
-                  check(instr->definitions[0].getTemp().type() == RegType::vgpr,
+                  check(instr->definitions[0].regClass().type() == RegType::vgpr,
                         "Wrong Definition type for VALU instruction", instr.get());
                }
 
@@ -352,35 +352,30 @@ validate_ir(Program* program)
                   if (instr->opcode == aco_opcode::v_readfirstlane_b32 ||
                       instr->opcode == aco_opcode::v_readlane_b32 ||
                       instr->opcode == aco_opcode::v_readlane_b32_e64) {
-                     check(i != 1 || (op.isTemp() && op.regClass().type() == RegType::sgpr) ||
-                              op.isConstant(),
+                     check(i != 1 || op.isOfType(RegType::sgpr) || op.isConstant(),
                            "Must be a SGPR or a constant", instr.get());
-                     check(i == 1 || (op.isTemp() && op.regClass().type() == RegType::vgpr &&
-                                      op.bytes() <= 4),
+                     check(i == 1 || (op.isOfType(RegType::vgpr) && op.bytes() <= 4),
                            "Wrong Operand type for VALU instruction", instr.get());
                      continue;
                   }
                   if (instr->opcode == aco_opcode::v_permlane16_b32 ||
                       instr->opcode == aco_opcode::v_permlanex16_b32) {
-                     check(i != 0 || (op.isTemp() && op.regClass().type() == RegType::vgpr),
+                     check(i != 0 || op.isOfType(RegType::vgpr),
                            "Operand 0 of v_permlane must be VGPR", instr.get());
-                     check(i == 0 || (op.isTemp() && op.regClass().type() == RegType::sgpr) ||
-                              op.isConstant(),
+                     check(i == 0 || op.isOfType(RegType::sgpr) || op.isConstant(),
                            "Lane select operands of v_permlane must be SGPR or constant",
                            instr.get());
                   }
 
                   if (instr->opcode == aco_opcode::v_writelane_b32 ||
                       instr->opcode == aco_opcode::v_writelane_b32_e64) {
-                     check(i != 2 || (op.isTemp() && op.regClass().type() == RegType::vgpr &&
-                                      op.bytes() <= 4),
+                     check(i != 2 || (op.isOfType(RegType::vgpr) && op.bytes() <= 4),
                            "Wrong Operand type for VALU instruction", instr.get());
-                     check(i == 2 || (op.isTemp() && op.regClass().type() == RegType::sgpr) ||
-                              op.isConstant(),
+                     check(i == 2 || op.isOfType(RegType::sgpr) || op.isConstant(),
                            "Must be a SGPR or a constant", instr.get());
                      continue;
                   }
-                  if (op.isTemp() && instr->operands[i].regClass().type() == RegType::sgpr) {
+                  if (op.isOfType(RegType::sgpr)) {
                      check(scalar_mask & (1 << i), "Wrong source position for SGPR argument",
                            instr.get());
 
@@ -425,10 +420,10 @@ validate_ir(Program* program)
 
             if (instr->isSOP1() || instr->isSOP2()) {
                if (!instr->definitions.empty())
-                  check(instr->definitions[0].getTemp().type() == RegType::sgpr,
+                  check(instr->definitions[0].regClass().type() == RegType::sgpr,
                         "Wrong Definition type for SALU instruction", instr.get());
                for (const Operand& op : instr->operands) {
-                  check(op.isConstant() || op.regClass().type() <= RegType::sgpr,
+                  check(op.isConstant() || op.isOfType(RegType::sgpr),
                         "Wrong Operand type for SALU instruction", instr.get());
                }
             }
@@ -444,19 +439,19 @@ validate_ir(Program* program)
                }
                check(size == instr->definitions[0].bytes(),
                      "Definition size does not match operand sizes", instr.get());
-               if (instr->definitions[0].getTemp().type() == RegType::sgpr) {
+               if (instr->definitions[0].regClass().type() == RegType::sgpr) {
                   for (const Operand& op : instr->operands) {
                      check(op.isConstant() || op.regClass().type() == RegType::sgpr,
                            "Wrong Operand type for scalar vector", instr.get());
                   }
                }
             } else if (instr->opcode == aco_opcode::p_extract_vector) {
-               check((instr->operands[0].isTemp()) && instr->operands[1].isConstant(),
+               check(!instr->operands[0].isConstant() && instr->operands[1].isConstant(),
                      "Wrong Operand types", instr.get());
                check((instr->operands[1].constantValue() + 1) * instr->definitions[0].bytes() <=
                         instr->operands[0].bytes(),
                      "Index out of range", instr.get());
-               check(instr->definitions[0].getTemp().type() == RegType::vgpr ||
+               check(instr->definitions[0].regClass().type() == RegType::vgpr ||
                         instr->operands[0].regClass().type() == RegType::sgpr,
                      "Cannot extract SGPR value from VGPR vector", instr.get());
                check(program->gfx_level >= GFX9 ||
@@ -464,14 +459,14 @@ validate_ir(Program* program)
                         instr->operands[0].regClass().type() == RegType::vgpr,
                      "Cannot extract subdword from SGPR before GFX9+", instr.get());
             } else if (instr->opcode == aco_opcode::p_split_vector) {
-               check(instr->operands[0].isTemp(), "Operand must be a temporary", instr.get());
+               check(!instr->operands[0].isConstant(), "Operand must not be constant", instr.get());
                unsigned size = 0;
                for (const Definition& def : instr->definitions) {
                   size += def.bytes();
                }
                check(size == instr->operands[0].bytes(),
                      "Operand size does not match definition sizes", instr.get());
-               if (instr->operands[0].getTemp().type() == RegType::vgpr) {
+               if (instr->operands[0].isOfType(RegType::vgpr)) {
                   for (const Definition& def : instr->definitions)
                      check(def.regClass().type() == RegType::vgpr,
                            "Wrong Definition type for VGPR split_vector", instr.get());
@@ -486,10 +481,10 @@ validate_ir(Program* program)
                for (unsigned i = 0; i < instr->operands.size(); i++) {
                   check(instr->definitions[i].bytes() == instr->operands[i].bytes(),
                         "Operand and Definition size must match", instr.get());
-                  if (instr->operands[i].isTemp()) {
-                     check((instr->definitions[i].getTemp().type() ==
+                  if (instr->operands[i].hasRegClass()) {
+                     check((instr->definitions[i].regClass().type() ==
                             instr->operands[i].regClass().type()) ||
-                              (instr->definitions[i].getTemp().type() == RegType::vgpr &&
+                              (instr->definitions[i].regClass().type() == RegType::vgpr &&
                                instr->operands[i].regClass().type() == RegType::sgpr),
                            "Operand and Definition types do not match", instr.get());
                      check(instr->definitions[i].regClass().is_linear_vgpr() ==
@@ -504,7 +499,7 @@ validate_ir(Program* program)
             } else if (instr->opcode == aco_opcode::p_phi) {
                check(instr->operands.size() == block.logical_preds.size(),
                      "Number of Operands does not match number of predecessors", instr.get());
-               check(instr->definitions[0].getTemp().type() == RegType::vgpr,
+               check(instr->definitions[0].regClass().type() == RegType::vgpr,
                      "Logical Phi Definition must be vgpr", instr.get());
                for (const Operand& op : instr->operands)
                   check(instr->definitions[0].size() == op.size(),
@@ -520,26 +515,27 @@ validate_ir(Program* program)
                      "Number of Operands does not match number of predecessors", instr.get());
             } else if (instr->opcode == aco_opcode::p_extract ||
                        instr->opcode == aco_opcode::p_insert) {
-               check(instr->operands[0].isTemp(), "Data operand must be temporary", instr.get());
+               check(!instr->operands[0].isConstant(), "Data operand must not be constant",
+                     instr.get());
                check(instr->operands[1].isConstant(), "Index must be constant", instr.get());
                if (instr->opcode == aco_opcode::p_extract)
                   check(instr->operands[3].isConstant(), "Sign-extend flag must be constant",
                         instr.get());
 
-               check(instr->definitions[0].getTemp().type() != RegType::sgpr ||
-                        instr->operands[0].getTemp().type() == RegType::sgpr,
+               check(instr->definitions[0].regClass().type() != RegType::sgpr ||
+                        instr->operands[0].regClass().type() == RegType::sgpr,
                      "Can't extract/insert VGPR to SGPR", instr.get());
 
                if (instr->opcode == aco_opcode::p_insert)
                   check(instr->operands[0].bytes() == instr->definitions[0].bytes(),
                         "Sizes of p_insert data operand and definition must match", instr.get());
 
-               if (instr->definitions[0].getTemp().type() == RegType::sgpr)
+               if (instr->definitions[0].regClass().type() == RegType::sgpr)
                   check(instr->definitions.size() >= 2 && instr->definitions[1].isFixed() &&
                            instr->definitions[1].physReg() == scc,
                         "SGPR extract/insert needs an SCC definition", instr.get());
 
-               unsigned data_bits = instr->operands[0].getTemp().bytes() * 8u;
+               unsigned data_bits = instr->operands[0].bytes() * 8u;
                unsigned op_bits = instr->operands[2].constantValue();
 
                if (instr->opcode == aco_opcode::p_insert) {
@@ -558,20 +554,19 @@ validate_ir(Program* program)
             } else if (instr->opcode == aco_opcode::p_jump_to_epilog) {
                check(instr->definitions.size() == 0, "p_jump_to_epilog must have 0 definitions",
                      instr.get());
-               check(instr->operands.size() > 0 &&
-                        instr->operands[0].getTemp().type() == RegType::sgpr &&
-                        instr->operands[0].getTemp().size() == 2,
+               check(instr->operands.size() > 0 && instr->operands[0].isOfType(RegType::sgpr) &&
+                        instr->operands[0].size() == 2,
                      "First operand of p_jump_to_epilog must be a SGPR", instr.get());
                for (unsigned i = 1; i < instr->operands.size(); i++) {
-                  check(instr->operands[i].getTemp().type() == RegType::vgpr ||
-                           instr->operands[i].isUndefined(),
-                        "Other operands of p_jump_to_epilog must be VGPRs or undef", instr.get());
+                  check(
+                     instr->operands[i].isOfType(RegType::vgpr) || instr->operands[i].isUndefined(),
+                     "Other operands of p_jump_to_epilog must be VGPRs or undef", instr.get());
                }
             } else if (instr->opcode == aco_opcode::p_dual_src_export_gfx11) {
                check(instr->definitions.size() == 6,
                      "p_dual_src_export_gfx11 must have 6 definitions", instr.get());
-               check(instr->definitions[2].getTemp().type() == RegType::vgpr &&
-                        instr->definitions[2].getTemp().size() == 1,
+               check(instr->definitions[2].regClass().type() == RegType::vgpr &&
+                        instr->definitions[2].regClass().size() == 1,
                      "Third definition of p_dual_src_export_gfx11 must be a v1", instr.get());
                check(instr->definitions[3].regClass() == program->lane_mask,
                      "Fourth definition of p_dual_src_export_gfx11 must be a lane mask", instr.get());
@@ -582,9 +577,9 @@ validate_ir(Program* program)
                check(instr->operands.size() == 8, "p_dual_src_export_gfx11 must have 8 operands",
                      instr.get());
                for (unsigned i = 0; i < instr->operands.size(); i++) {
-                  check(instr->operands[i].getTemp().type() == RegType::vgpr ||
-                           instr->operands[i].isUndefined(),
-                        "Operands of p_dual_src_export_gfx11 must be VGPRs or undef", instr.get());
+                  check(
+                     instr->operands[i].isOfType(RegType::vgpr) || instr->operands[i].isUndefined(),
+                     "Operands of p_dual_src_export_gfx11 must be VGPRs or undef", instr.get());
                }
             } else if (instr->opcode == aco_opcode::p_start_linear_vgpr) {
                check(instr->definitions.size() == 1, "Must have one definition", instr.get());
@@ -618,17 +613,13 @@ validate_ir(Program* program)
          }
          case Format::SMEM: {
             if (instr->operands.size() >= 1)
-               check((instr->operands[0].isFixed() && !instr->operands[0].isConstant()) ||
-                        (instr->operands[0].isTemp() &&
-                         instr->operands[0].regClass().type() == RegType::sgpr),
-                     "SMEM operands must be sgpr", instr.get());
+               check(instr->operands[0].isOfType(RegType::sgpr), "SMEM operands must be sgpr",
+                     instr.get());
             if (instr->operands.size() >= 2)
-               check(instr->operands[1].isConstant() ||
-                        (instr->operands[1].isTemp() &&
-                         instr->operands[1].regClass().type() == RegType::sgpr),
+               check(instr->operands[1].isConstant() || instr->operands[1].isOfType(RegType::sgpr),
                      "SMEM offset must be constant or sgpr", instr.get());
             if (!instr->definitions.empty())
-               check(instr->definitions[0].getTemp().type() == RegType::sgpr,
+               check(instr->definitions[0].regClass().type() == RegType::sgpr,
                      "SMEM result must be sgpr", instr.get());
             break;
          }
@@ -636,15 +627,11 @@ validate_ir(Program* program)
          case Format::MUBUF: {
             check(instr->operands.size() > 1, "VMEM instructions must have at least one operand",
                   instr.get());
-            check(instr->operands[1].hasRegClass() &&
-                     instr->operands[1].regClass().type() == RegType::vgpr,
+            check(instr->operands[1].isOfType(RegType::vgpr),
                   "VADDR must be in vgpr for VMEM instructions", instr.get());
-            check(
-               instr->operands[0].isTemp() && instr->operands[0].regClass().type() == RegType::sgpr,
-               "VMEM resource constant must be sgpr", instr.get());
-            check(instr->operands.size() < 4 ||
-                     (instr->operands[3].isTemp() &&
-                      instr->operands[3].regClass().type() == RegType::vgpr),
+            check(instr->operands[0].isOfType(RegType::sgpr), "VMEM resource constant must be sgpr",
+                  instr.get());
+            check(instr->operands.size() < 4 || instr->operands[3].isOfType(RegType::vgpr),
                   "VMEM write data must be vgpr", instr.get());
 
             const bool d16 = instr->opcode == aco_opcode::buffer_load_dword || // FIXME: used to spill subdword variables
@@ -668,8 +655,7 @@ validate_ir(Program* program)
                              instr->opcode == aco_opcode::tbuffer_load_format_d16_xyz ||
                              instr->opcode == aco_opcode::tbuffer_load_format_d16_xyzw;
             if (instr->definitions.size()) {
-               check(instr->definitions[0].isTemp() &&
-                        instr->definitions[0].regClass().type() == RegType::vgpr,
+               check(instr->definitions[0].regClass().type() == RegType::vgpr,
                      "VMEM definitions[0] (VDATA) must be VGPR", instr.get());
                check(d16 || !instr->definitions[0].regClass().is_subdword(),
                      "Only D16 opcodes can load subdword values.", instr.get());
@@ -699,12 +685,13 @@ validate_ir(Program* program)
             }
 
             if (instr->mimg().strict_wqm) {
-               check(instr->operands[3].isTemp() && instr->operands[3].regClass().is_linear_vgpr(),
+               check(instr->operands[3].hasRegClass() &&
+                        instr->operands[3].regClass().is_linear_vgpr(),
                      "MIMG operands[3] must be temp linear VGPR.", instr.get());
 
                unsigned total_size = 0;
                for (unsigned i = 4; i < instr->operands.size(); i++) {
-                  check(instr->operands[i].isTemp() && instr->operands[i].regClass() == v1,
+                  check(instr->operands[i].hasRegClass() && instr->operands[i].regClass() == v1,
                         "MIMG operands[4+] (VADDR) must be v1", instr.get());
                   total_size += instr->operands[i].bytes();
                }
@@ -733,8 +720,7 @@ validate_ir(Program* program)
             }
 
             if (instr->definitions.size()) {
-               check(instr->definitions[0].isTemp() &&
-                        instr->definitions[0].regClass().type() == RegType::vgpr,
+               check(instr->definitions[0].regClass().type() == RegType::vgpr,
                      "MIMG definitions[0] (VDATA) must be VGPR", instr.get());
                check(instr->mimg().d16 || !instr->definitions[0].regClass().is_subdword(),
                      "Only D16 MIMG instructions can load subdword values.", instr.get());
@@ -745,19 +731,17 @@ validate_ir(Program* program)
          }
          case Format::DS: {
             for (const Operand& op : instr->operands) {
-               check((op.isTemp() && op.regClass().type() == RegType::vgpr) || op.physReg() == m0 ||
-                     op.isUndefined(),
+               check(op.isOfType(RegType::vgpr) || op.physReg() == m0 || op.isUndefined(),
                      "Only VGPRs are valid DS instruction operands", instr.get());
             }
             if (!instr->definitions.empty())
-               check(instr->definitions[0].getTemp().type() == RegType::vgpr,
+               check(instr->definitions[0].regClass().type() == RegType::vgpr,
                      "DS instruction must return VGPR", instr.get());
             break;
          }
          case Format::EXP: {
             for (unsigned i = 0; i < 4; i++)
-               check(instr->operands[i].hasRegClass() &&
-                        instr->operands[i].regClass().type() == RegType::vgpr,
+               check(instr->operands[i].isOfType(RegType::vgpr),
                      "Only VGPRs are valid Export arguments", instr.get());
             break;
          }
@@ -766,25 +750,22 @@ validate_ir(Program* program)
                   instr.get());
             FALLTHROUGH;
          case Format::GLOBAL:
-            check(
-               instr->operands[0].isTemp() && instr->operands[0].regClass().type() == RegType::vgpr,
-               "FLAT/GLOBAL address must be vgpr", instr.get());
+            check(instr->operands[0].isOfType(RegType::vgpr), "FLAT/GLOBAL address must be vgpr",
+                  instr.get());
             FALLTHROUGH;
          case Format::SCRATCH: {
-            check(instr->operands[0].hasRegClass() &&
-                     instr->operands[0].regClass().type() == RegType::vgpr,
+            check(instr->operands[0].isOfType(RegType::vgpr),
                   "FLAT/GLOBAL/SCRATCH address must be undefined or vgpr", instr.get());
-            check(instr->operands[1].hasRegClass() &&
-                     instr->operands[1].regClass().type() == RegType::sgpr,
+            check(instr->operands[1].isOfType(RegType::sgpr),
                   "FLAT/GLOBAL/SCRATCH sgpr address must be undefined or sgpr", instr.get());
             if (instr->format == Format::SCRATCH && program->gfx_level < GFX10_3)
-               check(instr->operands[0].isTemp() || instr->operands[1].isTemp(),
+               check(!instr->operands[0].isUndefined() || !instr->operands[1].isUndefined(),
                      "SCRATCH must have either SADDR or ADDR operand", instr.get());
             if (!instr->definitions.empty())
-               check(instr->definitions[0].getTemp().type() == RegType::vgpr,
+               check(instr->definitions[0].regClass().type() == RegType::vgpr,
                      "FLAT/GLOBAL/SCRATCH result must be vgpr", instr.get());
             else
-               check(instr->operands[2].regClass().type() == RegType::vgpr,
+               check(instr->operands[2].isOfType(RegType::vgpr),
                      "FLAT/GLOBAL/SCRATCH data must be vgpr", instr.get());
             break;
          }