1 // SPDX-License-Identifier: GPL-2.0-only
4 * HID-BPF support for Linux
6 * Copyright (c) 2022 Benjamin Tissoires
9 #include <linux/bitops.h>
10 #include <linux/btf.h>
11 #include <linux/btf_ids.h>
12 #include <linux/circ_buf.h>
13 #include <linux/filter.h>
14 #include <linux/hid.h>
15 #include <linux/hid_bpf.h>
16 #include <linux/init.h>
17 #include <linux/module.h>
18 #include <linux/workqueue.h>
19 #include "hid_bpf_dispatch.h"
20 #include "entrypoints/entrypoints.lskel.h"
22 #define HID_BPF_MAX_PROGS 1024 /* keep this in sync with preloaded bpf,
23 * needs to be a power of 2 as we use it as
27 #define NEXT(idx) (((idx) + 1) & (HID_BPF_MAX_PROGS - 1))
28 #define PREV(idx) (((idx) - 1) & (HID_BPF_MAX_PROGS - 1))
31 * represents one attached program stored in the hid jump table
33 struct hid_bpf_prog_entry {
34 struct bpf_prog *prog;
35 struct hid_device *hdev;
36 enum hid_bpf_prog_type type;
40 struct hid_bpf_jmp_table {
42 struct hid_bpf_prog_entry entries[HID_BPF_MAX_PROGS]; /* compacted list, circular buffer */
44 struct bpf_prog *progs[HID_BPF_MAX_PROGS]; /* idx -> progs mapping */
45 unsigned long enabled[BITS_TO_LONGS(HID_BPF_MAX_PROGS)];
48 #define FOR_ENTRIES(__i, __start, __end) \
49 for (__i = __start; CIRC_CNT(__end, __i, HID_BPF_MAX_PROGS); __i = NEXT(__i))
51 static struct hid_bpf_jmp_table jmp_table;
53 static DEFINE_MUTEX(hid_bpf_attach_lock); /* held when attaching/detaching programs */
55 static void hid_bpf_release_progs(struct work_struct *work);
57 static DECLARE_WORK(release_work, hid_bpf_release_progs);
59 BTF_ID_LIST(hid_bpf_btf_ids)
60 BTF_ID(func, hid_bpf_device_event) /* HID_BPF_PROG_TYPE_DEVICE_EVENT */
61 BTF_ID(func, hid_bpf_rdesc_fixup) /* HID_BPF_PROG_TYPE_RDESC_FIXUP */
63 static int hid_bpf_max_programs(enum hid_bpf_prog_type type)
66 case HID_BPF_PROG_TYPE_DEVICE_EVENT:
67 return HID_BPF_MAX_PROGS_PER_DEV;
68 case HID_BPF_PROG_TYPE_RDESC_FIXUP:
75 static int hid_bpf_program_count(struct hid_device *hdev,
76 struct bpf_prog *prog,
77 enum hid_bpf_prog_type type)
81 if (type >= HID_BPF_PROG_TYPE_MAX)
84 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
85 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
87 if (type != HID_BPF_PROG_TYPE_UNDEF && entry->type != type)
90 if (hdev && entry->hdev != hdev)
93 if (prog && entry->prog != prog)
102 __weak noinline int __hid_bpf_tail_call(struct hid_bpf_ctx *ctx)
107 int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
108 struct hid_bpf_ctx_kern *ctx_kern)
110 struct hid_bpf_prog_list *prog_list;
114 prog_list = rcu_dereference(hdev->bpf.progs[type]);
119 for (i = 0; i < prog_list->prog_cnt; i++) {
120 idx = prog_list->prog_idx[i];
122 if (!test_bit(idx, jmp_table.enabled))
125 ctx_kern->ctx.index = idx;
126 err = __hid_bpf_tail_call(&ctx_kern->ctx);
130 ctx_kern->ctx.retval = err;
140 * assign the list of programs attached to a given hid device.
142 static void __hid_bpf_set_hdev_progs(struct hid_device *hdev, struct hid_bpf_prog_list *new_list,
143 enum hid_bpf_prog_type type)
145 struct hid_bpf_prog_list *old_list;
147 spin_lock(&hdev->bpf.progs_lock);
148 old_list = rcu_dereference_protected(hdev->bpf.progs[type],
149 lockdep_is_held(&hdev->bpf.progs_lock));
150 rcu_assign_pointer(hdev->bpf.progs[type], new_list);
151 spin_unlock(&hdev->bpf.progs_lock);
158 * allocate and populate the list of programs attached to a given hid device.
160 * Must be called under lock.
162 static int hid_bpf_populate_hdev(struct hid_device *hdev, enum hid_bpf_prog_type type)
164 struct hid_bpf_prog_list *new_list;
167 if (type >= HID_BPF_PROG_TYPE_MAX || !hdev)
170 if (hdev->bpf.destroyed)
173 new_list = kzalloc(sizeof(*new_list), GFP_KERNEL);
177 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
178 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
180 if (entry->type == type && entry->hdev == hdev &&
181 test_bit(entry->idx, jmp_table.enabled))
182 new_list->prog_idx[new_list->prog_cnt++] = entry->idx;
185 __hid_bpf_set_hdev_progs(hdev, new_list, type);
190 static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
192 skel_map_delete_elem(map_fd, &idx);
193 jmp_table.progs[idx] = NULL;
196 static void hid_bpf_release_progs(struct work_struct *work)
198 int i, j, n, map_fd = -1;
203 /* retrieve a fd of our prog_array map in BPF */
204 map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
208 mutex_lock(&hid_bpf_attach_lock); /* protects against attaching new programs */
210 /* detach unused progs from HID devices */
211 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
212 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
213 enum hid_bpf_prog_type type;
214 struct hid_device *hdev;
216 if (test_bit(entry->idx, jmp_table.enabled))
219 /* we have an attached prog */
224 hid_bpf_populate_hdev(hdev, type);
226 /* mark all other disabled progs from hdev of the given type as detached */
227 FOR_ENTRIES(j, i, jmp_table.head) {
228 struct hid_bpf_prog_entry *next;
230 next = &jmp_table.entries[j];
232 if (test_bit(next->idx, jmp_table.enabled))
235 if (next->hdev == hdev && next->type == type)
239 /* if type was rdesc fixup, reconnect device */
240 if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP)
241 hid_bpf_reconnect(hdev);
245 /* remove all unused progs from the jump table */
246 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
247 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
249 if (test_bit(entry->idx, jmp_table.enabled))
253 __hid_bpf_do_release_prog(map_fd, entry->idx);
256 /* compact the entry list */
258 FOR_ENTRIES(i, jmp_table.tail, jmp_table.head) {
259 struct hid_bpf_prog_entry *entry = &jmp_table.entries[i];
261 if (!test_bit(entry->idx, jmp_table.enabled))
264 jmp_table.entries[n] = jmp_table.entries[i];
270 mutex_unlock(&hid_bpf_attach_lock);
276 static void hid_bpf_release_prog_at(int idx)
280 /* retrieve a fd of our prog_array map in BPF */
281 map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
285 __hid_bpf_do_release_prog(map_fd, idx);
291 * Insert the given BPF program represented by its fd in the jmp table.
292 * Returns the index in the jump table or a negative error.
294 static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog)
296 int i, index = -1, map_fd = -1, err = -EINVAL;
298 /* retrieve a fd of our prog_array map in BPF */
299 map_fd = skel_map_get_fd_by_id(jmp_table.map->id);
306 /* find the first available index in the jmp_table */
307 for (i = 0; i < HID_BPF_MAX_PROGS; i++) {
308 if (!jmp_table.progs[i] && index < 0) {
309 /* mark the index as used */
310 jmp_table.progs[i] = prog;
312 __set_bit(i, jmp_table.enabled);
320 /* insert the program in the jump table */
321 err = skel_map_update_elem(map_fd, &index, &prog_fd, 0);
325 /* return the index */
330 __hid_bpf_do_release_prog(map_fd, index);
336 int hid_bpf_get_prog_attach_type(int prog_fd)
338 struct bpf_prog *prog = NULL;
340 int prog_type = HID_BPF_PROG_TYPE_UNDEF;
342 prog = bpf_prog_get(prog_fd);
344 return PTR_ERR(prog);
346 for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) {
347 if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) {
358 static void hid_bpf_link_release(struct bpf_link *link)
360 struct hid_bpf_link *hid_link =
361 container_of(link, struct hid_bpf_link, link);
363 __clear_bit(hid_link->hid_table_index, jmp_table.enabled);
364 schedule_work(&release_work);
367 static void hid_bpf_link_dealloc(struct bpf_link *link)
369 struct hid_bpf_link *hid_link =
370 container_of(link, struct hid_bpf_link, link);
375 static void hid_bpf_link_show_fdinfo(const struct bpf_link *link,
376 struct seq_file *seq)
379 "attach_type:\tHID-BPF\n");
382 static const struct bpf_link_ops hid_bpf_link_lops = {
383 .release = hid_bpf_link_release,
384 .dealloc = hid_bpf_link_dealloc,
385 .show_fdinfo = hid_bpf_link_show_fdinfo,
388 /* called from syscall */
390 __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
391 int prog_fd, __u32 flags)
393 struct bpf_link_primer link_primer;
394 struct hid_bpf_link *link;
395 struct bpf_prog *prog = NULL;
396 struct hid_bpf_prog_entry *prog_entry;
397 int cnt, err = -EINVAL, prog_table_idx = -1;
399 /* take a ref on the prog itself */
400 prog = bpf_prog_get(prog_fd);
402 return PTR_ERR(prog);
404 mutex_lock(&hid_bpf_attach_lock);
406 link = kzalloc(sizeof(*link), GFP_USER);
412 bpf_link_init(&link->link, BPF_LINK_TYPE_UNSPEC,
413 &hid_bpf_link_lops, prog);
415 /* do not attach too many programs to a given HID device */
416 cnt = hid_bpf_program_count(hdev, NULL, prog_type);
422 if (cnt >= hid_bpf_max_programs(prog_type)) {
427 prog_table_idx = hid_bpf_insert_prog(prog_fd, prog);
428 /* if the jmp table is full, abort */
429 if (prog_table_idx < 0) {
430 err = prog_table_idx;
434 if (flags & HID_BPF_FLAG_INSERT_HEAD) {
435 /* take the previous prog_entry slot */
436 jmp_table.tail = PREV(jmp_table.tail);
437 prog_entry = &jmp_table.entries[jmp_table.tail];
439 /* take the next prog_entry slot */
440 prog_entry = &jmp_table.entries[jmp_table.head];
441 jmp_table.head = NEXT(jmp_table.head);
444 /* we steal the ref here */
445 prog_entry->prog = prog;
446 prog_entry->idx = prog_table_idx;
447 prog_entry->hdev = hdev;
448 prog_entry->type = prog_type;
450 /* finally store the index in the device list */
451 err = hid_bpf_populate_hdev(hdev, prog_type);
453 hid_bpf_release_prog_at(prog_table_idx);
457 link->hid_table_index = prog_table_idx;
459 err = bpf_link_prime(&link->link, &link_primer);
463 mutex_unlock(&hid_bpf_attach_lock);
465 return bpf_link_settle(&link_primer);
468 mutex_unlock(&hid_bpf_attach_lock);
476 void __hid_bpf_destroy_device(struct hid_device *hdev)
479 struct hid_bpf_prog_list *prog_list;
483 for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++) {
484 prog_list = rcu_dereference(hdev->bpf.progs[type]);
489 for (i = 0; i < prog_list->prog_cnt; i++)
490 __clear_bit(prog_list->prog_idx[i], jmp_table.enabled);
495 for (type = 0; type < HID_BPF_PROG_TYPE_MAX; type++)
496 __hid_bpf_set_hdev_progs(hdev, NULL, type);
498 /* schedule release of all detached progs */
499 schedule_work(&release_work);
502 #define HID_BPF_PROGS_COUNT 1
504 static struct bpf_link *links[HID_BPF_PROGS_COUNT];
505 static struct entrypoints_bpf *skel;
507 void hid_bpf_free_links_and_skel(void)
511 /* the following is enough to release all programs attached to hid */
513 bpf_map_put_with_uref(jmp_table.map);
515 for (i = 0; i < ARRAY_SIZE(links); i++) {
516 if (!IS_ERR_OR_NULL(links[i]))
517 bpf_link_put(links[i]);
519 entrypoints_bpf__destroy(skel);
522 #define ATTACH_AND_STORE_LINK(__name) do { \
523 err = entrypoints_bpf__##__name##__attach(skel); \
527 links[idx] = bpf_link_get_from_fd(skel->links.__name##_fd); \
528 if (IS_ERR(links[idx])) { \
529 err = PTR_ERR(links[idx]); \
533 /* Avoid taking over stdin/stdout/stderr of init process. Zeroing out \
534 * makes skel_closenz() a no-op later in iterators_bpf__destroy(). \
536 close_fd(skel->links.__name##_fd); \
537 skel->links.__name##_fd = 0; \
541 int hid_bpf_preload_skel(void)
545 skel = entrypoints_bpf__open();
549 err = entrypoints_bpf__load(skel);
553 jmp_table.map = bpf_map_get_with_uref(skel->maps.hid_jmp_table.map_fd);
554 if (IS_ERR(jmp_table.map)) {
555 err = PTR_ERR(jmp_table.map);
559 ATTACH_AND_STORE_LINK(hid_tail_call);
563 hid_bpf_free_links_and_skel();