zink: Use unified atomics
authorAlyssa Rosenzweig <alyssa@rosenzweig.io>
Tue, 9 May 2023 13:02:00 +0000 (09:02 -0400)
committerMarge Bot <emma+marge@anholt.net>
Fri, 12 May 2023 20:39:46 +0000 (20:39 +0000)
Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22914>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
src/gallium/drivers/zink/zink_compiler.c

index b162643..a66e185 100644 (file)
@@ -227,14 +227,9 @@ emit_access_decorations(struct ntv_context *ctx, nir_variable *var, SpvId var_id
 }
 
 static SpvOp
-get_atomic_op(struct ntv_context *ctx, unsigned bit_size, nir_intrinsic_op op)
+get_atomic_op(struct ntv_context *ctx, unsigned bit_size, nir_atomic_op op)
 {
    switch (op) {
-#define CASE_ATOMIC_OP(type) \
-   case nir_intrinsic_deref_atomic_##type: \
-   case nir_intrinsic_image_deref_atomic_##type: \
-   case nir_intrinsic_shared_atomic_##type
-
 #define ATOMIC_FCAP(NAME) \
    do {\
       if (bit_size == 16) \
@@ -245,41 +240,41 @@ get_atomic_op(struct ntv_context *ctx, unsigned bit_size, nir_intrinsic_op op)
          spirv_builder_emit_cap(&ctx->builder, SpvCapabilityAtomicFloat64##NAME##EXT); \
    } while (0)
 
-   CASE_ATOMIC_OP(fadd):
+   case nir_atomic_op_fadd:
       ATOMIC_FCAP(Add);
       if (bit_size == 16)
          spirv_builder_emit_extension(&ctx->builder, "SPV_EXT_shader_atomic_float16_add");
       else
          spirv_builder_emit_extension(&ctx->builder, "SPV_EXT_shader_atomic_float_add");
       return SpvOpAtomicFAddEXT;
-   CASE_ATOMIC_OP(fmax):
+   case nir_atomic_op_fmax:
       ATOMIC_FCAP(MinMax);
       spirv_builder_emit_extension(&ctx->builder, "SPV_EXT_shader_atomic_float_min_max");
       return SpvOpAtomicFMaxEXT;
-   CASE_ATOMIC_OP(fmin):
+   case nir_atomic_op_fmin:
       ATOMIC_FCAP(MinMax);
       spirv_builder_emit_extension(&ctx->builder, "SPV_EXT_shader_atomic_float_min_max");
       return SpvOpAtomicFMinEXT;
 
-   CASE_ATOMIC_OP(add):
+   case nir_atomic_op_iadd:
       return SpvOpAtomicIAdd;
-   CASE_ATOMIC_OP(umin):
+   case nir_atomic_op_umin:
       return SpvOpAtomicUMin;
-   CASE_ATOMIC_OP(imin):
+   case nir_atomic_op_imin:
       return SpvOpAtomicSMin;
-   CASE_ATOMIC_OP(umax):
+   case nir_atomic_op_umax:
       return SpvOpAtomicUMax;
-   CASE_ATOMIC_OP(imax):
+   case nir_atomic_op_imax:
       return SpvOpAtomicSMax;
-   CASE_ATOMIC_OP(and):
+   case nir_atomic_op_iand:
       return SpvOpAtomicAnd;
-   CASE_ATOMIC_OP(or):
+   case nir_atomic_op_ior:
       return SpvOpAtomicOr;
-   CASE_ATOMIC_OP(xor):
+   case nir_atomic_op_ixor:
       return SpvOpAtomicXor;
-   CASE_ATOMIC_OP(exchange):
+   case nir_atomic_op_xchg:
       return SpvOpAtomicExchange;
-   CASE_ATOMIC_OP(comp_swap):
+   case nir_atomic_op_cmpxchg:
       return SpvOpAtomicCompareExchange;
    default:
       debug_printf("%s - ", nir_intrinsic_infos[op].name);
@@ -288,21 +283,6 @@ get_atomic_op(struct ntv_context *ctx, unsigned bit_size, nir_intrinsic_op op)
    return 0;
 }
 
-static bool
-atomic_op_is_float(nir_intrinsic_op op)
-{
-   switch (op) {
-   CASE_ATOMIC_OP(fadd):
-   CASE_ATOMIC_OP(fmax):
-   CASE_ATOMIC_OP(fmin):
-      return true;
-   default:
-      break;
-   }
-   return false;
-}
-#undef CASE_ATOMIC_OP
-
 static SpvId
 emit_float_const(struct ntv_context *ctx, int bit_size, double value)
 {
@@ -2882,7 +2862,7 @@ static void
 handle_atomic_op(struct ntv_context *ctx, nir_intrinsic_instr *intr, SpvId ptr, SpvId param, SpvId param2, nir_alu_type type)
 {
    SpvId dest_type = get_dest_type(ctx, &intr->dest, type);
-   SpvId result = emit_atomic(ctx, get_atomic_op(ctx, nir_dest_bit_size(intr->dest), intr->intrinsic), dest_type, ptr, param, param2);
+   SpvId result = emit_atomic(ctx, get_atomic_op(ctx, nir_dest_bit_size(intr->dest), nir_intrinsic_atomic_op(intr)), dest_type, ptr, param, param2);
    assert(result);
    store_dest(ctx, &intr->dest, result, type);
 }
@@ -2898,10 +2878,11 @@ emit_deref_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    if (nir_src_bit_size(intr->src[1]) == 64)
       spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInt64Atomics);
 
-   if (intr->intrinsic == nir_intrinsic_deref_atomic_comp_swap)
+   if (intr->intrinsic == nir_intrinsic_deref_atomic_swap)
       param2 = get_src(ctx, &intr->src[2]);
 
-   handle_atomic_op(ctx, intr, ptr, param, param2, atomic_op_is_float(intr->intrinsic) ? nir_type_float : nir_type_uint32);
+   nir_alu_type op_type = nir_atomic_op_type(nir_intrinsic_atomic_op(intr));
+   handle_atomic_op(ctx, intr, ptr, param, param2, op_type == nir_type_float ? nir_type_float : nir_type_uint32);
 }
 
 static void
@@ -2922,10 +2903,11 @@ emit_shared_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
       spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInt64Atomics);
    SpvId param2 = 0;
 
-   if (intr->intrinsic == nir_intrinsic_shared_atomic_comp_swap)
+   if (intr->intrinsic == nir_intrinsic_shared_atomic_swap)
       param2 = get_src(ctx, &intr->src[2]);
 
-   handle_atomic_op(ctx, intr, ptr, param, param2, atomic_op_is_float(intr->intrinsic) ? nir_type_float : nir_type_uint32);
+   nir_alu_type op_type = nir_atomic_op_type(nir_intrinsic_atomic_op(intr));
+   handle_atomic_op(ctx, intr, ptr, param, param2, op_type == nir_type_float ? nir_type_float : nir_type_uint32);
 }
 
 static void
@@ -3127,7 +3109,7 @@ emit_image_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    SpvId cast_type = get_dest_type(ctx, &intr->dest, ntype);
    param = emit_bitcast(ctx, cast_type, param);
 
-   if (intr->intrinsic == nir_intrinsic_image_deref_atomic_comp_swap) {
+   if (intr->intrinsic == nir_intrinsic_image_deref_atomic_swap) {
       param2 = get_src(ctx, &intr->src[4]);
       param2 = emit_bitcast(ctx, cast_type, param2);
    }
@@ -3406,36 +3388,13 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
                                         SpvMemorySemanticsAcquireReleaseMask);
       break;
 
-   case nir_intrinsic_deref_atomic_fadd:
-   case nir_intrinsic_deref_atomic_fmin:
-   case nir_intrinsic_deref_atomic_fmax:
-   case nir_intrinsic_deref_atomic_fcomp_swap:
-   case nir_intrinsic_deref_atomic_add:
-   case nir_intrinsic_deref_atomic_umin:
-   case nir_intrinsic_deref_atomic_imin:
-   case nir_intrinsic_deref_atomic_umax:
-   case nir_intrinsic_deref_atomic_imax:
-   case nir_intrinsic_deref_atomic_and:
-   case nir_intrinsic_deref_atomic_or:
-   case nir_intrinsic_deref_atomic_xor:
-   case nir_intrinsic_deref_atomic_exchange:
-   case nir_intrinsic_deref_atomic_comp_swap:
+   case nir_intrinsic_deref_atomic:
+   case nir_intrinsic_deref_atomic_swap:
       emit_deref_atomic_intrinsic(ctx, intr);
       break;
 
-   case nir_intrinsic_shared_atomic_fadd:
-   case nir_intrinsic_shared_atomic_fmin:
-   case nir_intrinsic_shared_atomic_fmax:
-   case nir_intrinsic_shared_atomic_add:
-   case nir_intrinsic_shared_atomic_umin:
-   case nir_intrinsic_shared_atomic_imin:
-   case nir_intrinsic_shared_atomic_umax:
-   case nir_intrinsic_shared_atomic_imax:
-   case nir_intrinsic_shared_atomic_and:
-   case nir_intrinsic_shared_atomic_or:
-   case nir_intrinsic_shared_atomic_xor:
-   case nir_intrinsic_shared_atomic_exchange:
-   case nir_intrinsic_shared_atomic_comp_swap:
+   case nir_intrinsic_shared_atomic:
+   case nir_intrinsic_shared_atomic_swap:
       emit_shared_atomic_intrinsic(ctx, intr);
       break;
 
@@ -3465,16 +3424,8 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
       emit_image_deref_samples(ctx, intr);
       break;
 
-   case nir_intrinsic_image_deref_atomic_add:
-   case nir_intrinsic_image_deref_atomic_umin:
-   case nir_intrinsic_image_deref_atomic_imin:
-   case nir_intrinsic_image_deref_atomic_umax:
-   case nir_intrinsic_image_deref_atomic_imax:
-   case nir_intrinsic_image_deref_atomic_and:
-   case nir_intrinsic_image_deref_atomic_or:
-   case nir_intrinsic_image_deref_atomic_xor:
-   case nir_intrinsic_image_deref_atomic_exchange:
-   case nir_intrinsic_image_deref_atomic_comp_swap:
+   case nir_intrinsic_image_deref_atomic:
+   case nir_intrinsic_image_deref_atomic_swap:
       emit_image_intrinsic(ctx, intr);
       break;
 
index 41a7016..a74af59 100644 (file)
@@ -2151,17 +2151,8 @@ rewrite_bo_access_instr(nir_builder *b, nir_instr *instr, void *data)
    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
    b->cursor = nir_before_instr(instr);
    switch (intr->intrinsic) {
-   case nir_intrinsic_ssbo_atomic_fadd:
-   case nir_intrinsic_ssbo_atomic_add:
-   case nir_intrinsic_ssbo_atomic_umin:
-   case nir_intrinsic_ssbo_atomic_imin:
-   case nir_intrinsic_ssbo_atomic_umax:
-   case nir_intrinsic_ssbo_atomic_imax:
-   case nir_intrinsic_ssbo_atomic_and:
-   case nir_intrinsic_ssbo_atomic_or:
-   case nir_intrinsic_ssbo_atomic_xor:
-   case nir_intrinsic_ssbo_atomic_exchange:
-   case nir_intrinsic_ssbo_atomic_comp_swap: {
+   case nir_intrinsic_ssbo_atomic:
+   case nir_intrinsic_ssbo_atomic_swap: {
       /* convert offset to uintN_t[idx] */
       nir_ssa_def *offset = nir_udiv_imm(b, intr->src[1].ssa, nir_dest_bit_size(intr->dest) / 8);
       nir_instr_rewrite_src_ssa(instr, &intr->src[1], offset);
@@ -2322,52 +2313,12 @@ rewrite_atomic_ssbo_instr(nir_builder *b, nir_instr *instr, struct bo_vars *bo)
 {
    nir_intrinsic_op op;
    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
-   switch (intr->intrinsic) {
-   case nir_intrinsic_ssbo_atomic_fadd:
-      op = nir_intrinsic_deref_atomic_fadd;
-      break;
-   case nir_intrinsic_ssbo_atomic_fmin:
-      op = nir_intrinsic_deref_atomic_fmin;
-      break;
-   case nir_intrinsic_ssbo_atomic_fmax:
-      op = nir_intrinsic_deref_atomic_fmax;
-      break;
-   case nir_intrinsic_ssbo_atomic_fcomp_swap:
-      op = nir_intrinsic_deref_atomic_fcomp_swap;
-      break;
-   case nir_intrinsic_ssbo_atomic_add:
-      op = nir_intrinsic_deref_atomic_add;
-      break;
-   case nir_intrinsic_ssbo_atomic_umin:
-      op = nir_intrinsic_deref_atomic_umin;
-      break;
-   case nir_intrinsic_ssbo_atomic_imin:
-      op = nir_intrinsic_deref_atomic_imin;
-      break;
-   case nir_intrinsic_ssbo_atomic_umax:
-      op = nir_intrinsic_deref_atomic_umax;
-      break;
-   case nir_intrinsic_ssbo_atomic_imax:
-      op = nir_intrinsic_deref_atomic_imax;
-      break;
-   case nir_intrinsic_ssbo_atomic_and:
-      op = nir_intrinsic_deref_atomic_and;
-      break;
-   case nir_intrinsic_ssbo_atomic_or:
-      op = nir_intrinsic_deref_atomic_or;
-      break;
-   case nir_intrinsic_ssbo_atomic_xor:
-      op = nir_intrinsic_deref_atomic_xor;
-      break;
-   case nir_intrinsic_ssbo_atomic_exchange:
-      op = nir_intrinsic_deref_atomic_exchange;
-      break;
-   case nir_intrinsic_ssbo_atomic_comp_swap:
-      op = nir_intrinsic_deref_atomic_comp_swap;
-      break;
-   default:
+   if (intr->intrinsic == nir_intrinsic_ssbo_atomic)
+      op = nir_intrinsic_deref_atomic;
+   else if (intr->intrinsic == nir_intrinsic_ssbo_atomic_swap)
+      op = nir_intrinsic_deref_atomic_swap;
+   else
       unreachable("unknown intrinsic");
-   }
    nir_ssa_def *offset = intr->src[1].ssa;
    nir_src *src = &intr->src[0];
    nir_variable *var = get_bo_var(b->shader, bo, true, src, nir_dest_bit_size(intr->dest));
@@ -2385,6 +2336,7 @@ rewrite_atomic_ssbo_instr(nir_builder *b, nir_instr *instr, struct bo_vars *bo)
       nir_deref_instr *deref_arr = nir_build_deref_array(b, deref_struct, offset);
       nir_intrinsic_instr *new_instr = nir_intrinsic_instr_create(b->shader, op);
       nir_ssa_dest_init(&new_instr->instr, &new_instr->dest, 1, nir_dest_bit_size(intr->dest), "");
+      nir_intrinsic_set_atomic_op(new_instr, nir_intrinsic_atomic_op(intr));
       new_instr->src[0] = nir_src_for_ssa(&deref_arr->dest.ssa);
       /* deref ops have no offset src, so copy the srcs after it */
       for (unsigned i = 2; i < nir_intrinsic_infos[intr->intrinsic].num_srcs; i++)
@@ -2414,20 +2366,8 @@ remove_bo_access_instr(nir_builder *b, nir_instr *instr, void *data)
    nir_src *src;
    bool ssbo = true;
    switch (intr->intrinsic) {
-   case nir_intrinsic_ssbo_atomic_fadd:
-   case nir_intrinsic_ssbo_atomic_fmin:
-   case nir_intrinsic_ssbo_atomic_fmax:
-   case nir_intrinsic_ssbo_atomic_fcomp_swap:
-   case nir_intrinsic_ssbo_atomic_add:
-   case nir_intrinsic_ssbo_atomic_umin:
-   case nir_intrinsic_ssbo_atomic_imin:
-   case nir_intrinsic_ssbo_atomic_umax:
-   case nir_intrinsic_ssbo_atomic_imax:
-   case nir_intrinsic_ssbo_atomic_and:
-   case nir_intrinsic_ssbo_atomic_or:
-   case nir_intrinsic_ssbo_atomic_xor:
-   case nir_intrinsic_ssbo_atomic_exchange:
-   case nir_intrinsic_ssbo_atomic_comp_swap:
+   case nir_intrinsic_ssbo_atomic:
+   case nir_intrinsic_ssbo_atomic_swap:
       rewrite_atomic_ssbo_instr(b, instr, bo);
       return true;
    case nir_intrinsic_store_ssbo:
@@ -4019,20 +3959,8 @@ analyze_io(struct zink_shader *zs, nir_shader *shader)
             ret = true;
             break;
          }
-         case nir_intrinsic_ssbo_atomic_fadd:
-         case nir_intrinsic_ssbo_atomic_add:
-         case nir_intrinsic_ssbo_atomic_imin:
-         case nir_intrinsic_ssbo_atomic_umin:
-         case nir_intrinsic_ssbo_atomic_imax:
-         case nir_intrinsic_ssbo_atomic_umax:
-         case nir_intrinsic_ssbo_atomic_and:
-         case nir_intrinsic_ssbo_atomic_or:
-         case nir_intrinsic_ssbo_atomic_xor:
-         case nir_intrinsic_ssbo_atomic_exchange:
-         case nir_intrinsic_ssbo_atomic_comp_swap:
-         case nir_intrinsic_ssbo_atomic_fmin:
-         case nir_intrinsic_ssbo_atomic_fmax:
-         case nir_intrinsic_ssbo_atomic_fcomp_swap:
+         case nir_intrinsic_ssbo_atomic:
+         case nir_intrinsic_ssbo_atomic_swap:
          case nir_intrinsic_load_ssbo:
             zs->ssbos_used |= get_src_mask_ssbo(shader->info.num_ssbos, intrin->src[0]);
             break;
@@ -4134,21 +4062,8 @@ lower_bindless_instr(nir_builder *b, nir_instr *in, void *data)
 
    /* convert bindless intrinsics to deref intrinsics */
    switch (instr->intrinsic) {
-   OP_SWAP(atomic_add)
-   OP_SWAP(atomic_and)
-   OP_SWAP(atomic_comp_swap)
-   OP_SWAP(atomic_dec_wrap)
-   OP_SWAP(atomic_exchange)
-   OP_SWAP(atomic_fadd)
-   OP_SWAP(atomic_fmax)
-   OP_SWAP(atomic_fmin)
-   OP_SWAP(atomic_imax)
-   OP_SWAP(atomic_imin)
-   OP_SWAP(atomic_inc_wrap)
-   OP_SWAP(atomic_or)
-   OP_SWAP(atomic_umax)
-   OP_SWAP(atomic_umin)
-   OP_SWAP(atomic_xor)
+   OP_SWAP(atomic)
+   OP_SWAP(atomic_swap)
    OP_SWAP(format)
    OP_SWAP(load)
    OP_SWAP(order)
@@ -4409,17 +4324,8 @@ scan_nir(struct zink_screen *screen, nir_shader *shader, struct zink_shader *zs)
             if (intr->intrinsic == nir_intrinsic_image_deref_load ||
                 intr->intrinsic == nir_intrinsic_image_deref_sparse_load ||
                 intr->intrinsic == nir_intrinsic_image_deref_store ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_add ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_imin ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_umin ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_imax ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_umax ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_and ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_or ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_xor ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_exchange ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_comp_swap ||
-                intr->intrinsic == nir_intrinsic_image_deref_atomic_fadd ||
+                intr->intrinsic == nir_intrinsic_image_deref_atomic ||
+                intr->intrinsic == nir_intrinsic_image_deref_atomic_swap ||
                 intr->intrinsic == nir_intrinsic_image_deref_size ||
                 intr->intrinsic == nir_intrinsic_image_deref_samples ||
                 intr->intrinsic == nir_intrinsic_image_deref_format ||
@@ -4441,9 +4347,10 @@ scan_nir(struct zink_screen *screen, nir_shader *shader, struct zink_shader *zs)
             static bool warned = false;
             if (!screen->info.have_EXT_shader_atomic_float && !screen->is_cpu && !warned) {
                switch (intr->intrinsic) {
-               case nir_intrinsic_image_deref_atomic_add: {
+               case nir_intrinsic_image_deref_atomic: {
                   nir_variable *var = nir_intrinsic_get_var(intr, 0);
-                  if (util_format_is_float(var->data.image.format))
+                  if (nir_intrinsic_atomic_op(intr) == nir_atomic_op_iadd &&
+                      util_format_is_float(var->data.image.format))
                      fprintf(stderr, "zink: Vulkan driver missing VK_EXT_shader_atomic_float but attempting to do atomic ops!\n");
                   break;
                }
@@ -4660,17 +4567,8 @@ type_image(nir_shader *nir, nir_variable *var)
             if (intr->intrinsic == nir_intrinsic_image_deref_load ||
                intr->intrinsic == nir_intrinsic_image_deref_sparse_load ||
                intr->intrinsic == nir_intrinsic_image_deref_store ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_add ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_imin ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_umin ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_imax ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_umax ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_and ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_or ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_xor ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_exchange ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_comp_swap ||
-               intr->intrinsic == nir_intrinsic_image_deref_atomic_fadd ||
+               intr->intrinsic == nir_intrinsic_image_deref_atomic ||
+               intr->intrinsic == nir_intrinsic_image_deref_atomic_swap ||
                intr->intrinsic == nir_intrinsic_image_deref_samples ||
                intr->intrinsic == nir_intrinsic_image_deref_format ||
                intr->intrinsic == nir_intrinsic_image_deref_order) {
@@ -4970,6 +4868,10 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
                                           nir_lower_terminate_if_to_cf));
    NIR_PASS_V(nir, nir_lower_fragcolor,
          nir->info.fs.color_is_dual_source ? 1 : 8);
+
+   /* Temporary stop gap until glsl-to-nir produces unified atomics */
+   NIR_PASS_V(nir, nir_lower_legacy_atomics);
+
    NIR_PASS_V(nir, lower_64bit_vertex_attribs);
    bool needs_size = analyze_io(ret, nir);
    NIR_PASS_V(nir, unbreak_bos, ret, needs_size);