Merge branch 'tcp-udp-fix-memory-leaks-and-data-races-around-ipv6_addrform'
authorJakub Kicinski <kuba@kernel.org>
Thu, 13 Oct 2022 00:50:40 +0000 (17:50 -0700)
committerJakub Kicinski <kuba@kernel.org>
Thu, 13 Oct 2022 00:50:40 +0000 (17:50 -0700)
Kuniyuki Iwashima says:

====================
tcp/udp: Fix memory leaks and data races around IPV6_ADDRFORM.

This series fixes some memory leaks and data races caused in the
same scenario where one thread converts an IPv6 socket into IPv4
with IPV6_ADDRFORM and another accesses the socket concurrently.

  v4: https://lore.kernel.org/netdev/20221004171802.40968-1-kuniyu@amazon.com/
  v3 (Resend): https://lore.kernel.org/netdev/20221003154425.49458-1-kuniyu@amazon.com/
  v3: https://lore.kernel.org/netdev/20220929012542.55424-1-kuniyu@amazon.com/
  v2: https://lore.kernel.org/netdev/20220928002741.64237-1-kuniyu@amazon.com/
  v1: https://lore.kernel.org/netdev/20220927161209.32939-1-kuniyu@amazon.com/
====================

Link: https://lore.kernel.org/r/20221006185349.74777-1-kuniyu@amazon.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
14 files changed:
include/net/ipv6.h
include/net/udp.h
include/net/udplite.h
net/core/sock.c
net/ipv4/af_inet.c
net/ipv4/tcp.c
net/ipv4/udp.c
net/ipv4/udplite.c
net/ipv6/af_inet6.c
net/ipv6/ipv6_sockglue.c
net/ipv6/tcp_ipv6.c
net/ipv6/udp.c
net/ipv6/udp_impl.h
net/ipv6/udplite.c

index d664ba5812d87cf64802c25c7a3175f8a30c32e6..37943ba3a73c5c6a5124fa6cf2199c7987abb11d 100644 (file)
@@ -1182,6 +1182,8 @@ void ipv6_icmp_error(struct sock *sk, struct sk_buff *skb, int err, __be16 port,
 void ipv6_local_error(struct sock *sk, int err, struct flowi6 *fl6, u32 info);
 void ipv6_local_rxpmtu(struct sock *sk, struct flowi6 *fl6, u32 mtu);
 
+void inet6_cleanup_sock(struct sock *sk);
+void inet6_sock_destruct(struct sock *sk);
 int inet6_release(struct socket *sock);
 int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len);
 int inet6_getname(struct socket *sock, struct sockaddr *uaddr,
index 5ee88ddf79c3fe61c007d501dcd5f124ff517642..fee053bcd17c6ddc7cea6c055e587f4321e800e0 100644 (file)
@@ -247,7 +247,7 @@ static inline bool udp_sk_bound_dev_eq(struct net *net, int bound_dev_if,
 }
 
 /* net/ipv4/udp.c */
-void udp_destruct_sock(struct sock *sk);
+void udp_destruct_common(struct sock *sk);
 void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len);
 int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb);
 void udp_skb_destructor(struct sock *sk, struct sk_buff *skb);
index 0143b373602ec547e5f133f9c02ae4e539ecfcfa..299c14ce2bb949029624710798810919cccb7934 100644 (file)
@@ -25,14 +25,6 @@ static __inline__ int udplite_getfrag(void *from, char *to, int  offset,
        return copy_from_iter_full(to, len, &msg->msg_iter) ? 0 : -EFAULT;
 }
 
-/* Designate sk as UDP-Lite socket */
-static inline int udplite_sk_init(struct sock *sk)
-{
-       udp_init_sock(sk);
-       udp_sk(sk)->pcflag = UDPLITE_BIT;
-       return 0;
-}
-
 /*
  *     Checksumming routines
  */
index eeb6cbac6f4998dbc41fc686e7e882135e45b9e3..a3ba0358c77c0e44db1cfbaeb420f8b80ad7cf98 100644 (file)
@@ -3610,7 +3610,8 @@ int sock_common_getsockopt(struct socket *sock, int level, int optname,
 {
        struct sock *sk = sock->sk;
 
-       return sk->sk_prot->getsockopt(sk, level, optname, optval, optlen);
+       /* IPV6_ADDRFORM can change sk->sk_prot under us. */
+       return READ_ONCE(sk->sk_prot)->getsockopt(sk, level, optname, optval, optlen);
 }
 EXPORT_SYMBOL(sock_common_getsockopt);
 
@@ -3636,7 +3637,8 @@ int sock_common_setsockopt(struct socket *sock, int level, int optname,
 {
        struct sock *sk = sock->sk;
 
-       return sk->sk_prot->setsockopt(sk, level, optname, optval, optlen);
+       /* IPV6_ADDRFORM can change sk->sk_prot under us. */
+       return READ_ONCE(sk->sk_prot)->setsockopt(sk, level, optname, optval, optlen);
 }
 EXPORT_SYMBOL(sock_common_setsockopt);
 
index e2c21938234552675c366aa9d034c39470a07411..3dd02396517df599cf4ff3b9ab8463ea959770a1 100644 (file)
@@ -558,22 +558,27 @@ int inet_dgram_connect(struct socket *sock, struct sockaddr *uaddr,
                       int addr_len, int flags)
 {
        struct sock *sk = sock->sk;
+       const struct proto *prot;
        int err;
 
        if (addr_len < sizeof(uaddr->sa_family))
                return -EINVAL;
+
+       /* IPV6_ADDRFORM can change sk->sk_prot under us. */
+       prot = READ_ONCE(sk->sk_prot);
+
        if (uaddr->sa_family == AF_UNSPEC)
-               return sk->sk_prot->disconnect(sk, flags);
+               return prot->disconnect(sk, flags);
 
        if (BPF_CGROUP_PRE_CONNECT_ENABLED(sk)) {
-               err = sk->sk_prot->pre_connect(sk, uaddr, addr_len);
+               err = prot->pre_connect(sk, uaddr, addr_len);
                if (err)
                        return err;
        }
 
        if (data_race(!inet_sk(sk)->inet_num) && inet_autobind(sk))
                return -EAGAIN;
-       return sk->sk_prot->connect(sk, uaddr, addr_len);
+       return prot->connect(sk, uaddr, addr_len);
 }
 EXPORT_SYMBOL(inet_dgram_connect);
 
@@ -734,10 +739,11 @@ EXPORT_SYMBOL(inet_stream_connect);
 int inet_accept(struct socket *sock, struct socket *newsock, int flags,
                bool kern)
 {
-       struct sock *sk1 = sock->sk;
+       struct sock *sk1 = sock->sk, *sk2;
        int err = -EINVAL;
-       struct sock *sk2 = sk1->sk_prot->accept(sk1, flags, &err, kern);
 
+       /* IPV6_ADDRFORM can change sk->sk_prot under us. */
+       sk2 = READ_ONCE(sk1->sk_prot)->accept(sk1, flags, &err, kern);
        if (!sk2)
                goto do_err;
 
@@ -825,12 +831,15 @@ ssize_t inet_sendpage(struct socket *sock, struct page *page, int offset,
                      size_t size, int flags)
 {
        struct sock *sk = sock->sk;
+       const struct proto *prot;
 
        if (unlikely(inet_send_prepare(sk)))
                return -EAGAIN;
 
-       if (sk->sk_prot->sendpage)
-               return sk->sk_prot->sendpage(sk, page, offset, size, flags);
+       /* IPV6_ADDRFORM can change sk->sk_prot under us. */
+       prot = READ_ONCE(sk->sk_prot);
+       if (prot->sendpage)
+               return prot->sendpage(sk, page, offset, size, flags);
        return sock_no_sendpage(sock, page, offset, size, flags);
 }
 EXPORT_SYMBOL(inet_sendpage);
index 0c51abeee172c3726af47907e8564de77b6f86ce..f8232811a5be17ec7652ff47ffde6341b2a76d1e 100644 (file)
@@ -3796,8 +3796,9 @@ int tcp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
        const struct inet_connection_sock *icsk = inet_csk(sk);
 
        if (level != SOL_TCP)
-               return icsk->icsk_af_ops->setsockopt(sk, level, optname,
-                                                    optval, optlen);
+               /* Paired with WRITE_ONCE() in do_ipv6_setsockopt() and tcp_v6_connect() */
+               return READ_ONCE(icsk->icsk_af_ops)->setsockopt(sk, level, optname,
+                                                               optval, optlen);
        return do_tcp_setsockopt(sk, level, optname, optval, optlen);
 }
 EXPORT_SYMBOL(tcp_setsockopt);
@@ -4396,8 +4397,9 @@ int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval,
        struct inet_connection_sock *icsk = inet_csk(sk);
 
        if (level != SOL_TCP)
-               return icsk->icsk_af_ops->getsockopt(sk, level, optname,
-                                                    optval, optlen);
+               /* Paired with WRITE_ONCE() in do_ipv6_setsockopt() and tcp_v6_connect() */
+               return READ_ONCE(icsk->icsk_af_ops)->getsockopt(sk, level, optname,
+                                                               optval, optlen);
        return do_tcp_getsockopt(sk, level, optname, USER_SOCKPTR(optval),
                                 USER_SOCKPTR(optlen));
 }
index d63118ce5900678004aa12d3b72526a001be2e52..8126f67d18b3410fa1af7ea738ff2f8b9d7826e1 100644 (file)
@@ -1598,7 +1598,7 @@ drop:
 }
 EXPORT_SYMBOL_GPL(__udp_enqueue_schedule_skb);
 
-void udp_destruct_sock(struct sock *sk)
+void udp_destruct_common(struct sock *sk)
 {
        /* reclaim completely the forward allocated memory */
        struct udp_sock *up = udp_sk(sk);
@@ -1611,10 +1611,14 @@ void udp_destruct_sock(struct sock *sk)
                kfree_skb(skb);
        }
        udp_rmem_release(sk, total, 0, true);
+}
+EXPORT_SYMBOL_GPL(udp_destruct_common);
 
+static void udp_destruct_sock(struct sock *sk)
+{
+       udp_destruct_common(sk);
        inet_sock_destruct(sk);
 }
-EXPORT_SYMBOL_GPL(udp_destruct_sock);
 
 int udp_init_sock(struct sock *sk)
 {
@@ -1622,7 +1626,6 @@ int udp_init_sock(struct sock *sk)
        sk->sk_destruct = udp_destruct_sock;
        return 0;
 }
-EXPORT_SYMBOL_GPL(udp_init_sock);
 
 void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len)
 {
index 6e08a76ae1e7e13905fa13ea12e075b94308a8ff..e0c9cc39b81e38df3f83d22a886b2f793c7b732b 100644 (file)
 struct udp_table       udplite_table __read_mostly;
 EXPORT_SYMBOL(udplite_table);
 
+/* Designate sk as UDP-Lite socket */
+static int udplite_sk_init(struct sock *sk)
+{
+       udp_init_sock(sk);
+       udp_sk(sk)->pcflag = UDPLITE_BIT;
+       return 0;
+}
+
 static int udplite_rcv(struct sk_buff *skb)
 {
        return __udp4_lib_rcv(skb, &udplite_table, IPPROTO_UDPLITE);
index d40b7d60e00eed6c84bac5762a887d56c85ba6cc..0241910049825ba6e67ac66e3569bdca4512640d 100644 (file)
@@ -109,6 +109,12 @@ static __inline__ struct ipv6_pinfo *inet6_sk_generic(struct sock *sk)
        return (struct ipv6_pinfo *)(((u8 *)sk) + offset);
 }
 
+void inet6_sock_destruct(struct sock *sk)
+{
+       inet6_cleanup_sock(sk);
+       inet_sock_destruct(sk);
+}
+
 static int inet6_create(struct net *net, struct socket *sock, int protocol,
                        int kern)
 {
@@ -201,7 +207,7 @@ lookup_protocol:
                        inet->hdrincl = 1;
        }
 
-       sk->sk_destruct         = inet_sock_destruct;
+       sk->sk_destruct         = inet6_sock_destruct;
        sk->sk_family           = PF_INET6;
        sk->sk_protocol         = protocol;
 
@@ -510,6 +516,12 @@ void inet6_destroy_sock(struct sock *sk)
 }
 EXPORT_SYMBOL_GPL(inet6_destroy_sock);
 
+void inet6_cleanup_sock(struct sock *sk)
+{
+       inet6_destroy_sock(sk);
+}
+EXPORT_SYMBOL_GPL(inet6_cleanup_sock);
+
 /*
  *     This does both peername and sockname.
  */
index 2d2f4dd9e5dfa8278f5dbad0bfd5a2e16a77406d..532f4478c88402b3967241e7399b5385d688db90 100644 (file)
@@ -419,15 +419,18 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
                rtnl_lock();
        sockopt_lock_sock(sk);
 
+       /* Another thread has converted the socket into IPv4 with
+        * IPV6_ADDRFORM concurrently.
+        */
+       if (unlikely(sk->sk_family != AF_INET6))
+               goto unlock;
+
        switch (optname) {
 
        case IPV6_ADDRFORM:
                if (optlen < sizeof(int))
                        goto e_inval;
                if (val == PF_INET) {
-                       struct ipv6_txoptions *opt;
-                       struct sk_buff *pktopt;
-
                        if (sk->sk_type == SOCK_RAW)
                                break;
 
@@ -458,7 +461,6 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
                                break;
                        }
 
-                       fl6_free_socklist(sk);
                        __ipv6_sock_mc_close(sk);
                        __ipv6_sock_ac_close(sk);
 
@@ -475,9 +477,10 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
                                sock_prot_inuse_add(net, sk->sk_prot, -1);
                                sock_prot_inuse_add(net, &tcp_prot, 1);
 
-                               /* Paired with READ_ONCE(sk->sk_prot) in net/ipv6/af_inet6.c */
+                               /* Paired with READ_ONCE(sk->sk_prot) in inet6_stream_ops */
                                WRITE_ONCE(sk->sk_prot, &tcp_prot);
-                               icsk->icsk_af_ops = &ipv4_specific;
+                               /* Paired with READ_ONCE() in tcp_(get|set)sockopt() */
+                               WRITE_ONCE(icsk->icsk_af_ops, &ipv4_specific);
                                sk->sk_socket->ops = &inet_stream_ops;
                                sk->sk_family = PF_INET;
                                tcp_sync_mss(sk, icsk->icsk_pmtu_cookie);
@@ -490,19 +493,19 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
                                sock_prot_inuse_add(net, sk->sk_prot, -1);
                                sock_prot_inuse_add(net, prot, 1);
 
-                               /* Paired with READ_ONCE(sk->sk_prot) in net/ipv6/af_inet6.c */
+                               /* Paired with READ_ONCE(sk->sk_prot) in inet6_dgram_ops */
                                WRITE_ONCE(sk->sk_prot, prot);
                                sk->sk_socket->ops = &inet_dgram_ops;
                                sk->sk_family = PF_INET;
                        }
-                       opt = xchg((__force struct ipv6_txoptions **)&np->opt,
-                                  NULL);
-                       if (opt) {
-                               atomic_sub(opt->tot_len, &sk->sk_omem_alloc);
-                               txopt_put(opt);
-                       }
-                       pktopt = xchg(&np->pktoptions, NULL);
-                       kfree_skb(pktopt);
+
+                       /* Disable all options not to allocate memory anymore,
+                        * but there is still a race.  See the lockless path
+                        * in udpv6_sendmsg() and ipv6_local_rxpmtu().
+                        */
+                       np->rxopt.all = 0;
+
+                       inet6_cleanup_sock(sk);
 
                        /*
                         * ... and add it to the refcnt debug socks count
@@ -994,6 +997,7 @@ done:
                break;
        }
 
+unlock:
        sockopt_release_sock(sk);
        if (needs_rtnl)
                rtnl_unlock();
index a8adda623da15f9d396257c53d267c935e995be0..2a3f9296df1e505b40e925c31b0d2aa2a2327cfd 100644 (file)
@@ -238,7 +238,8 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
                sin.sin_port = usin->sin6_port;
                sin.sin_addr.s_addr = usin->sin6_addr.s6_addr32[3];
 
-               icsk->icsk_af_ops = &ipv6_mapped;
+               /* Paired with READ_ONCE() in tcp_(get|set)sockopt() */
+               WRITE_ONCE(icsk->icsk_af_ops, &ipv6_mapped);
                if (sk_is_mptcp(sk))
                        mptcpv6_handle_mapped(sk, true);
                sk->sk_backlog_rcv = tcp_v4_do_rcv;
@@ -250,7 +251,8 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 
                if (err) {
                        icsk->icsk_ext_hdr_len = exthdrlen;
-                       icsk->icsk_af_ops = &ipv6_specific;
+                       /* Paired with READ_ONCE() in tcp_(get|set)sockopt() */
+                       WRITE_ONCE(icsk->icsk_af_ops, &ipv6_specific);
                        if (sk_is_mptcp(sk))
                                mptcpv6_handle_mapped(sk, false);
                        sk->sk_backlog_rcv = tcp_v6_do_rcv;
index 91e795bb9ade610c138e003ea2ffe326fa39def8..8d09f0ea5b8c70df643a9ddd892624fba9d08f5f 100644 (file)
 #include <trace/events/skb.h>
 #include "udp_impl.h"
 
+static void udpv6_destruct_sock(struct sock *sk)
+{
+       udp_destruct_common(sk);
+       inet6_sock_destruct(sk);
+}
+
+int udpv6_init_sock(struct sock *sk)
+{
+       skb_queue_head_init(&udp_sk(sk)->reader_queue);
+       sk->sk_destruct = udpv6_destruct_sock;
+       return 0;
+}
+
 static u32 udp6_ehashfn(const struct net *net,
                        const struct in6_addr *laddr,
                        const u16 lport,
@@ -1733,7 +1746,7 @@ struct proto udpv6_prot = {
        .connect                = ip6_datagram_connect,
        .disconnect             = udp_disconnect,
        .ioctl                  = udp_ioctl,
-       .init                   = udp_init_sock,
+       .init                   = udpv6_init_sock,
        .destroy                = udpv6_destroy_sock,
        .setsockopt             = udpv6_setsockopt,
        .getsockopt             = udpv6_getsockopt,
index 4251e49d32a0d067a282359e98eab0e8f67218fd..0590f566379d7d07dfdd1b0ae808b9d8964eb5aa 100644 (file)
@@ -12,6 +12,7 @@ int __udp6_lib_rcv(struct sk_buff *, struct udp_table *, int);
 int __udp6_lib_err(struct sk_buff *, struct inet6_skb_parm *, u8, u8, int,
                   __be32, struct udp_table *);
 
+int udpv6_init_sock(struct sock *sk);
 int udp_v6_get_port(struct sock *sk, unsigned short snum);
 void udp_v6_rehash(struct sock *sk);
 
index b707258562597ebddf5e0d75e6415b6f967cca33..67eaf3ca14cea71fad76d88414483c0d1d2321b9 100644 (file)
 #include <linux/proc_fs.h>
 #include "udp_impl.h"
 
+static int udplitev6_sk_init(struct sock *sk)
+{
+       udpv6_init_sock(sk);
+       udp_sk(sk)->pcflag = UDPLITE_BIT;
+       return 0;
+}
+
 static int udplitev6_rcv(struct sk_buff *skb)
 {
        return __udp6_lib_rcv(skb, &udplite_table, IPPROTO_UDPLITE);
@@ -38,7 +45,7 @@ struct proto udplitev6_prot = {
        .connect           = ip6_datagram_connect,
        .disconnect        = udp_disconnect,
        .ioctl             = udp_ioctl,
-       .init              = udplite_sk_init,
+       .init              = udplitev6_sk_init,
        .destroy           = udpv6_destroy_sock,
        .setsockopt        = udpv6_setsockopt,
        .getsockopt        = udpv6_getsockopt,