inet: introduce inet->inet_flags
authorEric Dumazet <edumazet@google.com>
Wed, 16 Aug 2023 08:15:33 +0000 (08:15 +0000)
committerDavid S. Miller <davem@davemloft.net>
Wed, 16 Aug 2023 10:09:16 +0000 (11:09 +0100)
Various inet fields are currently racy.

do_ip_setsockopt() and do_ip_getsockopt() are mostly holding
the socket lock, but some (fast) paths do not.

Use a new inet->inet_flags to hold atomic bits in the series.

Remove inet->cmsg_flags, and use instead 9 bits from inet_flags.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Acked-by: Soheil Hassas Yeganeh <soheil@google.com>
Reviewed-by: Simon Horman <horms@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/inet_sock.h
net/ipv4/ip_sockglue.c
net/ipv4/ping.c
net/ipv4/raw.c
net/ipv4/udp.c
net/ipv6/datagram.c
net/ipv6/udp.c
net/l2tp/l2tp_ip.c

index 0bb32bf..e3b35b0 100644 (file)
@@ -194,6 +194,7 @@ struct rtable;
  * @inet_rcv_saddr - Bound local IPv4 addr
  * @inet_dport - Destination port
  * @inet_num - Local port
+ * @inet_flags - various atomic flags
  * @inet_saddr - Sending source
  * @uc_ttl - Unicast TTL
  * @inet_sport - Source port
@@ -218,11 +219,11 @@ struct inet_sock {
 #define inet_dport             sk.__sk_common.skc_dport
 #define inet_num               sk.__sk_common.skc_num
 
+       unsigned long           inet_flags;
        __be32                  inet_saddr;
        __s16                   uc_ttl;
-       __u16                   cmsg_flags;
-       struct ip_options_rcu __rcu     *inet_opt;
        __be16                  inet_sport;
+       struct ip_options_rcu __rcu     *inet_opt;
        __u16                   inet_id;
 
        __u8                    tos;
@@ -259,16 +260,48 @@ struct inet_sock {
 #define IPCORK_OPT     1       /* ip-options has been held in ipcork.opt */
 #define IPCORK_ALLFRAG 2       /* always fragment (for ipv6 for now) */
 
+enum {
+       INET_FLAGS_PKTINFO      = 0,
+       INET_FLAGS_TTL          = 1,
+       INET_FLAGS_TOS          = 2,
+       INET_FLAGS_RECVOPTS     = 3,
+       INET_FLAGS_RETOPTS      = 4,
+       INET_FLAGS_PASSSEC      = 5,
+       INET_FLAGS_ORIGDSTADDR  = 6,
+       INET_FLAGS_CHECKSUM     = 7,
+       INET_FLAGS_RECVFRAGSIZE = 8,
+};
+
 /* cmsg flags for inet */
-#define IP_CMSG_PKTINFO                BIT(0)
-#define IP_CMSG_TTL            BIT(1)
-#define IP_CMSG_TOS            BIT(2)
-#define IP_CMSG_RECVOPTS       BIT(3)
-#define IP_CMSG_RETOPTS                BIT(4)
-#define IP_CMSG_PASSSEC                BIT(5)
-#define IP_CMSG_ORIGDSTADDR    BIT(6)
-#define IP_CMSG_CHECKSUM       BIT(7)
-#define IP_CMSG_RECVFRAGSIZE   BIT(8)
+#define IP_CMSG_PKTINFO                BIT(INET_FLAGS_PKTINFO)
+#define IP_CMSG_TTL            BIT(INET_FLAGS_TTL)
+#define IP_CMSG_TOS            BIT(INET_FLAGS_TOS)
+#define IP_CMSG_RECVOPTS       BIT(INET_FLAGS_RECVOPTS)
+#define IP_CMSG_RETOPTS                BIT(INET_FLAGS_RETOPTS)
+#define IP_CMSG_PASSSEC                BIT(INET_FLAGS_PASSSEC)
+#define IP_CMSG_ORIGDSTADDR    BIT(INET_FLAGS_ORIGDSTADDR)
+#define IP_CMSG_CHECKSUM       BIT(INET_FLAGS_CHECKSUM)
+#define IP_CMSG_RECVFRAGSIZE   BIT(INET_FLAGS_RECVFRAGSIZE)
+
+#define IP_CMSG_ALL    (IP_CMSG_PKTINFO | IP_CMSG_TTL |                \
+                        IP_CMSG_TOS | IP_CMSG_RECVOPTS |               \
+                        IP_CMSG_RETOPTS | IP_CMSG_PASSSEC |            \
+                        IP_CMSG_ORIGDSTADDR | IP_CMSG_CHECKSUM |       \
+                        IP_CMSG_RECVFRAGSIZE)
+
+static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
+{
+       return READ_ONCE(inet->inet_flags) & IP_CMSG_ALL;
+}
+
+#define inet_test_bit(nr, sk)                  \
+       test_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags)
+#define inet_set_bit(nr, sk)                   \
+       set_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags)
+#define inet_clear_bit(nr, sk)                 \
+       clear_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags)
+#define inet_assign_bit(nr, sk, val)           \
+       assign_bit(INET_FLAGS_##nr, &inet_sk(sk)->inet_flags, val)
 
 static inline bool sk_is_inet(struct sock *sk)
 {
index d41bce8..66f55f3 100644 (file)
@@ -171,8 +171,10 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb)
 void ip_cmsg_recv_offset(struct msghdr *msg, struct sock *sk,
                         struct sk_buff *skb, int tlen, int offset)
 {
-       struct inet_sock *inet = inet_sk(sk);
-       unsigned int flags = inet->cmsg_flags;
+       unsigned long flags = inet_cmsg_flags(inet_sk(sk));
+
+       if (!flags)
+               return;
 
        /* Ordered by supposed usage frequency */
        if (flags & IP_CMSG_PKTINFO) {
@@ -568,7 +570,7 @@ int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
        if (ipv4_datagram_support_cmsg(sk, skb, serr->ee.ee_origin)) {
                sin->sin_family = AF_INET;
                sin->sin_addr.s_addr = ip_hdr(skb)->saddr;
-               if (inet_sk(sk)->cmsg_flags)
+               if (inet_cmsg_flags(inet_sk(sk)))
                        ip_cmsg_recv(msg, skb);
        }
 
@@ -635,7 +637,7 @@ EXPORT_SYMBOL(ip_sock_set_mtu_discover);
 void ip_sock_set_pktinfo(struct sock *sk)
 {
        lock_sock(sk);
-       inet_sk(sk)->cmsg_flags |= IP_CMSG_PKTINFO;
+       inet_set_bit(PKTINFO, sk);
        release_sock(sk);
 }
 EXPORT_SYMBOL(ip_sock_set_pktinfo);
@@ -990,67 +992,43 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
                break;
        }
        case IP_PKTINFO:
-               if (val)
-                       inet->cmsg_flags |= IP_CMSG_PKTINFO;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_PKTINFO;
+               inet_assign_bit(PKTINFO, sk, val);
                break;
        case IP_RECVTTL:
-               if (val)
-                       inet->cmsg_flags |=  IP_CMSG_TTL;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_TTL;
+               inet_assign_bit(TTL, sk, val);
                break;
        case IP_RECVTOS:
-               if (val)
-                       inet->cmsg_flags |=  IP_CMSG_TOS;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_TOS;
+               inet_assign_bit(TOS, sk, val);
                break;
        case IP_RECVOPTS:
-               if (val)
-                       inet->cmsg_flags |=  IP_CMSG_RECVOPTS;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_RECVOPTS;
+               inet_assign_bit(RECVOPTS, sk, val);
                break;
        case IP_RETOPTS:
-               if (val)
-                       inet->cmsg_flags |= IP_CMSG_RETOPTS;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_RETOPTS;
+               inet_assign_bit(RETOPTS, sk, val);
                break;
        case IP_PASSSEC:
-               if (val)
-                       inet->cmsg_flags |= IP_CMSG_PASSSEC;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_PASSSEC;
+               inet_assign_bit(PASSSEC, sk, val);
                break;
        case IP_RECVORIGDSTADDR:
-               if (val)
-                       inet->cmsg_flags |= IP_CMSG_ORIGDSTADDR;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_ORIGDSTADDR;
+               inet_assign_bit(ORIGDSTADDR, sk, val);
                break;
        case IP_CHECKSUM:
                if (val) {
-                       if (!(inet->cmsg_flags & IP_CMSG_CHECKSUM)) {
+                       if (!(inet_test_bit(CHECKSUM, sk))) {
                                inet_inc_convert_csum(sk);
-                               inet->cmsg_flags |= IP_CMSG_CHECKSUM;
+                               inet_set_bit(CHECKSUM, sk);
                        }
                } else {
-                       if (inet->cmsg_flags & IP_CMSG_CHECKSUM) {
+                       if (inet_test_bit(CHECKSUM, sk)) {
                                inet_dec_convert_csum(sk);
-                               inet->cmsg_flags &= ~IP_CMSG_CHECKSUM;
+                               inet_clear_bit(CHECKSUM, sk);
                        }
                }
                break;
        case IP_RECVFRAGSIZE:
                if (sk->sk_type != SOCK_RAW && sk->sk_type != SOCK_DGRAM)
                        goto e_inval;
-               if (val)
-                       inet->cmsg_flags |= IP_CMSG_RECVFRAGSIZE;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_RECVFRAGSIZE;
+               inet_assign_bit(RECVFRAGSIZE, sk, val);
                break;
        case IP_TOS:    /* This sets both TOS and Precedence */
                __ip_sock_set_tos(sk, val);
@@ -1415,7 +1393,7 @@ e_inval:
 void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb)
 {
        struct in_pktinfo *pktinfo = PKTINFO_SKB_CB(skb);
-       bool prepare = (inet_sk(sk)->cmsg_flags & IP_CMSG_PKTINFO) ||
+       bool prepare = inet_test_bit(PKTINFO, sk) ||
                       ipv6_sk_rxinfo(sk);
 
        if (prepare && skb_rtable(skb)) {
@@ -1601,31 +1579,31 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
                return 0;
        }
        case IP_PKTINFO:
-               val = (inet->cmsg_flags & IP_CMSG_PKTINFO) != 0;
+               val = inet_test_bit(PKTINFO, sk);
                break;
        case IP_RECVTTL:
-               val = (inet->cmsg_flags & IP_CMSG_TTL) != 0;
+               val = inet_test_bit(TTL, sk);
                break;
        case IP_RECVTOS:
-               val = (inet->cmsg_flags & IP_CMSG_TOS) != 0;
+               val = inet_test_bit(TOS, sk);
                break;
        case IP_RECVOPTS:
-               val = (inet->cmsg_flags & IP_CMSG_RECVOPTS) != 0;
+               val = inet_test_bit(RECVOPTS, sk);
                break;
        case IP_RETOPTS:
-               val = (inet->cmsg_flags & IP_CMSG_RETOPTS) != 0;
+               val = inet_test_bit(RETOPTS, sk);
                break;
        case IP_PASSSEC:
-               val = (inet->cmsg_flags & IP_CMSG_PASSSEC) != 0;
+               val = inet_test_bit(PASSSEC, sk);
                break;
        case IP_RECVORIGDSTADDR:
-               val = (inet->cmsg_flags & IP_CMSG_ORIGDSTADDR) != 0;
+               val = inet_test_bit(ORIGDSTADDR, sk);
                break;
        case IP_CHECKSUM:
-               val = (inet->cmsg_flags & IP_CMSG_CHECKSUM) != 0;
+               val = inet_test_bit(CHECKSUM, sk);
                break;
        case IP_RECVFRAGSIZE:
-               val = (inet->cmsg_flags & IP_CMSG_RECVFRAGSIZE) != 0;
+               val = inet_test_bit(RECVFRAGSIZE, sk);
                break;
        case IP_TOS:
                val = inet->tos;
@@ -1737,7 +1715,7 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
                msg.msg_controllen = len;
                msg.msg_flags = in_compat_syscall() ? MSG_CMSG_COMPAT : 0;
 
-               if (inet->cmsg_flags & IP_CMSG_PKTINFO) {
+               if (inet_test_bit(PKTINFO, sk)) {
                        struct in_pktinfo info;
 
                        info.ipi_addr.s_addr = inet->inet_rcv_saddr;
@@ -1745,11 +1723,11 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
                        info.ipi_ifindex = inet->mc_index;
                        put_cmsg(&msg, SOL_IP, IP_PKTINFO, sizeof(info), &info);
                }
-               if (inet->cmsg_flags & IP_CMSG_TTL) {
+               if (inet_test_bit(TTL, sk)) {
                        int hlim = inet->mc_ttl;
                        put_cmsg(&msg, SOL_IP, IP_TTL, sizeof(hlim), &hlim);
                }
-               if (inet->cmsg_flags & IP_CMSG_TOS) {
+               if (inet_test_bit(TOS, sk)) {
                        int tos = inet->rcv_tos;
                        put_cmsg(&msg, SOL_IP, IP_TOS, sizeof(tos), &tos);
                }
index 25dd78c..7e8702c 100644 (file)
@@ -894,7 +894,7 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags,
                        *addr_len = sizeof(*sin);
                }
 
-               if (isk->cmsg_flags)
+               if (inet_cmsg_flags(isk))
                        ip_cmsg_recv(msg, skb);
 
 #if IS_ENABLED(CONFIG_IPV6)
@@ -921,7 +921,8 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags,
                if (skb->protocol == htons(ETH_P_IPV6) &&
                    inet6_sk(sk)->rxopt.all)
                        pingv6_ops.ip6_datagram_recv_specific_ctl(sk, msg, skb);
-               else if (skb->protocol == htons(ETH_P_IP) && isk->cmsg_flags)
+               else if (skb->protocol == htons(ETH_P_IP) &&
+                        inet_cmsg_flags(isk))
                        ip_cmsg_recv(msg, skb);
 #endif
        } else {
index cb381f5..e6e813f 100644 (file)
@@ -767,7 +767,7 @@ static int raw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                memset(&sin->sin_zero, 0, sizeof(sin->sin_zero));
                *addr_len = sizeof(*sin);
        }
-       if (inet->cmsg_flags)
+       if (inet_cmsg_flags(inet))
                ip_cmsg_recv(msg, skb);
        if (flags & MSG_TRUNC)
                copied = skb->len;
index 3e2f29c..4b79113 100644 (file)
@@ -1870,7 +1870,7 @@ try_again:
        if (udp_sk(sk)->gro_enabled)
                udp_cmsg_recv(msg, sk, skb);
 
-       if (inet->cmsg_flags)
+       if (inet_cmsg_flags(inet))
                ip_cmsg_recv_offset(msg, sk, skb, sizeof(struct udphdr), off);
 
        err = copied;
index d80d602..41ebc4e 100644 (file)
@@ -524,7 +524,7 @@ int ipv6_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
                } else {
                        ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr,
                                               &sin->sin6_addr);
-                       if (inet_sk(sk)->cmsg_flags)
+                       if (inet_cmsg_flags(inet_sk(sk)))
                                ip_cmsg_recv(msg, skb);
                }
        }
index 1ea01b0..ebc6ae4 100644 (file)
@@ -420,7 +420,7 @@ try_again:
                ip6_datagram_recv_common_ctl(sk, msg, skb);
 
        if (is_udp4) {
-               if (inet->cmsg_flags)
+               if (inet_cmsg_flags(inet))
                        ip_cmsg_recv_offset(msg, sk, skb,
                                            sizeof(struct udphdr), off);
        } else {
index f9073bc..9a2a9ed 100644 (file)
@@ -552,7 +552,7 @@ static int l2tp_ip_recvmsg(struct sock *sk, struct msghdr *msg,
                memset(&sin->sin_zero, 0, sizeof(sin->sin_zero));
                *addr_len = sizeof(*sin);
        }
-       if (inet->cmsg_flags)
+       if (inet_cmsg_flags(inet))
                ip_cmsg_recv(msg, skb);
        if (flags & MSG_TRUNC)
                copied = skb->len;