agx: Add interpolateAtOffset lowering pass
authorAlyssa Rosenzweig <alyssa@rosenzweig.io>
Fri, 4 Aug 2023 13:39:55 +0000 (09:39 -0400)
committerMarge Bot <emma+marge@anholt.net>
Fri, 11 Aug 2023 09:50:11 +0000 (09:50 +0000)
Add a lowering pass that lowers interpolation to math on the coefficient
registers. This handles interpolateAtOffset, as well as flat shading as an easy
special case.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24498>

src/asahi/compiler/agx_nir.h
src/asahi/compiler/agx_nir_lower_interpolation.c [new file with mode: 0644]
src/asahi/compiler/meson.build

index 1984599..db1ddea 100644 (file)
@@ -9,6 +9,7 @@
 
 struct nir_shader;
 
+bool agx_nir_lower_interpolation(struct nir_shader *s);
 bool agx_nir_opt_ixor_bcsel(struct nir_shader *shader);
 bool agx_nir_lower_algebraic_late(struct nir_shader *shader);
 bool agx_nir_fuse_algebraic_late(struct nir_shader *shader);
diff --git a/src/asahi/compiler/agx_nir_lower_interpolation.c b/src/asahi/compiler/agx_nir_lower_interpolation.c
new file mode 100644 (file)
index 0000000..afc006d
--- /dev/null
@@ -0,0 +1,164 @@
+/*
+ * Copyright 2023 Valve Corporation
+ * SPDX-License-Identifier: MIT
+ */
+
+#include "compiler/shader_enums.h"
+#include "agx_nir.h"
+#include "nir.h"
+#include "nir_builder.h"
+#include "nir_builder_opcodes.h"
+#include "nir_intrinsics.h"
+#include "nir_intrinsics_indices.h"
+
+/*
+ * In AGX, the values of fragment shader inputs are represented as coefficient
+ * vectors <A, B, C>, which are dotted with <x, y, 1> to perform interpolation.
+ * x and y are relative to the tile. In other words, A and B are the
+ * screen-space partial derivatives of the input, and C is the value at the
+ * corner of the tile.
+ *
+ * For some interpolation modes, the dot product happens in the iterator
+ * hardware. Other modes are implemented in this file, by lowering to math on
+ * the coefficient vectors.
+ */
+
+/* XXX: It's not clear what this is for, but seems necessary */
+static nir_ssa_def *
+cf_valid(nir_builder *b, nir_ssa_def *cf)
+{
+   nir_ssa_def *bit =
+      nir_ieq_imm(b, nir_iand_imm(b, nir_channel(b, cf, 0), 1), 0);
+
+   /* XXX: Apple's compiler actually checks that the significand is nonzero and
+    * the exponent is 0 or 1. This is probably a typo -- it doesn't make any
+    * logical sense.  Presumably they just meant to check for denorms, so let's
+    * do that. Either way the tests pass.
+    */
+   nir_ssa_def *cf01 = nir_trim_vector(b, cf, 2);
+   return nir_ior(b, bit, nir_fisnormal(b, cf01));
+}
+
+static nir_ssa_def *
+interpolate_at_offset(nir_builder *b, nir_ssa_def *cf, nir_ssa_def *offset,
+                      bool perspective)
+{
+   /* Get the coordinate of the pixel within the tile */
+   nir_ssa_def *pixel_coords = nir_load_pixel_coord(b);
+   nir_ssa_def *tile_offs = nir_umod_imm(b, pixel_coords, 32);
+
+   /* Convert to float, getting the center of the pixel */
+   nir_ssa_def *center = nir_fadd_imm(b, nir_u2f32(b, tile_offs), 0.5);
+
+   /* Calculate the location to interpolate. offset is defined relative to the
+    * center of the pixel and is a float.
+    */
+   nir_ssa_def *pos = nir_fadd(b, center, nir_f2f32(b, offset));
+
+   /* Interpolate with the given coefficients */
+   nir_ssa_def *interp = nir_ffma(b, nir_channel(b, pos, 1),
+                                  nir_channel(b, cf, 1), nir_channel(b, cf, 2));
+
+   interp = nir_ffma(b, nir_channel(b, pos, 0), nir_channel(b, cf, 0), interp);
+
+   /* Divide by RHW. This load will be lowered recursively. */
+   if (perspective) {
+      nir_ssa_def *bary = nir_load_barycentric_at_offset(
+         b, 32, offset, .interp_mode = INTERP_MODE_NOPERSPECTIVE);
+
+      nir_ssa_def *rhw = nir_load_interpolated_input(
+         b, 1, 32, bary, nir_imm_int(b, 0), .component = 3,
+         .io_semantics = {
+            .location = VARYING_SLOT_POS,
+            .num_slots = 1,
+         });
+
+      interp = nir_fdiv(b, interp, rhw);
+   }
+
+   /* Replace invalid interpolations with the constant channel */
+   return nir_bcsel(b, cf_valid(b, cf), interp, nir_channel(b, cf, 2));
+}
+
+static nir_ssa_def *
+interpolate_flat(nir_builder *b, nir_ssa_def *coefficients)
+{
+   /* Same value anywhere, so just take the constant (affine) component */
+   return nir_channel(b, coefficients, 2);
+}
+
+static enum glsl_interp_mode
+interp_mode_for_load(nir_intrinsic_instr *load)
+{
+   if (load->intrinsic == nir_intrinsic_load_input)
+      return INTERP_MODE_FLAT;
+   else
+      return nir_intrinsic_interp_mode(nir_src_as_intrinsic(load->src[0]));
+}
+
+static bool
+needs_lower(const nir_instr *instr, UNUSED const void *_)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+
+   const nir_intrinsic_instr *load = nir_instr_as_intrinsic(instr);
+
+   /* at_offset barycentrics need to be lowered */
+   if (load->intrinsic == nir_intrinsic_load_interpolated_input) {
+      return (nir_src_as_intrinsic(load->src[0])->intrinsic ==
+              nir_intrinsic_load_barycentric_at_offset);
+   }
+
+   /* Flat shading always lowered */
+   return (load->intrinsic == nir_intrinsic_load_input);
+}
+
+static nir_ssa_def *
+interpolate_channel(nir_builder *b, nir_intrinsic_instr *load, unsigned channel)
+{
+   nir_io_semantics sem = nir_intrinsic_io_semantics(load);
+
+   /* Indirect varyings not supported, just bias the location */
+   sem.location += nir_src_as_uint(*nir_get_io_offset_src(load));
+   sem.num_slots = 1;
+
+   nir_ssa_def *coefficients = nir_load_coefficients_agx(
+      b, .component = nir_intrinsic_component(load) + channel,
+      .interp_mode = interp_mode_for_load(load), .io_semantics = sem);
+
+   if (load->intrinsic == nir_intrinsic_load_input) {
+      assert(nir_dest_bit_size(load->dest) == 32);
+      return interpolate_flat(b, coefficients);
+   } else {
+      nir_intrinsic_instr *bary = nir_src_as_intrinsic(load->src[0]);
+
+      nir_ssa_def *interp = interpolate_at_offset(
+         b, coefficients, bary->src[0].ssa,
+         nir_intrinsic_interp_mode(bary) != INTERP_MODE_NOPERSPECTIVE);
+
+      return nir_f2fN(b, interp, nir_dest_bit_size(load->dest));
+   }
+}
+
+static nir_ssa_def *
+lower(nir_builder *b, nir_instr *instr, void *data)
+{
+   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+
+   /* Each component is loaded separated */
+   nir_ssa_def *values[NIR_MAX_VEC_COMPONENTS] = {NULL};
+   for (unsigned i = 0; i < nir_dest_num_components(intr->dest); ++i) {
+      values[i] = interpolate_channel(b, intr, i);
+   }
+
+   return nir_vec(b, values, nir_dest_num_components(intr->dest));
+}
+
+bool
+agx_nir_lower_interpolation(nir_shader *s)
+{
+   assert(s->info.stage == MESA_SHADER_FRAGMENT);
+
+   return nir_shader_lower_instructions(s, needs_lower, lower, NULL);
+}
index 57770de..9fbc904 100644 (file)
@@ -13,6 +13,7 @@ libasahi_agx_files = files(
   'agx_nir_lower_sample_mask.c',
   'agx_nir_lower_discard_zs_emit.c',
   'agx_nir_lower_texture.c',
+  'agx_nir_lower_interpolation.c',
   'agx_nir_lower_load_mask.c',
   'agx_nir_lower_shared_bitsize.c',
   'agx_nir_lower_ubo.c',