sctp: hold endpoint before calling cb in sctp_transport_lookup_process
authorXin Long <lucien.xin@gmail.com>
Fri, 31 Dec 2021 23:37:37 +0000 (18:37 -0500)
committerGreg Kroah-Hartman <gregkh@linuxfoundation.org>
Tue, 11 Jan 2022 14:35:14 +0000 (15:35 +0100)
commit f9d31c4cf4c11ff10317f038b9c6f7c3bda6cdd4 upstream.

The same fix in commit 5ec7d18d1813 ("sctp: use call_rcu to free endpoint")
is also needed for dumping one asoc and sock after the lookup.

Fixes: 86fdb3448cc1 ("sctp: ensure ep is not destroyed before doing the dump")
Signed-off-by: Xin Long <lucien.xin@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
include/net/sctp/sctp.h
net/sctp/diag.c
net/sctp/socket.c

index d314a18..3ae61ce 100644 (file)
@@ -112,8 +112,7 @@ struct sctp_transport *sctp_transport_get_next(struct net *net,
                        struct rhashtable_iter *iter);
 struct sctp_transport *sctp_transport_get_idx(struct net *net,
                        struct rhashtable_iter *iter, int pos);
-int sctp_transport_lookup_process(int (*cb)(struct sctp_transport *, void *),
-                                 struct net *net,
+int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net,
                                  const union sctp_addr *laddr,
                                  const union sctp_addr *paddr, void *p);
 int sctp_transport_traverse_process(sctp_callback_t cb, sctp_callback_t cb_done,
index a7d6231..034e2c7 100644 (file)
@@ -245,48 +245,44 @@ static size_t inet_assoc_attr_size(struct sctp_association *asoc)
                + 64;
 }
 
-static int sctp_tsp_dump_one(struct sctp_transport *tsp, void *p)
+static int sctp_sock_dump_one(struct sctp_endpoint *ep, struct sctp_transport *tsp, void *p)
 {
        struct sctp_association *assoc = tsp->asoc;
-       struct sock *sk = tsp->asoc->base.sk;
        struct sctp_comm_param *commp = p;
-       struct sk_buff *in_skb = commp->skb;
+       struct sock *sk = ep->base.sk;
        const struct inet_diag_req_v2 *req = commp->r;
-       const struct nlmsghdr *nlh = commp->nlh;
-       struct net *net = sock_net(in_skb->sk);
+       struct sk_buff *skb = commp->skb;
        struct sk_buff *rep;
        int err;
 
        err = sock_diag_check_cookie(sk, req->id.idiag_cookie);
        if (err)
-               goto out;
+               return err;
 
-       err = -ENOMEM;
        rep = nlmsg_new(inet_assoc_attr_size(assoc), GFP_KERNEL);
        if (!rep)
-               goto out;
+               return -ENOMEM;
 
        lock_sock(sk);
-       if (sk != assoc->base.sk) {
-               release_sock(sk);
-               sk = assoc->base.sk;
-               lock_sock(sk);
+       if (ep != assoc->ep) {
+               err = -EAGAIN;
+               goto out;
        }
-       err = inet_sctp_diag_fill(sk, assoc, rep, req,
-                                 sk_user_ns(NETLINK_CB(in_skb).sk),
-                                 NETLINK_CB(in_skb).portid,
-                                 nlh->nlmsg_seq, 0, nlh,
-                                 commp->net_admin);
-       release_sock(sk);
+
+       err = inet_sctp_diag_fill(sk, assoc, rep, req, sk_user_ns(NETLINK_CB(skb).sk),
+                                 NETLINK_CB(skb).portid, commp->nlh->nlmsg_seq, 0,
+                                 commp->nlh, commp->net_admin);
        if (err < 0) {
                WARN_ON(err == -EMSGSIZE);
-               kfree_skb(rep);
                goto out;
        }
+       release_sock(sk);
 
-       err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid);
+       return nlmsg_unicast(sock_net(skb->sk)->diag_nlsk, rep, NETLINK_CB(skb).portid);
 
 out:
+       release_sock(sk);
+       kfree_skb(rep);
        return err;
 }
 
@@ -429,15 +425,15 @@ static void sctp_diag_get_info(struct sock *sk, struct inet_diag_msg *r,
 static int sctp_diag_dump_one(struct netlink_callback *cb,
                              const struct inet_diag_req_v2 *req)
 {
-       struct sk_buff *in_skb = cb->skb;
-       struct net *net = sock_net(in_skb->sk);
+       struct sk_buff *skb = cb->skb;
+       struct net *net = sock_net(skb->sk);
        const struct nlmsghdr *nlh = cb->nlh;
        union sctp_addr laddr, paddr;
        struct sctp_comm_param commp = {
-               .skb = in_skb,
+               .skb = skb,
                .r = req,
                .nlh = nlh,
-               .net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN),
+               .net_admin = netlink_net_capable(skb, CAP_NET_ADMIN),
        };
 
        if (req->sdiag_family == AF_INET) {
@@ -460,7 +456,7 @@ static int sctp_diag_dump_one(struct netlink_callback *cb,
                paddr.v6.sin6_family = AF_INET6;
        }
 
-       return sctp_transport_lookup_process(sctp_tsp_dump_one,
+       return sctp_transport_lookup_process(sctp_sock_dump_one,
                                             net, &laddr, &paddr, &commp);
 }
 
index d2215d2..6b3c322 100644 (file)
@@ -5317,23 +5317,31 @@ int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *),
 }
 EXPORT_SYMBOL_GPL(sctp_for_each_endpoint);
 
-int sctp_transport_lookup_process(int (*cb)(struct sctp_transport *, void *),
-                                 struct net *net,
+int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net,
                                  const union sctp_addr *laddr,
                                  const union sctp_addr *paddr, void *p)
 {
        struct sctp_transport *transport;
-       int err;
+       struct sctp_endpoint *ep;
+       int err = -ENOENT;
 
        rcu_read_lock();
        transport = sctp_addrs_lookup_transport(net, laddr, paddr);
+       if (!transport) {
+               rcu_read_unlock();
+               return err;
+       }
+       ep = transport->asoc->ep;
+       if (!sctp_endpoint_hold(ep)) { /* asoc can be peeled off */
+               sctp_transport_put(transport);
+               rcu_read_unlock();
+               return err;
+       }
        rcu_read_unlock();
-       if (!transport)
-               return -ENOENT;
 
-       err = cb(transport, p);
+       err = cb(ep, transport, p);
+       sctp_endpoint_put(ep);
        sctp_transport_put(transport);
-
        return err;
 }
 EXPORT_SYMBOL_GPL(sctp_transport_lookup_process);