mctp: Use output netdev to allocate skb headroom
authorMatt Johnston <matt@codeconstruct.com.au>
Fri, 1 Apr 2022 02:48:44 +0000 (10:48 +0800)
committerDavid S. Miller <davem@davemloft.net>
Fri, 1 Apr 2022 11:04:15 +0000 (12:04 +0100)
Previously the skb was allocated with headroom MCTP_HEADER_MAXLEN,
but that isn't sufficient if we are using devs that are not MCTP
specific.

This also adds a check that the smctp_halen provided to sendmsg for
extended addressing is the correct size for the netdev.

Fixes: 833ef3b91de6 ("mctp: Populate socket implementation")
Reported-by: Matthew Rinaldi <mjrinal@g.clemson.edu>
Signed-off-by: Matt Johnston <matt@codeconstruct.com.au>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/mctp.h
net/mctp/af_mctp.c
net/mctp/route.c

index d37268f..82800d5 100644 (file)
@@ -36,8 +36,6 @@ struct mctp_hdr {
 #define MCTP_HDR_TAG_SHIFT     0
 #define MCTP_HDR_TAG_MASK      GENMASK(2, 0)
 
-#define MCTP_HEADER_MAXLEN     4
-
 #define MCTP_INITIAL_DEFAULT_NET       1
 
 static inline bool mctp_address_unicast(mctp_eid_t eid)
index f0702d9..e22b0cb 100644 (file)
@@ -93,13 +93,13 @@ out_release:
 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 {
        DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
-       const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
        int rc, addrlen = msg->msg_namelen;
        struct sock *sk = sock->sk;
        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
        struct mctp_skb_cb *cb;
        struct mctp_route *rt;
-       struct sk_buff *skb;
+       struct sk_buff *skb = NULL;
+       int hlen;
 
        if (addr) {
                const u8 tagbits = MCTP_TAG_MASK | MCTP_TAG_OWNER |
@@ -129,6 +129,34 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
        if (addr->smctp_network == MCTP_NET_ANY)
                addr->smctp_network = mctp_default_net(sock_net(sk));
 
+       /* direct addressing */
+       if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
+               DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
+                                extaddr, msg->msg_name);
+               struct net_device *dev;
+
+               rc = -EINVAL;
+               rcu_read_lock();
+               dev = dev_get_by_index_rcu(sock_net(sk), extaddr->smctp_ifindex);
+               /* check for correct halen */
+               if (dev && extaddr->smctp_halen == dev->addr_len) {
+                       hlen = LL_RESERVED_SPACE(dev) + sizeof(struct mctp_hdr);
+                       rc = 0;
+               }
+               rcu_read_unlock();
+               if (rc)
+                       goto err_free;
+               rt = NULL;
+       } else {
+               rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
+                                      addr->smctp_addr.s_addr);
+               if (!rt) {
+                       rc = -EHOSTUNREACH;
+                       goto err_free;
+               }
+               hlen = LL_RESERVED_SPACE(rt->dev->dev) + sizeof(struct mctp_hdr);
+       }
+
        skb = sock_alloc_send_skb(sk, hlen + 1 + len,
                                  msg->msg_flags & MSG_DONTWAIT, &rc);
        if (!skb)
@@ -147,8 +175,8 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
        cb = __mctp_cb(skb);
        cb->net = addr->smctp_network;
 
-       /* direct addressing */
-       if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
+       if (!rt) {
+               /* fill extended address in cb */
                DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
                                 extaddr, msg->msg_name);
 
@@ -159,17 +187,9 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
                }
 
                cb->ifindex = extaddr->smctp_ifindex;
+               /* smctp_halen is checked above */
                cb->halen = extaddr->smctp_halen;
                memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen);
-
-               rt = NULL;
-       } else {
-               rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
-                                      addr->smctp_addr.s_addr);
-               if (!rt) {
-                       rc = -EHOSTUNREACH;
-                       goto err_free;
-               }
        }
 
        rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
index ee548c4..3b24b8d 100644 (file)
@@ -503,6 +503,11 @@ static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
 
        if (cb->ifindex) {
                /* direct route; use the hwaddr we stashed in sendmsg */
+               if (cb->halen != skb->dev->addr_len) {
+                       /* sanity check, sendmsg should have already caught this */
+                       kfree_skb(skb);
+                       return -EMSGSIZE;
+               }
                daddr = cb->haddr;
        } else {
                /* If lookup fails let the device handle daddr==NULL */
@@ -756,7 +761,7 @@ static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
 {
        const unsigned int hlen = sizeof(struct mctp_hdr);
        struct mctp_hdr *hdr, *hdr2;
-       unsigned int pos, size;
+       unsigned int pos, size, headroom;
        struct sk_buff *skb2;
        int rc;
        u8 seq;
@@ -770,6 +775,9 @@ static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
                return -EMSGSIZE;
        }
 
+       /* keep same headroom as the original skb */
+       headroom = skb_headroom(skb);
+
        /* we've got the header */
        skb_pull(skb, hlen);
 
@@ -777,7 +785,7 @@ static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
                /* size of message payload */
                size = min(mtu - hlen, skb->len - pos);
 
-               skb2 = alloc_skb(MCTP_HEADER_MAXLEN + hlen + size, GFP_KERNEL);
+               skb2 = alloc_skb(headroom + hlen + size, GFP_KERNEL);
                if (!skb2) {
                        rc = -ENOMEM;
                        break;
@@ -793,7 +801,7 @@ static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
                        skb_set_owner_w(skb2, skb->sk);
 
                /* establish packet */
-               skb_reserve(skb2, MCTP_HEADER_MAXLEN);
+               skb_reserve(skb2, headroom);
                skb_reset_network_header(skb2);
                skb_put(skb2, hlen + size);
                skb2->transport_header = skb2->network_header + hlen;