af_packet: Don't send zero-byte data in packet_sendmsg_spkt().
[platform/kernel/linux-rpi.git] / net / packet / af_packet.c
index 497193f..640d94e 100644 (file)
@@ -2033,7 +2033,7 @@ retry:
                goto retry;
        }
 
-       if (!dev_validate_header(dev, skb->data, len)) {
+       if (!dev_validate_header(dev, skb->data, len) || !skb->len) {
                err = -EINVAL;
                goto out_unlock;
        }
@@ -2090,18 +2090,18 @@ static unsigned int run_filter(struct sk_buff *skb,
 }
 
 static int packet_rcv_vnet(struct msghdr *msg, const struct sk_buff *skb,
-                          size_t *len)
+                          size_t *len, int vnet_hdr_sz)
 {
-       struct virtio_net_hdr vnet_hdr;
+       struct virtio_net_hdr_mrg_rxbuf vnet_hdr = { .num_buffers = 0 };
 
-       if (*len < sizeof(vnet_hdr))
+       if (*len < vnet_hdr_sz)
                return -EINVAL;
-       *len -= sizeof(vnet_hdr);
+       *len -= vnet_hdr_sz;
 
-       if (virtio_net_hdr_from_skb(skb, &vnet_hdr, vio_le(), true, 0))
+       if (virtio_net_hdr_from_skb(skb, (struct virtio_net_hdr *)&vnet_hdr, vio_le(), true, 0))
                return -EINVAL;
 
-       return memcpy_to_msg(msg, (void *)&vnet_hdr, sizeof(vnet_hdr));
+       return memcpy_to_msg(msg, (void *)&vnet_hdr, vnet_hdr_sz);
 }
 
 /*
@@ -2250,7 +2250,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
        __u32 ts_status;
        bool is_drop_n_account = false;
        unsigned int slot_id = 0;
-       bool do_vnet = false;
+       int vnet_hdr_sz = 0;
 
        /* struct tpacket{2,3}_hdr is aligned to a multiple of TPACKET_ALIGNMENT.
         * We may add members to them until current aligned size without forcing
@@ -2308,10 +2308,9 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                netoff = TPACKET_ALIGN(po->tp_hdrlen +
                                       (maclen < 16 ? 16 : maclen)) +
                                       po->tp_reserve;
-               if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
-                       netoff += sizeof(struct virtio_net_hdr);
-                       do_vnet = true;
-               }
+               vnet_hdr_sz = READ_ONCE(po->vnet_hdr_sz);
+               if (vnet_hdr_sz)
+                       netoff += vnet_hdr_sz;
                macoff = netoff - maclen;
        }
        if (netoff > USHRT_MAX) {
@@ -2337,7 +2336,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                        snaplen = po->rx_ring.frame_size - macoff;
                        if ((int)snaplen < 0) {
                                snaplen = 0;
-                               do_vnet = false;
+                               vnet_hdr_sz = 0;
                        }
                }
        } else if (unlikely(macoff + snaplen >
@@ -2351,7 +2350,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                if (unlikely((int)snaplen < 0)) {
                        snaplen = 0;
                        macoff = GET_PBDQC_FROM_RB(&po->rx_ring)->max_frame_len;
-                       do_vnet = false;
+                       vnet_hdr_sz = 0;
                }
        }
        spin_lock(&sk->sk_receive_queue.lock);
@@ -2367,7 +2366,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                __set_bit(slot_id, po->rx_ring.rx_owner_map);
        }
 
-       if (do_vnet &&
+       if (vnet_hdr_sz &&
            virtio_net_hdr_from_skb(skb, h.raw + macoff -
                                    sizeof(struct virtio_net_hdr),
                                    vio_le(), true, 0)) {
@@ -2551,16 +2550,26 @@ static int __packet_snd_vnet_parse(struct virtio_net_hdr *vnet_hdr, size_t len)
 }
 
 static int packet_snd_vnet_parse(struct msghdr *msg, size_t *len,
-                                struct virtio_net_hdr *vnet_hdr)
+                                struct virtio_net_hdr *vnet_hdr, int vnet_hdr_sz)
 {
-       if (*len < sizeof(*vnet_hdr))
+       int ret;
+
+       if (*len < vnet_hdr_sz)
                return -EINVAL;
-       *len -= sizeof(*vnet_hdr);
+       *len -= vnet_hdr_sz;
 
        if (!copy_from_iter_full(vnet_hdr, sizeof(*vnet_hdr), &msg->msg_iter))
                return -EFAULT;
 
-       return __packet_snd_vnet_parse(vnet_hdr, *len);
+       ret = __packet_snd_vnet_parse(vnet_hdr, *len);
+       if (ret)
+               return ret;
+
+       /* move iter to point to the start of mac header */
+       if (vnet_hdr_sz != sizeof(struct virtio_net_hdr))
+               iov_iter_advance(&msg->msg_iter, vnet_hdr_sz - sizeof(struct virtio_net_hdr));
+
+       return 0;
 }
 
 static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
@@ -2622,8 +2631,8 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
                nr_frags = skb_shinfo(skb)->nr_frags;
 
                if (unlikely(nr_frags >= MAX_SKB_FRAGS)) {
-                       pr_err("Packet exceed the number of skb frags(%lu)\n",
-                              MAX_SKB_FRAGS);
+                       pr_err("Packet exceed the number of skb frags(%u)\n",
+                              (unsigned int)MAX_SKB_FRAGS);
                        return -EFAULT;
                }
 
@@ -2722,6 +2731,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
        void *ph;
        DECLARE_SOCKADDR(struct sockaddr_ll *, saddr, msg->msg_name);
        bool need_wait = !(msg->msg_flags & MSG_DONTWAIT);
+       int vnet_hdr_sz = READ_ONCE(po->vnet_hdr_sz);
        unsigned char *addr = NULL;
        int tp_len, size_max;
        void *data;
@@ -2779,8 +2789,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
        size_max = po->tx_ring.frame_size
                - (po->tp_hdrlen - sizeof(struct sockaddr_ll));
 
-       if ((size_max > dev->mtu + reserve + VLAN_HLEN) &&
-           !packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR))
+       if ((size_max > dev->mtu + reserve + VLAN_HLEN) && !vnet_hdr_sz)
                size_max = dev->mtu + reserve + VLAN_HLEN;
 
        reinit_completion(&po->skb_completion);
@@ -2809,10 +2818,10 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
                status = TP_STATUS_SEND_REQUEST;
                hlen = LL_RESERVED_SPACE(dev);
                tlen = dev->needed_tailroom;
-               if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
+               if (vnet_hdr_sz) {
                        vnet_hdr = data;
-                       data += sizeof(*vnet_hdr);
-                       tp_len -= sizeof(*vnet_hdr);
+                       data += vnet_hdr_sz;
+                       tp_len -= vnet_hdr_sz;
                        if (tp_len < 0 ||
                            __packet_snd_vnet_parse(vnet_hdr, tp_len)) {
                                tp_len = -EINVAL;
@@ -2837,7 +2846,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
                                          addr, hlen, copylen, &sockc);
                if (likely(tp_len >= 0) &&
                    tp_len > dev->mtu + reserve &&
-                   !packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR) &&
+                   !vnet_hdr_sz &&
                    !packet_extra_vlan_len_allowed(dev, skb))
                        tp_len = -EMSGSIZE;
 
@@ -2856,7 +2865,7 @@ tpacket_error:
                        }
                }
 
-               if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
+               if (vnet_hdr_sz) {
                        if (virtio_net_hdr_to_skb(skb, vnet_hdr, vio_le())) {
                                tp_len = -EINVAL;
                                goto tpacket_error;
@@ -2946,7 +2955,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
        struct virtio_net_hdr vnet_hdr = { 0 };
        int offset = 0;
        struct packet_sock *po = pkt_sk(sk);
-       bool has_vnet_hdr = false;
+       int vnet_hdr_sz = READ_ONCE(po->vnet_hdr_sz);
        int hlen, tlen, linear;
        int extra_len = 0;
 
@@ -2990,11 +2999,10 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
 
        if (sock->type == SOCK_RAW)
                reserve = dev->hard_header_len;
-       if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
-               err = packet_snd_vnet_parse(msg, &len, &vnet_hdr);
+       if (vnet_hdr_sz) {
+               err = packet_snd_vnet_parse(msg, &len, &vnet_hdr, vnet_hdr_sz);
                if (err)
                        goto out_unlock;
-               has_vnet_hdr = true;
        }
 
        if (unlikely(sock_flag(sk, SOCK_NOFCS))) {
@@ -3064,11 +3072,11 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
 
        packet_parse_headers(skb, sock);
 
-       if (has_vnet_hdr) {
+       if (vnet_hdr_sz) {
                err = virtio_net_hdr_to_skb(skb, &vnet_hdr, vio_le());
                if (err)
                        goto out_free;
-               len += sizeof(vnet_hdr);
+               len += vnet_hdr_sz;
                virtio_net_hdr_set_proto(skb, &vnet_hdr);
        }
 
@@ -3408,7 +3416,7 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
        struct sock *sk = sock->sk;
        struct sk_buff *skb;
        int copied, err;
-       int vnet_hdr_len = 0;
+       int vnet_hdr_len = READ_ONCE(pkt_sk(sk)->vnet_hdr_sz);
        unsigned int origlen = 0;
 
        err = -EINVAL;
@@ -3449,11 +3457,10 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 
        packet_rcv_try_clear_pressure(pkt_sk(sk));
 
-       if (packet_sock_flag(pkt_sk(sk), PACKET_SOCK_HAS_VNET_HDR)) {
-               err = packet_rcv_vnet(msg, skb, &len);
+       if (vnet_hdr_len) {
+               err = packet_rcv_vnet(msg, skb, &len, vnet_hdr_len);
                if (err)
                        goto out_free;
-               vnet_hdr_len = sizeof(struct virtio_net_hdr);
        }
 
        /* You lose any data beyond the buffer you gave. If it worries
@@ -3915,8 +3922,9 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
                return 0;
        }
        case PACKET_VNET_HDR:
+       case PACKET_VNET_HDR_SZ:
        {
-               int val;
+               int val, hdr_len;
 
                if (sock->type != SOCK_RAW)
                        return -EINVAL;
@@ -3925,11 +3933,19 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
                if (copy_from_sockptr(&val, optval, sizeof(val)))
                        return -EFAULT;
 
+               if (optname == PACKET_VNET_HDR_SZ) {
+                       if (val && val != sizeof(struct virtio_net_hdr) &&
+                           val != sizeof(struct virtio_net_hdr_mrg_rxbuf))
+                               return -EINVAL;
+                       hdr_len = val;
+               } else {
+                       hdr_len = val ? sizeof(struct virtio_net_hdr) : 0;
+               }
                lock_sock(sk);
                if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) {
                        ret = -EBUSY;
                } else {
-                       packet_sock_flag_set(po, PACKET_SOCK_HAS_VNET_HDR, val);
+                       WRITE_ONCE(po->vnet_hdr_sz, hdr_len);
                        ret = 0;
                }
                release_sock(sk);
@@ -4062,7 +4078,10 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
                val = packet_sock_flag(po, PACKET_SOCK_ORIGDEV);
                break;
        case PACKET_VNET_HDR:
-               val = packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR);
+               val = !!READ_ONCE(po->vnet_hdr_sz);
+               break;
+       case PACKET_VNET_HDR_SZ:
+               val = READ_ONCE(po->vnet_hdr_sz);
                break;
        case PACKET_VERSION:
                val = po->tp_version;