scm: add SO_PASSPIDFD and SCM_PIDFD
authorAlexander Mikhalitsyn <aleksandr.mikhalitsyn@canonical.com>
Thu, 8 Jun 2023 20:26:25 +0000 (22:26 +0200)
committerDavid S. Miller <davem@davemloft.net>
Mon, 12 Jun 2023 09:45:49 +0000 (10:45 +0100)
Implement SCM_PIDFD, a new type of CMSG type analogical to SCM_CREDENTIALS,
but it contains pidfd instead of plain pid, which allows programmers not
to care about PID reuse problem.

We mask SO_PASSPIDFD feature if CONFIG_UNIX is not builtin because
it depends on a pidfd_prepare() API which is not exported to the kernel
modules.

Idea comes from UAPI kernel group:
https://uapi-group.org/kernel-features/

Big thanks to Christian Brauner and Lennart Poettering for productive
discussions about this.

Cc: "David S. Miller" <davem@davemloft.net>
Cc: Eric Dumazet <edumazet@google.com>
Cc: Jakub Kicinski <kuba@kernel.org>
Cc: Paolo Abeni <pabeni@redhat.com>
Cc: Leon Romanovsky <leon@kernel.org>
Cc: David Ahern <dsahern@kernel.org>
Cc: Arnd Bergmann <arnd@arndb.de>
Cc: Kees Cook <keescook@chromium.org>
Cc: Christian Brauner <brauner@kernel.org>
Cc: Kuniyuki Iwashima <kuniyu@amazon.com>
Cc: Lennart Poettering <mzxreary@0pointer.de>
Cc: Luca Boccassi <bluca@debian.org>
Cc: linux-kernel@vger.kernel.org
Cc: netdev@vger.kernel.org
Cc: linux-arch@vger.kernel.org
Tested-by: Luca Boccassi <bluca@debian.org>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Christian Brauner <brauner@kernel.org>
Signed-off-by: Alexander Mikhalitsyn <aleksandr.mikhalitsyn@canonical.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
12 files changed:
arch/alpha/include/uapi/asm/socket.h
arch/mips/include/uapi/asm/socket.h
arch/parisc/include/uapi/asm/socket.h
arch/sparc/include/uapi/asm/socket.h
include/linux/net.h
include/linux/socket.h
include/net/scm.h
include/uapi/asm-generic/socket.h
net/core/sock.c
net/mptcp/sockopt.c
net/unix/af_unix.c
tools/include/uapi/asm-generic/socket.h

index 739891b..ff31061 100644 (file)
 
 #define SO_RCVMARK             75
 
+#define SO_PASSPIDFD           76
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64
index 18f3d95..762dcb8 100644 (file)
 
 #define SO_RCVMARK             75
 
+#define SO_PASSPIDFD           76
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64
index f486d3d..df16a3e 100644 (file)
 
 #define SO_RCVMARK             0x4049
 
+#define SO_PASSPIDFD           0x404A
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64
index 2fda57a..6e28478 100644 (file)
 
 #define SO_RCVMARK               0x0054
 
+#define SO_PASSPIDFD             0x0055
+
 #if !defined(__KERNEL__)
 
 
index 8defc8f..23324e9 100644 (file)
@@ -43,6 +43,7 @@ struct net;
 #define SOCK_PASSSEC           4
 #define SOCK_SUPPORT_ZC                5
 #define SOCK_CUSTOM_SOCKOPT    6
+#define SOCK_PASSPIDFD         7
 
 #ifndef ARCH_HAS_SOCKET_TYPES
 /**
index 3fd3436..5820470 100644 (file)
@@ -177,6 +177,7 @@ static inline size_t msg_data_left(struct msghdr *msg)
 #define        SCM_RIGHTS      0x01            /* rw: access rights (array of int) */
 #define SCM_CREDENTIALS 0x02           /* rw: struct ucred             */
 #define SCM_SECURITY   0x03            /* rw: security label           */
+#define SCM_PIDFD      0x04            /* ro: pidfd (int)              */
 
 struct ucred {
        __u32   pid;
index 585adc1..c67f765 100644 (file)
@@ -120,12 +120,44 @@ static inline bool scm_has_secdata(struct socket *sock)
 }
 #endif /* CONFIG_SECURITY_NETWORK */
 
+static __inline__ void scm_pidfd_recv(struct msghdr *msg, struct scm_cookie *scm)
+{
+       struct file *pidfd_file = NULL;
+       int pidfd;
+
+       /*
+        * put_cmsg() doesn't return an error if CMSG is truncated,
+        * that's why we need to opencode these checks here.
+        */
+       if ((msg->msg_controllen <= sizeof(struct cmsghdr)) ||
+           (msg->msg_controllen - sizeof(struct cmsghdr)) < sizeof(int)) {
+               msg->msg_flags |= MSG_CTRUNC;
+               return;
+       }
+
+       WARN_ON_ONCE(!scm->pid);
+       pidfd = pidfd_prepare(scm->pid, 0, &pidfd_file);
+
+       if (put_cmsg(msg, SOL_SOCKET, SCM_PIDFD, sizeof(int), &pidfd)) {
+               if (pidfd_file) {
+                       put_unused_fd(pidfd);
+                       fput(pidfd_file);
+               }
+
+               return;
+       }
+
+       if (pidfd_file)
+               fd_install(pidfd, pidfd_file);
+}
+
 static __inline__ void scm_recv(struct socket *sock, struct msghdr *msg,
                                struct scm_cookie *scm, int flags)
 {
        if (!msg->msg_control) {
-               if (test_bit(SOCK_PASSCRED, &sock->flags) || scm->fp ||
-                   scm_has_secdata(sock))
+               if (test_bit(SOCK_PASSCRED, &sock->flags) ||
+                   test_bit(SOCK_PASSPIDFD, &sock->flags) ||
+                   scm->fp || scm_has_secdata(sock))
                        msg->msg_flags |= MSG_CTRUNC;
                scm_destroy(scm);
                return;
@@ -141,6 +173,9 @@ static __inline__ void scm_recv(struct socket *sock, struct msghdr *msg,
                put_cmsg(msg, SOL_SOCKET, SCM_CREDENTIALS, sizeof(ucreds), &ucreds);
        }
 
+       if (test_bit(SOCK_PASSPIDFD, &sock->flags))
+               scm_pidfd_recv(msg, scm);
+
        scm_destroy_cred(scm);
 
        scm_passec(sock, msg, scm);
index 6382308..b76169f 100644 (file)
 
 #define SO_RCVMARK             75
 
+#define SO_PASSPIDFD           76
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
index 24f2761..ed4eb4b 100644 (file)
@@ -1246,6 +1246,13 @@ set_sndbuf:
                        clear_bit(SOCK_PASSCRED, &sock->flags);
                break;
 
+       case SO_PASSPIDFD:
+               if (valbool)
+                       set_bit(SOCK_PASSPIDFD, &sock->flags);
+               else
+                       clear_bit(SOCK_PASSPIDFD, &sock->flags);
+               break;
+
        case SO_TIMESTAMP_OLD:
        case SO_TIMESTAMP_NEW:
        case SO_TIMESTAMPNS_OLD:
@@ -1732,6 +1739,10 @@ int sk_getsockopt(struct sock *sk, int level, int optname,
                v.val = !!test_bit(SOCK_PASSCRED, &sock->flags);
                break;
 
+       case SO_PASSPIDFD:
+               v.val = !!test_bit(SOCK_PASSPIDFD, &sock->flags);
+               break;
+
        case SO_PEERCRED:
        {
                struct ucred peercred;
index d425886..e172a58 100644 (file)
@@ -355,6 +355,7 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname,
        case SO_BROADCAST:
        case SO_BSDCOMPAT:
        case SO_PASSCRED:
+       case SO_PASSPIDFD:
        case SO_PASSSEC:
        case SO_RXQ_OVFL:
        case SO_WIFI_STATUS:
index 653136d..c46c2f5 100644 (file)
@@ -1361,7 +1361,8 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr,
                if (err)
                        goto out;
 
-               if (test_bit(SOCK_PASSCRED, &sock->flags) &&
+               if ((test_bit(SOCK_PASSCRED, &sock->flags) ||
+                    test_bit(SOCK_PASSPIDFD, &sock->flags)) &&
                    !unix_sk(sk)->addr) {
                        err = unix_autobind(sk);
                        if (err)
@@ -1469,7 +1470,8 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
        if (err)
                goto out;
 
-       if (test_bit(SOCK_PASSCRED, &sock->flags) && !u->addr) {
+       if ((test_bit(SOCK_PASSCRED, &sock->flags) ||
+            test_bit(SOCK_PASSPIDFD, &sock->flags)) && !u->addr) {
                err = unix_autobind(sk);
                if (err)
                        goto out;
@@ -1670,6 +1672,8 @@ static void unix_sock_inherit_flags(const struct socket *old,
 {
        if (test_bit(SOCK_PASSCRED, &old->flags))
                set_bit(SOCK_PASSCRED, &new->flags);
+       if (test_bit(SOCK_PASSPIDFD, &old->flags))
+               set_bit(SOCK_PASSPIDFD, &new->flags);
        if (test_bit(SOCK_PASSSEC, &old->flags))
                set_bit(SOCK_PASSSEC, &new->flags);
 }
@@ -1819,8 +1823,10 @@ static bool unix_passcred_enabled(const struct socket *sock,
                                  const struct sock *other)
 {
        return test_bit(SOCK_PASSCRED, &sock->flags) ||
+              test_bit(SOCK_PASSPIDFD, &sock->flags) ||
               !other->sk_socket ||
-              test_bit(SOCK_PASSCRED, &other->sk_socket->flags);
+              test_bit(SOCK_PASSCRED, &other->sk_socket->flags) ||
+              test_bit(SOCK_PASSPIDFD, &other->sk_socket->flags);
 }
 
 /*
@@ -1904,7 +1910,8 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
                        goto out;
        }
 
-       if (test_bit(SOCK_PASSCRED, &sock->flags) && !u->addr) {
+       if ((test_bit(SOCK_PASSCRED, &sock->flags) ||
+            test_bit(SOCK_PASSPIDFD, &sock->flags)) && !u->addr) {
                err = unix_autobind(sk);
                if (err)
                        goto out;
@@ -2718,7 +2725,8 @@ unlock:
                        /* Never glue messages from different writers */
                        if (!unix_skb_scm_eq(skb, &scm))
                                break;
-               } else if (test_bit(SOCK_PASSCRED, &sock->flags)) {
+               } else if (test_bit(SOCK_PASSCRED, &sock->flags) ||
+                          test_bit(SOCK_PASSPIDFD, &sock->flags)) {
                        /* Copy credentials */
                        scm_set_cred(&scm, UNIXCB(skb).pid, UNIXCB(skb).uid, UNIXCB(skb).gid);
                        unix_set_secdata(&scm, skb);
index 8756df1..fbbc4bf 100644 (file)
 
 #define SO_RCVMARK             75
 
+#define SO_PASSPIDFD           76
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))