inet: move inet->recverr to inet->inet_flags
[platform/kernel/linux-starfive.git] / net / ipv4 / ip_sockglue.c
index 8e97d8d..8283d86 100644 (file)
@@ -171,8 +171,10 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb)
 void ip_cmsg_recv_offset(struct msghdr *msg, struct sock *sk,
                         struct sk_buff *skb, int tlen, int offset)
 {
-       struct inet_sock *inet = inet_sk(sk);
-       unsigned int flags = inet->cmsg_flags;
+       unsigned long flags = inet_cmsg_flags(inet_sk(sk));
+
+       if (!flags)
+               return;
 
        /* Ordered by supposed usage frequency */
        if (flags & IP_CMSG_PKTINFO) {
@@ -444,12 +446,11 @@ EXPORT_SYMBOL_GPL(ip_icmp_error);
 
 void ip_local_error(struct sock *sk, int err, __be32 daddr, __be16 port, u32 info)
 {
-       struct inet_sock *inet = inet_sk(sk);
        struct sock_exterr_skb *serr;
        struct iphdr *iph;
        struct sk_buff *skb;
 
-       if (!inet->recverr)
+       if (!inet_test_bit(RECVERR, sk))
                return;
 
        skb = alloc_skb(sizeof(struct iphdr), GFP_ATOMIC);
@@ -568,7 +569,7 @@ int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
        if (ipv4_datagram_support_cmsg(sk, skb, serr->ee.ee_origin)) {
                sin->sin_family = AF_INET;
                sin->sin_addr.s_addr = ip_hdr(skb)->saddr;
-               if (inet_sk(sk)->cmsg_flags)
+               if (inet_cmsg_flags(inet_sk(sk)))
                        ip_cmsg_recv(msg, skb);
        }
 
@@ -592,7 +593,7 @@ void __ip_sock_set_tos(struct sock *sk, int val)
        }
        if (inet_sk(sk)->tos != val) {
                inet_sk(sk)->tos = val;
-               sk->sk_priority = rt_tos2priority(val);
+               WRITE_ONCE(sk->sk_priority, rt_tos2priority(val));
                sk_dst_reset(sk);
        }
 }
@@ -615,9 +616,7 @@ EXPORT_SYMBOL(ip_sock_set_freebind);
 
 void ip_sock_set_recverr(struct sock *sk)
 {
-       lock_sock(sk);
-       inet_sk(sk)->recverr = true;
-       release_sock(sk);
+       inet_set_bit(RECVERR, sk);
 }
 EXPORT_SYMBOL(ip_sock_set_recverr);
 
@@ -634,9 +633,7 @@ EXPORT_SYMBOL(ip_sock_set_mtu_discover);
 
 void ip_sock_set_pktinfo(struct sock *sk)
 {
-       lock_sock(sk);
-       inet_sk(sk)->cmsg_flags |= IP_CMSG_PKTINFO;
-       release_sock(sk);
+       inet_set_bit(PKTINFO, sk);
 }
 EXPORT_SYMBOL(ip_sock_set_pktinfo);
 
@@ -950,6 +947,41 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
        if (ip_mroute_opt(optname))
                return ip_mroute_setsockopt(sk, optname, optval, optlen);
 
+       /* Handle options that can be set without locking the socket. */
+       switch (optname) {
+       case IP_PKTINFO:
+               inet_assign_bit(PKTINFO, sk, val);
+               return 0;
+       case IP_RECVTTL:
+               inet_assign_bit(TTL, sk, val);
+               return 0;
+       case IP_RECVTOS:
+               inet_assign_bit(TOS, sk, val);
+               return 0;
+       case IP_RECVOPTS:
+               inet_assign_bit(RECVOPTS, sk, val);
+               return 0;
+       case IP_RETOPTS:
+               inet_assign_bit(RETOPTS, sk, val);
+               return 0;
+       case IP_PASSSEC:
+               inet_assign_bit(PASSSEC, sk, val);
+               return 0;
+       case IP_RECVORIGDSTADDR:
+               inet_assign_bit(ORIGDSTADDR, sk, val);
+               return 0;
+       case IP_RECVFRAGSIZE:
+               if (sk->sk_type != SOCK_RAW && sk->sk_type != SOCK_DGRAM)
+                       return -EINVAL;
+               inet_assign_bit(RECVFRAGSIZE, sk, val);
+               return 0;
+       case IP_RECVERR:
+               inet_assign_bit(RECVERR, sk, val);
+               if (!val)
+                       skb_queue_purge(&sk->sk_error_queue);
+               return 0;
+       }
+
        err = 0;
        if (needs_rtnl)
                rtnl_lock();
@@ -989,69 +1021,19 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
                        kfree_rcu(old, rcu);
                break;
        }
-       case IP_PKTINFO:
-               if (val)
-                       inet->cmsg_flags |= IP_CMSG_PKTINFO;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_PKTINFO;
-               break;
-       case IP_RECVTTL:
-               if (val)
-                       inet->cmsg_flags |=  IP_CMSG_TTL;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_TTL;
-               break;
-       case IP_RECVTOS:
-               if (val)
-                       inet->cmsg_flags |=  IP_CMSG_TOS;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_TOS;
-               break;
-       case IP_RECVOPTS:
-               if (val)
-                       inet->cmsg_flags |=  IP_CMSG_RECVOPTS;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_RECVOPTS;
-               break;
-       case IP_RETOPTS:
-               if (val)
-                       inet->cmsg_flags |= IP_CMSG_RETOPTS;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_RETOPTS;
-               break;
-       case IP_PASSSEC:
-               if (val)
-                       inet->cmsg_flags |= IP_CMSG_PASSSEC;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_PASSSEC;
-               break;
-       case IP_RECVORIGDSTADDR:
-               if (val)
-                       inet->cmsg_flags |= IP_CMSG_ORIGDSTADDR;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_ORIGDSTADDR;
-               break;
        case IP_CHECKSUM:
                if (val) {
-                       if (!(inet->cmsg_flags & IP_CMSG_CHECKSUM)) {
+                       if (!(inet_test_bit(CHECKSUM, sk))) {
                                inet_inc_convert_csum(sk);
-                               inet->cmsg_flags |= IP_CMSG_CHECKSUM;
+                               inet_set_bit(CHECKSUM, sk);
                        }
                } else {
-                       if (inet->cmsg_flags & IP_CMSG_CHECKSUM) {
+                       if (inet_test_bit(CHECKSUM, sk)) {
                                inet_dec_convert_csum(sk);
-                               inet->cmsg_flags &= ~IP_CMSG_CHECKSUM;
+                               inet_clear_bit(CHECKSUM, sk);
                        }
                }
                break;
-       case IP_RECVFRAGSIZE:
-               if (sk->sk_type != SOCK_RAW && sk->sk_type != SOCK_DGRAM)
-                       goto e_inval;
-               if (val)
-                       inet->cmsg_flags |= IP_CMSG_RECVFRAGSIZE;
-               else
-                       inet->cmsg_flags &= ~IP_CMSG_RECVFRAGSIZE;
-               break;
        case IP_TOS:    /* This sets both TOS and Precedence */
                __ip_sock_set_tos(sk, val);
                break;
@@ -1084,11 +1066,6 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
                        goto e_inval;
                inet->pmtudisc = val;
                break;
-       case IP_RECVERR:
-               inet->recverr = !!val;
-               if (!val)
-                       skb_queue_purge(&sk->sk_error_queue);
-               break;
        case IP_RECVERR_RFC4884:
                if (val < 0 || val > 1)
                        goto e_inval;
@@ -1415,7 +1392,7 @@ e_inval:
 void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb)
 {
        struct in_pktinfo *pktinfo = PKTINFO_SKB_CB(skb);
-       bool prepare = (inet_sk(sk)->cmsg_flags & IP_CMSG_PKTINFO) ||
+       bool prepare = inet_test_bit(PKTINFO, sk) ||
                       ipv6_sk_rxinfo(sk);
 
        if (prepare && skb_rtable(skb)) {
@@ -1566,6 +1543,40 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
        if (len < 0)
                return -EINVAL;
 
+       /* Handle options that can be read without locking the socket. */
+       switch (optname) {
+       case IP_PKTINFO:
+               val = inet_test_bit(PKTINFO, sk);
+               goto copyval;
+       case IP_RECVTTL:
+               val = inet_test_bit(TTL, sk);
+               goto copyval;
+       case IP_RECVTOS:
+               val = inet_test_bit(TOS, sk);
+               goto copyval;
+       case IP_RECVOPTS:
+               val = inet_test_bit(RECVOPTS, sk);
+               goto copyval;
+       case IP_RETOPTS:
+               val = inet_test_bit(RETOPTS, sk);
+               goto copyval;
+       case IP_PASSSEC:
+               val = inet_test_bit(PASSSEC, sk);
+               goto copyval;
+       case IP_RECVORIGDSTADDR:
+               val = inet_test_bit(ORIGDSTADDR, sk);
+               goto copyval;
+       case IP_CHECKSUM:
+               val = inet_test_bit(CHECKSUM, sk);
+               goto copyval;
+       case IP_RECVFRAGSIZE:
+               val = inet_test_bit(RECVFRAGSIZE, sk);
+               goto copyval;
+       case IP_RECVERR:
+               val = inet_test_bit(RECVERR, sk);
+               goto copyval;
+       }
+
        if (needs_rtnl)
                rtnl_lock();
        sockopt_lock_sock(sk);
@@ -1600,33 +1611,6 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
                        return -EFAULT;
                return 0;
        }
-       case IP_PKTINFO:
-               val = (inet->cmsg_flags & IP_CMSG_PKTINFO) != 0;
-               break;
-       case IP_RECVTTL:
-               val = (inet->cmsg_flags & IP_CMSG_TTL) != 0;
-               break;
-       case IP_RECVTOS:
-               val = (inet->cmsg_flags & IP_CMSG_TOS) != 0;
-               break;
-       case IP_RECVOPTS:
-               val = (inet->cmsg_flags & IP_CMSG_RECVOPTS) != 0;
-               break;
-       case IP_RETOPTS:
-               val = (inet->cmsg_flags & IP_CMSG_RETOPTS) != 0;
-               break;
-       case IP_PASSSEC:
-               val = (inet->cmsg_flags & IP_CMSG_PASSSEC) != 0;
-               break;
-       case IP_RECVORIGDSTADDR:
-               val = (inet->cmsg_flags & IP_CMSG_ORIGDSTADDR) != 0;
-               break;
-       case IP_CHECKSUM:
-               val = (inet->cmsg_flags & IP_CMSG_CHECKSUM) != 0;
-               break;
-       case IP_RECVFRAGSIZE:
-               val = (inet->cmsg_flags & IP_CMSG_RECVFRAGSIZE) != 0;
-               break;
        case IP_TOS:
                val = inet->tos;
                break;
@@ -1665,9 +1649,6 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
                }
                break;
        }
-       case IP_RECVERR:
-               val = inet->recverr;
-               break;
        case IP_RECVERR_RFC4884:
                val = inet->recverr_rfc4884;
                break;
@@ -1737,7 +1718,7 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
                msg.msg_controllen = len;
                msg.msg_flags = in_compat_syscall() ? MSG_CMSG_COMPAT : 0;
 
-               if (inet->cmsg_flags & IP_CMSG_PKTINFO) {
+               if (inet_test_bit(PKTINFO, sk)) {
                        struct in_pktinfo info;
 
                        info.ipi_addr.s_addr = inet->inet_rcv_saddr;
@@ -1745,11 +1726,11 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
                        info.ipi_ifindex = inet->mc_index;
                        put_cmsg(&msg, SOL_IP, IP_PKTINFO, sizeof(info), &info);
                }
-               if (inet->cmsg_flags & IP_CMSG_TTL) {
+               if (inet_test_bit(TTL, sk)) {
                        int hlim = inet->mc_ttl;
                        put_cmsg(&msg, SOL_IP, IP_TTL, sizeof(hlim), &hlim);
                }
-               if (inet->cmsg_flags & IP_CMSG_TOS) {
+               if (inet_test_bit(TOS, sk)) {
                        int tos = inet->rcv_tos;
                        put_cmsg(&msg, SOL_IP, IP_TOS, sizeof(tos), &tos);
                }
@@ -1776,7 +1757,7 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
                return -ENOPROTOOPT;
        }
        sockopt_release_sock(sk);
-
+copyval:
        if (len < sizeof(int) && len > 0 && val >= 0 && val <= 255) {
                unsigned char ucval = (unsigned char)val;
                len = 1;