nir/lower_int64: fix shift lowering
[platform/upstream/mesa.git] / src / compiler / nir / nir_lower_int64.c
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 #define COND_LOWER_OP(b, name, ...)                                   \
28         (b->shader->options->lower_int64_options &                    \
29          nir_lower_int64_op_to_options_mask(nir_op_##name)) ?         \
30         lower_##name##64(b, __VA_ARGS__) : nir_##name(b, __VA_ARGS__)
31
32 #define COND_LOWER_CMP(b, name, ...)                                  \
33         (b->shader->options->lower_int64_options &                    \
34          nir_lower_int64_op_to_options_mask(nir_op_##name)) ?         \
35         lower_int64_compare(b, nir_op_##name, __VA_ARGS__) :          \
36         nir_##name(b, __VA_ARGS__)
37
38 #define COND_LOWER_CAST(b, name, ...)                                 \
39         (b->shader->options->lower_int64_options &                    \
40          nir_lower_int64_op_to_options_mask(nir_op_##name)) ?         \
41         lower_##name(b, __VA_ARGS__) :                                \
42         nir_##name(b, __VA_ARGS__)
43
44 static nir_ssa_def *
45 lower_b2i64(nir_builder *b, nir_ssa_def *x)
46 {
47    return nir_pack_64_2x32_split(b, nir_b2i32(b, x), nir_imm_int(b, 0));
48 }
49
50 static nir_ssa_def *
51 lower_i2b(nir_builder *b, nir_ssa_def *x)
52 {
53    return nir_ine(b, nir_ior(b, nir_unpack_64_2x32_split_x(b, x),
54                                 nir_unpack_64_2x32_split_y(b, x)),
55                      nir_imm_int(b, 0));
56 }
57
58 static nir_ssa_def *
59 lower_i2i8(nir_builder *b, nir_ssa_def *x)
60 {
61    return nir_i2i8(b, nir_unpack_64_2x32_split_x(b, x));
62 }
63
64 static nir_ssa_def *
65 lower_i2i16(nir_builder *b, nir_ssa_def *x)
66 {
67    return nir_i2i16(b, nir_unpack_64_2x32_split_x(b, x));
68 }
69
70
71 static nir_ssa_def *
72 lower_i2i32(nir_builder *b, nir_ssa_def *x)
73 {
74    return nir_unpack_64_2x32_split_x(b, x);
75 }
76
77 static nir_ssa_def *
78 lower_i2i64(nir_builder *b, nir_ssa_def *x)
79 {
80    nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
81    return nir_pack_64_2x32_split(b, x32, nir_ishr_imm(b, x32, 31));
82 }
83
84 static nir_ssa_def *
85 lower_u2u8(nir_builder *b, nir_ssa_def *x)
86 {
87    return nir_u2u8(b, nir_unpack_64_2x32_split_x(b, x));
88 }
89
90 static nir_ssa_def *
91 lower_u2u16(nir_builder *b, nir_ssa_def *x)
92 {
93    return nir_u2u16(b, nir_unpack_64_2x32_split_x(b, x));
94 }
95
96 static nir_ssa_def *
97 lower_u2u32(nir_builder *b, nir_ssa_def *x)
98 {
99    return nir_unpack_64_2x32_split_x(b, x);
100 }
101
102 static nir_ssa_def *
103 lower_u2u64(nir_builder *b, nir_ssa_def *x)
104 {
105    nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_u2u32(b, x);
106    return nir_pack_64_2x32_split(b, x32, nir_imm_int(b, 0));
107 }
108
109 static nir_ssa_def *
110 lower_bcsel64(nir_builder *b, nir_ssa_def *cond, nir_ssa_def *x, nir_ssa_def *y)
111 {
112    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
113    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
114    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
115    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
116
117    return nir_pack_64_2x32_split(b, nir_bcsel(b, cond, x_lo, y_lo),
118                                     nir_bcsel(b, cond, x_hi, y_hi));
119 }
120
121 static nir_ssa_def *
122 lower_inot64(nir_builder *b, nir_ssa_def *x)
123 {
124    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
125    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
126
127    return nir_pack_64_2x32_split(b, nir_inot(b, x_lo), nir_inot(b, x_hi));
128 }
129
130 static nir_ssa_def *
131 lower_iand64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
132 {
133    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
134    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
135    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
136    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
137
138    return nir_pack_64_2x32_split(b, nir_iand(b, x_lo, y_lo),
139                                     nir_iand(b, x_hi, y_hi));
140 }
141
142 static nir_ssa_def *
143 lower_ior64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
144 {
145    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
146    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
147    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
148    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
149
150    return nir_pack_64_2x32_split(b, nir_ior(b, x_lo, y_lo),
151                                     nir_ior(b, x_hi, y_hi));
152 }
153
154 static nir_ssa_def *
155 lower_ixor64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
156 {
157    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
158    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
159    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
160    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
161
162    return nir_pack_64_2x32_split(b, nir_ixor(b, x_lo, y_lo),
163                                     nir_ixor(b, x_hi, y_hi));
164 }
165
166 static nir_ssa_def *
167 lower_ishl64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
168 {
169    /* Implemented as
170     *
171     * uint64_t lshift(uint64_t x, int c)
172     * {
173     *    c %= 64;
174     *
175     *    if (c == 0) return x;
176     *
177     *    uint32_t lo = LO(x), hi = HI(x);
178     *
179     *    if (c < 32) {
180     *       uint32_t lo_shifted = lo << c;
181     *       uint32_t hi_shifted = hi << c;
182     *       uint32_t lo_shifted_hi = lo >> abs(32 - c);
183     *       return pack_64(lo_shifted, hi_shifted | lo_shifted_hi);
184     *    } else {
185     *       uint32_t lo_shifted_hi = lo << abs(32 - c);
186     *       return pack_64(0, lo_shifted_hi);
187     *    }
188     * }
189     */
190    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
191    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
192    y = nir_iand_imm(b, y, 0x3f);
193
194    nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
195    nir_ssa_def *lo_shifted = nir_ishl(b, x_lo, y);
196    nir_ssa_def *hi_shifted = nir_ishl(b, x_hi, y);
197    nir_ssa_def *lo_shifted_hi = nir_ushr(b, x_lo, reverse_count);
198
199    nir_ssa_def *res_if_lt_32 =
200       nir_pack_64_2x32_split(b, lo_shifted,
201                                 nir_ior(b, hi_shifted, lo_shifted_hi));
202    nir_ssa_def *res_if_ge_32 =
203       nir_pack_64_2x32_split(b, nir_imm_int(b, 0),
204                                 nir_ishl(b, x_lo, reverse_count));
205
206    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
207                     nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
208                                  res_if_ge_32, res_if_lt_32));
209 }
210
211 static nir_ssa_def *
212 lower_ishr64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
213 {
214    /* Implemented as
215     *
216     * uint64_t arshift(uint64_t x, int c)
217     * {
218     *    c %= 64;
219     *
220     *    if (c == 0) return x;
221     *
222     *    uint32_t lo = LO(x);
223     *    int32_t  hi = HI(x);
224     *
225     *    if (c < 32) {
226     *       uint32_t lo_shifted = lo >> c;
227     *       uint32_t hi_shifted = hi >> c;
228     *       uint32_t hi_shifted_lo = hi << abs(32 - c);
229     *       return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
230     *    } else {
231     *       uint32_t hi_shifted = hi >> 31;
232     *       uint32_t hi_shifted_lo = hi >> abs(32 - c);
233     *       return pack_64(hi_shifted, hi_shifted_lo);
234     *    }
235     * }
236     */
237    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
238    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
239    y = nir_iand_imm(b, y, 0x3f);
240
241    nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
242    nir_ssa_def *lo_shifted = nir_ushr(b, x_lo, y);
243    nir_ssa_def *hi_shifted = nir_ishr(b, x_hi, y);
244    nir_ssa_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
245
246    nir_ssa_def *res_if_lt_32 =
247       nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
248                                 hi_shifted);
249    nir_ssa_def *res_if_ge_32 =
250       nir_pack_64_2x32_split(b, nir_ishr(b, x_hi, reverse_count),
251                                 nir_ishr(b, x_hi, nir_imm_int(b, 31)));
252
253    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
254                     nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
255                                  res_if_ge_32, res_if_lt_32));
256 }
257
258 static nir_ssa_def *
259 lower_ushr64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
260 {
261    /* Implemented as
262     *
263     * uint64_t rshift(uint64_t x, int c)
264     * {
265     *    c %= 64;
266     *
267     *    if (c == 0) return x;
268     *
269     *    uint32_t lo = LO(x), hi = HI(x);
270     *
271     *    if (c < 32) {
272     *       uint32_t lo_shifted = lo >> c;
273     *       uint32_t hi_shifted = hi >> c;
274     *       uint32_t hi_shifted_lo = hi << abs(32 - c);
275     *       return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
276     *    } else {
277     *       uint32_t hi_shifted_lo = hi >> abs(32 - c);
278     *       return pack_64(0, hi_shifted_lo);
279     *    }
280     * }
281     */
282
283    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
284    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
285    y = nir_iand_imm(b, y, 0x3f);
286
287    nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
288    nir_ssa_def *lo_shifted = nir_ushr(b, x_lo, y);
289    nir_ssa_def *hi_shifted = nir_ushr(b, x_hi, y);
290    nir_ssa_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
291
292    nir_ssa_def *res_if_lt_32 =
293       nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
294                                 hi_shifted);
295    nir_ssa_def *res_if_ge_32 =
296       nir_pack_64_2x32_split(b, nir_ushr(b, x_hi, reverse_count),
297                                 nir_imm_int(b, 0));
298
299    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
300                     nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
301                                  res_if_ge_32, res_if_lt_32));
302 }
303
304 static nir_ssa_def *
305 lower_iadd64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
306 {
307    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
308    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
309    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
310    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
311
312    nir_ssa_def *res_lo = nir_iadd(b, x_lo, y_lo);
313    nir_ssa_def *carry = nir_b2i32(b, nir_ult(b, res_lo, x_lo));
314    nir_ssa_def *res_hi = nir_iadd(b, carry, nir_iadd(b, x_hi, y_hi));
315
316    return nir_pack_64_2x32_split(b, res_lo, res_hi);
317 }
318
319 static nir_ssa_def *
320 lower_isub64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
321 {
322    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
323    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
324    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
325    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
326
327    nir_ssa_def *res_lo = nir_isub(b, x_lo, y_lo);
328    nir_ssa_def *borrow = nir_ineg(b, nir_b2i32(b, nir_ult(b, x_lo, y_lo)));
329    nir_ssa_def *res_hi = nir_iadd(b, nir_isub(b, x_hi, y_hi), borrow);
330
331    return nir_pack_64_2x32_split(b, res_lo, res_hi);
332 }
333
334 static nir_ssa_def *
335 lower_ineg64(nir_builder *b, nir_ssa_def *x)
336 {
337    /* Since isub is the same number of instructions (with better dependencies)
338     * as iadd, subtraction is actually more efficient for ineg than the usual
339     * 2's complement "flip the bits and add one".
340     */
341    return lower_isub64(b, nir_imm_int64(b, 0), x);
342 }
343
344 static nir_ssa_def *
345 lower_iabs64(nir_builder *b, nir_ssa_def *x)
346 {
347    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
348    nir_ssa_def *x_is_neg = nir_ilt(b, x_hi, nir_imm_int(b, 0));
349    return nir_bcsel(b, x_is_neg, nir_ineg(b, x), x);
350 }
351
352 static nir_ssa_def *
353 lower_int64_compare(nir_builder *b, nir_op op, nir_ssa_def *x, nir_ssa_def *y)
354 {
355    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
356    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
357    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
358    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
359
360    switch (op) {
361    case nir_op_ieq:
362       return nir_iand(b, nir_ieq(b, x_hi, y_hi), nir_ieq(b, x_lo, y_lo));
363    case nir_op_ine:
364       return nir_ior(b, nir_ine(b, x_hi, y_hi), nir_ine(b, x_lo, y_lo));
365    case nir_op_ult:
366       return nir_ior(b, nir_ult(b, x_hi, y_hi),
367                         nir_iand(b, nir_ieq(b, x_hi, y_hi),
368                                     nir_ult(b, x_lo, y_lo)));
369    case nir_op_ilt:
370       return nir_ior(b, nir_ilt(b, x_hi, y_hi),
371                         nir_iand(b, nir_ieq(b, x_hi, y_hi),
372                                     nir_ult(b, x_lo, y_lo)));
373       break;
374    case nir_op_uge:
375       /* Lower as !(x < y) in the hopes of better CSE */
376       return nir_inot(b, lower_int64_compare(b, nir_op_ult, x, y));
377    case nir_op_ige:
378       /* Lower as !(x < y) in the hopes of better CSE */
379       return nir_inot(b, lower_int64_compare(b, nir_op_ilt, x, y));
380    default:
381       unreachable("Invalid comparison");
382    }
383 }
384
385 static nir_ssa_def *
386 lower_umax64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
387 {
388    return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), y, x);
389 }
390
391 static nir_ssa_def *
392 lower_imax64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
393 {
394    return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), y, x);
395 }
396
397 static nir_ssa_def *
398 lower_umin64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
399 {
400    return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), x, y);
401 }
402
403 static nir_ssa_def *
404 lower_imin64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
405 {
406    return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), x, y);
407 }
408
409 static nir_ssa_def *
410 lower_mul_2x32_64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
411                   bool sign_extend)
412 {
413    nir_ssa_def *res_hi = sign_extend ? nir_imul_high(b, x, y)
414                                      : nir_umul_high(b, x, y);
415
416    return nir_pack_64_2x32_split(b, nir_imul(b, x, y), res_hi);
417 }
418
419 static nir_ssa_def *
420 lower_imul64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
421 {
422    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
423    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
424    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
425    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
426
427    nir_ssa_def *mul_lo = nir_umul_2x32_64(b, x_lo, y_lo);
428    nir_ssa_def *res_hi = nir_iadd(b, nir_unpack_64_2x32_split_y(b, mul_lo),
429                          nir_iadd(b, nir_imul(b, x_lo, y_hi),
430                                      nir_imul(b, x_hi, y_lo)));
431
432    return nir_pack_64_2x32_split(b, nir_unpack_64_2x32_split_x(b, mul_lo),
433                                  res_hi);
434 }
435
436 static nir_ssa_def *
437 lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
438                  bool sign_extend)
439 {
440    nir_ssa_def *x32[4], *y32[4];
441    x32[0] = nir_unpack_64_2x32_split_x(b, x);
442    x32[1] = nir_unpack_64_2x32_split_y(b, x);
443    if (sign_extend) {
444       x32[2] = x32[3] = nir_ishr_imm(b, x32[1], 31);
445    } else {
446       x32[2] = x32[3] = nir_imm_int(b, 0);
447    }
448
449    y32[0] = nir_unpack_64_2x32_split_x(b, y);
450    y32[1] = nir_unpack_64_2x32_split_y(b, y);
451    if (sign_extend) {
452       y32[2] = y32[3] = nir_ishr_imm(b, y32[1], 31);
453    } else {
454       y32[2] = y32[3] = nir_imm_int(b, 0);
455    }
456
457    nir_ssa_def *res[8] = { NULL, };
458
459    /* Yes, the following generates a pile of code.  However, we throw res[0]
460     * and res[1] away in the end and, if we're in the umul case, four of our
461     * eight dword operands will be constant zero and opt_algebraic will clean
462     * this up nicely.
463     */
464    for (unsigned i = 0; i < 4; i++) {
465       nir_ssa_def *carry = NULL;
466       for (unsigned j = 0; j < 4; j++) {
467          /* The maximum values of x32[i] and y32[j] are UINT32_MAX so the
468           * maximum value of tmp is UINT32_MAX * UINT32_MAX.  The maximum
469           * value that will fit in tmp is
470           *
471           *    UINT64_MAX = UINT32_MAX << 32 + UINT32_MAX
472           *               = UINT32_MAX * (UINT32_MAX + 1) + UINT32_MAX
473           *               = UINT32_MAX * UINT32_MAX + 2 * UINT32_MAX
474           *
475           * so we're guaranteed that we can add in two more 32-bit values
476           * without overflowing tmp.
477           */
478          nir_ssa_def *tmp = nir_umul_2x32_64(b, x32[i], y32[j]);
479
480          if (res[i + j])
481             tmp = nir_iadd(b, tmp, nir_u2u64(b, res[i + j]));
482          if (carry)
483             tmp = nir_iadd(b, tmp, carry);
484          res[i + j] = nir_u2u32(b, tmp);
485          carry = nir_ushr_imm(b, tmp, 32);
486       }
487       res[i + 4] = nir_u2u32(b, carry);
488    }
489
490    return nir_pack_64_2x32_split(b, res[2], res[3]);
491 }
492
493 static nir_ssa_def *
494 lower_isign64(nir_builder *b, nir_ssa_def *x)
495 {
496    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
497    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
498
499    nir_ssa_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
500    nir_ssa_def *res_hi = nir_ishr_imm(b, x_hi, 31);
501    nir_ssa_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
502
503    return nir_pack_64_2x32_split(b, res_lo, res_hi);
504 }
505
506 static void
507 lower_udiv64_mod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d,
508                    nir_ssa_def **q, nir_ssa_def **r)
509 {
510    /* TODO: We should specially handle the case where the denominator is a
511     * constant.  In that case, we should be able to reduce it to a multiply by
512     * a constant, some shifts, and an add.
513     */
514    nir_ssa_def *n_lo = nir_unpack_64_2x32_split_x(b, n);
515    nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
516    nir_ssa_def *d_lo = nir_unpack_64_2x32_split_x(b, d);
517    nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
518
519    nir_ssa_def *q_lo = nir_imm_zero(b, n->num_components, 32);
520    nir_ssa_def *q_hi = nir_imm_zero(b, n->num_components, 32);
521
522    nir_ssa_def *n_hi_before_if = n_hi;
523    nir_ssa_def *q_hi_before_if = q_hi;
524
525    /* If the upper 32 bits of denom are non-zero, it is impossible for shifts
526     * greater than 32 bits to occur.  If the upper 32 bits of the numerator
527     * are zero, it is impossible for (denom << [63, 32]) <= numer unless
528     * denom == 0.
529     */
530    nir_ssa_def *need_high_div =
531       nir_iand(b, nir_ieq_imm(b, d_hi, 0), nir_uge(b, n_hi, d_lo));
532    nir_push_if(b, nir_bany(b, need_high_div));
533    {
534       /* If we only have one component, then the bany above goes away and
535        * this is always true within the if statement.
536        */
537       if (n->num_components == 1)
538          need_high_div = nir_imm_true(b);
539
540       nir_ssa_def *log2_d_lo = nir_ufind_msb(b, d_lo);
541
542       for (int i = 31; i >= 0; i--) {
543          /* if ((d.x << i) <= n.y) {
544           *    n.y -= d.x << i;
545           *    quot.y |= 1U << i;
546           * }
547           */
548          nir_ssa_def *d_shift = nir_ishl(b, d_lo, nir_imm_int(b, i));
549          nir_ssa_def *new_n_hi = nir_isub(b, n_hi, d_shift);
550          nir_ssa_def *new_q_hi = nir_ior(b, q_hi, nir_imm_int(b, 1u << i));
551          nir_ssa_def *cond = nir_iand(b, need_high_div,
552                                          nir_uge(b, n_hi, d_shift));
553          if (i != 0) {
554             /* log2_d_lo is always <= 31, so we don't need to bother with it
555              * in the last iteration.
556              */
557             cond = nir_iand(b, cond,
558                                nir_ige(b, nir_imm_int(b, 31 - i), log2_d_lo));
559          }
560          n_hi = nir_bcsel(b, cond, new_n_hi, n_hi);
561          q_hi = nir_bcsel(b, cond, new_q_hi, q_hi);
562       }
563    }
564    nir_pop_if(b, NULL);
565    n_hi = nir_if_phi(b, n_hi, n_hi_before_if);
566    q_hi = nir_if_phi(b, q_hi, q_hi_before_if);
567
568    nir_ssa_def *log2_denom = nir_ufind_msb(b, d_hi);
569
570    n = nir_pack_64_2x32_split(b, n_lo, n_hi);
571    d = nir_pack_64_2x32_split(b, d_lo, d_hi);
572    for (int i = 31; i >= 0; i--) {
573       /* if ((d64 << i) <= n64) {
574        *    n64 -= d64 << i;
575        *    quot.x |= 1U << i;
576        * }
577        */
578       nir_ssa_def *d_shift = nir_ishl(b, d, nir_imm_int(b, i));
579       nir_ssa_def *new_n = nir_isub(b, n, d_shift);
580       nir_ssa_def *new_q_lo = nir_ior(b, q_lo, nir_imm_int(b, 1u << i));
581       nir_ssa_def *cond = nir_uge(b, n, d_shift);
582       if (i != 0) {
583          /* log2_denom is always <= 31, so we don't need to bother with it
584           * in the last iteration.
585           */
586          cond = nir_iand(b, cond,
587                             nir_ige(b, nir_imm_int(b, 31 - i), log2_denom));
588       }
589       n = nir_bcsel(b, cond, new_n, n);
590       q_lo = nir_bcsel(b, cond, new_q_lo, q_lo);
591    }
592
593    *q = nir_pack_64_2x32_split(b, q_lo, q_hi);
594    *r = n;
595 }
596
597 static nir_ssa_def *
598 lower_udiv64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
599 {
600    nir_ssa_def *q, *r;
601    lower_udiv64_mod64(b, n, d, &q, &r);
602    return q;
603 }
604
605 static nir_ssa_def *
606 lower_idiv64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
607 {
608    nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
609    nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
610
611    nir_ssa_def *negate = nir_ine(b, nir_ilt(b, n_hi, nir_imm_int(b, 0)),
612                                     nir_ilt(b, d_hi, nir_imm_int(b, 0)));
613    nir_ssa_def *q, *r;
614    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
615    return nir_bcsel(b, negate, nir_ineg(b, q), q);
616 }
617
618 static nir_ssa_def *
619 lower_umod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
620 {
621    nir_ssa_def *q, *r;
622    lower_udiv64_mod64(b, n, d, &q, &r);
623    return r;
624 }
625
626 static nir_ssa_def *
627 lower_imod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
628 {
629    nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
630    nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
631    nir_ssa_def *n_is_neg = nir_ilt(b, n_hi, nir_imm_int(b, 0));
632    nir_ssa_def *d_is_neg = nir_ilt(b, d_hi, nir_imm_int(b, 0));
633
634    nir_ssa_def *q, *r;
635    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
636
637    nir_ssa_def *rem = nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
638
639    return nir_bcsel(b, nir_ieq_imm(b, r, 0), nir_imm_int64(b, 0),
640           nir_bcsel(b, nir_ieq(b, n_is_neg, d_is_neg), rem,
641                        nir_iadd(b, rem, d)));
642 }
643
644 static nir_ssa_def *
645 lower_irem64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
646 {
647    nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
648    nir_ssa_def *n_is_neg = nir_ilt(b, n_hi, nir_imm_int(b, 0));
649
650    nir_ssa_def *q, *r;
651    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
652    return nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
653 }
654
655 static nir_ssa_def *
656 lower_extract(nir_builder *b, nir_op op, nir_ssa_def *x, nir_ssa_def *c)
657 {
658    assert(op == nir_op_extract_u8 || op == nir_op_extract_i8 ||
659           op == nir_op_extract_u16 || op == nir_op_extract_i16);
660
661    const int chunk = nir_src_as_uint(nir_src_for_ssa(c));
662    const int chunk_bits =
663       (op == nir_op_extract_u8 || op == nir_op_extract_i8) ? 8 : 16;
664    const int num_chunks_in_32 = 32 / chunk_bits;
665
666    nir_ssa_def *extract32;
667    if (chunk < num_chunks_in_32) {
668       extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_x(b, x),
669                                    nir_imm_int(b, chunk),
670                                    NULL, NULL);
671    } else {
672       extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_y(b, x),
673                                    nir_imm_int(b, chunk - num_chunks_in_32),
674                                    NULL, NULL);
675    }
676
677    if (op == nir_op_extract_i8 || op == nir_op_extract_i16)
678       return lower_i2i64(b, extract32);
679    else
680       return lower_u2u64(b, extract32);
681 }
682
683 static nir_ssa_def *
684 lower_ufind_msb64(nir_builder *b, nir_ssa_def *x)
685 {
686
687    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
688    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
689    nir_ssa_def *lo_count = nir_ufind_msb(b, x_lo);
690    nir_ssa_def *hi_count = nir_ufind_msb(b, x_hi);
691    nir_ssa_def *valid_hi_bits = nir_ine(b, x_hi, nir_imm_int(b, 0));
692    nir_ssa_def *hi_res = nir_iadd(b, nir_imm_intN_t(b, 32, 32), hi_count);
693    return nir_bcsel(b, valid_hi_bits, hi_res, lo_count);
694 }
695
696 static nir_ssa_def *
697 lower_2f(nir_builder *b, nir_ssa_def *x, unsigned dest_bit_size,
698          bool src_is_signed)
699 {
700    nir_ssa_def *x_sign = NULL;
701
702    if (src_is_signed) {
703       x_sign = nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, nir_imm_int64(b, 0)),
704                          nir_imm_floatN_t(b, -1, dest_bit_size),
705                          nir_imm_floatN_t(b, 1, dest_bit_size));
706       x = COND_LOWER_OP(b, iabs, x);
707    }
708
709    nir_ssa_def *exp = COND_LOWER_OP(b, ufind_msb, x);
710    unsigned significand_bits;
711
712    switch (dest_bit_size) {
713    case 64:
714       significand_bits = 52;
715       break;
716    case 32:
717       significand_bits = 23;
718       break;
719    case 16:
720       significand_bits = 10;
721       break;
722    default:
723       unreachable("Invalid dest_bit_size");
724    }
725
726    nir_ssa_def *discard =
727       nir_imax(b, nir_isub(b, exp, nir_imm_int(b, significand_bits)),
728                   nir_imm_int(b, 0));
729    nir_ssa_def *significand = COND_LOWER_OP(b, ushr, x, discard);
730    if (significand_bits < 32)
731       significand = COND_LOWER_CAST(b, u2u32, significand);
732
733    /* Round-to-nearest-even implementation:
734     * - if the non-representable part of the significand is higher than half
735     *   the minimum representable significand, we round-up
736     * - if the non-representable part of the significand is equal to half the
737     *   minimum representable significand and the representable part of the
738     *   significand is odd, we round-up
739     * - in any other case, we round-down
740     */
741    nir_ssa_def *lsb_mask = COND_LOWER_OP(b, ishl, nir_imm_int64(b, 1), discard);
742    nir_ssa_def *rem_mask = COND_LOWER_OP(b, isub, lsb_mask, nir_imm_int64(b, 1));
743    nir_ssa_def *half = COND_LOWER_OP(b, ishr, lsb_mask, nir_imm_int(b, 1));
744    nir_ssa_def *rem = COND_LOWER_OP(b, iand, x, rem_mask);
745    nir_ssa_def *halfway = nir_iand(b, COND_LOWER_CMP(b, ieq, rem, half),
746                                    nir_ine(b, discard, nir_imm_int(b, 0)));
747    nir_ssa_def *is_odd = COND_LOWER_CMP(b, ine, nir_imm_int64(b, 0),
748                                          COND_LOWER_OP(b, iand, x, lsb_mask));
749    nir_ssa_def *round_up = nir_ior(b, COND_LOWER_CMP(b, ilt, half, rem),
750                                    nir_iand(b, halfway, is_odd));
751    if (significand_bits >= 32)
752       significand = COND_LOWER_OP(b, iadd, significand,
753                                   COND_LOWER_CAST(b, b2i64, round_up));
754    else
755       significand = nir_iadd(b, significand, nir_b2i32(b, round_up));
756
757    nir_ssa_def *res;
758
759    if (dest_bit_size == 64) {
760       /* Compute the left shift required to normalize the original
761        * unrounded input manually.
762        */
763       nir_ssa_def *shift =
764          nir_imax(b, nir_isub(b, nir_imm_int(b, significand_bits), exp),
765                   nir_imm_int(b, 0));
766       significand = COND_LOWER_OP(b, ishl, significand, shift);
767
768       /* Check whether normalization led to overflow of the available
769        * significand bits, which can only happen if round_up was true
770        * above, in which case we need to add carry to the exponent and
771        * discard an extra bit from the significand.  Note that we
772        * don't need to repeat the round-up logic again, since the LSB
773        * of the significand is guaranteed to be zero if there was
774        * overflow.
775        */
776       nir_ssa_def *carry = nir_b2i32(
777          b, nir_uge(b, nir_unpack_64_2x32_split_y(b, significand),
778                     nir_imm_int(b, 1 << (significand_bits - 31))));
779       significand = COND_LOWER_OP(b, ishr, significand, carry);
780       exp = nir_iadd(b, exp, carry);
781
782       /* Compute the biased exponent, taking care to handle a zero
783        * input correctly, which would have caused exp to be negative.
784        */
785       nir_ssa_def *biased_exp = nir_bcsel(b, nir_ilt(b, exp, nir_imm_int(b, 0)),
786                                           nir_imm_int(b, 0),
787                                           nir_iadd(b, exp, nir_imm_int(b, 1023)));
788
789       /* Pack the significand and exponent manually. */
790       nir_ssa_def *lo = nir_unpack_64_2x32_split_x(b, significand);
791       nir_ssa_def *hi = nir_bitfield_insert(
792          b, nir_unpack_64_2x32_split_y(b, significand),
793          biased_exp, nir_imm_int(b, 20), nir_imm_int(b, 11));
794
795       res = nir_pack_64_2x32_split(b, lo, hi);
796
797    } else if (dest_bit_size == 32) {
798       res = nir_fmul(b, nir_u2f32(b, significand),
799                      nir_fexp2(b, nir_u2f32(b, discard)));
800    } else {
801       res = nir_fmul(b, nir_u2f16(b, significand),
802                      nir_fexp2(b, nir_u2f16(b, discard)));
803    }
804
805    if (src_is_signed)
806       res = nir_fmul(b, res, x_sign);
807
808    return res;
809 }
810
811 static nir_ssa_def *
812 lower_f2(nir_builder *b, nir_ssa_def *x, bool dst_is_signed)
813 {
814    assert(x->bit_size == 16 || x->bit_size == 32 || x->bit_size == 64);
815    nir_ssa_def *x_sign = NULL;
816
817    if (dst_is_signed)
818       x_sign = nir_fsign(b, x);
819
820    x = nir_ftrunc(b, x);
821
822    if (dst_is_signed)
823       x = nir_fabs(b, x);
824
825    nir_ssa_def *res;
826    if (x->bit_size < 32) {
827       res = nir_pack_64_2x32_split(b, nir_f2u32(b, x), nir_imm_int(b, 0));
828    } else {
829       nir_ssa_def *div = nir_imm_floatN_t(b, 1ULL << 32, x->bit_size);
830       nir_ssa_def *res_hi = nir_f2u32(b, nir_fdiv(b, x, div));
831       nir_ssa_def *res_lo = nir_f2u32(b, nir_frem(b, x, div));
832       res = nir_pack_64_2x32_split(b, res_lo, res_hi);
833    }
834
835    if (dst_is_signed)
836       res = nir_bcsel(b, nir_flt(b, x_sign, nir_imm_floatN_t(b, 0, x->bit_size)),
837                       nir_ineg(b, res), res);
838
839    return res;
840 }
841
842 static nir_ssa_def *
843 lower_bit_count64(nir_builder *b, nir_ssa_def *x)
844 {
845    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
846    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
847    nir_ssa_def *lo_count = nir_bit_count(b, x_lo);
848    nir_ssa_def *hi_count = nir_bit_count(b, x_hi);
849    return nir_iadd(b, lo_count, hi_count);
850 }
851
852 nir_lower_int64_options
853 nir_lower_int64_op_to_options_mask(nir_op opcode)
854 {
855    switch (opcode) {
856    case nir_op_imul:
857    case nir_op_amul:
858       return nir_lower_imul64;
859    case nir_op_imul_2x32_64:
860    case nir_op_umul_2x32_64:
861       return nir_lower_imul_2x32_64;
862    case nir_op_imul_high:
863    case nir_op_umul_high:
864       return nir_lower_imul_high64;
865    case nir_op_isign:
866       return nir_lower_isign64;
867    case nir_op_udiv:
868    case nir_op_idiv:
869    case nir_op_umod:
870    case nir_op_imod:
871    case nir_op_irem:
872       return nir_lower_divmod64;
873    case nir_op_b2i64:
874    case nir_op_i2b1:
875    case nir_op_i2i8:
876    case nir_op_i2i16:
877    case nir_op_i2i32:
878    case nir_op_i2i64:
879    case nir_op_u2u8:
880    case nir_op_u2u16:
881    case nir_op_u2u32:
882    case nir_op_u2u64:
883    case nir_op_i2f64:
884    case nir_op_u2f64:
885    case nir_op_i2f32:
886    case nir_op_u2f32:
887    case nir_op_i2f16:
888    case nir_op_u2f16:
889    case nir_op_f2i64:
890    case nir_op_f2u64:
891    case nir_op_bcsel:
892       return nir_lower_mov64;
893    case nir_op_ieq:
894    case nir_op_ine:
895    case nir_op_ult:
896    case nir_op_ilt:
897    case nir_op_uge:
898    case nir_op_ige:
899       return nir_lower_icmp64;
900    case nir_op_iadd:
901    case nir_op_isub:
902       return nir_lower_iadd64;
903    case nir_op_imin:
904    case nir_op_imax:
905    case nir_op_umin:
906    case nir_op_umax:
907       return nir_lower_minmax64;
908    case nir_op_iabs:
909       return nir_lower_iabs64;
910    case nir_op_ineg:
911       return nir_lower_ineg64;
912    case nir_op_iand:
913    case nir_op_ior:
914    case nir_op_ixor:
915    case nir_op_inot:
916       return nir_lower_logic64;
917    case nir_op_ishl:
918    case nir_op_ishr:
919    case nir_op_ushr:
920       return nir_lower_shift64;
921    case nir_op_extract_u8:
922    case nir_op_extract_i8:
923    case nir_op_extract_u16:
924    case nir_op_extract_i16:
925       return nir_lower_extract64;
926    case nir_op_ufind_msb:
927       return nir_lower_ufind_msb64;
928    case nir_op_bit_count:
929       return nir_lower_bit_count64;
930    default:
931       return 0;
932    }
933 }
934
935 static nir_ssa_def *
936 lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu)
937 {
938    nir_ssa_def *src[4];
939    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
940       src[i] = nir_ssa_for_alu_src(b, alu, i);
941
942    switch (alu->op) {
943    case nir_op_imul:
944    case nir_op_amul:
945       return lower_imul64(b, src[0], src[1]);
946    case nir_op_imul_2x32_64:
947       return lower_mul_2x32_64(b, src[0], src[1], true);
948    case nir_op_umul_2x32_64:
949       return lower_mul_2x32_64(b, src[0], src[1], false);
950    case nir_op_imul_high:
951       return lower_mul_high64(b, src[0], src[1], true);
952    case nir_op_umul_high:
953       return lower_mul_high64(b, src[0], src[1], false);
954    case nir_op_isign:
955       return lower_isign64(b, src[0]);
956    case nir_op_udiv:
957       return lower_udiv64(b, src[0], src[1]);
958    case nir_op_idiv:
959       return lower_idiv64(b, src[0], src[1]);
960    case nir_op_umod:
961       return lower_umod64(b, src[0], src[1]);
962    case nir_op_imod:
963       return lower_imod64(b, src[0], src[1]);
964    case nir_op_irem:
965       return lower_irem64(b, src[0], src[1]);
966    case nir_op_b2i64:
967       return lower_b2i64(b, src[0]);
968    case nir_op_i2b1:
969       return lower_i2b(b, src[0]);
970    case nir_op_i2i8:
971       return lower_i2i8(b, src[0]);
972    case nir_op_i2i16:
973       return lower_i2i16(b, src[0]);
974    case nir_op_i2i32:
975       return lower_i2i32(b, src[0]);
976    case nir_op_i2i64:
977       return lower_i2i64(b, src[0]);
978    case nir_op_u2u8:
979       return lower_u2u8(b, src[0]);
980    case nir_op_u2u16:
981       return lower_u2u16(b, src[0]);
982    case nir_op_u2u32:
983       return lower_u2u32(b, src[0]);
984    case nir_op_u2u64:
985       return lower_u2u64(b, src[0]);
986    case nir_op_bcsel:
987       return lower_bcsel64(b, src[0], src[1], src[2]);
988    case nir_op_ieq:
989    case nir_op_ine:
990    case nir_op_ult:
991    case nir_op_ilt:
992    case nir_op_uge:
993    case nir_op_ige:
994       return lower_int64_compare(b, alu->op, src[0], src[1]);
995    case nir_op_iadd:
996       return lower_iadd64(b, src[0], src[1]);
997    case nir_op_isub:
998       return lower_isub64(b, src[0], src[1]);
999    case nir_op_imin:
1000       return lower_imin64(b, src[0], src[1]);
1001    case nir_op_imax:
1002       return lower_imax64(b, src[0], src[1]);
1003    case nir_op_umin:
1004       return lower_umin64(b, src[0], src[1]);
1005    case nir_op_umax:
1006       return lower_umax64(b, src[0], src[1]);
1007    case nir_op_iabs:
1008       return lower_iabs64(b, src[0]);
1009    case nir_op_ineg:
1010       return lower_ineg64(b, src[0]);
1011    case nir_op_iand:
1012       return lower_iand64(b, src[0], src[1]);
1013    case nir_op_ior:
1014       return lower_ior64(b, src[0], src[1]);
1015    case nir_op_ixor:
1016       return lower_ixor64(b, src[0], src[1]);
1017    case nir_op_inot:
1018       return lower_inot64(b, src[0]);
1019    case nir_op_ishl:
1020       return lower_ishl64(b, src[0], src[1]);
1021    case nir_op_ishr:
1022       return lower_ishr64(b, src[0], src[1]);
1023    case nir_op_ushr:
1024       return lower_ushr64(b, src[0], src[1]);
1025    case nir_op_extract_u8:
1026    case nir_op_extract_i8:
1027    case nir_op_extract_u16:
1028    case nir_op_extract_i16:
1029       return lower_extract(b, alu->op, src[0], src[1]);
1030    case nir_op_ufind_msb:
1031       return lower_ufind_msb64(b, src[0]);
1032    case nir_op_bit_count:
1033       return lower_bit_count64(b, src[0]);
1034    case nir_op_i2f64:
1035    case nir_op_i2f32:
1036    case nir_op_i2f16:
1037       return lower_2f(b, src[0], nir_dest_bit_size(alu->dest.dest), true);
1038    case nir_op_u2f64:
1039    case nir_op_u2f32:
1040    case nir_op_u2f16:
1041       return lower_2f(b, src[0], nir_dest_bit_size(alu->dest.dest), false);
1042    case nir_op_f2i64:
1043    case nir_op_f2u64:
1044       return lower_f2(b, src[0], alu->op == nir_op_f2i64);
1045    default:
1046       unreachable("Invalid ALU opcode to lower");
1047    }
1048 }
1049
1050 static bool
1051 should_lower_int64_alu_instr(const nir_alu_instr *alu,
1052                              const nir_shader_compiler_options *options)
1053 {
1054    switch (alu->op) {
1055    case nir_op_i2b1:
1056    case nir_op_i2i8:
1057    case nir_op_i2i16:
1058    case nir_op_i2i32:
1059    case nir_op_u2u8:
1060    case nir_op_u2u16:
1061    case nir_op_u2u32:
1062       assert(alu->src[0].src.is_ssa);
1063       if (alu->src[0].src.ssa->bit_size != 64)
1064          return false;
1065       break;
1066    case nir_op_bcsel:
1067       assert(alu->src[1].src.is_ssa);
1068       assert(alu->src[2].src.is_ssa);
1069       assert(alu->src[1].src.ssa->bit_size ==
1070              alu->src[2].src.ssa->bit_size);
1071       if (alu->src[1].src.ssa->bit_size != 64)
1072          return false;
1073       break;
1074    case nir_op_ieq:
1075    case nir_op_ine:
1076    case nir_op_ult:
1077    case nir_op_ilt:
1078    case nir_op_uge:
1079    case nir_op_ige:
1080       assert(alu->src[0].src.is_ssa);
1081       assert(alu->src[1].src.is_ssa);
1082       assert(alu->src[0].src.ssa->bit_size ==
1083              alu->src[1].src.ssa->bit_size);
1084       if (alu->src[0].src.ssa->bit_size != 64)
1085          return false;
1086       break;
1087    case nir_op_ufind_msb:
1088    case nir_op_bit_count:
1089       assert(alu->src[0].src.is_ssa);
1090       if (alu->src[0].src.ssa->bit_size != 64)
1091          return false;
1092       break;
1093    case nir_op_amul:
1094       assert(alu->dest.dest.is_ssa);
1095       if (options->has_imul24)
1096          return false;
1097       if (alu->dest.dest.ssa.bit_size != 64)
1098          return false;
1099       break;
1100    case nir_op_i2f64:
1101    case nir_op_u2f64:
1102    case nir_op_i2f32:
1103    case nir_op_u2f32:
1104    case nir_op_i2f16:
1105    case nir_op_u2f16:
1106       assert(alu->src[0].src.is_ssa);
1107       if (alu->src[0].src.ssa->bit_size != 64)
1108          return false;
1109       break;
1110    case nir_op_f2u64:
1111    case nir_op_f2i64:
1112       FALLTHROUGH;
1113    default:
1114       assert(alu->dest.dest.is_ssa);
1115       if (alu->dest.dest.ssa.bit_size != 64)
1116          return false;
1117       break;
1118    }
1119
1120    unsigned mask = nir_lower_int64_op_to_options_mask(alu->op);
1121    return (options->lower_int64_options & mask) != 0;
1122 }
1123
1124 static nir_ssa_def *
1125 split_64bit_subgroup_op(nir_builder *b, const nir_intrinsic_instr *intrin)
1126 {
1127    const nir_intrinsic_info *info = &nir_intrinsic_infos[intrin->intrinsic];
1128
1129    /* This works on subgroup ops with a single 64-bit source which can be
1130     * trivially lowered by doing the exact same op on both halves.
1131     */
1132    assert(intrin->src[0].is_ssa && intrin->src[0].ssa->bit_size == 64);
1133    nir_ssa_def *split_src0[2] = {
1134       nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa),
1135       nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa),
1136    };
1137
1138    assert(info->has_dest && intrin->dest.is_ssa &&
1139           intrin->dest.ssa.bit_size == 64);
1140
1141    nir_ssa_def *res[2];
1142    for (unsigned i = 0; i < 2; i++) {
1143       nir_intrinsic_instr *split =
1144          nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
1145       split->num_components = intrin->num_components;
1146       split->src[0] = nir_src_for_ssa(split_src0[i]);
1147
1148       /* Other sources must be less than 64 bits and get copied directly */
1149       for (unsigned j = 1; j < info->num_srcs; j++) {
1150          assert(intrin->src[j].is_ssa && intrin->src[j].ssa->bit_size < 64);
1151          split->src[j] = nir_src_for_ssa(intrin->src[j].ssa);
1152       }
1153
1154       /* Copy const indices, if any */
1155       memcpy(split->const_index, intrin->const_index,
1156              sizeof(intrin->const_index));
1157
1158       nir_ssa_dest_init(&split->instr, &split->dest,
1159                         intrin->dest.ssa.num_components, 32, NULL);
1160       nir_builder_instr_insert(b, &split->instr);
1161
1162       res[i] = &split->dest.ssa;
1163    }
1164
1165    return nir_pack_64_2x32_split(b, res[0], res[1]);
1166 }
1167
1168 static nir_ssa_def *
1169 build_vote_ieq(nir_builder *b, nir_ssa_def *x)
1170 {
1171    nir_intrinsic_instr *vote =
1172       nir_intrinsic_instr_create(b->shader, nir_intrinsic_vote_ieq);
1173    vote->src[0] = nir_src_for_ssa(x);
1174    vote->num_components = x->num_components;
1175    nir_ssa_dest_init(&vote->instr, &vote->dest, 1, 1, NULL);
1176    nir_builder_instr_insert(b, &vote->instr);
1177    return &vote->dest.ssa;
1178 }
1179
1180 static nir_ssa_def *
1181 lower_vote_ieq(nir_builder *b, nir_ssa_def *x)
1182 {
1183    return nir_iand(b, build_vote_ieq(b, nir_unpack_64_2x32_split_x(b, x)),
1184                       build_vote_ieq(b, nir_unpack_64_2x32_split_y(b, x)));
1185 }
1186
1187 static nir_ssa_def *
1188 build_scan_intrinsic(nir_builder *b, nir_intrinsic_op scan_op,
1189                      nir_op reduction_op, unsigned cluster_size,
1190                      nir_ssa_def *val)
1191 {
1192    nir_intrinsic_instr *scan =
1193       nir_intrinsic_instr_create(b->shader, scan_op);
1194    scan->num_components = val->num_components;
1195    scan->src[0] = nir_src_for_ssa(val);
1196    nir_intrinsic_set_reduction_op(scan, reduction_op);
1197    if (scan_op == nir_intrinsic_reduce)
1198       nir_intrinsic_set_cluster_size(scan, cluster_size);
1199    nir_ssa_dest_init(&scan->instr, &scan->dest,
1200                      val->num_components, val->bit_size, NULL);
1201    nir_builder_instr_insert(b, &scan->instr);
1202    return &scan->dest.ssa;
1203 }
1204
1205 static nir_ssa_def *
1206 lower_scan_iadd64(nir_builder *b, const nir_intrinsic_instr *intrin)
1207 {
1208    unsigned cluster_size =
1209       intrin->intrinsic == nir_intrinsic_reduce ?
1210       nir_intrinsic_cluster_size(intrin) : 0;
1211
1212    /* Split it into three chunks of no more than 24 bits each.  With 8 bits
1213     * of headroom, we're guaranteed that there will never be overflow in the
1214     * individual subgroup operations.  (Assuming, of course, a subgroup size
1215     * no larger than 256 which seems reasonable.)  We can then scan on each of
1216     * the chunks and add them back together at the end.
1217     */
1218    assert(intrin->src[0].is_ssa);
1219    nir_ssa_def *x = intrin->src[0].ssa;
1220    nir_ssa_def *x_low =
1221       nir_u2u32(b, nir_iand_imm(b, x, 0xffffff));
1222    nir_ssa_def *x_mid =
1223       nir_u2u32(b, nir_iand_imm(b, nir_ushr(b, x, nir_imm_int(b, 24)),
1224                                    0xffffff));
1225    nir_ssa_def *x_hi =
1226       nir_u2u32(b, nir_ushr(b, x, nir_imm_int(b, 48)));
1227
1228    nir_ssa_def *scan_low =
1229       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1230                               cluster_size, x_low);
1231    nir_ssa_def *scan_mid =
1232       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1233                               cluster_size, x_mid);
1234    nir_ssa_def *scan_hi =
1235       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1236                               cluster_size, x_hi);
1237
1238    scan_low = nir_u2u64(b, scan_low);
1239    scan_mid = nir_ishl(b, nir_u2u64(b, scan_mid), nir_imm_int(b, 24));
1240    scan_hi = nir_ishl(b, nir_u2u64(b, scan_hi), nir_imm_int(b, 48));
1241
1242    return nir_iadd(b, scan_hi, nir_iadd(b, scan_mid, scan_low));
1243 }
1244
1245 static bool
1246 should_lower_int64_intrinsic(const nir_intrinsic_instr *intrin,
1247                              const nir_shader_compiler_options *options)
1248 {
1249    switch (intrin->intrinsic) {
1250    case nir_intrinsic_read_invocation:
1251    case nir_intrinsic_read_first_invocation:
1252    case nir_intrinsic_shuffle:
1253    case nir_intrinsic_shuffle_xor:
1254    case nir_intrinsic_shuffle_up:
1255    case nir_intrinsic_shuffle_down:
1256    case nir_intrinsic_quad_broadcast:
1257    case nir_intrinsic_quad_swap_horizontal:
1258    case nir_intrinsic_quad_swap_vertical:
1259    case nir_intrinsic_quad_swap_diagonal:
1260       assert(intrin->dest.is_ssa);
1261       return intrin->dest.ssa.bit_size == 64 &&
1262              (options->lower_int64_options & nir_lower_subgroup_shuffle64);
1263
1264    case nir_intrinsic_vote_ieq:
1265       assert(intrin->src[0].is_ssa);
1266       return intrin->src[0].ssa->bit_size == 64 &&
1267              (options->lower_int64_options & nir_lower_vote_ieq64);
1268
1269    case nir_intrinsic_reduce:
1270    case nir_intrinsic_inclusive_scan:
1271    case nir_intrinsic_exclusive_scan:
1272       assert(intrin->dest.is_ssa);
1273       if (intrin->dest.ssa.bit_size != 64)
1274          return false;
1275
1276       switch (nir_intrinsic_reduction_op(intrin)) {
1277       case nir_op_iadd:
1278          return options->lower_int64_options & nir_lower_scan_reduce_iadd64;
1279       case nir_op_iand:
1280       case nir_op_ior:
1281       case nir_op_ixor:
1282          return options->lower_int64_options & nir_lower_scan_reduce_bitwise64;
1283       default:
1284          return false;
1285       }
1286       break;
1287
1288    default:
1289       return false;
1290    }
1291 }
1292
1293 static nir_ssa_def *
1294 lower_int64_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin)
1295 {
1296    switch (intrin->intrinsic) {
1297    case nir_intrinsic_read_invocation:
1298    case nir_intrinsic_read_first_invocation:
1299    case nir_intrinsic_shuffle:
1300    case nir_intrinsic_shuffle_xor:
1301    case nir_intrinsic_shuffle_up:
1302    case nir_intrinsic_shuffle_down:
1303    case nir_intrinsic_quad_broadcast:
1304    case nir_intrinsic_quad_swap_horizontal:
1305    case nir_intrinsic_quad_swap_vertical:
1306    case nir_intrinsic_quad_swap_diagonal:
1307       return split_64bit_subgroup_op(b, intrin);
1308
1309    case nir_intrinsic_vote_ieq:
1310       assert(intrin->src[0].is_ssa);
1311       return lower_vote_ieq(b, intrin->src[0].ssa);
1312
1313    case nir_intrinsic_reduce:
1314    case nir_intrinsic_inclusive_scan:
1315    case nir_intrinsic_exclusive_scan:
1316       switch (nir_intrinsic_reduction_op(intrin)) {
1317       case nir_op_iadd:
1318          return lower_scan_iadd64(b, intrin);
1319       case nir_op_iand:
1320       case nir_op_ior:
1321       case nir_op_ixor:
1322          return split_64bit_subgroup_op(b, intrin);
1323       default:
1324          unreachable("Unsupported subgroup scan/reduce op");
1325       }
1326       break;
1327
1328    default:
1329       unreachable("Unsupported intrinsic");
1330    }
1331 }
1332
1333 static bool
1334 should_lower_int64_instr(const nir_instr *instr, const void *_options)
1335 {
1336    switch (instr->type) {
1337    case nir_instr_type_alu:
1338       return should_lower_int64_alu_instr(nir_instr_as_alu(instr), _options);
1339    case nir_instr_type_intrinsic:
1340       return should_lower_int64_intrinsic(nir_instr_as_intrinsic(instr),
1341                                           _options);
1342    default:
1343       return false;
1344    }
1345 }
1346
1347 static nir_ssa_def *
1348 lower_int64_instr(nir_builder *b, nir_instr *instr, void *_options)
1349 {
1350    switch (instr->type) {
1351    case nir_instr_type_alu:
1352       return lower_int64_alu_instr(b, nir_instr_as_alu(instr));
1353    case nir_instr_type_intrinsic:
1354       return lower_int64_intrinsic(b, nir_instr_as_intrinsic(instr));
1355    default:
1356       return NULL;
1357    }
1358 }
1359
1360 bool
1361 nir_lower_int64(nir_shader *shader)
1362 {
1363    return nir_shader_lower_instructions(shader, should_lower_int64_instr,
1364                                         lower_int64_instr,
1365                                         (void *)shader->options);
1366 }