aco: use VOP3+DPP
authorGeorg Lehmann <dadschoorse@gmail.com>
Sun, 23 Apr 2023 12:55:17 +0000 (14:55 +0200)
committerMarge Bot <emma+marge@anholt.net>
Fri, 12 May 2023 13:31:16 +0000 (13:31 +0000)
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22698>

src/amd/compiler/aco_ir.cpp
src/amd/compiler/aco_ir.h
src/amd/compiler/aco_optimizer.cpp
src/amd/compiler/aco_optimizer_postRA.cpp
src/amd/compiler/aco_validate.cpp

index 1543657..c6cdec4 100644 (file)
@@ -336,38 +336,52 @@ convert_to_SDWA(amd_gfx_level gfx_level, aco_ptr<Instruction>& instr)
 }
 
 bool
-can_use_DPP(const aco_ptr<Instruction>& instr, bool pre_ra, bool dpp8)
+can_use_DPP(amd_gfx_level gfx_level, const aco_ptr<Instruction>& instr, bool dpp8)
 {
    assert(instr->isVALU() && !instr->operands.empty());
 
    if (instr->isDPP())
       return instr->isDPP8() == dpp8;
 
-   if (instr->operands.size() && instr->operands[0].isLiteral())
+   if (instr->isSDWA() || instr->isVINTERP_INREG())
       return false;
 
-   if (instr->isSDWA() || instr->isVINTERP_INREG() || instr->isVOP3P())
+   if ((instr->format == Format::VOP3 || instr->isVOP3P()) && gfx_level < GFX11)
       return false;
 
-   if (!pre_ra && (instr->isVOPC() || instr->definitions.size() > 1) &&
-       instr->definitions.back().physReg() != vcc)
+   if ((instr->isVOPC() || instr->definitions.size() > 1) && instr->definitions.back().isFixed() &&
+       instr->definitions.back().physReg() != vcc && gfx_level < GFX11)
       return false;
 
-   if (!pre_ra && instr->operands.size() >= 3 && instr->operands[2].physReg() != vcc)
+   if (instr->operands.size() >= 3 && instr->operands[2].isFixed() &&
+       instr->operands[2].isOfType(RegType::sgpr) && instr->operands[2].physReg() != vcc &&
+       gfx_level < GFX11)
       return false;
 
-   if (instr->isVOP3()) {
+   if (instr->isVOP3() && gfx_level < GFX11) {
       const VALU_instruction* vop3 = &instr->valu();
-      if (vop3->clamp || vop3->omod || vop3->opsel)
+      if (vop3->clamp || vop3->omod)
          return false;
       if (dpp8)
          return false;
-      if (instr->format == Format::VOP3)
+   }
+
+   for (unsigned i = 0; i < instr->operands.size(); i++) {
+      if (instr->operands[i].isLiteral())
          return false;
-      if (instr->operands.size() > 1 && !instr->operands[1].isOfType(RegType::vgpr))
+      if (!instr->operands[i].isOfType(RegType::vgpr) && i < 2)
          return false;
    }
 
+   /* simpler than listing all VOP3P opcodes which do not support DPP */
+   if (instr->isVOP3P()) {
+      return instr->opcode == aco_opcode::v_fma_mix_f32 ||
+             instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
+             instr->opcode == aco_opcode::v_fma_mixhi_f16 ||
+             instr->opcode == aco_opcode::v_dot2_f32_f16 ||
+             instr->opcode == aco_opcode::v_dot2_f32_bf16;
+   }
+
    /* there are more cases but those all take 64-bit inputs */
    return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
           instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
@@ -375,18 +389,31 @@ can_use_DPP(const aco_ptr<Instruction>& instr, bool pre_ra, bool dpp8)
           instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
           instr->opcode != aco_opcode::v_readfirstlane_b32 &&
           instr->opcode != aco_opcode::v_cvt_f64_i32 &&
-          instr->opcode != aco_opcode::v_cvt_f64_f32 && instr->opcode != aco_opcode::v_cvt_f64_u32;
+          instr->opcode != aco_opcode::v_cvt_f64_f32 &&
+          instr->opcode != aco_opcode::v_cvt_f64_u32 && instr->opcode != aco_opcode::v_mul_lo_u32 &&
+          instr->opcode != aco_opcode::v_mul_lo_i32 && instr->opcode != aco_opcode::v_mul_hi_u32 &&
+          instr->opcode != aco_opcode::v_mul_hi_i32 &&
+          instr->opcode != aco_opcode::v_qsad_pk_u16_u8 &&
+          instr->opcode != aco_opcode::v_mqsad_pk_u16_u8 &&
+          instr->opcode != aco_opcode::v_mqsad_u32_u8 &&
+          instr->opcode != aco_opcode::v_mad_u64_u32 &&
+          instr->opcode != aco_opcode::v_mad_i64_i32 &&
+          instr->opcode != aco_opcode::v_permlane16_b32 &&
+          instr->opcode != aco_opcode::v_permlanex16_b32 &&
+          instr->opcode != aco_opcode::v_permlane64_b32 &&
+          instr->opcode != aco_opcode::v_readlane_b32_e64 &&
+          instr->opcode != aco_opcode::v_writelane_b32_e64;
 }
 
 aco_ptr<Instruction>
-convert_to_DPP(aco_ptr<Instruction>& instr, bool dpp8)
+convert_to_DPP(amd_gfx_level gfx_level, aco_ptr<Instruction>& instr, bool dpp8)
 {
    if (instr->isDPP())
       return NULL;
 
    aco_ptr<Instruction> tmp = std::move(instr);
-   Format format = (Format)(((uint32_t)tmp->format & ~(uint32_t)Format::VOP3) |
-                            (dpp8 ? (uint32_t)Format::DPP8 : (uint32_t)Format::DPP16));
+   Format format =
+      (Format)((uint32_t)tmp->format | (uint32_t)(dpp8 ? Format::DPP8 : Format::DPP16));
    if (dpp8)
       instr.reset(create_instruction<DPP8_instruction>(tmp->opcode, format, tmp->operands.size(),
                                                        tmp->definitions.size()));
@@ -394,8 +421,7 @@ convert_to_DPP(aco_ptr<Instruction>& instr, bool dpp8)
       instr.reset(create_instruction<DPP16_instruction>(tmp->opcode, format, tmp->operands.size(),
                                                         tmp->definitions.size()));
    std::copy(tmp->operands.cbegin(), tmp->operands.cend(), instr->operands.begin());
-   for (unsigned i = 0; i < instr->definitions.size(); i++)
-      instr->definitions[i] = tmp->definitions[i];
+   std::copy(tmp->definitions.cbegin(), tmp->definitions.cend(), instr->definitions.begin());
 
    if (dpp8) {
       DPP8_instruction* dpp = &instr->dpp8();
@@ -410,16 +436,37 @@ convert_to_DPP(aco_ptr<Instruction>& instr, bool dpp8)
 
    instr->valu().neg = tmp->valu().neg;
    instr->valu().abs = tmp->valu().abs;
+   instr->valu().omod = tmp->valu().omod;
+   instr->valu().clamp = tmp->valu().clamp;
    instr->valu().opsel = tmp->valu().opsel;
+   instr->valu().opsel_lo = tmp->valu().opsel_lo;
+   instr->valu().opsel_hi = tmp->valu().opsel_hi;
 
-   if (instr->isVOPC() || instr->definitions.size() > 1)
+   if ((instr->isVOPC() || instr->definitions.size() > 1) && gfx_level < GFX11)
       instr->definitions.back().setFixed(vcc);
 
-   if (instr->operands.size() >= 3)
+   if (instr->operands.size() >= 3 && instr->operands[2].isOfType(RegType::sgpr) &&
+       gfx_level < GFX11)
       instr->operands[2].setFixed(vcc);
 
    instr->pass_flags = tmp->pass_flags;
 
+   /* DPP16 supports input modifiers, so we might no longer need VOP3. */
+   bool remove_vop3 = !dpp8 && !instr->valu().omod && !instr->valu().clamp &&
+                      (instr->isVOP1() || instr->isVOP2() || instr->isVOPC());
+
+   /* VOPC/add_co/sub_co definition needs VCC without VOP3. */
+   remove_vop3 &= instr->definitions.back().regClass().type() != RegType::sgpr ||
+                  !instr->definitions.back().isFixed() ||
+                  instr->definitions.back().physReg() == vcc;
+
+   /* addc/subb/cndmask 3rd operand needs VCC without VOP3. */
+   remove_vop3 &= instr->operands.size() < 3 || !instr->operands[2].isFixed() ||
+                  instr->operands[2].isOfType(RegType::vgpr) || instr->operands[2].physReg() == vcc;
+
+   if (remove_vop3)
+      instr->format = (Format)((uint32_t)instr->format & ~(uint32_t)Format::VOP3);
+
    return tmp;
 }
 
@@ -931,27 +978,77 @@ is_cmpx(aco_opcode op)
 }
 
 bool
-can_swap_operands(aco_ptr<Instruction>& instr, aco_opcode* new_op)
+can_swap_operands(aco_ptr<Instruction>& instr, aco_opcode* new_op, unsigned idx0, unsigned idx1)
 {
+   if (idx0 == idx1) {
+      *new_op = instr->opcode;
+      return true;
+   }
+
+   if (idx0 > idx1)
+      std::swap(idx0, idx1);
+
    if (instr->isDPP())
       return false;
 
-   if (instr->operands[0].isConstant() ||
-       (instr->operands[0].isTemp() && instr->operands[0].getTemp().type() == RegType::sgpr))
+   if (!instr->isVOP3() && !instr->isVOP3P() && !instr->operands[0].isOfType(RegType::vgpr))
       return false;
 
+   if (instr->isVOPC()) {
+      CmpInfo info;
+      if (get_cmp_info(instr->opcode, &info) && info.swapped != aco_opcode::num_opcodes) {
+         *new_op = info.swapped;
+         return true;
+      }
+   }
+
+   /* opcodes not relevant for DPP or SGPRs optimizations are not included. */
    switch (instr->opcode) {
+   case aco_opcode::v_med3_f32: return false; /* order matters for clamp+GFX8+denorm ftz. */
    case aco_opcode::v_add_u32:
    case aco_opcode::v_add_co_u32:
    case aco_opcode::v_add_co_u32_e64:
    case aco_opcode::v_add_i32:
+   case aco_opcode::v_add_i16:
+   case aco_opcode::v_add_u16_e64:
+   case aco_opcode::v_add3_u32:
    case aco_opcode::v_add_f16:
    case aco_opcode::v_add_f32:
+   case aco_opcode::v_mul_i32_i24:
+   case aco_opcode::v_mul_hi_i32_i24:
+   case aco_opcode::v_mul_u32_u24:
+   case aco_opcode::v_mul_hi_u32_u24:
+   case aco_opcode::v_mul_lo_u16:
+   case aco_opcode::v_mul_lo_u16_e64:
    case aco_opcode::v_mul_f16:
    case aco_opcode::v_mul_f32:
+   case aco_opcode::v_mul_legacy_f32:
    case aco_opcode::v_or_b32:
    case aco_opcode::v_and_b32:
    case aco_opcode::v_xor_b32:
+   case aco_opcode::v_xnor_b32:
+   case aco_opcode::v_xor3_b32:
+   case aco_opcode::v_or3_b32:
+   case aco_opcode::v_and_b16:
+   case aco_opcode::v_or_b16:
+   case aco_opcode::v_xor_b16:
+   case aco_opcode::v_max3_f32:
+   case aco_opcode::v_min3_f32:
+   case aco_opcode::v_max3_f16:
+   case aco_opcode::v_min3_f16:
+   case aco_opcode::v_med3_f16:
+   case aco_opcode::v_max3_u32:
+   case aco_opcode::v_min3_u32:
+   case aco_opcode::v_med3_u32:
+   case aco_opcode::v_max3_i32:
+   case aco_opcode::v_min3_i32:
+   case aco_opcode::v_med3_i32:
+   case aco_opcode::v_max3_u16:
+   case aco_opcode::v_min3_u16:
+   case aco_opcode::v_med3_u16:
+   case aco_opcode::v_max3_i16:
+   case aco_opcode::v_min3_i16:
+   case aco_opcode::v_med3_i16:
    case aco_opcode::v_max_f16:
    case aco_opcode::v_max_f32:
    case aco_opcode::v_min_f16:
@@ -973,14 +1070,73 @@ can_swap_operands(aco_ptr<Instruction>& instr, aco_opcode* new_op)
    case aco_opcode::v_sub_co_u32: *new_op = aco_opcode::v_subrev_co_u32; return true;
    case aco_opcode::v_sub_u16: *new_op = aco_opcode::v_subrev_u16; return true;
    case aco_opcode::v_sub_u32: *new_op = aco_opcode::v_subrev_u32; return true;
-   default: {
-      CmpInfo info;
-      if (get_cmp_info(instr->opcode, &info) && info.swapped != aco_opcode::num_opcodes) {
-         *new_op = info.swapped;
-         return true;
-      }
-      return false;
+   case aco_opcode::v_sub_co_u32_e64: *new_op = aco_opcode::v_subrev_co_u32_e64; return true;
+   case aco_opcode::v_subrev_f16: *new_op = aco_opcode::v_sub_f16; return true;
+   case aco_opcode::v_subrev_f32: *new_op = aco_opcode::v_sub_f32; return true;
+   case aco_opcode::v_subrev_co_u32: *new_op = aco_opcode::v_sub_co_u32; return true;
+   case aco_opcode::v_subrev_u16: *new_op = aco_opcode::v_sub_u16; return true;
+   case aco_opcode::v_subrev_u32: *new_op = aco_opcode::v_sub_u32; return true;
+   case aco_opcode::v_subrev_co_u32_e64: *new_op = aco_opcode::v_sub_co_u32_e64; return true;
+   case aco_opcode::v_addc_co_u32:
+   case aco_opcode::v_mad_i32_i24:
+   case aco_opcode::v_mad_u32_u24:
+   case aco_opcode::v_lerp_u8:
+   case aco_opcode::v_sad_u8:
+   case aco_opcode::v_sad_hi_u8:
+   case aco_opcode::v_sad_u16:
+   case aco_opcode::v_sad_u32:
+   case aco_opcode::v_xad_u32:
+   case aco_opcode::v_add_lshl_u32:
+   case aco_opcode::v_and_or_b32:
+   case aco_opcode::v_mad_u16:
+   case aco_opcode::v_mad_i16:
+   case aco_opcode::v_mad_u32_u16:
+   case aco_opcode::v_mad_i32_i16:
+   case aco_opcode::v_maxmin_f32:
+   case aco_opcode::v_minmax_f32:
+   case aco_opcode::v_maxmin_f16:
+   case aco_opcode::v_minmax_f16:
+   case aco_opcode::v_maxmin_u32:
+   case aco_opcode::v_minmax_u32:
+   case aco_opcode::v_maxmin_i32:
+   case aco_opcode::v_minmax_i32:
+   case aco_opcode::v_fma_f32:
+   case aco_opcode::v_fma_legacy_f32:
+   case aco_opcode::v_fmac_f32:
+   case aco_opcode::v_fmac_legacy_f32:
+   case aco_opcode::v_mac_f32:
+   case aco_opcode::v_mac_legacy_f32:
+   case aco_opcode::v_fma_f16:
+   case aco_opcode::v_fmac_f16:
+   case aco_opcode::v_mac_f16:
+   case aco_opcode::v_dot4c_i32_i8:
+   case aco_opcode::v_dot2c_f32_f16:
+   case aco_opcode::v_dot2_f32_f16:
+   case aco_opcode::v_dot2_f32_bf16:
+   case aco_opcode::v_dot2_f16_f16:
+   case aco_opcode::v_dot2_bf16_bf16:
+   case aco_opcode::v_fma_mix_f32:
+   case aco_opcode::v_fma_mixlo_f16:
+   case aco_opcode::v_fma_mixhi_f16:
+   case aco_opcode::v_pk_fmac_f16: {
+      if (idx1 == 2)
+         return false;
+      *new_op = instr->opcode;
+      return true;
+   }
+   case aco_opcode::v_subb_co_u32: {
+      if (idx1 == 2)
+         return false;
+      *new_op = aco_opcode::v_subbrev_co_u32;
+      return true;
    }
+   case aco_opcode::v_subbrev_co_u32: {
+      if (idx1 == 2)
+         return false;
+      *new_op = aco_opcode::v_subb_co_u32;
+      return true;
+   }
+   default: return false;
    }
 }
 
index 7a95b00..04b9342 100644 (file)
@@ -1803,11 +1803,12 @@ bool can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx);
 bool instr_is_16bit(amd_gfx_level gfx_level, aco_opcode op);
 uint8_t get_gfx11_true16_mask(aco_opcode op);
 bool can_use_SDWA(amd_gfx_level gfx_level, const aco_ptr<Instruction>& instr, bool pre_ra);
-bool can_use_DPP(const aco_ptr<Instruction>& instr, bool pre_ra, bool dpp8);
+bool can_use_DPP(amd_gfx_level gfx_level, const aco_ptr<Instruction>& instr, bool dpp8);
 bool can_write_m0(const aco_ptr<Instruction>& instr);
 /* updates "instr" and returns the old instruction (or NULL if no update was needed) */
 aco_ptr<Instruction> convert_to_SDWA(amd_gfx_level gfx_level, aco_ptr<Instruction>& instr);
-aco_ptr<Instruction> convert_to_DPP(aco_ptr<Instruction>& instr, bool dpp8);
+aco_ptr<Instruction> convert_to_DPP(amd_gfx_level gfx_level, aco_ptr<Instruction>& instr,
+                                    bool dpp8);
 bool needs_exec_mask(const Instruction* instr);
 
 aco_opcode get_ordered(aco_opcode op);
@@ -1820,7 +1821,8 @@ unsigned get_cmp_bitsize(aco_opcode op);
 bool is_fp_cmp(aco_opcode op);
 bool is_cmpx(aco_opcode op);
 
-bool can_swap_operands(aco_ptr<Instruction>& instr, aco_opcode* new_op);
+bool can_swap_operands(aco_ptr<Instruction>& instr, aco_opcode* new_op, unsigned idx0 = 0,
+                       unsigned idx1 = 1);
 
 uint32_t get_reduction_identity(ReduceOp op, unsigned idx);
 
index c5d238a..a6d1a60 100644 (file)
@@ -4810,7 +4810,7 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
    }
 
    /* Combine DPP copies into VALU. This should be done after creating MAD/FMA. */
-   if (instr->isVALU()) {
+   if (instr->isVALU() && !instr->isDPP()) {
       for (unsigned i = 0; i < instr->operands.size(); i++) {
          if (!instr->operands[i].isTemp())
             continue;
@@ -4819,41 +4819,44 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
          if (!info.is_dpp() || info.instr->pass_flags != instr->pass_flags)
             continue;
 
-         aco_opcode swapped_op;
-         if (i != 0 && !can_swap_operands(instr, &swapped_op))
-            continue;
+         if (i != 0) {
+            if (!can_swap_operands(instr, &instr->opcode, 0, i))
+               continue;
+            std::swap(instr->operands[0], instr->operands[i]);
+            instr->valu().neg[0].swap(instr->valu().neg[i]);
+            instr->valu().abs[0].swap(instr->valu().abs[i]);
+            instr->valu().opsel[0].swap(instr->valu().opsel[i]);
+            instr->valu().opsel_lo[0].swap(instr->valu().opsel_lo[i]);
+            instr->valu().opsel_hi[0].swap(instr->valu().opsel_hi[i]);
+         }
 
-         if (instr->isDPP() || !can_use_DPP(instr, true, info.is_dpp8()))
+         if (!can_use_DPP(ctx.program->gfx_level, instr, info.is_dpp8()))
             continue;
 
          bool dpp8 = info.is_dpp8();
          bool input_mods = instr_info.can_use_input_modifiers[(int)instr->opcode] &&
                            instr_info.operand_size[(int)instr->opcode] == 32;
-         if (!dpp8 && (info.instr->dpp16().neg[0] || info.instr->dpp16().abs[0]) && !input_mods)
+         bool mov_uses_mods = info.instr->valu().neg[0] || info.instr->valu().abs[0];
+         if (((dpp8 && ctx.program->gfx_level < GFX11) || !input_mods) && mov_uses_mods)
             continue;
 
-         convert_to_DPP(instr, dpp8);
-
-         if (i != 0) {
-            instr->opcode = swapped_op;
-            std::swap(instr->operands[0], instr->operands[1]);
-            instr->valu().neg[0].swap(instr->valu().neg[1]);
-            instr->valu().abs[0].swap(instr->valu().abs[1]);
-            instr->valu().opsel[0].swap(instr->valu().opsel[1]);
-         }
+         convert_to_DPP(ctx.program->gfx_level, instr, dpp8);
 
          if (dpp8) {
             DPP8_instruction* dpp = &instr->dpp8();
             for (unsigned j = 0; j < 8; ++j)
                dpp->lane_sel[j] = info.instr->dpp8().lane_sel[j];
+            if (mov_uses_mods)
+               instr->format = asVOP3(instr->format);
          } else {
             DPP16_instruction* dpp = &instr->dpp16();
             dpp->dpp_ctrl = info.instr->dpp16().dpp_ctrl;
             dpp->bound_ctrl = info.instr->dpp16().bound_ctrl;
-            dpp->neg[0] ^= info.instr->dpp16().neg[0] && !dpp->abs[0];
-            dpp->abs[0] |= info.instr->dpp16().abs[0];
          }
 
+         instr->valu().neg[0] ^= info.instr->valu().neg[0] && !instr->valu().abs[0];
+         instr->valu().abs[0] |= info.instr->valu().abs[0];
+
          if (--ctx.uses[info.instr->definitions[0].tempId()])
             ctx.uses[info.instr->operands[0].tempId()]++;
          instr->operands[0].setTemp(info.instr->operands[0].getTemp());
index a21bc33..e8d5939 100644 (file)
@@ -485,7 +485,7 @@ try_combine_dpp(pr_opt_ctx& ctx, aco_ptr<Instruction>& instr)
    if (!instr->isVALU() || instr->isDPP())
       return;
 
-   for (unsigned i = 0; i < MIN2(2, instr->operands.size()); i++) {
+   for (unsigned i = 0; i < instr->operands.size(); i++) {
       Idx op_instr_idx = last_writer_idx(ctx, instr->operands[i]);
       if (!op_instr_idx.found())
          continue;
@@ -493,9 +493,6 @@ try_combine_dpp(pr_opt_ctx& ctx, aco_ptr<Instruction>& instr)
       const Instruction* mov = ctx.get(op_instr_idx);
       if (mov->opcode != aco_opcode::v_mov_b32 || !mov->isDPP())
          continue;
-      bool dpp8 = mov->isDPP8();
-      if (!can_use_DPP(instr, false, dpp8))
-         return;
 
       /* If we aren't going to remove the v_mov_b32, we have to ensure that it doesn't overwrite
        * it's own operand before we use it.
@@ -508,12 +505,25 @@ try_combine_dpp(pr_opt_ctx& ctx, aco_ptr<Instruction>& instr)
       if (is_overwritten_since(ctx, mov->operands[0], op_instr_idx))
          continue;
 
-      if (i && !can_swap_operands(instr, &instr->opcode))
-         continue;
-
+      bool dpp8 = mov->isDPP8();
       bool input_mods = instr_info.can_use_input_modifiers[(int)instr->opcode] &&
                         instr_info.operand_size[(int)instr->opcode] == 32;
-      if (!dpp8 && (mov->dpp16().neg[0] || mov->dpp16().abs[0]) && !input_mods)
+      bool mov_uses_mods = mov->valu().neg[0] || mov->valu().abs[0];
+      if (((dpp8 && ctx.program->gfx_level < GFX11) || !input_mods) && mov_uses_mods)
+         continue;
+
+      if (i != 0) {
+         if (!can_swap_operands(instr, &instr->opcode, 0, i))
+            continue;
+         std::swap(instr->operands[0], instr->operands[i]);
+         instr->valu().neg[0].swap(instr->valu().neg[i]);
+         instr->valu().abs[0].swap(instr->valu().abs[i]);
+         instr->valu().opsel[0].swap(instr->valu().opsel[i]);
+         instr->valu().opsel_lo[0].swap(instr->valu().opsel_lo[i]);
+         instr->valu().opsel_hi[0].swap(instr->valu().opsel_hi[i]);
+      }
+
+      if (!can_use_DPP(ctx.program->gfx_level, instr, dpp8))
          continue;
 
       if (!dpp8) /* anything else doesn't make sense in SSA */
@@ -522,27 +532,22 @@ try_combine_dpp(pr_opt_ctx& ctx, aco_ptr<Instruction>& instr)
       if (--ctx.uses[mov->definitions[0].tempId()])
          ctx.uses[mov->operands[0].tempId()]++;
 
-      convert_to_DPP(instr, dpp8);
-
-      if (i) {
-         std::swap(instr->operands[0], instr->operands[1]);
-         instr->valu().neg[0].swap(instr->valu().neg[1]);
-         instr->valu().abs[0].swap(instr->valu().abs[1]);
-         instr->valu().opsel[0].swap(instr->valu().opsel[1]);
-      }
+      convert_to_DPP(ctx.program->gfx_level, instr, dpp8);
 
       instr->operands[0] = mov->operands[0];
 
       if (dpp8) {
          DPP8_instruction* dpp = &instr->dpp8();
          memcpy(dpp->lane_sel, mov->dpp8().lane_sel, sizeof(dpp->lane_sel));
+         if (mov_uses_mods)
+            instr->format = asVOP3(instr->format);
       } else {
          DPP16_instruction* dpp = &instr->dpp16();
          dpp->dpp_ctrl = mov->dpp16().dpp_ctrl;
          dpp->bound_ctrl = true;
-         dpp->neg[0] ^= mov->dpp16().neg[0] && !dpp->abs[0];
-         dpp->abs[0] |= mov->dpp16().abs[0];
       }
+      instr->valu().neg[0] ^= mov->valu().neg[0] && !instr->valu().abs[0];
+      instr->valu().abs[0] |= mov->valu().abs[0];
       return;
    }
 }
index d407329..b95e60b 100644 (file)
@@ -142,12 +142,21 @@ validate_ir(Program* program)
                "Wrong base format for instruction", instr.get());
 
          /* check VOP3 modifiers */
-         if (instr->isVOP3() && instr->format != Format::VOP3) {
+         if (instr->isVOP3() && withoutDPP(instr->format) != Format::VOP3) {
             check(base_format == Format::VOP2 || base_format == Format::VOP1 ||
                      base_format == Format::VOPC || base_format == Format::VINTRP,
                   "Format cannot have VOP3/VOP3B applied", instr.get());
          }
 
+         if (instr->isDPP()) {
+            check(base_format == Format::VOP2 || base_format == Format::VOP1 ||
+                     base_format == Format::VOPC || base_format == Format::VOP3 ||
+                     base_format == Format::VOP3P,
+                  "Format cannot have DPP applied", instr.get());
+            check((!instr->isVOP3() && !instr->isVOP3P()) || program->gfx_level >= GFX11,
+                  "VOP3+DPP is GFX11+ only", instr.get());
+         }
+
          /* check SDWA */
          if (instr->isSDWA()) {
             check(base_format == Format::VOP2 || base_format == Format::VOP1 ||