Bluetooth: Write host suggested default le data length
[platform/kernel/linux-starfive.git] / net / mctp / route.c
index f9a80b8..f51a05e 100644 (file)
@@ -147,6 +147,7 @@ static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
        key->valid = true;
        spin_lock_init(&key->lock);
        refcount_set(&key->refs, 1);
        key->valid = true;
        spin_lock_init(&key->lock);
        refcount_set(&key->refs, 1);
+       sock_hold(key->sk);
 
        return key;
 }
 
        return key;
 }
@@ -165,6 +166,7 @@ void mctp_key_unref(struct mctp_sk_key *key)
        mctp_dev_release_key(key->dev, key);
        spin_unlock_irqrestore(&key->lock, flags);
 
        mctp_dev_release_key(key->dev, key);
        spin_unlock_irqrestore(&key->lock, flags);
 
+       sock_put(key->sk);
        kfree(key);
 }
 
        kfree(key);
 }
 
@@ -177,6 +179,11 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
 
        spin_lock_irqsave(&net->mctp.keys_lock, flags);
 
 
        spin_lock_irqsave(&net->mctp.keys_lock, flags);
 
+       if (sock_flag(&msk->sk, SOCK_DEAD)) {
+               rc = -EINVAL;
+               goto out_unlock;
+       }
+
        hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
                if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
                                   key->tag)) {
        hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
                if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
                                   key->tag)) {
@@ -198,6 +205,7 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
                hlist_add_head(&key->sklist, &msk->keys);
        }
 
                hlist_add_head(&key->sklist, &msk->keys);
        }
 
+out_unlock:
        spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
 
        return rc;
        spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
 
        return rc;
@@ -315,8 +323,8 @@ static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb)
 
 static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 {
 
 static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 {
+       struct mctp_sk_key *key, *any_key = NULL;
        struct net *net = dev_net(skb->dev);
        struct net *net = dev_net(skb->dev);
-       struct mctp_sk_key *key;
        struct mctp_sock *msk;
        struct mctp_hdr *mh;
        unsigned long f;
        struct mctp_sock *msk;
        struct mctp_hdr *mh;
        unsigned long f;
@@ -361,13 +369,11 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
                         * key for reassembly - we'll create a more specific
                         * one for future packets if required (ie, !EOM).
                         */
                         * key for reassembly - we'll create a more specific
                         * one for future packets if required (ie, !EOM).
                         */
-                       key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f);
-                       if (key) {
-                               msk = container_of(key->sk,
+                       any_key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f);
+                       if (any_key) {
+                               msk = container_of(any_key->sk,
                                                   struct mctp_sock, sk);
                                                   struct mctp_sock, sk);
-                               spin_unlock_irqrestore(&key->lock, f);
-                               mctp_key_unref(key);
-                               key = NULL;
+                               spin_unlock_irqrestore(&any_key->lock, f);
                        }
                }
 
                        }
                }
 
@@ -419,14 +425,14 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
                         * this function.
                         */
                        rc = mctp_key_add(key, msk);
                         * this function.
                         */
                        rc = mctp_key_add(key, msk);
-                       if (rc) {
-                               kfree(key);
-                       } else {
+                       if (!rc)
                                trace_mctp_key_acquire(key);
 
                                trace_mctp_key_acquire(key);
 
-                               /* we don't need to release key->lock on exit */
-                               mctp_key_unref(key);
-                       }
+                       /* we don't need to release key->lock on exit, so
+                        * clean up here and suppress the unlock via
+                        * setting to NULL
+                        */
+                       mctp_key_unref(key);
                        key = NULL;
 
                } else {
                        key = NULL;
 
                } else {
@@ -473,6 +479,8 @@ out_unlock:
                spin_unlock_irqrestore(&key->lock, f);
                mctp_key_unref(key);
        }
                spin_unlock_irqrestore(&key->lock, f);
                mctp_key_unref(key);
        }
+       if (any_key)
+               mctp_key_unref(any_key);
 out:
        if (rc)
                kfree_skb(skb);
 out:
        if (rc)
                kfree_skb(skb);