selftests/bpf: Fix erroneous bitmask operation
[platform/kernel/linux-rpi.git] / tools / testing / selftests / bpf / progs / iters.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2023 Meta Platforms, Inc. and affiliates. */
3
4 #include <stdbool.h>
5 #include <linux/bpf.h>
6 #include <bpf/bpf_helpers.h>
7 #include "bpf_misc.h"
8
9 #define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
10
11 static volatile int zero = 0;
12
13 int my_pid;
14 int arr[256];
15 int small_arr[16] SEC(".data.small_arr");
16
17 #ifdef REAL_TEST
18 #define MY_PID_GUARD() if (my_pid != (bpf_get_current_pid_tgid() >> 32)) return 0
19 #else
20 #define MY_PID_GUARD() ({ })
21 #endif
22
23 SEC("?raw_tp")
24 __failure __msg("math between map_value pointer and register with unbounded min value is not allowed")
25 int iter_err_unsafe_c_loop(const void *ctx)
26 {
27         struct bpf_iter_num it;
28         int *v, i = zero; /* obscure initial value of i */
29
30         MY_PID_GUARD();
31
32         bpf_iter_num_new(&it, 0, 1000);
33         while ((v = bpf_iter_num_next(&it))) {
34                 i++;
35         }
36         bpf_iter_num_destroy(&it);
37
38         small_arr[i] = 123; /* invalid */
39
40         return 0;
41 }
42
43 SEC("?raw_tp")
44 __failure __msg("unbounded memory access")
45 int iter_err_unsafe_asm_loop(const void *ctx)
46 {
47         struct bpf_iter_num it;
48
49         MY_PID_GUARD();
50
51         asm volatile (
52                 "r6 = %[zero];" /* iteration counter */
53                 "r1 = %[it];" /* iterator state */
54                 "r2 = 0;"
55                 "r3 = 1000;"
56                 "r4 = 1;"
57                 "call %[bpf_iter_num_new];"
58         "loop:"
59                 "r1 = %[it];"
60                 "call %[bpf_iter_num_next];"
61                 "if r0 == 0 goto out;"
62                 "r6 += 1;"
63                 "goto loop;"
64         "out:"
65                 "r1 = %[it];"
66                 "call %[bpf_iter_num_destroy];"
67                 "r1 = %[small_arr];"
68                 "r2 = r6;"
69                 "r2 <<= 2;"
70                 "r1 += r2;"
71                 "*(u32 *)(r1 + 0) = r6;" /* invalid */
72                 :
73                 : [it]"r"(&it),
74                   [small_arr]"p"(small_arr),
75                   [zero]"p"(zero),
76                   __imm(bpf_iter_num_new),
77                   __imm(bpf_iter_num_next),
78                   __imm(bpf_iter_num_destroy)
79                 : __clobber_common, "r6"
80         );
81
82         return 0;
83 }
84
85 SEC("raw_tp")
86 __success
87 int iter_while_loop(const void *ctx)
88 {
89         struct bpf_iter_num it;
90         int *v;
91
92         MY_PID_GUARD();
93
94         bpf_iter_num_new(&it, 0, 3);
95         while ((v = bpf_iter_num_next(&it))) {
96                 bpf_printk("ITER_BASIC: E1 VAL: v=%d", *v);
97         }
98         bpf_iter_num_destroy(&it);
99
100         return 0;
101 }
102
103 SEC("raw_tp")
104 __success
105 int iter_while_loop_auto_cleanup(const void *ctx)
106 {
107         __attribute__((cleanup(bpf_iter_num_destroy))) struct bpf_iter_num it;
108         int *v;
109
110         MY_PID_GUARD();
111
112         bpf_iter_num_new(&it, 0, 3);
113         while ((v = bpf_iter_num_next(&it))) {
114                 bpf_printk("ITER_BASIC: E1 VAL: v=%d", *v);
115         }
116         /* (!) no explicit bpf_iter_num_destroy() */
117
118         return 0;
119 }
120
121 SEC("raw_tp")
122 __success
123 int iter_for_loop(const void *ctx)
124 {
125         struct bpf_iter_num it;
126         int *v;
127
128         MY_PID_GUARD();
129
130         bpf_iter_num_new(&it, 5, 10);
131         for (v = bpf_iter_num_next(&it); v; v = bpf_iter_num_next(&it)) {
132                 bpf_printk("ITER_BASIC: E2 VAL: v=%d", *v);
133         }
134         bpf_iter_num_destroy(&it);
135
136         return 0;
137 }
138
139 SEC("raw_tp")
140 __success
141 int iter_bpf_for_each_macro(const void *ctx)
142 {
143         int *v;
144
145         MY_PID_GUARD();
146
147         bpf_for_each(num, v, 5, 10) {
148                 bpf_printk("ITER_BASIC: E2 VAL: v=%d", *v);
149         }
150
151         return 0;
152 }
153
154 SEC("raw_tp")
155 __success
156 int iter_bpf_for_macro(const void *ctx)
157 {
158         int i;
159
160         MY_PID_GUARD();
161
162         bpf_for(i, 5, 10) {
163                 bpf_printk("ITER_BASIC: E2 VAL: v=%d", i);
164         }
165
166         return 0;
167 }
168
169 SEC("raw_tp")
170 __success
171 int iter_pragma_unroll_loop(const void *ctx)
172 {
173         struct bpf_iter_num it;
174         int *v, i;
175
176         MY_PID_GUARD();
177
178         bpf_iter_num_new(&it, 0, 2);
179 #pragma nounroll
180         for (i = 0; i < 3; i++) {
181                 v = bpf_iter_num_next(&it);
182                 bpf_printk("ITER_BASIC: E3 VAL: i=%d v=%d", i, v ? *v : -1);
183         }
184         bpf_iter_num_destroy(&it);
185
186         return 0;
187 }
188
189 SEC("raw_tp")
190 __success
191 int iter_manual_unroll_loop(const void *ctx)
192 {
193         struct bpf_iter_num it;
194         int *v;
195
196         MY_PID_GUARD();
197
198         bpf_iter_num_new(&it, 100, 200);
199         v = bpf_iter_num_next(&it);
200         bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
201         v = bpf_iter_num_next(&it);
202         bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
203         v = bpf_iter_num_next(&it);
204         bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
205         v = bpf_iter_num_next(&it);
206         bpf_printk("ITER_BASIC: E4 VAL: v=%d\n", v ? *v : -1);
207         bpf_iter_num_destroy(&it);
208
209         return 0;
210 }
211
212 SEC("raw_tp")
213 __success
214 int iter_multiple_sequential_loops(const void *ctx)
215 {
216         struct bpf_iter_num it;
217         int *v, i;
218
219         MY_PID_GUARD();
220
221         bpf_iter_num_new(&it, 0, 3);
222         while ((v = bpf_iter_num_next(&it))) {
223                 bpf_printk("ITER_BASIC: E1 VAL: v=%d", *v);
224         }
225         bpf_iter_num_destroy(&it);
226
227         bpf_iter_num_new(&it, 5, 10);
228         for (v = bpf_iter_num_next(&it); v; v = bpf_iter_num_next(&it)) {
229                 bpf_printk("ITER_BASIC: E2 VAL: v=%d", *v);
230         }
231         bpf_iter_num_destroy(&it);
232
233         bpf_iter_num_new(&it, 0, 2);
234 #pragma nounroll
235         for (i = 0; i < 3; i++) {
236                 v = bpf_iter_num_next(&it);
237                 bpf_printk("ITER_BASIC: E3 VAL: i=%d v=%d", i, v ? *v : -1);
238         }
239         bpf_iter_num_destroy(&it);
240
241         bpf_iter_num_new(&it, 100, 200);
242         v = bpf_iter_num_next(&it);
243         bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
244         v = bpf_iter_num_next(&it);
245         bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
246         v = bpf_iter_num_next(&it);
247         bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
248         v = bpf_iter_num_next(&it);
249         bpf_printk("ITER_BASIC: E4 VAL: v=%d\n", v ? *v : -1);
250         bpf_iter_num_destroy(&it);
251
252         return 0;
253 }
254
255 SEC("raw_tp")
256 __success
257 int iter_limit_cond_break_loop(const void *ctx)
258 {
259         struct bpf_iter_num it;
260         int *v, i = 0, sum = 0;
261
262         MY_PID_GUARD();
263
264         bpf_iter_num_new(&it, 0, 10);
265         while ((v = bpf_iter_num_next(&it))) {
266                 bpf_printk("ITER_SIMPLE: i=%d v=%d", i, *v);
267                 sum += *v;
268
269                 i++;
270                 if (i > 3)
271                         break;
272         }
273         bpf_iter_num_destroy(&it);
274
275         bpf_printk("ITER_SIMPLE: sum=%d\n", sum);
276
277         return 0;
278 }
279
280 SEC("raw_tp")
281 __success
282 int iter_obfuscate_counter(const void *ctx)
283 {
284         struct bpf_iter_num it;
285         int *v, sum = 0;
286         /* Make i's initial value unknowable for verifier to prevent it from
287          * pruning if/else branch inside the loop body and marking i as precise.
288          */
289         int i = zero;
290
291         MY_PID_GUARD();
292
293         bpf_iter_num_new(&it, 0, 10);
294         while ((v = bpf_iter_num_next(&it))) {
295                 int x;
296
297                 i += 1;
298
299                 /* If we initialized i as `int i = 0;` above, verifier would
300                  * track that i becomes 1 on first iteration after increment
301                  * above, and here verifier would eagerly prune else branch
302                  * and mark i as precise, ruining open-coded iterator logic
303                  * completely, as each next iteration would have a different
304                  * *precise* value of i, and thus there would be no
305                  * convergence of state. This would result in reaching maximum
306                  * instruction limit, no matter what the limit is.
307                  */
308                 if (i == 1)
309                         x = 123;
310                 else
311                         x = i * 3 + 1;
312
313                 bpf_printk("ITER_OBFUSCATE_COUNTER: i=%d v=%d x=%d", i, *v, x);
314
315                 sum += x;
316         }
317         bpf_iter_num_destroy(&it);
318
319         bpf_printk("ITER_OBFUSCATE_COUNTER: sum=%d\n", sum);
320
321         return 0;
322 }
323
324 SEC("raw_tp")
325 __success
326 int iter_search_loop(const void *ctx)
327 {
328         struct bpf_iter_num it;
329         int *v, *elem = NULL;
330         bool found = false;
331
332         MY_PID_GUARD();
333
334         bpf_iter_num_new(&it, 0, 10);
335
336         while ((v = bpf_iter_num_next(&it))) {
337                 bpf_printk("ITER_SEARCH_LOOP: v=%d", *v);
338
339                 if (*v == 2) {
340                         found = true;
341                         elem = v;
342                         barrier_var(elem);
343                 }
344         }
345
346         /* should fail to verify if bpf_iter_num_destroy() is here */
347
348         if (found)
349                 /* here found element will be wrong, we should have copied
350                  * value to a variable, but here we want to make sure we can
351                  * access memory after the loop anyways
352                  */
353                 bpf_printk("ITER_SEARCH_LOOP: FOUND IT = %d!\n", *elem);
354         else
355                 bpf_printk("ITER_SEARCH_LOOP: NOT FOUND IT!\n");
356
357         bpf_iter_num_destroy(&it);
358
359         return 0;
360 }
361
362 SEC("raw_tp")
363 __success
364 int iter_array_fill(const void *ctx)
365 {
366         int sum, i;
367
368         MY_PID_GUARD();
369
370         bpf_for(i, 0, ARRAY_SIZE(arr)) {
371                 arr[i] = i * 2;
372         }
373
374         sum = 0;
375         bpf_for(i, 0, ARRAY_SIZE(arr)) {
376                 sum += arr[i];
377         }
378
379         bpf_printk("ITER_ARRAY_FILL: sum=%d (should be %d)\n", sum, 255 * 256);
380
381         return 0;
382 }
383
384 static int arr2d[4][5];
385 static int arr2d_row_sums[4];
386 static int arr2d_col_sums[5];
387
388 SEC("raw_tp")
389 __success
390 int iter_nested_iters(const void *ctx)
391 {
392         int sum, row, col;
393
394         MY_PID_GUARD();
395
396         bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
397                 bpf_for( col, 0, ARRAY_SIZE(arr2d[0])) {
398                         arr2d[row][col] = row * col;
399                 }
400         }
401
402         /* zero-initialize sums */
403         sum = 0;
404         bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
405                 arr2d_row_sums[row] = 0;
406         }
407         bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
408                 arr2d_col_sums[col] = 0;
409         }
410
411         /* calculate sums */
412         bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
413                 bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
414                         sum += arr2d[row][col];
415                         arr2d_row_sums[row] += arr2d[row][col];
416                         arr2d_col_sums[col] += arr2d[row][col];
417                 }
418         }
419
420         bpf_printk("ITER_NESTED_ITERS: total sum=%d", sum);
421         bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
422                 bpf_printk("ITER_NESTED_ITERS: row #%d sum=%d", row, arr2d_row_sums[row]);
423         }
424         bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
425                 bpf_printk("ITER_NESTED_ITERS: col #%d sum=%d%s",
426                            col, arr2d_col_sums[col],
427                            col == ARRAY_SIZE(arr2d[0]) - 1 ? "\n" : "");
428         }
429
430         return 0;
431 }
432
433 SEC("raw_tp")
434 __success
435 int iter_nested_deeply_iters(const void *ctx)
436 {
437         int sum = 0;
438
439         MY_PID_GUARD();
440
441         bpf_repeat(10) {
442                 bpf_repeat(10) {
443                         bpf_repeat(10) {
444                                 bpf_repeat(10) {
445                                         bpf_repeat(10) {
446                                                 sum += 1;
447                                         }
448                                 }
449                         }
450                 }
451                 /* validate that we can break from inside bpf_repeat() */
452                 break;
453         }
454
455         return sum;
456 }
457
458 static __noinline void fill_inner_dimension(int row)
459 {
460         int col;
461
462         bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
463                 arr2d[row][col] = row * col;
464         }
465 }
466
467 static __noinline int sum_inner_dimension(int row)
468 {
469         int sum = 0, col;
470
471         bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
472                 sum += arr2d[row][col];
473                 arr2d_row_sums[row] += arr2d[row][col];
474                 arr2d_col_sums[col] += arr2d[row][col];
475         }
476
477         return sum;
478 }
479
480 SEC("raw_tp")
481 __success
482 int iter_subprog_iters(const void *ctx)
483 {
484         int sum, row, col;
485
486         MY_PID_GUARD();
487
488         bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
489                 fill_inner_dimension(row);
490         }
491
492         /* zero-initialize sums */
493         sum = 0;
494         bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
495                 arr2d_row_sums[row] = 0;
496         }
497         bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
498                 arr2d_col_sums[col] = 0;
499         }
500
501         /* calculate sums */
502         bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
503                 sum += sum_inner_dimension(row);
504         }
505
506         bpf_printk("ITER_SUBPROG_ITERS: total sum=%d", sum);
507         bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
508                 bpf_printk("ITER_SUBPROG_ITERS: row #%d sum=%d",
509                            row, arr2d_row_sums[row]);
510         }
511         bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
512                 bpf_printk("ITER_SUBPROG_ITERS: col #%d sum=%d%s",
513                            col, arr2d_col_sums[col],
514                            col == ARRAY_SIZE(arr2d[0]) - 1 ? "\n" : "");
515         }
516
517         return 0;
518 }
519
520 struct {
521         __uint(type, BPF_MAP_TYPE_ARRAY);
522         __type(key, int);
523         __type(value, int);
524         __uint(max_entries, 1000);
525 } arr_map SEC(".maps");
526
527 SEC("?raw_tp")
528 __failure __msg("invalid mem access 'scalar'")
529 int iter_err_too_permissive1(const void *ctx)
530 {
531         int *map_val = NULL;
532         int key = 0;
533
534         MY_PID_GUARD();
535
536         map_val = bpf_map_lookup_elem(&arr_map, &key);
537         if (!map_val)
538                 return 0;
539
540         bpf_repeat(1000000) {
541                 map_val = NULL;
542         }
543
544         *map_val = 123;
545
546         return 0;
547 }
548
549 SEC("?raw_tp")
550 __failure __msg("invalid mem access 'map_value_or_null'")
551 int iter_err_too_permissive2(const void *ctx)
552 {
553         int *map_val = NULL;
554         int key = 0;
555
556         MY_PID_GUARD();
557
558         map_val = bpf_map_lookup_elem(&arr_map, &key);
559         if (!map_val)
560                 return 0;
561
562         bpf_repeat(1000000) {
563                 map_val = bpf_map_lookup_elem(&arr_map, &key);
564         }
565
566         *map_val = 123;
567
568         return 0;
569 }
570
571 SEC("?raw_tp")
572 __failure __msg("invalid mem access 'map_value_or_null'")
573 int iter_err_too_permissive3(const void *ctx)
574 {
575         int *map_val = NULL;
576         int key = 0;
577         bool found = false;
578
579         MY_PID_GUARD();
580
581         bpf_repeat(1000000) {
582                 map_val = bpf_map_lookup_elem(&arr_map, &key);
583                 found = true;
584         }
585
586         if (found)
587                 *map_val = 123;
588
589         return 0;
590 }
591
592 SEC("raw_tp")
593 __success
594 int iter_tricky_but_fine(const void *ctx)
595 {
596         int *map_val = NULL;
597         int key = 0;
598         bool found = false;
599
600         MY_PID_GUARD();
601
602         bpf_repeat(1000000) {
603                 map_val = bpf_map_lookup_elem(&arr_map, &key);
604                 if (map_val) {
605                         found = true;
606                         break;
607                 }
608         }
609
610         if (found)
611                 *map_val = 123;
612
613         return 0;
614 }
615
616 #define __bpf_memzero(p, sz) bpf_probe_read_kernel((p), (sz), 0)
617
618 SEC("raw_tp")
619 __success
620 int iter_stack_array_loop(const void *ctx)
621 {
622         long arr1[16], arr2[16], sum = 0;
623         int i;
624
625         MY_PID_GUARD();
626
627         /* zero-init arr1 and arr2 in such a way that verifier doesn't know
628          * it's all zeros; if we don't do that, we'll make BPF verifier track
629          * all combination of zero/non-zero stack slots for arr1/arr2, which
630          * will lead to O(2^(ARRAY_SIZE(arr1)+ARRAY_SIZE(arr2))) different
631          * states
632          */
633         __bpf_memzero(arr1, sizeof(arr1));
634         __bpf_memzero(arr2, sizeof(arr1));
635
636         /* validate that we can break and continue when using bpf_for() */
637         bpf_for(i, 0, ARRAY_SIZE(arr1)) {
638                 if (i & 1) {
639                         arr1[i] = i;
640                         continue;
641                 } else {
642                         arr2[i] = i;
643                         break;
644                 }
645         }
646
647         bpf_for(i, 0, ARRAY_SIZE(arr1)) {
648                 sum += arr1[i] + arr2[i];
649         }
650
651         return sum;
652 }
653
654 static __noinline void fill(struct bpf_iter_num *it, int *arr, __u32 n, int mul)
655 {
656         int *t, i;
657
658         while ((t = bpf_iter_num_next(it))) {
659                 i = *t;
660                 if (i >= n)
661                         break;
662                 arr[i] =  i * mul;
663         }
664 }
665
666 static __noinline int sum(struct bpf_iter_num *it, int *arr, __u32 n)
667 {
668         int *t, i, sum = 0;;
669
670         while ((t = bpf_iter_num_next(it))) {
671                 i = *t;
672                 if (i >= n)
673                         break;
674                 sum += arr[i];
675         }
676
677         return sum;
678 }
679
680 SEC("raw_tp")
681 __success
682 int iter_pass_iter_ptr_to_subprog(const void *ctx)
683 {
684         int arr1[16], arr2[32];
685         struct bpf_iter_num it;
686         int n, sum1, sum2;
687
688         MY_PID_GUARD();
689
690         /* fill arr1 */
691         n = ARRAY_SIZE(arr1);
692         bpf_iter_num_new(&it, 0, n);
693         fill(&it, arr1, n, 2);
694         bpf_iter_num_destroy(&it);
695
696         /* fill arr2 */
697         n = ARRAY_SIZE(arr2);
698         bpf_iter_num_new(&it, 0, n);
699         fill(&it, arr2, n, 10);
700         bpf_iter_num_destroy(&it);
701
702         /* sum arr1 */
703         n = ARRAY_SIZE(arr1);
704         bpf_iter_num_new(&it, 0, n);
705         sum1 = sum(&it, arr1, n);
706         bpf_iter_num_destroy(&it);
707
708         /* sum arr2 */
709         n = ARRAY_SIZE(arr2);
710         bpf_iter_num_new(&it, 0, n);
711         sum2 = sum(&it, arr2, n);
712         bpf_iter_num_destroy(&it);
713
714         bpf_printk("sum1=%d, sum2=%d", sum1, sum2);
715
716         return 0;
717 }
718
719 char _license[] SEC("license") = "GPL";