Merge tag 'md-fixes-20231003' of https://git.kernel.org/pub/scm/linux/kernel/git...
[platform/kernel/linux-starfive.git] / kernel / static_call_inline.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/init.h>
3 #include <linux/static_call.h>
4 #include <linux/bug.h>
5 #include <linux/smp.h>
6 #include <linux/sort.h>
7 #include <linux/slab.h>
8 #include <linux/module.h>
9 #include <linux/cpu.h>
10 #include <linux/processor.h>
11 #include <asm/sections.h>
12
13 extern struct static_call_site __start_static_call_sites[],
14                                __stop_static_call_sites[];
15 extern struct static_call_tramp_key __start_static_call_tramp_key[],
16                                     __stop_static_call_tramp_key[];
17
18 static int static_call_initialized;
19
20 /*
21  * Must be called before early_initcall() to be effective.
22  */
23 void static_call_force_reinit(void)
24 {
25         if (WARN_ON_ONCE(!static_call_initialized))
26                 return;
27
28         static_call_initialized++;
29 }
30
31 /* mutex to protect key modules/sites */
32 static DEFINE_MUTEX(static_call_mutex);
33
34 static void static_call_lock(void)
35 {
36         mutex_lock(&static_call_mutex);
37 }
38
39 static void static_call_unlock(void)
40 {
41         mutex_unlock(&static_call_mutex);
42 }
43
44 static inline void *static_call_addr(struct static_call_site *site)
45 {
46         return (void *)((long)site->addr + (long)&site->addr);
47 }
48
49 static inline unsigned long __static_call_key(const struct static_call_site *site)
50 {
51         return (long)site->key + (long)&site->key;
52 }
53
54 static inline struct static_call_key *static_call_key(const struct static_call_site *site)
55 {
56         return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS);
57 }
58
59 /* These assume the key is word-aligned. */
60 static inline bool static_call_is_init(struct static_call_site *site)
61 {
62         return __static_call_key(site) & STATIC_CALL_SITE_INIT;
63 }
64
65 static inline bool static_call_is_tail(struct static_call_site *site)
66 {
67         return __static_call_key(site) & STATIC_CALL_SITE_TAIL;
68 }
69
70 static inline void static_call_set_init(struct static_call_site *site)
71 {
72         site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) -
73                     (long)&site->key;
74 }
75
76 static int static_call_site_cmp(const void *_a, const void *_b)
77 {
78         const struct static_call_site *a = _a;
79         const struct static_call_site *b = _b;
80         const struct static_call_key *key_a = static_call_key(a);
81         const struct static_call_key *key_b = static_call_key(b);
82
83         if (key_a < key_b)
84                 return -1;
85
86         if (key_a > key_b)
87                 return 1;
88
89         return 0;
90 }
91
92 static void static_call_site_swap(void *_a, void *_b, int size)
93 {
94         long delta = (unsigned long)_a - (unsigned long)_b;
95         struct static_call_site *a = _a;
96         struct static_call_site *b = _b;
97         struct static_call_site tmp = *a;
98
99         a->addr = b->addr  - delta;
100         a->key  = b->key   - delta;
101
102         b->addr = tmp.addr + delta;
103         b->key  = tmp.key  + delta;
104 }
105
106 static inline void static_call_sort_entries(struct static_call_site *start,
107                                             struct static_call_site *stop)
108 {
109         sort(start, stop - start, sizeof(struct static_call_site),
110              static_call_site_cmp, static_call_site_swap);
111 }
112
113 static inline bool static_call_key_has_mods(struct static_call_key *key)
114 {
115         return !(key->type & 1);
116 }
117
118 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key)
119 {
120         if (!static_call_key_has_mods(key))
121                 return NULL;
122
123         return key->mods;
124 }
125
126 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key)
127 {
128         if (static_call_key_has_mods(key))
129                 return NULL;
130
131         return (struct static_call_site *)(key->type & ~1);
132 }
133
134 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
135 {
136         struct static_call_site *site, *stop;
137         struct static_call_mod *site_mod, first;
138
139         cpus_read_lock();
140         static_call_lock();
141
142         if (key->func == func)
143                 goto done;
144
145         key->func = func;
146
147         arch_static_call_transform(NULL, tramp, func, false);
148
149         /*
150          * If uninitialized, we'll not update the callsites, but they still
151          * point to the trampoline and we just patched that.
152          */
153         if (WARN_ON_ONCE(!static_call_initialized))
154                 goto done;
155
156         first = (struct static_call_mod){
157                 .next = static_call_key_next(key),
158                 .mod = NULL,
159                 .sites = static_call_key_sites(key),
160         };
161
162         for (site_mod = &first; site_mod; site_mod = site_mod->next) {
163                 bool init = system_state < SYSTEM_RUNNING;
164                 struct module *mod = site_mod->mod;
165
166                 if (!site_mod->sites) {
167                         /*
168                          * This can happen if the static call key is defined in
169                          * a module which doesn't use it.
170                          *
171                          * It also happens in the has_mods case, where the
172                          * 'first' entry has no sites associated with it.
173                          */
174                         continue;
175                 }
176
177                 stop = __stop_static_call_sites;
178
179                 if (mod) {
180 #ifdef CONFIG_MODULES
181                         stop = mod->static_call_sites +
182                                mod->num_static_call_sites;
183                         init = mod->state == MODULE_STATE_COMING;
184 #endif
185                 }
186
187                 for (site = site_mod->sites;
188                      site < stop && static_call_key(site) == key; site++) {
189                         void *site_addr = static_call_addr(site);
190
191                         if (!init && static_call_is_init(site))
192                                 continue;
193
194                         if (!kernel_text_address((unsigned long)site_addr)) {
195                                 /*
196                                  * This skips patching built-in __exit, which
197                                  * is part of init_section_contains() but is
198                                  * not part of kernel_text_address().
199                                  *
200                                  * Skipping built-in __exit is fine since it
201                                  * will never be executed.
202                                  */
203                                 WARN_ONCE(!static_call_is_init(site),
204                                           "can't patch static call site at %pS",
205                                           site_addr);
206                                 continue;
207                         }
208
209                         arch_static_call_transform(site_addr, NULL, func,
210                                                    static_call_is_tail(site));
211                 }
212         }
213
214 done:
215         static_call_unlock();
216         cpus_read_unlock();
217 }
218 EXPORT_SYMBOL_GPL(__static_call_update);
219
220 static int __static_call_init(struct module *mod,
221                               struct static_call_site *start,
222                               struct static_call_site *stop)
223 {
224         struct static_call_site *site;
225         struct static_call_key *key, *prev_key = NULL;
226         struct static_call_mod *site_mod;
227
228         if (start == stop)
229                 return 0;
230
231         static_call_sort_entries(start, stop);
232
233         for (site = start; site < stop; site++) {
234                 void *site_addr = static_call_addr(site);
235
236                 if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
237                     (!mod && init_section_contains(site_addr, 1)))
238                         static_call_set_init(site);
239
240                 key = static_call_key(site);
241                 if (key != prev_key) {
242                         prev_key = key;
243
244                         /*
245                          * For vmlinux (!mod) avoid the allocation by storing
246                          * the sites pointer in the key itself. Also see
247                          * __static_call_update()'s @first.
248                          *
249                          * This allows architectures (eg. x86) to call
250                          * static_call_init() before memory allocation works.
251                          */
252                         if (!mod) {
253                                 key->sites = site;
254                                 key->type |= 1;
255                                 goto do_transform;
256                         }
257
258                         site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
259                         if (!site_mod)
260                                 return -ENOMEM;
261
262                         /*
263                          * When the key has a direct sites pointer, extract
264                          * that into an explicit struct static_call_mod, so we
265                          * can have a list of modules.
266                          */
267                         if (static_call_key_sites(key)) {
268                                 site_mod->mod = NULL;
269                                 site_mod->next = NULL;
270                                 site_mod->sites = static_call_key_sites(key);
271
272                                 key->mods = site_mod;
273
274                                 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
275                                 if (!site_mod)
276                                         return -ENOMEM;
277                         }
278
279                         site_mod->mod = mod;
280                         site_mod->sites = site;
281                         site_mod->next = static_call_key_next(key);
282                         key->mods = site_mod;
283                 }
284
285 do_transform:
286                 arch_static_call_transform(site_addr, NULL, key->func,
287                                 static_call_is_tail(site));
288         }
289
290         return 0;
291 }
292
293 static int addr_conflict(struct static_call_site *site, void *start, void *end)
294 {
295         unsigned long addr = (unsigned long)static_call_addr(site);
296
297         if (addr <= (unsigned long)end &&
298             addr + CALL_INSN_SIZE > (unsigned long)start)
299                 return 1;
300
301         return 0;
302 }
303
304 static int __static_call_text_reserved(struct static_call_site *iter_start,
305                                        struct static_call_site *iter_stop,
306                                        void *start, void *end, bool init)
307 {
308         struct static_call_site *iter = iter_start;
309
310         while (iter < iter_stop) {
311                 if (init || !static_call_is_init(iter)) {
312                         if (addr_conflict(iter, start, end))
313                                 return 1;
314                 }
315                 iter++;
316         }
317
318         return 0;
319 }
320
321 #ifdef CONFIG_MODULES
322
323 static int __static_call_mod_text_reserved(void *start, void *end)
324 {
325         struct module *mod;
326         int ret;
327
328         preempt_disable();
329         mod = __module_text_address((unsigned long)start);
330         WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
331         if (!try_module_get(mod))
332                 mod = NULL;
333         preempt_enable();
334
335         if (!mod)
336                 return 0;
337
338         ret = __static_call_text_reserved(mod->static_call_sites,
339                         mod->static_call_sites + mod->num_static_call_sites,
340                         start, end, mod->state == MODULE_STATE_COMING);
341
342         module_put(mod);
343
344         return ret;
345 }
346
347 static unsigned long tramp_key_lookup(unsigned long addr)
348 {
349         struct static_call_tramp_key *start = __start_static_call_tramp_key;
350         struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
351         struct static_call_tramp_key *tramp_key;
352
353         for (tramp_key = start; tramp_key != stop; tramp_key++) {
354                 unsigned long tramp;
355
356                 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
357                 if (tramp == addr)
358                         return (long)tramp_key->key + (long)&tramp_key->key;
359         }
360
361         return 0;
362 }
363
364 static int static_call_add_module(struct module *mod)
365 {
366         struct static_call_site *start = mod->static_call_sites;
367         struct static_call_site *stop = start + mod->num_static_call_sites;
368         struct static_call_site *site;
369
370         for (site = start; site != stop; site++) {
371                 unsigned long s_key = __static_call_key(site);
372                 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS;
373                 unsigned long key;
374
375                 /*
376                  * Is the key is exported, 'addr' points to the key, which
377                  * means modules are allowed to call static_call_update() on
378                  * it.
379                  *
380                  * Otherwise, the key isn't exported, and 'addr' points to the
381                  * trampoline so we need to lookup the key.
382                  *
383                  * We go through this dance to prevent crazy modules from
384                  * abusing sensitive static calls.
385                  */
386                 if (!kernel_text_address(addr))
387                         continue;
388
389                 key = tramp_key_lookup(addr);
390                 if (!key) {
391                         pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
392                                 static_call_addr(site));
393                         return -EINVAL;
394                 }
395
396                 key |= s_key & STATIC_CALL_SITE_FLAGS;
397                 site->key = key - (long)&site->key;
398         }
399
400         return __static_call_init(mod, start, stop);
401 }
402
403 static void static_call_del_module(struct module *mod)
404 {
405         struct static_call_site *start = mod->static_call_sites;
406         struct static_call_site *stop = mod->static_call_sites +
407                                         mod->num_static_call_sites;
408         struct static_call_key *key, *prev_key = NULL;
409         struct static_call_mod *site_mod, **prev;
410         struct static_call_site *site;
411
412         for (site = start; site < stop; site++) {
413                 key = static_call_key(site);
414                 if (key == prev_key)
415                         continue;
416
417                 prev_key = key;
418
419                 for (prev = &key->mods, site_mod = key->mods;
420                      site_mod && site_mod->mod != mod;
421                      prev = &site_mod->next, site_mod = site_mod->next)
422                         ;
423
424                 if (!site_mod)
425                         continue;
426
427                 *prev = site_mod->next;
428                 kfree(site_mod);
429         }
430 }
431
432 static int static_call_module_notify(struct notifier_block *nb,
433                                      unsigned long val, void *data)
434 {
435         struct module *mod = data;
436         int ret = 0;
437
438         cpus_read_lock();
439         static_call_lock();
440
441         switch (val) {
442         case MODULE_STATE_COMING:
443                 ret = static_call_add_module(mod);
444                 if (ret) {
445                         WARN(1, "Failed to allocate memory for static calls");
446                         static_call_del_module(mod);
447                 }
448                 break;
449         case MODULE_STATE_GOING:
450                 static_call_del_module(mod);
451                 break;
452         }
453
454         static_call_unlock();
455         cpus_read_unlock();
456
457         return notifier_from_errno(ret);
458 }
459
460 static struct notifier_block static_call_module_nb = {
461         .notifier_call = static_call_module_notify,
462 };
463
464 #else
465
466 static inline int __static_call_mod_text_reserved(void *start, void *end)
467 {
468         return 0;
469 }
470
471 #endif /* CONFIG_MODULES */
472
473 int static_call_text_reserved(void *start, void *end)
474 {
475         bool init = system_state < SYSTEM_RUNNING;
476         int ret = __static_call_text_reserved(__start_static_call_sites,
477                         __stop_static_call_sites, start, end, init);
478
479         if (ret)
480                 return ret;
481
482         return __static_call_mod_text_reserved(start, end);
483 }
484
485 int __init static_call_init(void)
486 {
487         int ret;
488
489         /* See static_call_force_reinit(). */
490         if (static_call_initialized == 1)
491                 return 0;
492
493         cpus_read_lock();
494         static_call_lock();
495         ret = __static_call_init(NULL, __start_static_call_sites,
496                                  __stop_static_call_sites);
497         static_call_unlock();
498         cpus_read_unlock();
499
500         if (ret) {
501                 pr_err("Failed to allocate memory for static_call!\n");
502                 BUG();
503         }
504
505 #ifdef CONFIG_MODULES
506         if (!static_call_initialized)
507                 register_module_notifier(&static_call_module_nb);
508 #endif
509
510         static_call_initialized = 1;
511         return 0;
512 }
513 early_initcall(static_call_init);
514
515 #ifdef CONFIG_STATIC_CALL_SELFTEST
516
517 static int func_a(int x)
518 {
519         return x+1;
520 }
521
522 static int func_b(int x)
523 {
524         return x+2;
525 }
526
527 DEFINE_STATIC_CALL(sc_selftest, func_a);
528
529 static struct static_call_data {
530       int (*func)(int);
531       int val;
532       int expect;
533 } static_call_data [] __initdata = {
534       { NULL,   2, 3 },
535       { func_b, 2, 4 },
536       { func_a, 2, 3 }
537 };
538
539 static int __init test_static_call_init(void)
540 {
541       int i;
542
543       for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
544               struct static_call_data *scd = &static_call_data[i];
545
546               if (scd->func)
547                       static_call_update(sc_selftest, scd->func);
548
549               WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
550       }
551
552       return 0;
553 }
554 early_initcall(test_static_call_init);
555
556 #endif /* CONFIG_STATIC_CALL_SELFTEST */