aco: Clean up and fix quad group instructions with WQM.
authorTimur Kristóf <timur.kristof@gmail.com>
Tue, 23 Nov 2021 15:50:20 +0000 (16:50 +0100)
committerMarge Bot <emma+marge@anholt.net>
Thu, 9 Dec 2021 17:36:51 +0000 (17:36 +0000)
According to the Vulkan spec chapter 9.25 Helper Invocations,
quad group operations have to be executed by helper invocations.

This commit cleans up the code for quad group instructions by
unifying the code path of quad broadcast with the others, and then
calling emit_wqm just once at the end.

Fixes: 93c8ebfa780ebd1495095e794731881aef29e7d3
Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/5570
Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13929>

src/amd/compiler/aco_instruction_selection.cpp

index d14cb78..ad496f7 100644 (file)
@@ -8465,146 +8465,106 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
       }
       break;
    }
-   case nir_intrinsic_quad_broadcast: {
-      Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
-      if (!nir_dest_is_divergent(instr->dest)) {
-         emit_uniform_subgroup(ctx, instr, src);
-      } else {
-         Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
-         unsigned lane = nir_src_as_const_value(instr->src[1])->u32;
-         uint32_t dpp_ctrl = dpp_quad_perm(lane, lane, lane, lane);
-
-         if (instr->dest.ssa.bit_size != 1)
-            src = as_vgpr(ctx, src);
-
-         if (instr->dest.ssa.bit_size == 1) {
-            assert(src.regClass() == bld.lm);
-            assert(dst.regClass() == bld.lm);
-            uint32_t half_mask = 0x11111111u << lane;
-            Temp mask_tmp = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2),
-                                       Operand::c32(half_mask), Operand::c32(half_mask));
-            Temp tmp = bld.tmp(bld.lm);
-            bld.sop1(Builder::s_wqm, Definition(tmp),
-                     bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), mask_tmp,
-                              bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src,
-                                       Operand(exec, bld.lm))));
-            emit_wqm(bld, tmp, dst);
-         } else if (instr->dest.ssa.bit_size == 8) {
-            Temp tmp = bld.tmp(v1);
-            if (ctx->program->chip_class >= GFX8)
-               emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl), tmp);
-            else
-               emit_wqm(bld,
-                        bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl),
-                        tmp);
-            bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v3b), tmp);
-         } else if (instr->dest.ssa.bit_size == 16) {
-            Temp tmp = bld.tmp(v1);
-            if (ctx->program->chip_class >= GFX8)
-               emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl), tmp);
-            else
-               emit_wqm(bld,
-                        bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl),
-                        tmp);
-            bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
-         } else if (instr->dest.ssa.bit_size == 32) {
-            if (ctx->program->chip_class >= GFX8)
-               emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl), dst);
-            else
-               emit_wqm(bld,
-                        bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl),
-                        dst);
-         } else if (instr->dest.ssa.bit_size == 64) {
-            Temp lo = bld.tmp(v1), hi = bld.tmp(v1);
-            bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), src);
-            if (ctx->program->chip_class >= GFX8) {
-               lo = emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), lo, dpp_ctrl));
-               hi = emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), hi, dpp_ctrl));
-            } else {
-               lo = emit_wqm(
-                  bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), lo, (1 << 15) | dpp_ctrl));
-               hi = emit_wqm(
-                  bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), hi, (1 << 15) | dpp_ctrl));
-            }
-            bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi);
-            emit_split_vector(ctx, dst, 2);
-         } else {
-            isel_err(&instr->instr, "Unimplemented NIR instr bit size");
-         }
-      }
-      break;
-   }
+   case nir_intrinsic_quad_broadcast:
    case nir_intrinsic_quad_swap_horizontal:
    case nir_intrinsic_quad_swap_vertical:
    case nir_intrinsic_quad_swap_diagonal:
    case nir_intrinsic_quad_swizzle_amd: {
       Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
+
       if (!nir_dest_is_divergent(instr->dest)) {
          emit_uniform_subgroup(ctx, instr, src);
          break;
       }
+
+      /* Quad broadcast lane. */
+      unsigned lane = 0;
+      /* Use VALU for the bool instructions that don't have a SALU-only special case. */
+      bool bool_use_valu = instr->dest.ssa.bit_size == 1;
+
       uint16_t dpp_ctrl = 0;
+
       switch (instr->intrinsic) {
       case nir_intrinsic_quad_swap_horizontal: dpp_ctrl = dpp_quad_perm(1, 0, 3, 2); break;
       case nir_intrinsic_quad_swap_vertical: dpp_ctrl = dpp_quad_perm(2, 3, 0, 1); break;
       case nir_intrinsic_quad_swap_diagonal: dpp_ctrl = dpp_quad_perm(3, 2, 1, 0); break;
       case nir_intrinsic_quad_swizzle_amd: dpp_ctrl = nir_intrinsic_swizzle_mask(instr); break;
+      case nir_intrinsic_quad_broadcast:
+         lane = nir_src_as_const_value(instr->src[1])->u32;
+         dpp_ctrl = dpp_quad_perm(lane, lane, lane, lane);
+         bool_use_valu = false;
+         break;
       default: break;
       }
-      if (ctx->program->chip_class < GFX8)
-         dpp_ctrl |= (1 << 15);
 
       Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
+      Temp tmp(dst);
 
-      if (instr->dest.ssa.bit_size != 1)
-         src = as_vgpr(ctx, src);
-
-      if (instr->dest.ssa.bit_size == 1) {
-         assert(src.regClass() == bld.lm);
+      /* Setup source. */
+      if (bool_use_valu)
          src = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand::zero(),
                             Operand::c32(-1), src);
+      else if (instr->dest.ssa.bit_size != 1)
+         src = as_vgpr(ctx, src);
+
+      /* Setup temporary destination. */
+      if (bool_use_valu)
+         tmp = bld.tmp(v1);
+      else if (ctx->program->stage == fragment_fs)
+         tmp = bld.tmp(dst.regClass());
+
+      if (instr->dest.ssa.bit_size == 1 && instr->intrinsic == nir_intrinsic_quad_broadcast) {
+         /* Special case for quad broadcast using SALU only. */
+         assert(src.regClass() == bld.lm && tmp.regClass() == bld.lm);
+
+         uint32_t half_mask = 0x11111111u << lane;
+         Operand mask_tmp = bld.lm.bytes() == 4
+                               ? Operand::c32(half_mask)
+                               : bld.pseudo(aco_opcode::p_create_vector, bld.def(bld.lm),
+                                            Operand::c32(half_mask), Operand::c32(half_mask));
+
+         src =
+            bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src, Operand(exec, bld.lm));
+         src = bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), mask_tmp, src);
+         bld.sop1(Builder::s_wqm, Definition(tmp), src);
+      } else if (instr->dest.ssa.bit_size <= 32 || bool_use_valu) {
+         unsigned excess_bytes = bool_use_valu ? 0 : 4 - instr->dest.ssa.bit_size / 8;
+         Definition def = excess_bytes ? bld.def(v1) : Definition(tmp);
+
          if (ctx->program->chip_class >= GFX8)
-            src = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl);
-         else
-            src = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, dpp_ctrl);
-         Temp tmp = bld.vopc(aco_opcode::v_cmp_lg_u32, bld.def(bld.lm), Operand::zero(), src);
-         emit_wqm(bld, tmp, dst);
-      } else if (instr->dest.ssa.bit_size == 8) {
-         Temp tmp = bld.tmp(v1);
-         if (ctx->program->chip_class >= GFX8)
-            emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl), tmp);
-         else
-            emit_wqm(bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, dpp_ctrl), tmp);
-         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v3b), tmp);
-      } else if (instr->dest.ssa.bit_size == 16) {
-         Temp tmp = bld.tmp(v1);
-         if (ctx->program->chip_class >= GFX8)
-            emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl), tmp);
-         else
-            emit_wqm(bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, dpp_ctrl), tmp);
-         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
-      } else if (instr->dest.ssa.bit_size == 32) {
-         Temp tmp;
-         if (ctx->program->chip_class >= GFX8)
-            tmp = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl);
+            bld.vop1_dpp(aco_opcode::v_mov_b32, def, src, dpp_ctrl);
          else
-            tmp = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, dpp_ctrl);
-         emit_wqm(bld, tmp, dst);
+            bld.ds(aco_opcode::ds_swizzle_b32, def, src, (1 << 15) | dpp_ctrl);
+
+         if (excess_bytes)
+            bld.pseudo(aco_opcode::p_split_vector, Definition(tmp),
+                       bld.def(RegClass::get(tmp.type(), excess_bytes)), def.getTemp());
       } else if (instr->dest.ssa.bit_size == 64) {
          Temp lo = bld.tmp(v1), hi = bld.tmp(v1);
          bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), src);
+
          if (ctx->program->chip_class >= GFX8) {
-            lo = emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), lo, dpp_ctrl));
-            hi = emit_wqm(bld, bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), hi, dpp_ctrl));
+            lo = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), lo, dpp_ctrl);
+            hi = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), hi, dpp_ctrl);
          } else {
-            lo = emit_wqm(bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), lo, dpp_ctrl));
-            hi = emit_wqm(bld, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), hi, dpp_ctrl));
+            lo = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), lo, (1 << 15) | dpp_ctrl);
+            hi = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), hi, (1 << 15) | dpp_ctrl);
          }
-         bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi);
-         emit_split_vector(ctx, dst, 2);
+
+         bld.pseudo(aco_opcode::p_create_vector, Definition(tmp), lo, hi);
+         emit_split_vector(ctx, tmp, 2);
       } else {
-         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
+         isel_err(&instr->instr, "Unimplemented NIR quad group instruction bit size.");
       }
+
+      if (tmp.id() != dst.id()) {
+         if (bool_use_valu)
+            tmp = bld.vopc(aco_opcode::v_cmp_lg_u32, bld.def(bld.lm), Operand::zero(), tmp);
+
+         /* Vulkan spec 9.25: Helper invocations must be active for quad group instructions. */
+         emit_wqm(bld, tmp, dst, true);
+      }
+
       break;
    }
    case nir_intrinsic_masked_swizzle_amd: {