net/tcp: Merge TCP-MD5 inbound callbacks
authorDmitry Safonov <dima@arista.com>
Wed, 23 Feb 2022 17:57:40 +0000 (17:57 +0000)
committerJakub Kicinski <kuba@kernel.org>
Fri, 25 Feb 2022 05:43:53 +0000 (21:43 -0800)
The functions do essentially the same work to verify TCP-MD5 sign.
Code can be merged into one family-independent function in order to
reduce copy'n'paste and generated code.
Later with TCP-AO option added, this will allow to create one function
that's responsible for segment verification, that will have all the
different checks for MD5/AO/non-signed packets, which in turn will help
to see checks for all corner-cases in one function, rather than spread
around different families and functions.

Cc: Eric Dumazet <edumazet@google.com>
Cc: Hideaki YOSHIFUJI <yoshfuji@linux-ipv6.org>
Signed-off-by: Dmitry Safonov <dima@arista.com>
Reviewed-by: David Ahern <dsahern@kernel.org>
Link: https://lore.kernel.org/r/20220223175740.452397-1-dima@arista.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
include/net/tcp.h
net/ipv4/tcp.c
net/ipv4/tcp_ipv4.c
net/ipv6/tcp_ipv6.c

index 04f4650..479a277 100644 (file)
@@ -1674,6 +1674,11 @@ tcp_md5_do_lookup(const struct sock *sk, int l3index,
                return NULL;
        return __tcp_md5_do_lookup(sk, l3index, addr, family);
 }
+bool tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
+                         enum skb_drop_reason *reason,
+                         const void *saddr, const void *daddr,
+                         int family, int dif, int sdif);
+
 
 #define tcp_twsk_md5_key(twsk) ((twsk)->tw_md5_key)
 #else
@@ -1683,6 +1688,14 @@ tcp_md5_do_lookup(const struct sock *sk, int l3index,
 {
        return NULL;
 }
+static inline bool tcp_inbound_md5_hash(const struct sock *sk,
+                                       const struct sk_buff *skb,
+                                       enum skb_drop_reason *reason,
+                                       const void *saddr, const void *daddr,
+                                       int family, int dif, int sdif)
+{
+       return false;
+}
 #define tcp_twsk_md5_key(twsk) NULL
 #endif
 
index 760e822..68f1236 100644 (file)
@@ -4431,6 +4431,76 @@ int tcp_md5_hash_key(struct tcp_md5sig_pool *hp, const struct tcp_md5sig_key *ke
 }
 EXPORT_SYMBOL(tcp_md5_hash_key);
 
+/* Called with rcu_read_lock() */
+bool tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb,
+                         enum skb_drop_reason *reason,
+                         const void *saddr, const void *daddr,
+                         int family, int dif, int sdif)
+{
+       /*
+        * This gets called for each TCP segment that arrives
+        * so we want to be efficient.
+        * We have 3 drop cases:
+        * o No MD5 hash and one expected.
+        * o MD5 hash and we're not expecting one.
+        * o MD5 hash and its wrong.
+        */
+       const __u8 *hash_location = NULL;
+       struct tcp_md5sig_key *hash_expected;
+       const struct tcphdr *th = tcp_hdr(skb);
+       struct tcp_sock *tp = tcp_sk(sk);
+       int genhash, l3index;
+       u8 newhash[16];
+
+       /* sdif set, means packet ingressed via a device
+        * in an L3 domain and dif is set to the l3mdev
+        */
+       l3index = sdif ? dif : 0;
+
+       hash_expected = tcp_md5_do_lookup(sk, l3index, saddr, family);
+       hash_location = tcp_parse_md5sig_option(th);
+
+       /* We've parsed the options - do we have a hash? */
+       if (!hash_expected && !hash_location)
+               return false;
+
+       if (hash_expected && !hash_location) {
+               *reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
+               NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
+               return true;
+       }
+
+       if (!hash_expected && hash_location) {
+               *reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
+               NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
+               return true;
+       }
+
+       /* check the signature */
+       genhash = tp->af_specific->calc_md5_hash(newhash, hash_expected,
+                                                NULL, skb);
+
+       if (genhash || memcmp(hash_location, newhash, 16) != 0) {
+               *reason = SKB_DROP_REASON_TCP_MD5FAILURE;
+               NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
+               if (family == AF_INET) {
+                       net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
+                                       saddr, ntohs(th->source),
+                                       daddr, ntohs(th->dest),
+                                       genhash ? " tcp_v4_calc_md5_hash failed"
+                                       : "", l3index);
+               } else {
+                       net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n",
+                                       genhash ? "failed" : "mismatch",
+                                       saddr, ntohs(th->source),
+                                       daddr, ntohs(th->dest), l3index);
+               }
+               return true;
+       }
+       return false;
+}
+EXPORT_SYMBOL(tcp_inbound_md5_hash);
+
 #endif
 
 void tcp_done(struct sock *sk)
index d42824a..411357a 100644 (file)
@@ -1409,76 +1409,6 @@ EXPORT_SYMBOL(tcp_v4_md5_hash_skb);
 
 #endif
 
-/* Called with rcu_read_lock() */
-static bool tcp_v4_inbound_md5_hash(const struct sock *sk,
-                                   const struct sk_buff *skb,
-                                   int dif, int sdif,
-                                   enum skb_drop_reason *reason)
-{
-#ifdef CONFIG_TCP_MD5SIG
-       /*
-        * This gets called for each TCP segment that arrives
-        * so we want to be efficient.
-        * We have 3 drop cases:
-        * o No MD5 hash and one expected.
-        * o MD5 hash and we're not expecting one.
-        * o MD5 hash and its wrong.
-        */
-       const __u8 *hash_location = NULL;
-       struct tcp_md5sig_key *hash_expected;
-       const struct iphdr *iph = ip_hdr(skb);
-       const struct tcphdr *th = tcp_hdr(skb);
-       const union tcp_md5_addr *addr;
-       unsigned char newhash[16];
-       int genhash, l3index;
-
-       /* sdif set, means packet ingressed via a device
-        * in an L3 domain and dif is set to the l3mdev
-        */
-       l3index = sdif ? dif : 0;
-
-       addr = (union tcp_md5_addr *)&iph->saddr;
-       hash_expected = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
-       hash_location = tcp_parse_md5sig_option(th);
-
-       /* We've parsed the options - do we have a hash? */
-       if (!hash_expected && !hash_location)
-               return false;
-
-       if (hash_expected && !hash_location) {
-               *reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
-               NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
-               return true;
-       }
-
-       if (!hash_expected && hash_location) {
-               *reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
-               NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
-               return true;
-       }
-
-       /* Okay, so this is hash_expected and hash_location -
-        * so we need to calculate the checksum.
-        */
-       genhash = tcp_v4_md5_hash_skb(newhash,
-                                     hash_expected,
-                                     NULL, skb);
-
-       if (genhash || memcmp(hash_location, newhash, 16) != 0) {
-               *reason = SKB_DROP_REASON_TCP_MD5FAILURE;
-               NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
-               net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n",
-                                    &iph->saddr, ntohs(th->source),
-                                    &iph->daddr, ntohs(th->dest),
-                                    genhash ? " tcp_v4_calc_md5_hash failed"
-                                    : "", l3index);
-               return true;
-       }
-       return false;
-#endif
-       return false;
-}
-
 static void tcp_v4_init_req(struct request_sock *req,
                            const struct sock *sk_listener,
                            struct sk_buff *skb)
@@ -2035,8 +1965,9 @@ process:
                struct sock *nsk;
 
                sk = req->rsk_listener;
-               if (unlikely(tcp_v4_inbound_md5_hash(sk, skb, dif, sdif,
-                                                    &drop_reason))) {
+               if (unlikely(tcp_inbound_md5_hash(sk, skb, &drop_reason,
+                                                 &iph->saddr, &iph->daddr,
+                                                 AF_INET, dif, sdif))) {
                        sk_drops_add(sk, skb);
                        reqsk_put(req);
                        goto discard_it;
@@ -2110,7 +2041,8 @@ process:
                goto discard_and_relse;
        }
 
-       if (tcp_v4_inbound_md5_hash(sk, skb, dif, sdif, &drop_reason))
+       if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &iph->saddr,
+                                &iph->daddr, AF_INET, dif, sdif))
                goto discard_and_relse;
 
        nf_reset_ct(skb);
index 749de85..e98af86 100644 (file)
@@ -773,61 +773,6 @@ clear_hash_noput:
 
 #endif
 
-static bool tcp_v6_inbound_md5_hash(const struct sock *sk,
-                                   const struct sk_buff *skb,
-                                   int dif, int sdif,
-                                   enum skb_drop_reason *reason)
-{
-#ifdef CONFIG_TCP_MD5SIG
-       const __u8 *hash_location = NULL;
-       struct tcp_md5sig_key *hash_expected;
-       const struct ipv6hdr *ip6h = ipv6_hdr(skb);
-       const struct tcphdr *th = tcp_hdr(skb);
-       int genhash, l3index;
-       u8 newhash[16];
-
-       /* sdif set, means packet ingressed via a device
-        * in an L3 domain and dif is set to the l3mdev
-        */
-       l3index = sdif ? dif : 0;
-
-       hash_expected = tcp_v6_md5_do_lookup(sk, &ip6h->saddr, l3index);
-       hash_location = tcp_parse_md5sig_option(th);
-
-       /* We've parsed the options - do we have a hash? */
-       if (!hash_expected && !hash_location)
-               return false;
-
-       if (hash_expected && !hash_location) {
-               *reason = SKB_DROP_REASON_TCP_MD5NOTFOUND;
-               NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND);
-               return true;
-       }
-
-       if (!hash_expected && hash_location) {
-               *reason = SKB_DROP_REASON_TCP_MD5UNEXPECTED;
-               NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED);
-               return true;
-       }
-
-       /* check the signature */
-       genhash = tcp_v6_md5_hash_skb(newhash,
-                                     hash_expected,
-                                     NULL, skb);
-
-       if (genhash || memcmp(hash_location, newhash, 16) != 0) {
-               *reason = SKB_DROP_REASON_TCP_MD5FAILURE;
-               NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE);
-               net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n",
-                                    genhash ? "failed" : "mismatch",
-                                    &ip6h->saddr, ntohs(th->source),
-                                    &ip6h->daddr, ntohs(th->dest), l3index);
-               return true;
-       }
-#endif
-       return false;
-}
-
 static void tcp_v6_init_req(struct request_sock *req,
                            const struct sock *sk_listener,
                            struct sk_buff *skb)
@@ -1687,8 +1632,8 @@ process:
                struct sock *nsk;
 
                sk = req->rsk_listener;
-               if (tcp_v6_inbound_md5_hash(sk, skb, dif, sdif,
-                                           &drop_reason)) {
+               if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &hdr->saddr,
+                                        &hdr->daddr, AF_INET6, dif, sdif)) {
                        sk_drops_add(sk, skb);
                        reqsk_put(req);
                        goto discard_it;
@@ -1759,7 +1704,8 @@ process:
                goto discard_and_relse;
        }
 
-       if (tcp_v6_inbound_md5_hash(sk, skb, dif, sdif, &drop_reason))
+       if (tcp_inbound_md5_hash(sk, skb, &drop_reason, &hdr->saddr,
+                                &hdr->daddr, AF_INET6, dif, sdif))
                goto discard_and_relse;
 
        if (tcp_filter(sk, skb)) {