nir/zink: use sysvals in `nir_create_passthrough_gs`
[platform/upstream/mesa.git] / src / compiler / nir / nir_passthrough_gs.c
1 /*
2  * Copyright © 2022 Collabora Ltc.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21  * SOFTWARE.
22  */
23
24 #include "nir.h"
25 #include "nir_xfb_info.h"
26 #include "nir_builder.h"
27 #include "util/u_memory.h"
28
29 static unsigned int
30 gs_in_prim_for_topology(enum shader_prim prim)
31 {
32    switch (prim) {
33    case SHADER_PRIM_QUADS:
34       return SHADER_PRIM_LINES_ADJACENCY;
35    default:
36       return prim;
37    }
38 }
39
40 static enum shader_prim
41 gs_out_prim_for_topology(enum shader_prim prim)
42 {
43    switch (prim) {
44    case SHADER_PRIM_POINTS:
45       return SHADER_PRIM_POINTS;
46    case SHADER_PRIM_LINES:
47    case SHADER_PRIM_LINE_LOOP:
48    case SHADER_PRIM_LINES_ADJACENCY:
49    case SHADER_PRIM_LINE_STRIP_ADJACENCY:
50    case SHADER_PRIM_LINE_STRIP:
51       return SHADER_PRIM_LINE_STRIP;
52    case SHADER_PRIM_TRIANGLES:
53    case SHADER_PRIM_TRIANGLE_STRIP:
54    case SHADER_PRIM_TRIANGLE_FAN:
55    case SHADER_PRIM_TRIANGLES_ADJACENCY:
56    case SHADER_PRIM_TRIANGLE_STRIP_ADJACENCY:
57    case SHADER_PRIM_POLYGON:
58       return SHADER_PRIM_TRIANGLE_STRIP;
59    case SHADER_PRIM_QUADS:
60    case SHADER_PRIM_QUAD_STRIP:
61    case SHADER_PRIM_PATCHES:
62    default:
63       return SHADER_PRIM_QUADS;
64    }
65 }
66
67 static unsigned int
68 vertices_for_prim(enum shader_prim prim)
69 {
70    switch (prim) {
71    case SHADER_PRIM_POINTS:
72       return 1;
73    case SHADER_PRIM_LINES:
74    case SHADER_PRIM_LINE_LOOP:
75    case SHADER_PRIM_LINES_ADJACENCY:
76    case SHADER_PRIM_LINE_STRIP_ADJACENCY:
77    case SHADER_PRIM_LINE_STRIP:
78       return 2;
79    case SHADER_PRIM_TRIANGLES:
80    case SHADER_PRIM_TRIANGLE_STRIP:
81    case SHADER_PRIM_TRIANGLE_FAN:
82    case SHADER_PRIM_TRIANGLES_ADJACENCY:
83    case SHADER_PRIM_TRIANGLE_STRIP_ADJACENCY:
84    case SHADER_PRIM_POLYGON:
85       return 3;
86    case SHADER_PRIM_QUADS:
87    case SHADER_PRIM_QUAD_STRIP:
88       return 4;
89    case SHADER_PRIM_PATCHES:
90    default:
91       unreachable("unsupported primitive for gs input");
92    }
93 }
94
95 static unsigned int
96 array_size_for_prim(enum shader_prim prim)
97 {
98    switch (prim) {
99    case SHADER_PRIM_POINTS:
100       return 1;
101    case SHADER_PRIM_LINES:
102    case SHADER_PRIM_LINE_LOOP:
103    case SHADER_PRIM_LINE_STRIP:
104       return 2;
105    case SHADER_PRIM_LINES_ADJACENCY:
106    case SHADER_PRIM_LINE_STRIP_ADJACENCY:
107       return 4;
108    case SHADER_PRIM_TRIANGLES:
109    case SHADER_PRIM_TRIANGLE_STRIP:
110    case SHADER_PRIM_TRIANGLE_FAN:
111    case SHADER_PRIM_POLYGON:
112       return 3;
113    case SHADER_PRIM_TRIANGLES_ADJACENCY:
114    case SHADER_PRIM_TRIANGLE_STRIP_ADJACENCY:
115       return 6;
116    case SHADER_PRIM_QUADS:
117    case SHADER_PRIM_QUAD_STRIP:
118       return 4;
119    case SHADER_PRIM_PATCHES:
120    default:
121       unreachable("unsupported primitive for gs input");
122    }
123 }
124
125 static void
126 copy_vars(nir_builder *b, nir_deref_instr *dst, nir_deref_instr *src)
127 {
128    assert(glsl_get_bare_type(dst->type) == glsl_get_bare_type(src->type));
129    if (glsl_type_is_struct(dst->type)) {
130       for (unsigned i = 0; i < glsl_get_length(dst->type); ++i) {
131          copy_vars(b, nir_build_deref_struct(b, dst, i), nir_build_deref_struct(b, src, i));
132       }
133    } else if (glsl_type_is_array_or_matrix(dst->type)) {
134       unsigned count = glsl_type_is_array(dst->type) ? glsl_array_size(dst->type) : glsl_get_matrix_columns(dst->type);
135       for (unsigned i = 0; i < count; i++) {
136          copy_vars(b, nir_build_deref_array_imm(b, dst, i), nir_build_deref_array_imm(b, src, i));
137       }
138    } else {
139       nir_ssa_def *load = nir_load_deref(b, src);
140       nir_store_deref(b, dst, load, BITFIELD_MASK(load->num_components));
141    }
142 }
143
144 /*
145  * A helper to create a passthrough GS shader for drivers that needs to lower
146  * some rendering tasks to the GS.
147  */
148
149 nir_shader *
150 nir_create_passthrough_gs(const nir_shader_compiler_options *options,
151                           const nir_shader *prev_stage,
152                           enum shader_prim primitive_type,
153                           bool emulate_edgeflags,
154                           bool force_line_strip_out)
155 {
156    unsigned int vertices_out = vertices_for_prim(primitive_type);
157    emulate_edgeflags = emulate_edgeflags && (prev_stage->info.outputs_written & VARYING_BIT_EDGE);
158    bool needs_closing = (force_line_strip_out || emulate_edgeflags) && vertices_out >= 3;
159    enum shader_prim original_our_prim = gs_out_prim_for_topology(primitive_type);
160    nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_GEOMETRY,
161                                                   options,
162                                                   "gs passthrough");
163
164    nir_shader *nir = b.shader;
165    nir->info.gs.input_primitive = gs_in_prim_for_topology(primitive_type);
166    nir->info.gs.output_primitive = (force_line_strip_out || emulate_edgeflags) ?
167       SHADER_PRIM_LINE_STRIP : original_our_prim;
168    nir->info.gs.vertices_in = vertices_out;
169    nir->info.gs.vertices_out = needs_closing ? vertices_out + 1 : vertices_out;
170    nir->info.gs.invocations = 1;
171    nir->info.gs.active_stream_mask = 1;
172
173    nir->info.has_transform_feedback_varyings = prev_stage->info.has_transform_feedback_varyings;
174    memcpy(nir->info.xfb_stride, prev_stage->info.xfb_stride, sizeof(prev_stage->info.xfb_stride));
175    if (prev_stage->xfb_info) {
176       nir->xfb_info = mem_dup(prev_stage->xfb_info, sizeof(nir_xfb_info));
177    }
178
179    bool handle_flat = nir->info.gs.output_primitive == SHADER_PRIM_LINE_STRIP &&
180                       nir->info.gs.output_primitive != original_our_prim;
181    nir_variable *in_vars[VARYING_SLOT_MAX];
182    nir_variable *out_vars[VARYING_SLOT_MAX];
183    unsigned num_inputs = 0, num_outputs = 0;
184
185    /* Create input/output variables. */
186    nir_foreach_shader_out_variable(var, prev_stage) {
187       assert(!var->data.patch);
188
189       char name[100];
190       if (var->name)
191          snprintf(name, sizeof(name), "in_%s", var->name);
192       else
193          snprintf(name, sizeof(name), "in_%d", var->data.driver_location);
194
195       nir_variable *in = nir_variable_create(nir, nir_var_shader_in,
196                                              glsl_array_type(var->type,
197                                                              array_size_for_prim(primitive_type),
198                                                              false),
199                                              name);
200       in->data.location = var->data.location;
201       in->data.location_frac = var->data.location_frac;
202       in->data.driver_location = var->data.driver_location;
203       in->data.interpolation = var->data.interpolation;
204       in->data.compact = var->data.compact;
205
206       in_vars[num_inputs++] = in;
207
208       nir->num_inputs++;
209       if (in->data.location == VARYING_SLOT_EDGE)
210          continue;
211
212       if (var->data.location != VARYING_SLOT_POS)
213          nir->num_outputs++;
214
215       if (var->name)
216          snprintf(name, sizeof(name), "out_%s", var->name);
217       else
218          snprintf(name, sizeof(name), "out_%d", var->data.driver_location);
219
220       nir_variable *out = nir_variable_create(nir, nir_var_shader_out,
221                                               var->type, name);
222       out->data.location = var->data.location;
223       out->data.location_frac = var->data.location_frac;
224       out->data.driver_location = var->data.driver_location;
225       out->data.interpolation = var->data.interpolation;
226       out->data.compact = var->data.compact;
227       out->data.is_xfb = var->data.is_xfb;
228       out->data.is_xfb_only = var->data.is_xfb_only;
229       out->data.explicit_xfb_buffer = var->data.explicit_xfb_buffer;
230       out->data.explicit_xfb_stride = var->data.explicit_xfb_stride;
231       out->data.xfb = var->data.xfb;
232       out->data.offset = var->data.offset;
233
234       out_vars[num_outputs++] = out;
235    }
236
237    unsigned int start_vert = 0;
238    unsigned int end_vert = vertices_out;
239    unsigned int vert_step = 1;
240    switch (primitive_type) {
241    case PIPE_PRIM_LINES_ADJACENCY:
242    case PIPE_PRIM_LINE_STRIP_ADJACENCY:
243       start_vert = 1;
244       end_vert += 1;
245       break;
246    case PIPE_PRIM_TRIANGLES_ADJACENCY:
247    case PIPE_PRIM_TRIANGLE_STRIP_ADJACENCY:
248       end_vert = 5;
249       vert_step = 2;
250       break;
251    default:
252       break;
253    }
254
255    nir_variable *edge_var = nir_find_variable_with_location(nir, nir_var_shader_in, VARYING_SLOT_EDGE);
256    nir_ssa_def *flat_interp_mask_def = nir_load_flat_mask(&b);
257    nir_ssa_def *last_pv_vert_def = nir_load_provoking_last(&b);
258    last_pv_vert_def = nir_ine_imm(&b, last_pv_vert_def, 0);
259    nir_ssa_def *start_vert_index = nir_imm_int(&b, start_vert);
260    nir_ssa_def *end_vert_index = nir_imm_int(&b, end_vert - 1);
261    nir_ssa_def *pv_vert_index = nir_bcsel(&b, last_pv_vert_def, end_vert_index, start_vert_index);
262    for (unsigned i = start_vert; i < end_vert || needs_closing; i += vert_step) {
263       int idx = i < end_vert ? i : start_vert;
264       /* Copy inputs to outputs. */
265       for (unsigned j = 0, oj = 0, of = 0; j < num_inputs; ++j) {
266          if (in_vars[j]->data.location == VARYING_SLOT_EDGE) {
267             continue;
268          }
269          /* no need to use copy_var to save a lower pass */
270          nir_ssa_def *index;
271          if (in_vars[j]->data.location == VARYING_SLOT_POS || !handle_flat)
272             index = nir_imm_int(&b, idx);
273          else {
274             unsigned mask = 1u << (of++);
275             index = nir_bcsel(&b, nir_ieq_imm(&b, nir_iand_imm(&b, flat_interp_mask_def, mask), 0), nir_imm_int(&b, idx), pv_vert_index);
276          }
277          nir_deref_instr *value = nir_build_deref_array(&b, nir_build_deref_var(&b, in_vars[j]), index);
278          copy_vars(&b, nir_build_deref_var(&b, out_vars[oj]), value);
279          ++oj;
280       }
281       nir_emit_vertex(&b, 0);
282       if (emulate_edgeflags) {
283          nir_ssa_def *edge_value = nir_channel(&b, nir_load_array_var_imm(&b, edge_var, idx), 0);
284          nir_if *edge_if = nir_push_if(&b, nir_fneu(&b, edge_value, nir_imm_float(&b, 1.0)));
285          nir_end_primitive(&b, 0);
286          nir_pop_if(&b, edge_if);
287       }
288       if (i >= end_vert)
289          break;
290    }
291
292    nir_end_primitive(&b, 0);
293    nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
294    nir_validate_shader(nir, "in nir_create_passthrough_gs");
295
296    return nir;
297 }