Merge tag '6.6-rc4-ksmbd-server-fixes' of git://git.samba.org/ksmbd
[platform/kernel/linux-starfive.git] / net / core / sock_map.c
index 8f07fea..4292c2e 100644 (file)
@@ -18,7 +18,7 @@ struct bpf_stab {
        struct bpf_map map;
        struct sock **sks;
        struct sk_psock_progs progs;
-       raw_spinlock_t lock;
+       spinlock_t lock;
 };
 
 #define SOCK_CREATE_FLAG_MASK                          \
@@ -44,7 +44,7 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
                return ERR_PTR(-ENOMEM);
 
        bpf_map_init_from_attr(&stab->map, attr);
-       raw_spin_lock_init(&stab->lock);
+       spin_lock_init(&stab->lock);
 
        stab->sks = bpf_map_area_alloc((u64) stab->map.max_entries *
                                       sizeof(struct sock *),
@@ -411,7 +411,7 @@ static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
        struct sock *sk;
        int err = 0;
 
-       raw_spin_lock_bh(&stab->lock);
+       spin_lock_bh(&stab->lock);
        sk = *psk;
        if (!sk_test || sk_test == sk)
                sk = xchg(psk, NULL);
@@ -421,7 +421,7 @@ static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
        else
                err = -EINVAL;
 
-       raw_spin_unlock_bh(&stab->lock);
+       spin_unlock_bh(&stab->lock);
        return err;
 }
 
@@ -487,7 +487,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
        psock = sk_psock(sk);
        WARN_ON_ONCE(!psock);
 
-       raw_spin_lock_bh(&stab->lock);
+       spin_lock_bh(&stab->lock);
        osk = stab->sks[idx];
        if (osk && flags == BPF_NOEXIST) {
                ret = -EEXIST;
@@ -501,10 +501,10 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
        stab->sks[idx] = sk;
        if (osk)
                sock_map_unref(osk, &stab->sks[idx]);
-       raw_spin_unlock_bh(&stab->lock);
+       spin_unlock_bh(&stab->lock);
        return 0;
 out_unlock:
-       raw_spin_unlock_bh(&stab->lock);
+       spin_unlock_bh(&stab->lock);
        if (psock)
                sk_psock_put(sk, psock);
 out_free:
@@ -668,6 +668,8 @@ BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
        sk = __sock_map_lookup_elem(map, key);
        if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
                return SK_DROP;
+       if (!(flags & BPF_F_INGRESS) && !sk_is_tcp(sk))
+               return SK_DROP;
 
        msg->flags = flags;
        msg->sk_redir = sk;
@@ -835,7 +837,7 @@ struct bpf_shtab_elem {
 
 struct bpf_shtab_bucket {
        struct hlist_head head;
-       raw_spinlock_t lock;
+       spinlock_t lock;
 };
 
 struct bpf_shtab {
@@ -910,7 +912,7 @@ static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
         * is okay since it's going away only after RCU grace period.
         * However, we need to check whether it's still present.
         */
-       raw_spin_lock_bh(&bucket->lock);
+       spin_lock_bh(&bucket->lock);
        elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
                                               elem->key, map->key_size);
        if (elem_probe && elem_probe == elem) {
@@ -918,7 +920,7 @@ static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
                sock_map_unref(elem->sk, elem);
                sock_hash_free_elem(htab, elem);
        }
-       raw_spin_unlock_bh(&bucket->lock);
+       spin_unlock_bh(&bucket->lock);
 }
 
 static long sock_hash_delete_elem(struct bpf_map *map, void *key)
@@ -932,7 +934,7 @@ static long sock_hash_delete_elem(struct bpf_map *map, void *key)
        hash = sock_hash_bucket_hash(key, key_size);
        bucket = sock_hash_select_bucket(htab, hash);
 
-       raw_spin_lock_bh(&bucket->lock);
+       spin_lock_bh(&bucket->lock);
        elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
        if (elem) {
                hlist_del_rcu(&elem->node);
@@ -940,7 +942,7 @@ static long sock_hash_delete_elem(struct bpf_map *map, void *key)
                sock_hash_free_elem(htab, elem);
                ret = 0;
        }
-       raw_spin_unlock_bh(&bucket->lock);
+       spin_unlock_bh(&bucket->lock);
        return ret;
 }
 
@@ -1000,7 +1002,7 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
        hash = sock_hash_bucket_hash(key, key_size);
        bucket = sock_hash_select_bucket(htab, hash);
 
-       raw_spin_lock_bh(&bucket->lock);
+       spin_lock_bh(&bucket->lock);
        elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
        if (elem && flags == BPF_NOEXIST) {
                ret = -EEXIST;
@@ -1026,10 +1028,10 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
                sock_map_unref(elem->sk, elem);
                sock_hash_free_elem(htab, elem);
        }
-       raw_spin_unlock_bh(&bucket->lock);
+       spin_unlock_bh(&bucket->lock);
        return 0;
 out_unlock:
-       raw_spin_unlock_bh(&bucket->lock);
+       spin_unlock_bh(&bucket->lock);
        sk_psock_put(sk, psock);
 out_free:
        sk_psock_free_link(link);
@@ -1115,7 +1117,7 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
 
        for (i = 0; i < htab->buckets_num; i++) {
                INIT_HLIST_HEAD(&htab->buckets[i].head);
-               raw_spin_lock_init(&htab->buckets[i].lock);
+               spin_lock_init(&htab->buckets[i].lock);
        }
 
        return &htab->map;
@@ -1147,11 +1149,11 @@ static void sock_hash_free(struct bpf_map *map)
                 * exists, psock exists and holds a ref to socket. That
                 * lets us to grab a socket ref too.
                 */
-               raw_spin_lock_bh(&bucket->lock);
+               spin_lock_bh(&bucket->lock);
                hlist_for_each_entry(elem, &bucket->head, node)
                        sock_hold(elem->sk);
                hlist_move_list(&bucket->head, &unlink_list);
-               raw_spin_unlock_bh(&bucket->lock);
+               spin_unlock_bh(&bucket->lock);
 
                /* Process removed entries out of atomic context to
                 * block for socket lock before deleting the psock's
@@ -1267,6 +1269,8 @@ BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
        sk = __sock_hash_lookup_elem(map, key);
        if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
                return SK_DROP;
+       if (!(flags & BPF_F_INGRESS) && !sk_is_tcp(sk))
+               return SK_DROP;
 
        msg->flags = flags;
        msg->sk_redir = sk;