Merge tag 'sh-for-v6.6-tag1' of git://git.kernel.org/pub/scm/linux/kernel/git/glaubit...
[platform/kernel/linux-rpi.git] / drivers / hid / bpf / hid_bpf_jmp_table.c
1 // SPDX-License-Identifier: GPL-2.0-only
2
3 /*
4  *  HID-BPF support for Linux
5  *
6  *  Copyright (c) 2022 Benjamin Tissoires
7  */
8
9 #include <linux/bitops.h>
10 #include <linux/btf.h>
11 #include <linux/btf_ids.h>
12 #include <linux/circ_buf.h>
13 #include <linux/filter.h>
14 #include <linux/hid.h>
15 #include <linux/hid_bpf.h>
16 #include <linux/init.h>
17 #include <linux/module.h>
18 #include <linux/workqueue.h>
19 #include "hid_bpf_dispatch.h"
20 #include "entrypoints/entrypoints.lskel.h"
21
22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf,
23                                 * needs to be a power of 2 as we use it as
24                                 * a circular buffer
25                                 */
26
27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1))
28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1))
29
30 /*
31  * represents one attached program stored in the hid jump table
32  */
33 struct hid_bpf_prog_entry {
34         struct bpf_prog *prog;
35         struct hid_device *hdev;
36         enum hid_bpf_prog_type type;
37         u16 idx;
38 };
39
40 struct hid_bpf_jmp_table {
41         struct bpf_map *map;
42         struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */
43         int tail, head;
44         struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */
45         unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)];
46 };
47
48 #define FOR_ENTRIES(__i, __start, __end) \
49         for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i))
50
51 static struct hid_bpf_jmp_table jmp_table;
52
53 static DEFINE_MUTEX(hid_bpf_attach_lock);               /* held when attaching/detaching programs */
54
55 static void hid_bpf_release_progs(struct work_struct *work);
56
57 static DECLARE_WORK(release_work, hid_bpf_release_progs);
58
59 BTF_ID_LIST(hid_bpf_btf_ids)
60 BTF_ID(func, hid_bpf_device_event)                      /* HID_BPF_PROG_TYPE_DEVICE_EVENT */
61 BTF_ID(func, hid_bpf_rdesc_fixup)                       /* HID_BPF_PROG_TYPE_RDESC_FIXUP */
62
63 static int hid_bpf_max_programs(enum hid_bpf_prog_type type)
64 {
65         switch (type) {
66         case HID_BPF_PROG_TYPE_DEVICE_EVENT:
67                 return HID_BPF_MAX_PROGS_PER_DEV;
68         case HID_BPF_PROG_TYPE_RDESC_FIXUP:
69                 return 1;
70         default:
71                 return -EINVAL;
72         }
73 }
74
75 static int hid_bpf_program_count(struct hid_device *hdev,
76                                  struct bpf_prog *prog,
77                                  enum hid_bpf_prog_type type)
78 {
79         int i, n = 0;
80
81         if (type >= HID_BPF_PROG_TYPE_MAX)
82                 return -EINVAL;
83
84         FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
85                 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
86
87                 if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type)
88                         continue;
89
90                 if (hdev && entry->hdev != hdev)
91                         continue;
92
93                 if (prog && entry->prog != prog)
94                         continue;
95
96                 n++;
97         }
98
99         return n;
100 }
101
102 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx)
103 {
104         return 0;
105 }
106
107 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
108                      struct hid_bpf_ctx_kern *ctx_kern)
109 {
110         struct hid_bpf_prog_list *prog_list;
111         int i, idx, err = 0;
112
113         rcu_read_lock();
114         prog_list = rcu_dereference(hdev->bpf.progs[type]);
115
116         if (!prog_list)
117                 goto out_unlock;
118
119         for (i = 0; i < prog_list->prog_cnt; i++) {
120                 idx = prog_list->prog_idx[i];
121
122                 if (!test_bit(idx, jmp_table.enabled))
123                         continue;
124
125                 ctx_kern->ctx.index = idx;
126                 err = __hid_bpf_tail_call(&ctx_kern->ctx);
127                 if (err < 0)
128                         break;
129                 if (err)
130                         ctx_kern->ctx.retval = err;
131         }
132
133  out_unlock:
134         rcu_read_unlock();
135
136         return err;
137 }
138
139 /*
140  * assign the list of programs attached to a given hid device.
141  */
142 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list,
143                                      enum hid_bpf_prog_type type)
144 {
145         struct hid_bpf_prog_list *old_list;
146
147         spin_lock(&hdev->bpf.progs_lock);
148         old_list = rcu_dereference_protected(hdev->bpf.progs[type],
149                                              lockdep_is_held(&hdev->bpf.progs_lock));
150         rcu_assign_pointer(hdev->bpf.progs[type], new_list);
151         spin_unlock(&hdev->bpf.progs_lock);
152         synchronize_rcu();
153
154         kfree(old_list);
155 }
156
157 /*
158  * allocate and populate the list of programs attached to a given hid device.
159  *
160  * Must be called under lock.
161  */
162 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type)
163 {
164         struct hid_bpf_prog_list *new_list;
165         int i;
166
167         if (type >= HID_BPF_PROG_TYPE_MAX || !hdev)
168                 return -EINVAL;
169
170         if (hdev->bpf.destroyed)
171                 return 0;
172
173         new_list = kzalloc(sizeof(*new_list), GFP_KERNEL);
174         if (!new_list)
175                 return -ENOMEM;
176
177         FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
178                 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
179
180                 if (entry->type == type && entry->hdev == hdev &&
181                     test_bit(entry->idx, jmp_table.enabled))
182                         new_list->prog_idx[new_list->prog_cnt++] = entry->idx;
183         }
184
185         __hid_bpf_set_hdev_progs(hdev, new_list, type);
186
187         return 0;
188 }
189
190 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
191 {
192         skel_map_delete_elem(map_fd, &idx);
193         jmp_table.progs[idx] = NULL;
194 }
195
196 static void hid_bpf_release_progs(struct work_struct *work)
197 {
198         int i, j, n, map_fd = -1;
199
200         if (!jmp_table.map)
201                 return;
202
203         /* retrieve a fd of our prog_array map in BPF */
204         map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
205         if (map_fd < 0)
206                 return;
207
208         mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */
209
210         /* detach unused progs from HID devices */
211         FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
212                 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
213                 enum hid_bpf_prog_type type;
214                 struct hid_device *hdev;
215
216                 if (test_bit(entry->idx, jmp_table.enabled))
217                         continue;
218
219                 /* we have an attached prog */
220                 if (entry->hdev) {
221                         hdev = entry->hdev;
222                         type = entry->type;
223
224                         hid_bpf_populate_hdev(hdev, type);
225
226                         /* mark all other disabled progs from hdev of the given type as detached */
227                         FOR_ENTRIES(j, i, jmp_table.head) {
228                                 struct hid_bpf_prog_entry *next;
229
230                                 next = &jmp_table.entries[j];
231
232                                 if (test_bit(next->idx, jmp_table.enabled))
233                                         continue;
234
235                                 if (next->hdev == hdev && next->type == type)
236                                         next->hdev = NULL;
237                         }
238
239                         /* if type was rdesc fixup, reconnect device */
240                         if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP)
241                                 hid_bpf_reconnect(hdev);
242                 }
243         }
244
245         /* remove all unused progs from the jump table */
246         FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
247                 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
248
249                 if (test_bit(entry->idx, jmp_table.enabled))
250                         continue;
251
252                 if (entry->prog)
253                         __hid_bpf_do_release_prog(map_fd, entry->idx);
254         }
255
256         /* compact the entry list */
257         n = jmp_table.tail;
258         FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
259                 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
260
261                 if (!test_bit(entry->idx, jmp_table.enabled))
262                         continue;
263
264                 jmp_table.entries[n] = jmp_table.entries[i];
265                 n = NEXT(n);
266         }
267
268         jmp_table.head = n;
269
270         mutex_unlock(&hid_bpf_attach_lock);
271
272         if (map_fd >= 0)
273                 close_fd(map_fd);
274 }
275
276 static void hid_bpf_release_prog_at(int idx)
277 {
278         int map_fd = -1;
279
280         /* retrieve a fd of our prog_array map in BPF */
281         map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
282         if (map_fd < 0)
283                 return;
284
285         __hid_bpf_do_release_prog(map_fd, idx);
286
287         close(map_fd);
288 }
289
290 /*
291  * Insert the given BPF program represented by its fd in the jmp table.
292  * Returns the index in the jump table or a negative error.
293  */
294 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog)
295 {
296         int i, index = -1, map_fd = -1, err = -EINVAL;
297
298         /* retrieve a fd of our prog_array map in BPF */
299         map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
300
301         if (map_fd < 0) {
302                 err = -EINVAL;
303                 goto out;
304         }
305
306         /* find the first available index in the jmp_table */
307         for (i = 0; i < HID_BPF_MAX_PROGS; i++) {
308                 if (!jmp_table.progs[i] && index < 0) {
309                         /* mark the index as used */
310                         jmp_table.progs[i] = prog;
311                         index = i;
312                         __set_bit(i, jmp_table.enabled);
313                 }
314         }
315         if (index < 0) {
316                 err = -ENOMEM;
317                 goto out;
318         }
319
320         /* insert the program in the jump table */
321         err = skel_map_update_elem(map_fd, &index, &prog_fd, 0);
322         if (err)
323                 goto out;
324
325         /* return the index */
326         err = index;
327
328  out:
329         if (err < 0)
330                 __hid_bpf_do_release_prog(map_fd, index);
331         if (map_fd >= 0)
332                 close_fd(map_fd);
333         return err;
334 }
335
336 int hid_bpf_get_prog_attach_type(int prog_fd)
337 {
338         struct bpf_prog *prog = NULL;
339         int i;
340         int prog_type = HID_BPF_PROG_TYPE_UNDEF;
341
342         prog = bpf_prog_get(prog_fd);
343         if (IS_ERR(prog))
344                 return PTR_ERR(prog);
345
346         for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) {
347                 if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) {
348                         prog_type = i;
349                         break;
350                 }
351         }
352
353         bpf_prog_put(prog);
354
355         return prog_type;
356 }
357
358 static void hid_bpf_link_release(struct bpf_link *link)
359 {
360         struct hid_bpf_link *hid_link =
361                 container_of(link, struct hid_bpf_link, link);
362
363         __clear_bit(hid_link->hid_table_index, jmp_table.enabled);
364         schedule_work(&release_work);
365 }
366
367 static void hid_bpf_link_dealloc(struct bpf_link *link)
368 {
369         struct hid_bpf_link *hid_link =
370                 container_of(link, struct hid_bpf_link, link);
371
372         kfree(hid_link);
373 }
374
375 static void hid_bpf_link_show_fdinfo(const struct bpf_link *link,
376                                          struct seq_file *seq)
377 {
378         seq_printf(seq,
379                    "attach_type:\tHID-BPF\n");
380 }
381
382 static const struct bpf_link_ops hid_bpf_link_lops = {
383         .release = hid_bpf_link_release,
384         .dealloc = hid_bpf_link_dealloc,
385         .show_fdinfo = hid_bpf_link_show_fdinfo,
386 };
387
388 /* called from syscall */
389 noinline int
390 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
391                       int prog_fd, __u32 flags)
392 {
393         struct bpf_link_primer link_primer;
394         struct hid_bpf_link *link;
395         struct bpf_prog *prog = NULL;
396         struct hid_bpf_prog_entry *prog_entry;
397         int cnt, err = -EINVAL, prog_table_idx = -1;
398
399         /* take a ref on the prog itself */
400         prog = bpf_prog_get(prog_fd);
401         if (IS_ERR(prog))
402                 return PTR_ERR(prog);
403
404         mutex_lock(&hid_bpf_attach_lock);
405
406         link = kzalloc(sizeof(*link), GFP_USER);
407         if (!link) {
408                 err = -ENOMEM;
409                 goto err_unlock;
410         }
411
412         bpf_link_init(&link->link, BPF_LINK_TYPE_UNSPEC,
413                       &hid_bpf_link_lops, prog);
414
415         /* do not attach too many programs to a given HID device */
416         cnt = hid_bpf_program_count(hdev, NULL, prog_type);
417         if (cnt < 0) {
418                 err = cnt;
419                 goto err_unlock;
420         }
421
422         if (cnt >= hid_bpf_max_programs(prog_type)) {
423                 err = -E2BIG;
424                 goto err_unlock;
425         }
426
427         prog_table_idx = hid_bpf_insert_prog(prog_fd, prog);
428         /* if the jmp table is full, abort */
429         if (prog_table_idx < 0) {
430                 err = prog_table_idx;
431                 goto err_unlock;
432         }
433
434         if (flags & HID_BPF_FLAG_INSERT_HEAD) {
435                 /* take the previous prog_entry slot */
436                 jmp_table.tail = PREV(jmp_table.tail);
437                 prog_entry = &jmp_table.entries[jmp_table.tail];
438         } else {
439                 /* take the next prog_entry slot */
440                 prog_entry = &jmp_table.entries[jmp_table.head];
441                 jmp_table.head = NEXT(jmp_table.head);
442         }
443
444         /* we steal the ref here */
445         prog_entry->prog = prog;
446         prog_entry->idx = prog_table_idx;
447         prog_entry->hdev = hdev;
448         prog_entry->type = prog_type;
449
450         /* finally store the index in the device list */
451         err = hid_bpf_populate_hdev(hdev, prog_type);
452         if (err) {
453                 hid_bpf_release_prog_at(prog_table_idx);
454                 goto err_unlock;
455         }
456
457         link->hid_table_index = prog_table_idx;
458
459         err = bpf_link_prime(&link->link, &link_primer);
460         if (err)
461                 goto err_unlock;
462
463         mutex_unlock(&hid_bpf_attach_lock);
464
465         return bpf_link_settle(&link_primer);
466
467  err_unlock:
468         mutex_unlock(&hid_bpf_attach_lock);
469
470         bpf_prog_put(prog);
471         kfree(link);
472
473         return err;
474 }
475
476 void __hid_bpf_destroy_device(struct hid_device *hdev)
477 {
478         int type, i;
479         struct hid_bpf_prog_list *prog_list;
480
481         rcu_read_lock();
482
483         for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) {
484                 prog_list = rcu_dereference(hdev->bpf.progs[type]);
485
486                 if (!prog_list)
487                         continue;
488
489                 for (i = 0; i < prog_list->prog_cnt; i++)
490                         __clear_bit(prog_list->prog_idx[i], jmp_table.enabled);
491         }
492
493         rcu_read_unlock();
494
495         for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++)
496                 __hid_bpf_set_hdev_progs(hdev, NULL, type);
497
498         /* schedule release of all detached progs */
499         schedule_work(&release_work);
500 }
501
502 #define HID_BPF_PROGS_COUNT 1
503
504 static struct bpf_link *links[HID_BPF_PROGS_COUNT];
505 static struct entrypoints_bpf *skel;
506
507 void hid_bpf_free_links_and_skel(void)
508 {
509         int i;
510
511         /* the following is enough to release all programs attached to hid */
512         if (jmp_table.map)
513                 bpf_map_put_with_uref(jmp_table.map);
514
515         for (i = 0; i < ARRAY_SIZE(links); i++) {
516                 if (!IS_ERR_OR_NULL(links[i]))
517                         bpf_link_put(links[i]);
518         }
519         entrypoints_bpf__destroy(skel);
520 }
521
522 #define ATTACH_AND_STORE_LINK(__name) do {                                      \
523         err = entrypoints_bpf__##__name##__attach(skel);                        \
524         if (err)                                                                \
525                 goto out;                                                       \
526                                                                                 \
527         links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd);             \
528         if (IS_ERR(links[idx])) {                                               \
529                 err = PTR_ERR(links[idx]);                                      \
530                 goto out;                                                       \
531         }                                                                       \
532                                                                                 \
533         /* Avoid taking over stdin/stdout/stderr of init process. Zeroing out   \
534          * makes skel_closenz() a no-op later in iterators_bpf__destroy().      \
535          */                                                                     \
536         close_fd(skel->links.__name##_fd);                                      \
537         skel->links.__name##_fd = 0;                                            \
538         idx++;                                                                  \
539 } while (0)
540
541 int hid_bpf_preload_skel(void)
542 {
543         int err, idx = 0;
544
545         skel = entrypoints_bpf__open();
546         if (!skel)
547                 return -ENOMEM;
548
549         err = entrypoints_bpf__load(skel);
550         if (err)
551                 goto out;
552
553         jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd);
554         if (IS_ERR(jmp_table.map)) {
555                 err = PTR_ERR(jmp_table.map);
556                 goto out;
557         }
558
559         ATTACH_AND_STORE_LINK(hid_tail_call);
560
561         return 0;
562 out:
563         hid_bpf_free_links_and_skel();
564         return err;
565 }