inet: read sk->sk_family once in inet_recv_error()
[platform/kernel/linux-starfive.git] / net / socket.c
index c8b08b3..8d83c4b 100644 (file)
@@ -737,6 +737,14 @@ static inline int sock_sendmsg_nosec(struct socket *sock, struct msghdr *msg)
        return ret;
 }
 
+static int __sock_sendmsg(struct socket *sock, struct msghdr *msg)
+{
+       int err = security_socket_sendmsg(sock, msg,
+                                         msg_data_left(msg));
+
+       return err ?: sock_sendmsg_nosec(sock, msg);
+}
+
 /**
  *     sock_sendmsg - send a message through @sock
  *     @sock: socket
@@ -747,10 +755,21 @@ static inline int sock_sendmsg_nosec(struct socket *sock, struct msghdr *msg)
  */
 int sock_sendmsg(struct socket *sock, struct msghdr *msg)
 {
-       int err = security_socket_sendmsg(sock, msg,
-                                         msg_data_left(msg));
+       struct sockaddr_storage *save_addr = (struct sockaddr_storage *)msg->msg_name;
+       struct sockaddr_storage address;
+       int save_len = msg->msg_namelen;
+       int ret;
 
-       return err ?: sock_sendmsg_nosec(sock, msg);
+       if (msg->msg_name) {
+               memcpy(&address, msg->msg_name, msg->msg_namelen);
+               msg->msg_name = &address;
+       }
+
+       ret = __sock_sendmsg(sock, msg);
+       msg->msg_name = save_addr;
+       msg->msg_namelen = save_len;
+
+       return ret;
 }
 EXPORT_SYMBOL(sock_sendmsg);
 
@@ -1138,7 +1157,7 @@ static ssize_t sock_write_iter(struct kiocb *iocb, struct iov_iter *from)
        if (sock->type == SOCK_SEQPACKET)
                msg.msg_flags |= MSG_EOR;
 
-       res = sock_sendmsg(sock, &msg);
+       res = __sock_sendmsg(sock, &msg);
        *from = msg.msg_iter;
        return res;
 }
@@ -2174,7 +2193,7 @@ int __sys_sendto(int fd, void __user *buff, size_t len, unsigned int flags,
        if (sock->file->f_flags & O_NONBLOCK)
                flags |= MSG_DONTWAIT;
        msg.msg_flags = flags;
-       err = sock_sendmsg(sock, &msg);
+       err = __sock_sendmsg(sock, &msg);
 
 out_put:
        fput_light(sock->file, fput_needed);
@@ -2538,7 +2557,7 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
                err = sock_sendmsg_nosec(sock, msg_sys);
                goto out_freectl;
        }
-       err = sock_sendmsg(sock, msg_sys);
+       err = __sock_sendmsg(sock, msg_sys);
        /*
         * If this is sendmmsg() and sending to current destination address was
         * successful, remember it.
@@ -3499,7 +3518,12 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd,
 
 int kernel_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
 {
-       return READ_ONCE(sock->ops)->bind(sock, addr, addrlen);
+       struct sockaddr_storage address;
+
+       memcpy(&address, addr, addrlen);
+
+       return READ_ONCE(sock->ops)->bind(sock, (struct sockaddr *)&address,
+                                         addrlen);
 }
 EXPORT_SYMBOL(kernel_bind);