inet: add READ_ONCE(sk->sk_bound_dev_if) in INET_MATCH()
authorEric Dumazet <edumazet@google.com>
Thu, 12 May 2022 16:56:01 +0000 (09:56 -0700)
committerGreg Kroah-Hartman <gregkh@linuxfoundation.org>
Wed, 17 Aug 2022 12:23:35 +0000 (14:23 +0200)
[ Upstream commit 4915d50e300e96929d2462041d6f6c6f061167fd ]

INET_MATCH() runs without holding a lock on the socket.

We probably need to annotate most reads.

This patch makes INET_MATCH() an inline function
to ease our changes.

v2:

We remove the 32bit version of it, as modern compilers
should generate the same code really, no need to
try to be smarter.

Also make 'struct net *net' the first argument.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Sasha Levin <sashal@kernel.org>
include/net/inet_hashtables.h
include/net/sock.h
net/ipv4/inet_hashtables.c
net/ipv4/udp.c

index 749bb1e..825ad1d 100644 (file)
@@ -295,7 +295,6 @@ static inline struct sock *inet_lookup_listener(struct net *net,
        ((__force __portpair)(((__u32)(__dport) << 16) | (__force __u32)(__be16)(__sport)))
 #endif
 
-#if (BITS_PER_LONG == 64)
 #ifdef __BIG_ENDIAN
 #define INET_ADDR_COOKIE(__name, __saddr, __daddr) \
        const __addrpair __name = (__force __addrpair) ( \
@@ -307,24 +306,22 @@ static inline struct sock *inet_lookup_listener(struct net *net,
                                   (((__force __u64)(__be32)(__daddr)) << 32) | \
                                   ((__force __u64)(__be32)(__saddr)))
 #endif /* __BIG_ENDIAN */
-#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif, __sdif) \
-       (((__sk)->sk_portpair == (__ports))                     &&      \
-        ((__sk)->sk_addrpair == (__cookie))                    &&      \
-        (((__sk)->sk_bound_dev_if == (__dif))                  ||      \
-         ((__sk)->sk_bound_dev_if == (__sdif)))                &&      \
-        net_eq(sock_net(__sk), (__net)))
-#else /* 32-bit arch */
-#define INET_ADDR_COOKIE(__name, __saddr, __daddr) \
-       const int __name __deprecated __attribute__((unused))
-
-#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif, __sdif) \
-       (((__sk)->sk_portpair == (__ports))             &&              \
-        ((__sk)->sk_daddr      == (__saddr))           &&              \
-        ((__sk)->sk_rcv_saddr  == (__daddr))           &&              \
-        (((__sk)->sk_bound_dev_if == (__dif))          ||              \
-         ((__sk)->sk_bound_dev_if == (__sdif)))        &&              \
-        net_eq(sock_net(__sk), (__net)))
-#endif /* 64-bit arch */
+
+static inline bool INET_MATCH(struct net *net, const struct sock *sk,
+                             const __addrpair cookie, const __portpair ports,
+                             int dif, int sdif)
+{
+       int bound_dev_if;
+
+       if (!net_eq(sock_net(sk), net) ||
+           sk->sk_portpair != ports ||
+           sk->sk_addrpair != cookie)
+               return false;
+
+       /* Paired with WRITE_ONCE() from sock_bindtoindex_locked() */
+       bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
+       return bound_dev_if == dif || bound_dev_if == sdif;
+}
 
 /* Sockets in TCP_CLOSE state are _always_ taken out of the hash, so we need
  * not check it for lookups anymore, thanks Alexey. -DaveM
index e0a88bb..49a6315 100644 (file)
@@ -161,9 +161,6 @@ typedef __u64 __bitwise __addrpair;
  *     for struct sock and struct inet_timewait_sock.
  */
 struct sock_common {
-       /* skc_daddr and skc_rcv_saddr must be grouped on a 8 bytes aligned
-        * address on 64bit arches : cf INET_MATCH()
-        */
        union {
                __addrpair      skc_addrpair;
                struct {
index 342f3df..7c502c4 100644 (file)
@@ -410,13 +410,11 @@ begin:
        sk_nulls_for_each_rcu(sk, node, &head->chain) {
                if (sk->sk_hash != hash)
                        continue;
-               if (likely(INET_MATCH(sk, net, acookie,
-                                     saddr, daddr, ports, dif, sdif))) {
+               if (likely(INET_MATCH(net, sk, acookie, ports, dif, sdif))) {
                        if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
                                goto out;
-                       if (unlikely(!INET_MATCH(sk, net, acookie,
-                                                saddr, daddr, ports,
-                                                dif, sdif))) {
+                       if (unlikely(!INET_MATCH(net, sk, acookie,
+                                                ports, dif, sdif))) {
                                sock_gen_put(sk);
                                goto begin;
                        }
@@ -465,8 +463,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
                if (sk2->sk_hash != hash)
                        continue;
 
-               if (likely(INET_MATCH(sk2, net, acookie,
-                                        saddr, daddr, ports, dif, sdif))) {
+               if (likely(INET_MATCH(net, sk2, acookie, ports, dif, sdif))) {
                        if (sk2->sk_state == TCP_TIME_WAIT) {
                                tw = inet_twsk(sk2);
                                if (twsk_unique(sk, sk2, twp))
@@ -532,9 +529,7 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk,
                if (esk->sk_hash != sk->sk_hash)
                        continue;
                if (sk->sk_family == AF_INET) {
-                       if (unlikely(INET_MATCH(esk, net, acookie,
-                                               sk->sk_daddr,
-                                               sk->sk_rcv_saddr,
+                       if (unlikely(INET_MATCH(net, esk, acookie,
                                                ports, dif, sdif))) {
                                return true;
                        }
index 4ad4daa..efef7ba 100644 (file)
@@ -2554,8 +2554,7 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net,
        struct sock *sk;
 
        udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
-               if (INET_MATCH(sk, net, acookie, rmt_addr,
-                              loc_addr, ports, dif, sdif))
+               if (INET_MATCH(net, sk, acookie, ports, dif, sdif))
                        return sk;
                /* Only check first socket in chain */
                break;