Merge git://git.kernel.org/pub/scm/linux/kernel/git/kvalo/wireless-drivers.git
[platform/kernel/linux-starfive.git] / net / core / bpf_sk_storage.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook  */
3 #include <linux/rculist.h>
4 #include <linux/list.h>
5 #include <linux/hash.h>
6 #include <linux/types.h>
7 #include <linux/spinlock.h>
8 #include <linux/bpf.h>
9 #include <net/bpf_sk_storage.h>
10 #include <net/sock.h>
11 #include <uapi/linux/sock_diag.h>
12 #include <uapi/linux/btf.h>
13
14 #define SK_STORAGE_CREATE_FLAG_MASK                                     \
15         (BPF_F_NO_PREALLOC | BPF_F_CLONE)
16
17 struct bucket {
18         struct hlist_head list;
19         raw_spinlock_t lock;
20 };
21
22 /* Thp map is not the primary owner of a bpf_sk_storage_elem.
23  * Instead, the sk->sk_bpf_storage is.
24  *
25  * The map (bpf_sk_storage_map) is for two purposes
26  * 1. Define the size of the "sk local storage".  It is
27  *    the map's value_size.
28  *
29  * 2. Maintain a list to keep track of all elems such
30  *    that they can be cleaned up during the map destruction.
31  *
32  * When a bpf local storage is being looked up for a
33  * particular sk,  the "bpf_map" pointer is actually used
34  * as the "key" to search in the list of elem in
35  * sk->sk_bpf_storage.
36  *
37  * Hence, consider sk->sk_bpf_storage is the mini-map
38  * with the "bpf_map" pointer as the searching key.
39  */
40 struct bpf_sk_storage_map {
41         struct bpf_map map;
42         /* Lookup elem does not require accessing the map.
43          *
44          * Updating/Deleting requires a bucket lock to
45          * link/unlink the elem from the map.  Having
46          * multiple buckets to improve contention.
47          */
48         struct bucket *buckets;
49         u32 bucket_log;
50         u16 elem_size;
51         u16 cache_idx;
52 };
53
54 struct bpf_sk_storage_data {
55         /* smap is used as the searching key when looking up
56          * from sk->sk_bpf_storage.
57          *
58          * Put it in the same cacheline as the data to minimize
59          * the number of cachelines access during the cache hit case.
60          */
61         struct bpf_sk_storage_map __rcu *smap;
62         u8 data[] __aligned(8);
63 };
64
65 /* Linked to bpf_sk_storage and bpf_sk_storage_map */
66 struct bpf_sk_storage_elem {
67         struct hlist_node map_node;     /* Linked to bpf_sk_storage_map */
68         struct hlist_node snode;        /* Linked to bpf_sk_storage */
69         struct bpf_sk_storage __rcu *sk_storage;
70         struct rcu_head rcu;
71         /* 8 bytes hole */
72         /* The data is stored in aother cacheline to minimize
73          * the number of cachelines access during a cache hit.
74          */
75         struct bpf_sk_storage_data sdata ____cacheline_aligned;
76 };
77
78 #define SELEM(_SDATA) container_of((_SDATA), struct bpf_sk_storage_elem, sdata)
79 #define SDATA(_SELEM) (&(_SELEM)->sdata)
80 #define BPF_SK_STORAGE_CACHE_SIZE       16
81
82 static DEFINE_SPINLOCK(cache_idx_lock);
83 static u64 cache_idx_usage_counts[BPF_SK_STORAGE_CACHE_SIZE];
84
85 struct bpf_sk_storage {
86         struct bpf_sk_storage_data __rcu *cache[BPF_SK_STORAGE_CACHE_SIZE];
87         struct hlist_head list; /* List of bpf_sk_storage_elem */
88         struct sock *sk;        /* The sk that owns the the above "list" of
89                                  * bpf_sk_storage_elem.
90                                  */
91         struct rcu_head rcu;
92         raw_spinlock_t lock;    /* Protect adding/removing from the "list" */
93 };
94
95 static struct bucket *select_bucket(struct bpf_sk_storage_map *smap,
96                                     struct bpf_sk_storage_elem *selem)
97 {
98         return &smap->buckets[hash_ptr(selem, smap->bucket_log)];
99 }
100
101 static int omem_charge(struct sock *sk, unsigned int size)
102 {
103         /* same check as in sock_kmalloc() */
104         if (size <= sysctl_optmem_max &&
105             atomic_read(&sk->sk_omem_alloc) + size < sysctl_optmem_max) {
106                 atomic_add(size, &sk->sk_omem_alloc);
107                 return 0;
108         }
109
110         return -ENOMEM;
111 }
112
113 static bool selem_linked_to_sk(const struct bpf_sk_storage_elem *selem)
114 {
115         return !hlist_unhashed(&selem->snode);
116 }
117
118 static bool selem_linked_to_map(const struct bpf_sk_storage_elem *selem)
119 {
120         return !hlist_unhashed(&selem->map_node);
121 }
122
123 static struct bpf_sk_storage_elem *selem_alloc(struct bpf_sk_storage_map *smap,
124                                                struct sock *sk, void *value,
125                                                bool charge_omem)
126 {
127         struct bpf_sk_storage_elem *selem;
128
129         if (charge_omem && omem_charge(sk, smap->elem_size))
130                 return NULL;
131
132         selem = kzalloc(smap->elem_size, GFP_ATOMIC | __GFP_NOWARN);
133         if (selem) {
134                 if (value)
135                         memcpy(SDATA(selem)->data, value, smap->map.value_size);
136                 return selem;
137         }
138
139         if (charge_omem)
140                 atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
141
142         return NULL;
143 }
144
145 /* sk_storage->lock must be held and selem->sk_storage == sk_storage.
146  * The caller must ensure selem->smap is still valid to be
147  * dereferenced for its smap->elem_size and smap->cache_idx.
148  */
149 static bool __selem_unlink_sk(struct bpf_sk_storage *sk_storage,
150                               struct bpf_sk_storage_elem *selem,
151                               bool uncharge_omem)
152 {
153         struct bpf_sk_storage_map *smap;
154         bool free_sk_storage;
155         struct sock *sk;
156
157         smap = rcu_dereference(SDATA(selem)->smap);
158         sk = sk_storage->sk;
159
160         /* All uncharging on sk->sk_omem_alloc must be done first.
161          * sk may be freed once the last selem is unlinked from sk_storage.
162          */
163         if (uncharge_omem)
164                 atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
165
166         free_sk_storage = hlist_is_singular_node(&selem->snode,
167                                                  &sk_storage->list);
168         if (free_sk_storage) {
169                 atomic_sub(sizeof(struct bpf_sk_storage), &sk->sk_omem_alloc);
170                 sk_storage->sk = NULL;
171                 /* After this RCU_INIT, sk may be freed and cannot be used */
172                 RCU_INIT_POINTER(sk->sk_bpf_storage, NULL);
173
174                 /* sk_storage is not freed now.  sk_storage->lock is
175                  * still held and raw_spin_unlock_bh(&sk_storage->lock)
176                  * will be done by the caller.
177                  *
178                  * Although the unlock will be done under
179                  * rcu_read_lock(),  it is more intutivie to
180                  * read if kfree_rcu(sk_storage, rcu) is done
181                  * after the raw_spin_unlock_bh(&sk_storage->lock).
182                  *
183                  * Hence, a "bool free_sk_storage" is returned
184                  * to the caller which then calls the kfree_rcu()
185                  * after unlock.
186                  */
187         }
188         hlist_del_init_rcu(&selem->snode);
189         if (rcu_access_pointer(sk_storage->cache[smap->cache_idx]) ==
190             SDATA(selem))
191                 RCU_INIT_POINTER(sk_storage->cache[smap->cache_idx], NULL);
192
193         kfree_rcu(selem, rcu);
194
195         return free_sk_storage;
196 }
197
198 static void selem_unlink_sk(struct bpf_sk_storage_elem *selem)
199 {
200         struct bpf_sk_storage *sk_storage;
201         bool free_sk_storage = false;
202
203         if (unlikely(!selem_linked_to_sk(selem)))
204                 /* selem has already been unlinked from sk */
205                 return;
206
207         sk_storage = rcu_dereference(selem->sk_storage);
208         raw_spin_lock_bh(&sk_storage->lock);
209         if (likely(selem_linked_to_sk(selem)))
210                 free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
211         raw_spin_unlock_bh(&sk_storage->lock);
212
213         if (free_sk_storage)
214                 kfree_rcu(sk_storage, rcu);
215 }
216
217 static void __selem_link_sk(struct bpf_sk_storage *sk_storage,
218                             struct bpf_sk_storage_elem *selem)
219 {
220         RCU_INIT_POINTER(selem->sk_storage, sk_storage);
221         hlist_add_head(&selem->snode, &sk_storage->list);
222 }
223
224 static void selem_unlink_map(struct bpf_sk_storage_elem *selem)
225 {
226         struct bpf_sk_storage_map *smap;
227         struct bucket *b;
228
229         if (unlikely(!selem_linked_to_map(selem)))
230                 /* selem has already be unlinked from smap */
231                 return;
232
233         smap = rcu_dereference(SDATA(selem)->smap);
234         b = select_bucket(smap, selem);
235         raw_spin_lock_bh(&b->lock);
236         if (likely(selem_linked_to_map(selem)))
237                 hlist_del_init_rcu(&selem->map_node);
238         raw_spin_unlock_bh(&b->lock);
239 }
240
241 static void selem_link_map(struct bpf_sk_storage_map *smap,
242                            struct bpf_sk_storage_elem *selem)
243 {
244         struct bucket *b = select_bucket(smap, selem);
245
246         raw_spin_lock_bh(&b->lock);
247         RCU_INIT_POINTER(SDATA(selem)->smap, smap);
248         hlist_add_head_rcu(&selem->map_node, &b->list);
249         raw_spin_unlock_bh(&b->lock);
250 }
251
252 static void selem_unlink(struct bpf_sk_storage_elem *selem)
253 {
254         /* Always unlink from map before unlinking from sk_storage
255          * because selem will be freed after successfully unlinked from
256          * the sk_storage.
257          */
258         selem_unlink_map(selem);
259         selem_unlink_sk(selem);
260 }
261
262 static struct bpf_sk_storage_data *
263 __sk_storage_lookup(struct bpf_sk_storage *sk_storage,
264                     struct bpf_sk_storage_map *smap,
265                     bool cacheit_lockit)
266 {
267         struct bpf_sk_storage_data *sdata;
268         struct bpf_sk_storage_elem *selem;
269
270         /* Fast path (cache hit) */
271         sdata = rcu_dereference(sk_storage->cache[smap->cache_idx]);
272         if (sdata && rcu_access_pointer(sdata->smap) == smap)
273                 return sdata;
274
275         /* Slow path (cache miss) */
276         hlist_for_each_entry_rcu(selem, &sk_storage->list, snode)
277                 if (rcu_access_pointer(SDATA(selem)->smap) == smap)
278                         break;
279
280         if (!selem)
281                 return NULL;
282
283         sdata = SDATA(selem);
284         if (cacheit_lockit) {
285                 /* spinlock is needed to avoid racing with the
286                  * parallel delete.  Otherwise, publishing an already
287                  * deleted sdata to the cache will become a use-after-free
288                  * problem in the next __sk_storage_lookup().
289                  */
290                 raw_spin_lock_bh(&sk_storage->lock);
291                 if (selem_linked_to_sk(selem))
292                         rcu_assign_pointer(sk_storage->cache[smap->cache_idx],
293                                            sdata);
294                 raw_spin_unlock_bh(&sk_storage->lock);
295         }
296
297         return sdata;
298 }
299
300 static struct bpf_sk_storage_data *
301 sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
302 {
303         struct bpf_sk_storage *sk_storage;
304         struct bpf_sk_storage_map *smap;
305
306         sk_storage = rcu_dereference(sk->sk_bpf_storage);
307         if (!sk_storage)
308                 return NULL;
309
310         smap = (struct bpf_sk_storage_map *)map;
311         return __sk_storage_lookup(sk_storage, smap, cacheit_lockit);
312 }
313
314 static int check_flags(const struct bpf_sk_storage_data *old_sdata,
315                        u64 map_flags)
316 {
317         if (old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_NOEXIST)
318                 /* elem already exists */
319                 return -EEXIST;
320
321         if (!old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_EXIST)
322                 /* elem doesn't exist, cannot update it */
323                 return -ENOENT;
324
325         return 0;
326 }
327
328 static int sk_storage_alloc(struct sock *sk,
329                             struct bpf_sk_storage_map *smap,
330                             struct bpf_sk_storage_elem *first_selem)
331 {
332         struct bpf_sk_storage *prev_sk_storage, *sk_storage;
333         int err;
334
335         err = omem_charge(sk, sizeof(*sk_storage));
336         if (err)
337                 return err;
338
339         sk_storage = kzalloc(sizeof(*sk_storage), GFP_ATOMIC | __GFP_NOWARN);
340         if (!sk_storage) {
341                 err = -ENOMEM;
342                 goto uncharge;
343         }
344         INIT_HLIST_HEAD(&sk_storage->list);
345         raw_spin_lock_init(&sk_storage->lock);
346         sk_storage->sk = sk;
347
348         __selem_link_sk(sk_storage, first_selem);
349         selem_link_map(smap, first_selem);
350         /* Publish sk_storage to sk.  sk->sk_lock cannot be acquired.
351          * Hence, atomic ops is used to set sk->sk_bpf_storage
352          * from NULL to the newly allocated sk_storage ptr.
353          *
354          * From now on, the sk->sk_bpf_storage pointer is protected
355          * by the sk_storage->lock.  Hence,  when freeing
356          * the sk->sk_bpf_storage, the sk_storage->lock must
357          * be held before setting sk->sk_bpf_storage to NULL.
358          */
359         prev_sk_storage = cmpxchg((struct bpf_sk_storage **)&sk->sk_bpf_storage,
360                                   NULL, sk_storage);
361         if (unlikely(prev_sk_storage)) {
362                 selem_unlink_map(first_selem);
363                 err = -EAGAIN;
364                 goto uncharge;
365
366                 /* Note that even first_selem was linked to smap's
367                  * bucket->list, first_selem can be freed immediately
368                  * (instead of kfree_rcu) because
369                  * bpf_sk_storage_map_free() does a
370                  * synchronize_rcu() before walking the bucket->list.
371                  * Hence, no one is accessing selem from the
372                  * bucket->list under rcu_read_lock().
373                  */
374         }
375
376         return 0;
377
378 uncharge:
379         kfree(sk_storage);
380         atomic_sub(sizeof(*sk_storage), &sk->sk_omem_alloc);
381         return err;
382 }
383
384 /* sk cannot be going away because it is linking new elem
385  * to sk->sk_bpf_storage. (i.e. sk->sk_refcnt cannot be 0).
386  * Otherwise, it will become a leak (and other memory issues
387  * during map destruction).
388  */
389 static struct bpf_sk_storage_data *sk_storage_update(struct sock *sk,
390                                                      struct bpf_map *map,
391                                                      void *value,
392                                                      u64 map_flags)
393 {
394         struct bpf_sk_storage_data *old_sdata = NULL;
395         struct bpf_sk_storage_elem *selem;
396         struct bpf_sk_storage *sk_storage;
397         struct bpf_sk_storage_map *smap;
398         int err;
399
400         /* BPF_EXIST and BPF_NOEXIST cannot be both set */
401         if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
402             /* BPF_F_LOCK can only be used in a value with spin_lock */
403             unlikely((map_flags & BPF_F_LOCK) && !map_value_has_spin_lock(map)))
404                 return ERR_PTR(-EINVAL);
405
406         smap = (struct bpf_sk_storage_map *)map;
407         sk_storage = rcu_dereference(sk->sk_bpf_storage);
408         if (!sk_storage || hlist_empty(&sk_storage->list)) {
409                 /* Very first elem for this sk */
410                 err = check_flags(NULL, map_flags);
411                 if (err)
412                         return ERR_PTR(err);
413
414                 selem = selem_alloc(smap, sk, value, true);
415                 if (!selem)
416                         return ERR_PTR(-ENOMEM);
417
418                 err = sk_storage_alloc(sk, smap, selem);
419                 if (err) {
420                         kfree(selem);
421                         atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
422                         return ERR_PTR(err);
423                 }
424
425                 return SDATA(selem);
426         }
427
428         if ((map_flags & BPF_F_LOCK) && !(map_flags & BPF_NOEXIST)) {
429                 /* Hoping to find an old_sdata to do inline update
430                  * such that it can avoid taking the sk_storage->lock
431                  * and changing the lists.
432                  */
433                 old_sdata = __sk_storage_lookup(sk_storage, smap, false);
434                 err = check_flags(old_sdata, map_flags);
435                 if (err)
436                         return ERR_PTR(err);
437                 if (old_sdata && selem_linked_to_sk(SELEM(old_sdata))) {
438                         copy_map_value_locked(map, old_sdata->data,
439                                               value, false);
440                         return old_sdata;
441                 }
442         }
443
444         raw_spin_lock_bh(&sk_storage->lock);
445
446         /* Recheck sk_storage->list under sk_storage->lock */
447         if (unlikely(hlist_empty(&sk_storage->list))) {
448                 /* A parallel del is happening and sk_storage is going
449                  * away.  It has just been checked before, so very
450                  * unlikely.  Return instead of retry to keep things
451                  * simple.
452                  */
453                 err = -EAGAIN;
454                 goto unlock_err;
455         }
456
457         old_sdata = __sk_storage_lookup(sk_storage, smap, false);
458         err = check_flags(old_sdata, map_flags);
459         if (err)
460                 goto unlock_err;
461
462         if (old_sdata && (map_flags & BPF_F_LOCK)) {
463                 copy_map_value_locked(map, old_sdata->data, value, false);
464                 selem = SELEM(old_sdata);
465                 goto unlock;
466         }
467
468         /* sk_storage->lock is held.  Hence, we are sure
469          * we can unlink and uncharge the old_sdata successfully
470          * later.  Hence, instead of charging the new selem now
471          * and then uncharge the old selem later (which may cause
472          * a potential but unnecessary charge failure),  avoid taking
473          * a charge at all here (the "!old_sdata" check) and the
474          * old_sdata will not be uncharged later during __selem_unlink_sk().
475          */
476         selem = selem_alloc(smap, sk, value, !old_sdata);
477         if (!selem) {
478                 err = -ENOMEM;
479                 goto unlock_err;
480         }
481
482         /* First, link the new selem to the map */
483         selem_link_map(smap, selem);
484
485         /* Second, link (and publish) the new selem to sk_storage */
486         __selem_link_sk(sk_storage, selem);
487
488         /* Third, remove old selem, SELEM(old_sdata) */
489         if (old_sdata) {
490                 selem_unlink_map(SELEM(old_sdata));
491                 __selem_unlink_sk(sk_storage, SELEM(old_sdata), false);
492         }
493
494 unlock:
495         raw_spin_unlock_bh(&sk_storage->lock);
496         return SDATA(selem);
497
498 unlock_err:
499         raw_spin_unlock_bh(&sk_storage->lock);
500         return ERR_PTR(err);
501 }
502
503 static int sk_storage_delete(struct sock *sk, struct bpf_map *map)
504 {
505         struct bpf_sk_storage_data *sdata;
506
507         sdata = sk_storage_lookup(sk, map, false);
508         if (!sdata)
509                 return -ENOENT;
510
511         selem_unlink(SELEM(sdata));
512
513         return 0;
514 }
515
516 static u16 cache_idx_get(void)
517 {
518         u64 min_usage = U64_MAX;
519         u16 i, res = 0;
520
521         spin_lock(&cache_idx_lock);
522
523         for (i = 0; i < BPF_SK_STORAGE_CACHE_SIZE; i++) {
524                 if (cache_idx_usage_counts[i] < min_usage) {
525                         min_usage = cache_idx_usage_counts[i];
526                         res = i;
527
528                         /* Found a free cache_idx */
529                         if (!min_usage)
530                                 break;
531                 }
532         }
533         cache_idx_usage_counts[res]++;
534
535         spin_unlock(&cache_idx_lock);
536
537         return res;
538 }
539
540 static void cache_idx_free(u16 idx)
541 {
542         spin_lock(&cache_idx_lock);
543         cache_idx_usage_counts[idx]--;
544         spin_unlock(&cache_idx_lock);
545 }
546
547 /* Called by __sk_destruct() & bpf_sk_storage_clone() */
548 void bpf_sk_storage_free(struct sock *sk)
549 {
550         struct bpf_sk_storage_elem *selem;
551         struct bpf_sk_storage *sk_storage;
552         bool free_sk_storage = false;
553         struct hlist_node *n;
554
555         rcu_read_lock();
556         sk_storage = rcu_dereference(sk->sk_bpf_storage);
557         if (!sk_storage) {
558                 rcu_read_unlock();
559                 return;
560         }
561
562         /* Netiher the bpf_prog nor the bpf-map's syscall
563          * could be modifying the sk_storage->list now.
564          * Thus, no elem can be added-to or deleted-from the
565          * sk_storage->list by the bpf_prog or by the bpf-map's syscall.
566          *
567          * It is racing with bpf_sk_storage_map_free() alone
568          * when unlinking elem from the sk_storage->list and
569          * the map's bucket->list.
570          */
571         raw_spin_lock_bh(&sk_storage->lock);
572         hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) {
573                 /* Always unlink from map before unlinking from
574                  * sk_storage.
575                  */
576                 selem_unlink_map(selem);
577                 free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
578         }
579         raw_spin_unlock_bh(&sk_storage->lock);
580         rcu_read_unlock();
581
582         if (free_sk_storage)
583                 kfree_rcu(sk_storage, rcu);
584 }
585
586 static void bpf_sk_storage_map_free(struct bpf_map *map)
587 {
588         struct bpf_sk_storage_elem *selem;
589         struct bpf_sk_storage_map *smap;
590         struct bucket *b;
591         unsigned int i;
592
593         smap = (struct bpf_sk_storage_map *)map;
594
595         cache_idx_free(smap->cache_idx);
596
597         /* Note that this map might be concurrently cloned from
598          * bpf_sk_storage_clone. Wait for any existing bpf_sk_storage_clone
599          * RCU read section to finish before proceeding. New RCU
600          * read sections should be prevented via bpf_map_inc_not_zero.
601          */
602         synchronize_rcu();
603
604         /* bpf prog and the userspace can no longer access this map
605          * now.  No new selem (of this map) can be added
606          * to the sk->sk_bpf_storage or to the map bucket's list.
607          *
608          * The elem of this map can be cleaned up here
609          * or
610          * by bpf_sk_storage_free() during __sk_destruct().
611          */
612         for (i = 0; i < (1U << smap->bucket_log); i++) {
613                 b = &smap->buckets[i];
614
615                 rcu_read_lock();
616                 /* No one is adding to b->list now */
617                 while ((selem = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(&b->list)),
618                                                  struct bpf_sk_storage_elem,
619                                                  map_node))) {
620                         selem_unlink(selem);
621                         cond_resched_rcu();
622                 }
623                 rcu_read_unlock();
624         }
625
626         /* bpf_sk_storage_free() may still need to access the map.
627          * e.g. bpf_sk_storage_free() has unlinked selem from the map
628          * which then made the above while((selem = ...)) loop
629          * exited immediately.
630          *
631          * However, the bpf_sk_storage_free() still needs to access
632          * the smap->elem_size to do the uncharging in
633          * __selem_unlink_sk().
634          *
635          * Hence, wait another rcu grace period for the
636          * bpf_sk_storage_free() to finish.
637          */
638         synchronize_rcu();
639
640         kvfree(smap->buckets);
641         kfree(map);
642 }
643
644 /* U16_MAX is much more than enough for sk local storage
645  * considering a tcp_sock is ~2k.
646  */
647 #define MAX_VALUE_SIZE                                                  \
648         min_t(u32,                                                      \
649               (KMALLOC_MAX_SIZE - MAX_BPF_STACK - sizeof(struct bpf_sk_storage_elem)), \
650               (U16_MAX - sizeof(struct bpf_sk_storage_elem)))
651
652 static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr)
653 {
654         if (attr->map_flags & ~SK_STORAGE_CREATE_FLAG_MASK ||
655             !(attr->map_flags & BPF_F_NO_PREALLOC) ||
656             attr->max_entries ||
657             attr->key_size != sizeof(int) || !attr->value_size ||
658             /* Enforce BTF for userspace sk dumping */
659             !attr->btf_key_type_id || !attr->btf_value_type_id)
660                 return -EINVAL;
661
662         if (!bpf_capable())
663                 return -EPERM;
664
665         if (attr->value_size > MAX_VALUE_SIZE)
666                 return -E2BIG;
667
668         return 0;
669 }
670
671 static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
672 {
673         struct bpf_sk_storage_map *smap;
674         unsigned int i;
675         u32 nbuckets;
676         u64 cost;
677         int ret;
678
679         smap = kzalloc(sizeof(*smap), GFP_USER | __GFP_NOWARN);
680         if (!smap)
681                 return ERR_PTR(-ENOMEM);
682         bpf_map_init_from_attr(&smap->map, attr);
683
684         nbuckets = roundup_pow_of_two(num_possible_cpus());
685         /* Use at least 2 buckets, select_bucket() is undefined behavior with 1 bucket */
686         nbuckets = max_t(u32, 2, nbuckets);
687         smap->bucket_log = ilog2(nbuckets);
688         cost = sizeof(*smap->buckets) * nbuckets + sizeof(*smap);
689
690         ret = bpf_map_charge_init(&smap->map.memory, cost);
691         if (ret < 0) {
692                 kfree(smap);
693                 return ERR_PTR(ret);
694         }
695
696         smap->buckets = kvcalloc(sizeof(*smap->buckets), nbuckets,
697                                  GFP_USER | __GFP_NOWARN);
698         if (!smap->buckets) {
699                 bpf_map_charge_finish(&smap->map.memory);
700                 kfree(smap);
701                 return ERR_PTR(-ENOMEM);
702         }
703
704         for (i = 0; i < nbuckets; i++) {
705                 INIT_HLIST_HEAD(&smap->buckets[i].list);
706                 raw_spin_lock_init(&smap->buckets[i].lock);
707         }
708
709         smap->elem_size = sizeof(struct bpf_sk_storage_elem) + attr->value_size;
710         smap->cache_idx = cache_idx_get();
711
712         return &smap->map;
713 }
714
715 static int notsupp_get_next_key(struct bpf_map *map, void *key,
716                                 void *next_key)
717 {
718         return -ENOTSUPP;
719 }
720
721 static int bpf_sk_storage_map_check_btf(const struct bpf_map *map,
722                                         const struct btf *btf,
723                                         const struct btf_type *key_type,
724                                         const struct btf_type *value_type)
725 {
726         u32 int_data;
727
728         if (BTF_INFO_KIND(key_type->info) != BTF_KIND_INT)
729                 return -EINVAL;
730
731         int_data = *(u32 *)(key_type + 1);
732         if (BTF_INT_BITS(int_data) != 32 || BTF_INT_OFFSET(int_data))
733                 return -EINVAL;
734
735         return 0;
736 }
737
738 static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key)
739 {
740         struct bpf_sk_storage_data *sdata;
741         struct socket *sock;
742         int fd, err;
743
744         fd = *(int *)key;
745         sock = sockfd_lookup(fd, &err);
746         if (sock) {
747                 sdata = sk_storage_lookup(sock->sk, map, true);
748                 sockfd_put(sock);
749                 return sdata ? sdata->data : NULL;
750         }
751
752         return ERR_PTR(err);
753 }
754
755 static int bpf_fd_sk_storage_update_elem(struct bpf_map *map, void *key,
756                                          void *value, u64 map_flags)
757 {
758         struct bpf_sk_storage_data *sdata;
759         struct socket *sock;
760         int fd, err;
761
762         fd = *(int *)key;
763         sock = sockfd_lookup(fd, &err);
764         if (sock) {
765                 sdata = sk_storage_update(sock->sk, map, value, map_flags);
766                 sockfd_put(sock);
767                 return PTR_ERR_OR_ZERO(sdata);
768         }
769
770         return err;
771 }
772
773 static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
774 {
775         struct socket *sock;
776         int fd, err;
777
778         fd = *(int *)key;
779         sock = sockfd_lookup(fd, &err);
780         if (sock) {
781                 err = sk_storage_delete(sock->sk, map);
782                 sockfd_put(sock);
783                 return err;
784         }
785
786         return err;
787 }
788
789 static struct bpf_sk_storage_elem *
790 bpf_sk_storage_clone_elem(struct sock *newsk,
791                           struct bpf_sk_storage_map *smap,
792                           struct bpf_sk_storage_elem *selem)
793 {
794         struct bpf_sk_storage_elem *copy_selem;
795
796         copy_selem = selem_alloc(smap, newsk, NULL, true);
797         if (!copy_selem)
798                 return NULL;
799
800         if (map_value_has_spin_lock(&smap->map))
801                 copy_map_value_locked(&smap->map, SDATA(copy_selem)->data,
802                                       SDATA(selem)->data, true);
803         else
804                 copy_map_value(&smap->map, SDATA(copy_selem)->data,
805                                SDATA(selem)->data);
806
807         return copy_selem;
808 }
809
810 int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk)
811 {
812         struct bpf_sk_storage *new_sk_storage = NULL;
813         struct bpf_sk_storage *sk_storage;
814         struct bpf_sk_storage_elem *selem;
815         int ret = 0;
816
817         RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
818
819         rcu_read_lock();
820         sk_storage = rcu_dereference(sk->sk_bpf_storage);
821
822         if (!sk_storage || hlist_empty(&sk_storage->list))
823                 goto out;
824
825         hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
826                 struct bpf_sk_storage_elem *copy_selem;
827                 struct bpf_sk_storage_map *smap;
828                 struct bpf_map *map;
829
830                 smap = rcu_dereference(SDATA(selem)->smap);
831                 if (!(smap->map.map_flags & BPF_F_CLONE))
832                         continue;
833
834                 /* Note that for lockless listeners adding new element
835                  * here can race with cleanup in bpf_sk_storage_map_free.
836                  * Try to grab map refcnt to make sure that it's still
837                  * alive and prevent concurrent removal.
838                  */
839                 map = bpf_map_inc_not_zero(&smap->map);
840                 if (IS_ERR(map))
841                         continue;
842
843                 copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem);
844                 if (!copy_selem) {
845                         ret = -ENOMEM;
846                         bpf_map_put(map);
847                         goto out;
848                 }
849
850                 if (new_sk_storage) {
851                         selem_link_map(smap, copy_selem);
852                         __selem_link_sk(new_sk_storage, copy_selem);
853                 } else {
854                         ret = sk_storage_alloc(newsk, smap, copy_selem);
855                         if (ret) {
856                                 kfree(copy_selem);
857                                 atomic_sub(smap->elem_size,
858                                            &newsk->sk_omem_alloc);
859                                 bpf_map_put(map);
860                                 goto out;
861                         }
862
863                         new_sk_storage = rcu_dereference(copy_selem->sk_storage);
864                 }
865                 bpf_map_put(map);
866         }
867
868 out:
869         rcu_read_unlock();
870
871         /* In case of an error, don't free anything explicitly here, the
872          * caller is responsible to call bpf_sk_storage_free.
873          */
874
875         return ret;
876 }
877
878 BPF_CALL_4(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
879            void *, value, u64, flags)
880 {
881         struct bpf_sk_storage_data *sdata;
882
883         if (flags > BPF_SK_STORAGE_GET_F_CREATE)
884                 return (unsigned long)NULL;
885
886         sdata = sk_storage_lookup(sk, map, true);
887         if (sdata)
888                 return (unsigned long)sdata->data;
889
890         if (flags == BPF_SK_STORAGE_GET_F_CREATE &&
891             /* Cannot add new elem to a going away sk.
892              * Otherwise, the new elem may become a leak
893              * (and also other memory issues during map
894              *  destruction).
895              */
896             refcount_inc_not_zero(&sk->sk_refcnt)) {
897                 sdata = sk_storage_update(sk, map, value, BPF_NOEXIST);
898                 /* sk must be a fullsock (guaranteed by verifier),
899                  * so sock_gen_put() is unnecessary.
900                  */
901                 sock_put(sk);
902                 return IS_ERR(sdata) ?
903                         (unsigned long)NULL : (unsigned long)sdata->data;
904         }
905
906         return (unsigned long)NULL;
907 }
908
909 BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
910 {
911         if (refcount_inc_not_zero(&sk->sk_refcnt)) {
912                 int err;
913
914                 err = sk_storage_delete(sk, map);
915                 sock_put(sk);
916                 return err;
917         }
918
919         return -ENOENT;
920 }
921
922 static int sk_storage_map_btf_id;
923 const struct bpf_map_ops sk_storage_map_ops = {
924         .map_alloc_check = bpf_sk_storage_map_alloc_check,
925         .map_alloc = bpf_sk_storage_map_alloc,
926         .map_free = bpf_sk_storage_map_free,
927         .map_get_next_key = notsupp_get_next_key,
928         .map_lookup_elem = bpf_fd_sk_storage_lookup_elem,
929         .map_update_elem = bpf_fd_sk_storage_update_elem,
930         .map_delete_elem = bpf_fd_sk_storage_delete_elem,
931         .map_check_btf = bpf_sk_storage_map_check_btf,
932         .map_btf_name = "bpf_sk_storage_map",
933         .map_btf_id = &sk_storage_map_btf_id,
934 };
935
936 const struct bpf_func_proto bpf_sk_storage_get_proto = {
937         .func           = bpf_sk_storage_get,
938         .gpl_only       = false,
939         .ret_type       = RET_PTR_TO_MAP_VALUE_OR_NULL,
940         .arg1_type      = ARG_CONST_MAP_PTR,
941         .arg2_type      = ARG_PTR_TO_SOCKET,
942         .arg3_type      = ARG_PTR_TO_MAP_VALUE_OR_NULL,
943         .arg4_type      = ARG_ANYTHING,
944 };
945
946 const struct bpf_func_proto bpf_sk_storage_delete_proto = {
947         .func           = bpf_sk_storage_delete,
948         .gpl_only       = false,
949         .ret_type       = RET_INTEGER,
950         .arg1_type      = ARG_CONST_MAP_PTR,
951         .arg2_type      = ARG_PTR_TO_SOCKET,
952 };
953
954 struct bpf_sk_storage_diag {
955         u32 nr_maps;
956         struct bpf_map *maps[];
957 };
958
959 /* The reply will be like:
960  * INET_DIAG_BPF_SK_STORAGES (nla_nest)
961  *      SK_DIAG_BPF_STORAGE (nla_nest)
962  *              SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
963  *              SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
964  *      SK_DIAG_BPF_STORAGE (nla_nest)
965  *              SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
966  *              SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
967  *      ....
968  */
969 static int nla_value_size(u32 value_size)
970 {
971         /* SK_DIAG_BPF_STORAGE (nla_nest)
972          *      SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
973          *      SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
974          */
975         return nla_total_size(0) + nla_total_size(sizeof(u32)) +
976                 nla_total_size_64bit(value_size);
977 }
978
979 void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
980 {
981         u32 i;
982
983         if (!diag)
984                 return;
985
986         for (i = 0; i < diag->nr_maps; i++)
987                 bpf_map_put(diag->maps[i]);
988
989         kfree(diag);
990 }
991 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
992
993 static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
994                            const struct bpf_map *map)
995 {
996         u32 i;
997
998         for (i = 0; i < diag->nr_maps; i++) {
999                 if (diag->maps[i] == map)
1000                         return true;
1001         }
1002
1003         return false;
1004 }
1005
1006 struct bpf_sk_storage_diag *
1007 bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
1008 {
1009         struct bpf_sk_storage_diag *diag;
1010         struct nlattr *nla;
1011         u32 nr_maps = 0;
1012         int rem, err;
1013
1014         /* bpf_sk_storage_map is currently limited to CAP_SYS_ADMIN as
1015          * the map_alloc_check() side also does.
1016          */
1017         if (!bpf_capable())
1018                 return ERR_PTR(-EPERM);
1019
1020         nla_for_each_nested(nla, nla_stgs, rem) {
1021                 if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
1022                         nr_maps++;
1023         }
1024
1025         diag = kzalloc(sizeof(*diag) + sizeof(diag->maps[0]) * nr_maps,
1026                        GFP_KERNEL);
1027         if (!diag)
1028                 return ERR_PTR(-ENOMEM);
1029
1030         nla_for_each_nested(nla, nla_stgs, rem) {
1031                 struct bpf_map *map;
1032                 int map_fd;
1033
1034                 if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
1035                         continue;
1036
1037                 map_fd = nla_get_u32(nla);
1038                 map = bpf_map_get(map_fd);
1039                 if (IS_ERR(map)) {
1040                         err = PTR_ERR(map);
1041                         goto err_free;
1042                 }
1043                 if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
1044                         bpf_map_put(map);
1045                         err = -EINVAL;
1046                         goto err_free;
1047                 }
1048                 if (diag_check_dup(diag, map)) {
1049                         bpf_map_put(map);
1050                         err = -EEXIST;
1051                         goto err_free;
1052                 }
1053                 diag->maps[diag->nr_maps++] = map;
1054         }
1055
1056         return diag;
1057
1058 err_free:
1059         bpf_sk_storage_diag_free(diag);
1060         return ERR_PTR(err);
1061 }
1062 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
1063
1064 static int diag_get(struct bpf_sk_storage_data *sdata, struct sk_buff *skb)
1065 {
1066         struct nlattr *nla_stg, *nla_value;
1067         struct bpf_sk_storage_map *smap;
1068
1069         /* It cannot exceed max nlattr's payload */
1070         BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < MAX_VALUE_SIZE);
1071
1072         nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
1073         if (!nla_stg)
1074                 return -EMSGSIZE;
1075
1076         smap = rcu_dereference(sdata->smap);
1077         if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
1078                 goto errout;
1079
1080         nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
1081                                       smap->map.value_size,
1082                                       SK_DIAG_BPF_STORAGE_PAD);
1083         if (!nla_value)
1084                 goto errout;
1085
1086         if (map_value_has_spin_lock(&smap->map))
1087                 copy_map_value_locked(&smap->map, nla_data(nla_value),
1088                                       sdata->data, true);
1089         else
1090                 copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
1091
1092         nla_nest_end(skb, nla_stg);
1093         return 0;
1094
1095 errout:
1096         nla_nest_cancel(skb, nla_stg);
1097         return -EMSGSIZE;
1098 }
1099
1100 static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
1101                                        int stg_array_type,
1102                                        unsigned int *res_diag_size)
1103 {
1104         /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1105         unsigned int diag_size = nla_total_size(0);
1106         struct bpf_sk_storage *sk_storage;
1107         struct bpf_sk_storage_elem *selem;
1108         struct bpf_sk_storage_map *smap;
1109         struct nlattr *nla_stgs;
1110         unsigned int saved_len;
1111         int err = 0;
1112
1113         rcu_read_lock();
1114
1115         sk_storage = rcu_dereference(sk->sk_bpf_storage);
1116         if (!sk_storage || hlist_empty(&sk_storage->list)) {
1117                 rcu_read_unlock();
1118                 return 0;
1119         }
1120
1121         nla_stgs = nla_nest_start(skb, stg_array_type);
1122         if (!nla_stgs)
1123                 /* Continue to learn diag_size */
1124                 err = -EMSGSIZE;
1125
1126         saved_len = skb->len;
1127         hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
1128                 smap = rcu_dereference(SDATA(selem)->smap);
1129                 diag_size += nla_value_size(smap->map.value_size);
1130
1131                 if (nla_stgs && diag_get(SDATA(selem), skb))
1132                         /* Continue to learn diag_size */
1133                         err = -EMSGSIZE;
1134         }
1135
1136         rcu_read_unlock();
1137
1138         if (nla_stgs) {
1139                 if (saved_len == skb->len)
1140                         nla_nest_cancel(skb, nla_stgs);
1141                 else
1142                         nla_nest_end(skb, nla_stgs);
1143         }
1144
1145         if (diag_size == nla_total_size(0)) {
1146                 *res_diag_size = 0;
1147                 return 0;
1148         }
1149
1150         *res_diag_size = diag_size;
1151         return err;
1152 }
1153
1154 int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
1155                             struct sock *sk, struct sk_buff *skb,
1156                             int stg_array_type,
1157                             unsigned int *res_diag_size)
1158 {
1159         /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1160         unsigned int diag_size = nla_total_size(0);
1161         struct bpf_sk_storage *sk_storage;
1162         struct bpf_sk_storage_data *sdata;
1163         struct nlattr *nla_stgs;
1164         unsigned int saved_len;
1165         int err = 0;
1166         u32 i;
1167
1168         *res_diag_size = 0;
1169
1170         /* No map has been specified.  Dump all. */
1171         if (!diag->nr_maps)
1172                 return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
1173                                                    res_diag_size);
1174
1175         rcu_read_lock();
1176         sk_storage = rcu_dereference(sk->sk_bpf_storage);
1177         if (!sk_storage || hlist_empty(&sk_storage->list)) {
1178                 rcu_read_unlock();
1179                 return 0;
1180         }
1181
1182         nla_stgs = nla_nest_start(skb, stg_array_type);
1183         if (!nla_stgs)
1184                 /* Continue to learn diag_size */
1185                 err = -EMSGSIZE;
1186
1187         saved_len = skb->len;
1188         for (i = 0; i < diag->nr_maps; i++) {
1189                 sdata = __sk_storage_lookup(sk_storage,
1190                                 (struct bpf_sk_storage_map *)diag->maps[i],
1191                                 false);
1192
1193                 if (!sdata)
1194                         continue;
1195
1196                 diag_size += nla_value_size(diag->maps[i]->value_size);
1197
1198                 if (nla_stgs && diag_get(sdata, skb))
1199                         /* Continue to learn diag_size */
1200                         err = -EMSGSIZE;
1201         }
1202         rcu_read_unlock();
1203
1204         if (nla_stgs) {
1205                 if (saved_len == skb->len)
1206                         nla_nest_cancel(skb, nla_stgs);
1207                 else
1208                         nla_nest_end(skb, nla_stgs);
1209         }
1210
1211         if (diag_size == nla_total_size(0)) {
1212                 *res_diag_size = 0;
1213                 return 0;
1214         }
1215
1216         *res_diag_size = diag_size;
1217         return err;
1218 }
1219 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);