ip, ip6: Fix splice to raw and ping sockets
[platform/kernel/linux-starfive.git] / net / ipv6 / ip6_output.c
index c314fdd..1e8c90e 100644 (file)
@@ -42,6 +42,7 @@
 #include <net/sock.h>
 #include <net/snmp.h>
 
+#include <net/gso.h>
 #include <net/ipv6.h>
 #include <net/ndisc.h>
 #include <net/protocol.h>
@@ -116,7 +117,7 @@ static int ip6_finish_output2(struct net *net, struct sock *sk, struct sk_buff *
                        return res;
        }
 
-       rcu_read_lock_bh();
+       rcu_read_lock();
        nexthop = rt6_nexthop((struct rt6_info *)dst, daddr);
        neigh = __ipv6_neigh_lookup_noref(dev, nexthop);
 
@@ -124,7 +125,7 @@ static int ip6_finish_output2(struct net *net, struct sock *sk, struct sk_buff *
                if (unlikely(!neigh))
                        neigh = __neigh_create(&nd_tbl, nexthop, dev, false);
                if (IS_ERR(neigh)) {
-                       rcu_read_unlock_bh();
+                       rcu_read_unlock();
                        IP6_INC_STATS(net, idev, IPSTATS_MIB_OUTNOROUTES);
                        kfree_skb_reason(skb, SKB_DROP_REASON_NEIGH_CREATEFAIL);
                        return -EINVAL;
@@ -132,7 +133,7 @@ static int ip6_finish_output2(struct net *net, struct sock *sk, struct sk_buff *
        }
        sock_confirm_neigh(skb, neigh);
        ret = neigh_output(neigh, skb, false);
-       rcu_read_unlock_bh();
+       rcu_read_unlock();
        return ret;
 }
 
@@ -1150,11 +1151,11 @@ static int ip6_dst_lookup_tail(struct net *net, const struct sock *sk,
         * dst entry of the nexthop router
         */
        rt = (struct rt6_info *) *dst;
-       rcu_read_lock_bh();
+       rcu_read_lock();
        n = __ipv6_neigh_lookup_noref(rt->dst.dev,
                                      rt6_nexthop(rt, &fl6->daddr));
-       err = n && !(n->nud_state & NUD_VALID) ? -EINVAL : 0;
-       rcu_read_unlock_bh();
+       err = n && !(READ_ONCE(n->nud_state) & NUD_VALID) ? -EINVAL : 0;
+       rcu_read_unlock();
 
        if (err) {
                struct inet6_ifaddr *ifp;
@@ -1500,7 +1501,7 @@ static int __ip6_append_data(struct sock *sk,
        mtu = cork->gso_size ? IP6_MAX_MTU : cork->fragsize;
        orig_mtu = mtu;
 
-       if (cork->tx_flags & SKBTX_ANY_SW_TSTAMP &&
+       if (cork->tx_flags & SKBTX_ANY_TSTAMP &&
            sk->sk_tsflags & SOF_TIMESTAMPING_OPT_ID)
                tskey = atomic_inc_return(&sk->sk_tskey) - 1;
 
@@ -1589,6 +1590,15 @@ emsgsize:
                                skb_zcopy_set(skb, uarg, &extra_uref);
                        }
                }
+       } else if ((flags & MSG_SPLICE_PAGES) && length) {
+               if (inet_sk(sk)->hdrincl)
+                       return -EPERM;
+               if (rt->dst.dev->features & NETIF_F_SG &&
+                   getfrag == ip_generic_getfrag)
+                       /* We need an empty buffer to attach stuff to */
+                       paged = true;
+               else
+                       flags &= ~MSG_SPLICE_PAGES;
        }
 
        /*
@@ -1778,6 +1788,15 @@ alloc_new_skb:
                                err = -EFAULT;
                                goto error;
                        }
+               } else if (flags & MSG_SPLICE_PAGES) {
+                       struct msghdr *msg = from;
+
+                       err = skb_splice_from_iter(skb, &msg->msg_iter, copy,
+                                                  sk->sk_allocation);
+                       if (err < 0)
+                               goto error;
+                       copy = err;
+                       wmem_alloc_delta += copy;
                } else if (!zc) {
                        int i = skb_shinfo(skb)->nr_frags;
 
@@ -1965,8 +1984,13 @@ struct sk_buff *__ip6_make_skb(struct sock *sk,
        IP6_UPD_PO_STATS(net, rt->rt6i_idev, IPSTATS_MIB_OUT, skb->len);
        if (proto == IPPROTO_ICMPV6) {
                struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb));
+               u8 icmp6_type;
 
-               ICMP6MSGOUT_INC_STATS(net, idev, icmp6_hdr(skb)->icmp6_type);
+               if (sk->sk_socket->type == SOCK_RAW && !inet_sk(sk)->hdrincl)
+                       icmp6_type = fl6->fl6_icmp_type;
+               else
+                       icmp6_type = icmp6_hdr(skb)->icmp6_type;
+               ICMP6MSGOUT_INC_STATS(net, idev, icmp6_type);
                ICMP6_INC_STATS(net, idev, ICMP6_MIB_OUTMSGS);
        }