microsoft/compiler: Implement a few basic wave/subgroup intrinsics
authorJesse Natalie <jenatali@microsoft.com>
Wed, 18 Jan 2023 22:12:03 +0000 (14:12 -0800)
committerMarge Bot <emma+marge@anholt.net>
Fri, 20 Jan 2023 18:50:57 +0000 (18:50 +0000)
These are the ones that map perfectly between SPIR-V and DXIL that
are in the "basic" extension group (except for read-lane-first,
but we'll use with some lowering shortly).

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20777>

src/microsoft/compiler/dxil_function.c
src/microsoft/compiler/nir_to_dxil.c

index 3a34eec..4fdc8e3 100644 (file)
@@ -93,6 +93,10 @@ static struct  predefined_func_descr predefined_funcs[] = {
 {"dx.op.createHandleFromBinding", "@", "i#ib", DXIL_ATTR_KIND_READ_NONE},
 {"dx.op.annotateHandle", "@", "i@P", DXIL_ATTR_KIND_READ_NONE},
 {"dx.op.isHelperLane", "b", "i", DXIL_ATTR_KIND_READ_ONLY},
+{"dx.op.waveIsFirstLane", "b", "i", DXIL_ATTR_KIND_NO_UNWIND},
+{"dx.op.waveGetLaneIndex", "i", "i", DXIL_ATTR_KIND_READ_NONE},
+{"dx.op.waveGetLaneCount", "i", "i", DXIL_ATTR_KIND_READ_NONE},
+{"dx.op.waveReadLaneFirst", "O", "iO", DXIL_ATTR_KIND_NO_UNWIND},
 };
 
 struct func_descr {
index 03d0038..b2f2ebc 100644 (file)
@@ -329,6 +329,11 @@ enum dxil_intr {
    DXIL_INTR_OUTPUT_CONTROL_POINT_ID = 107,
    DXIL_INTR_PRIMITIVE_ID = 108,
 
+   DXIL_INTR_WAVE_IS_FIRST_LANE = 110,
+   DXIL_INTR_WAVE_GET_LANE_INDEX = 111,
+   DXIL_INTR_WAVE_GET_LANE_COUNT = 112,
+   DXIL_INTR_WAVE_READ_LANE_FIRST = 118,
+
    DXIL_INTR_LEGACY_F32TOF16 = 130,
    DXIL_INTR_LEGACY_F16TOF32 = 131,
 
@@ -4395,6 +4400,26 @@ emit_load_sample_id(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 }
 
 static bool
+emit_read_first_invocation(struct ntd_context *ctx, nir_intrinsic_instr *intr)
+{
+   ctx->mod.feats.wave_ops = 1;
+   const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.waveReadLaneFirst",
+                                                    get_overload(nir_type_uint, intr->dest.ssa.bit_size));
+   const struct dxil_value *args[] = {
+      dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_WAVE_READ_LANE_FIRST),
+      get_src(ctx, intr->src, 0, nir_type_uint),
+   };
+   if (!func || !args[0] || !args[1])
+      return false;
+
+   const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
+   if (!ret)
+      return false;
+   store_dest_value(ctx, &intr->dest, 0, ret);
+   return true;
+}
+
+static bool
 emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 {
    switch (intr->intrinsic) {
@@ -4588,6 +4613,21 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
    case nir_intrinsic_is_helper_invocation:
       return emit_load_unary_external_function(
          ctx, intr, "dx.op.isHelperLane", DXIL_INTR_IS_HELPER_LANE, DXIL_I32);
+   case nir_intrinsic_elect:
+      ctx->mod.feats.wave_ops = 1;
+      return emit_load_unary_external_function(
+         ctx, intr, "dx.op.waveIsFirstLane", DXIL_INTR_WAVE_IS_FIRST_LANE, DXIL_NONE);
+   case nir_intrinsic_load_subgroup_size:
+      ctx->mod.feats.wave_ops = 1;
+      return emit_load_unary_external_function(
+         ctx, intr, "dx.op.waveGetLaneCount", DXIL_INTR_WAVE_GET_LANE_COUNT, DXIL_NONE);
+   case nir_intrinsic_load_subgroup_invocation:
+      ctx->mod.feats.wave_ops = 1;
+      return emit_load_unary_external_function(
+         ctx, intr, "dx.op.waveGetLaneIndex", DXIL_INTR_WAVE_GET_LANE_INDEX, DXIL_NONE);
+
+   case nir_intrinsic_read_first_invocation:
+      return emit_read_first_invocation(ctx, intr);
 
    case nir_intrinsic_load_num_workgroups:
    case nir_intrinsic_load_workgroup_size: