tracing/boot: Fix a hist trigger dependency for boot time tracing
[platform/kernel/linux-rpi.git] / kernel / static_call.c
index 84565c2..dc5665b 100644 (file)
@@ -12,6 +12,8 @@
 
 extern struct static_call_site __start_static_call_sites[],
                               __stop_static_call_sites[];
+extern struct static_call_tramp_key __start_static_call_tramp_key[],
+                                   __stop_static_call_tramp_key[];
 
 static bool static_call_initialized;
 
@@ -33,27 +35,30 @@ static inline void *static_call_addr(struct static_call_site *site)
        return (void *)((long)site->addr + (long)&site->addr);
 }
 
+static inline unsigned long __static_call_key(const struct static_call_site *site)
+{
+       return (long)site->key + (long)&site->key;
+}
 
 static inline struct static_call_key *static_call_key(const struct static_call_site *site)
 {
-       return (struct static_call_key *)
-               (((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
+       return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS);
 }
 
 /* These assume the key is word-aligned. */
 static inline bool static_call_is_init(struct static_call_site *site)
 {
-       return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
+       return __static_call_key(site) & STATIC_CALL_SITE_INIT;
 }
 
 static inline bool static_call_is_tail(struct static_call_site *site)
 {
-       return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
+       return __static_call_key(site) & STATIC_CALL_SITE_TAIL;
 }
 
 static inline void static_call_set_init(struct static_call_site *site)
 {
-       site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
+       site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) -
                    (long)&site->key;
 }
 
@@ -144,6 +149,7 @@ void __static_call_update(struct static_call_key *key, void *tramp, void *func)
        };
 
        for (site_mod = &first; site_mod; site_mod = site_mod->next) {
+               bool init = system_state < SYSTEM_RUNNING;
                struct module *mod = site_mod->mod;
 
                if (!site_mod->sites) {
@@ -159,36 +165,38 @@ void __static_call_update(struct static_call_key *key, void *tramp, void *func)
 
                stop = __stop_static_call_sites;
 
-#ifdef CONFIG_MODULES
                if (mod) {
+#ifdef CONFIG_MODULES
                        stop = mod->static_call_sites +
                               mod->num_static_call_sites;
-               }
+                       init = mod->state == MODULE_STATE_COMING;
 #endif
+               }
 
                for (site = site_mod->sites;
                     site < stop && static_call_key(site) == key; site++) {
                        void *site_addr = static_call_addr(site);
 
-                       if (static_call_is_init(site)) {
-                               /*
-                                * Don't write to call sites which were in
-                                * initmem and have since been freed.
-                                */
-                               if (!mod && system_state >= SYSTEM_RUNNING)
-                                       continue;
-                               if (mod && !within_module_init((unsigned long)site_addr, mod))
-                                       continue;
-                       }
+                       if (!init && static_call_is_init(site))
+                               continue;
 
                        if (!kernel_text_address((unsigned long)site_addr)) {
-                               WARN_ONCE(1, "can't patch static call site at %pS",
+                               /*
+                                * This skips patching built-in __exit, which
+                                * is part of init_section_contains() but is
+                                * not part of kernel_text_address().
+                                *
+                                * Skipping built-in __exit is fine since it
+                                * will never be executed.
+                                */
+                               WARN_ONCE(!static_call_is_init(site),
+                                         "can't patch static call site at %pS",
                                          site_addr);
                                continue;
                        }
 
                        arch_static_call_transform(site_addr, NULL, func,
-                               static_call_is_tail(site));
+                                                  static_call_is_tail(site));
                }
        }
 
@@ -284,13 +292,15 @@ static int addr_conflict(struct static_call_site *site, void *start, void *end)
 
 static int __static_call_text_reserved(struct static_call_site *iter_start,
                                       struct static_call_site *iter_stop,
-                                      void *start, void *end)
+                                      void *start, void *end, bool init)
 {
        struct static_call_site *iter = iter_start;
 
        while (iter < iter_stop) {
-               if (addr_conflict(iter, start, end))
-                       return 1;
+               if (init || !static_call_is_init(iter)) {
+                       if (addr_conflict(iter, start, end))
+                               return 1;
+               }
                iter++;
        }
 
@@ -316,17 +326,67 @@ static int __static_call_mod_text_reserved(void *start, void *end)
 
        ret = __static_call_text_reserved(mod->static_call_sites,
                        mod->static_call_sites + mod->num_static_call_sites,
-                       start, end);
+                       start, end, mod->state == MODULE_STATE_COMING);
 
        module_put(mod);
 
        return ret;
 }
 
+static unsigned long tramp_key_lookup(unsigned long addr)
+{
+       struct static_call_tramp_key *start = __start_static_call_tramp_key;
+       struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
+       struct static_call_tramp_key *tramp_key;
+
+       for (tramp_key = start; tramp_key != stop; tramp_key++) {
+               unsigned long tramp;
+
+               tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
+               if (tramp == addr)
+                       return (long)tramp_key->key + (long)&tramp_key->key;
+       }
+
+       return 0;
+}
+
 static int static_call_add_module(struct module *mod)
 {
-       return __static_call_init(mod, mod->static_call_sites,
-                                 mod->static_call_sites + mod->num_static_call_sites);
+       struct static_call_site *start = mod->static_call_sites;
+       struct static_call_site *stop = start + mod->num_static_call_sites;
+       struct static_call_site *site;
+
+       for (site = start; site != stop; site++) {
+               unsigned long s_key = __static_call_key(site);
+               unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS;
+               unsigned long key;
+
+               /*
+                * Is the key is exported, 'addr' points to the key, which
+                * means modules are allowed to call static_call_update() on
+                * it.
+                *
+                * Otherwise, the key isn't exported, and 'addr' points to the
+                * trampoline so we need to lookup the key.
+                *
+                * We go through this dance to prevent crazy modules from
+                * abusing sensitive static calls.
+                */
+               if (!kernel_text_address(addr))
+                       continue;
+
+               key = tramp_key_lookup(addr);
+               if (!key) {
+                       pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
+                               static_call_addr(site));
+                       return -EINVAL;
+               }
+
+               key |= s_key & STATIC_CALL_SITE_FLAGS;
+               site->key = key - (long)&site->key;
+       }
+
+       return __static_call_init(mod, start, stop);
 }
 
 static void static_call_del_module(struct module *mod)
@@ -401,8 +461,9 @@ static inline int __static_call_mod_text_reserved(void *start, void *end)
 
 int static_call_text_reserved(void *start, void *end)
 {
+       bool init = system_state < SYSTEM_RUNNING;
        int ret = __static_call_text_reserved(__start_static_call_sites,
-                       __stop_static_call_sites, start, end);
+                       __stop_static_call_sites, start, end, init);
 
        if (ret)
                return ret;