af_unix: Add OOB support
authorRao Shoaib <rao.shoaib@oracle.com>
Sun, 1 Aug 2021 07:57:07 +0000 (00:57 -0700)
committerDavid S. Miller <davem@davemloft.net>
Wed, 4 Aug 2021 08:55:52 +0000 (09:55 +0100)
This patch adds OOB support for AF_UNIX sockets.
The semantics is same as TCP.

The last byte of a message with the OOB flag is
treated as the OOB byte. The byte is separated into
a skb and a pointer to the skb is stored in unix_sock.
The pointer is used to enforce OOB semantics.

Signed-off-by: Rao Shoaib <rao.shoaib@oracle.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/af_unix.h
net/unix/Kconfig
net/unix/af_unix.c
tools/testing/selftests/Makefile
tools/testing/selftests/net/af_unix/Makefile [new file with mode: 0644]
tools/testing/selftests/net/af_unix/test_unix_oob.c [new file with mode: 0644]

index 435a2c3..4757d7f 100644 (file)
@@ -70,6 +70,9 @@ struct unix_sock {
        struct socket_wq        peer_wq;
        wait_queue_entry_t      peer_wake;
        struct scm_stat         scm_stat;
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+       struct sk_buff          *oob_skb;
+#endif
 };
 
 static inline struct unix_sock *unix_sk(const struct sock *sk)
index b6c4282..b7f8112 100644 (file)
@@ -25,6 +25,11 @@ config UNIX_SCM
        depends on UNIX
        default y
 
+config AF_UNIX_OOB
+       bool
+       depends on UNIX
+       default y
+
 config UNIX_DIAG
        tristate "UNIX: socket monitoring interface"
        depends on UNIX
index 256c4e3..ec02e70 100644 (file)
@@ -503,6 +503,12 @@ static void unix_sock_destructor(struct sock *sk)
 
        skb_queue_purge(&sk->sk_receive_queue);
 
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+       if (u->oob_skb) {
+               kfree_skb(u->oob_skb);
+               u->oob_skb = NULL;
+       }
+#endif
        WARN_ON(refcount_read(&sk->sk_wmem_alloc));
        WARN_ON(!sk_unhashed(sk));
        WARN_ON(sk->sk_socket);
@@ -1889,6 +1895,46 @@ out:
  */
 #define UNIX_SKB_FRAGS_SZ (PAGE_SIZE << get_order(32768))
 
+#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
+static int queue_oob(struct socket *sock, struct msghdr *msg, struct sock *other)
+{
+       struct unix_sock *ousk = unix_sk(other);
+       struct sk_buff *skb;
+       int err = 0;
+
+       skb = sock_alloc_send_skb(sock->sk, 1, msg->msg_flags & MSG_DONTWAIT, &err);
+
+       if (!skb)
+               return err;
+
+       skb_put(skb, 1);
+       skb->len = 1;
+       err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, 1);
+
+       if (err) {
+               kfree_skb(skb);
+               return err;
+       }
+
+       unix_state_lock(other);
+       maybe_add_creds(skb, sock, other);
+       skb_get(skb);
+
+       if (ousk->oob_skb)
+               kfree_skb(ousk->oob_skb);
+
+       ousk->oob_skb = skb;
+
+       scm_stat_add(other, skb);
+       skb_queue_tail(&other->sk_receive_queue, skb);
+       sk_send_sigurg(other);
+       unix_state_unlock(other);
+       other->sk_data_ready(other);
+
+       return err;
+}
+#endif
+
 static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                               size_t len)
 {
@@ -1907,8 +1953,14 @@ static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                return err;
 
        err = -EOPNOTSUPP;
-       if (msg->msg_flags&MSG_OOB)
-               goto out_err;
+       if (msg->msg_flags & MSG_OOB) {
+#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
+               if (len)
+                       len--;
+               else
+#endif
+                       goto out_err;
+       }
 
        if (msg->msg_namelen) {
                err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
@@ -1973,6 +2025,15 @@ static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                sent += size;
        }
 
+#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
+       if (msg->msg_flags & MSG_OOB) {
+               err = queue_oob(sock, msg, other);
+               if (err)
+                       goto out_err;
+               sent++;
+       }
+#endif
+
        scm_destroy(&scm);
 
        return sent;
@@ -2358,6 +2419,59 @@ struct unix_stream_read_state {
        unsigned int splice_flags;
 };
 
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+static int unix_stream_recv_urg(struct unix_stream_read_state *state)
+{
+       struct socket *sock = state->socket;
+       struct sock *sk = sock->sk;
+       struct unix_sock *u = unix_sk(sk);
+       int chunk = 1;
+
+       if (sock_flag(sk, SOCK_URGINLINE) || !u->oob_skb)
+               return -EINVAL;
+
+       chunk = state->recv_actor(u->oob_skb, 0, chunk, state);
+       if (chunk < 0)
+               return -EFAULT;
+
+       if (!(state->flags & MSG_PEEK)) {
+               UNIXCB(u->oob_skb).consumed += 1;
+               kfree_skb(u->oob_skb);
+               u->oob_skb = NULL;
+       }
+       state->msg->msg_flags |= MSG_OOB;
+       return 1;
+}
+
+static struct sk_buff *manage_oob(struct sk_buff *skb, struct sock *sk,
+                                 int flags, int copied)
+{
+       struct unix_sock *u = unix_sk(sk);
+
+       if (!unix_skb_len(skb) && !(flags & MSG_PEEK)) {
+               skb_unlink(skb, &sk->sk_receive_queue);
+               consume_skb(skb);
+               skb = NULL;
+       } else {
+               if (skb == u->oob_skb) {
+                       if (copied) {
+                               skb = NULL;
+                       } else if (sock_flag(sk, SOCK_URGINLINE)) {
+                               if (!(flags & MSG_PEEK)) {
+                                       u->oob_skb = NULL;
+                                       consume_skb(skb);
+                               }
+                       } else if (!(flags & MSG_PEEK)) {
+                               skb_unlink(skb, &sk->sk_receive_queue);
+                               consume_skb(skb);
+                               skb = skb_peek(&sk->sk_receive_queue);
+                       }
+               }
+       }
+       return skb;
+}
+#endif
+
 static int unix_stream_read_generic(struct unix_stream_read_state *state,
                                    bool freezable)
 {
@@ -2383,6 +2497,15 @@ static int unix_stream_read_generic(struct unix_stream_read_state *state,
 
        if (unlikely(flags & MSG_OOB)) {
                err = -EOPNOTSUPP;
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+               mutex_lock(&u->iolock);
+               unix_state_lock(sk);
+
+               err = unix_stream_recv_urg(state);
+
+               unix_state_unlock(sk);
+               mutex_unlock(&u->iolock);
+#endif
                goto out;
        }
 
@@ -2411,6 +2534,18 @@ redo:
                }
                last = skb = skb_peek(&sk->sk_receive_queue);
                last_len = last ? last->len : 0;
+
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+               if (skb) {
+                       skb = manage_oob(skb, sk, flags, copied);
+                       if (!skb) {
+                               unix_state_unlock(sk);
+                               if (copied)
+                                       break;
+                               goto redo;
+                       }
+               }
+#endif
 again:
                if (skb == NULL) {
                        if (copied >= target)
@@ -2746,6 +2881,20 @@ static int unix_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
        case SIOCUNIXFILE:
                err = unix_open_file(sk);
                break;
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+       case SIOCATMARK:
+               {
+                       struct sk_buff *skb;
+                       struct unix_sock *u = unix_sk(sk);
+                       int answ = 0;
+
+                       skb = skb_peek(&sk->sk_receive_queue);
+                       if (skb && skb == u->oob_skb)
+                               answ = 1;
+                       err = put_user(answ, (int __user *)arg);
+               }
+               break;
+#endif
        default:
                err = -ENOIOCTLCMD;
                break;
index fb010a3..da9e8b6 100644 (file)
@@ -38,6 +38,7 @@ TARGETS += mount_setattr
 TARGETS += mqueue
 TARGETS += nci
 TARGETS += net
+TARGETS += net/af_unix
 TARGETS += net/forwarding
 TARGETS += net/mptcp
 TARGETS += netfilter
diff --git a/tools/testing/selftests/net/af_unix/Makefile b/tools/testing/selftests/net/af_unix/Makefile
new file mode 100644 (file)
index 0000000..cfc7f4f
--- /dev/null
@@ -0,0 +1,5 @@
+##TEST_GEN_FILES := test_unix_oob
+TEST_PROGS := test_unix_oob
+include ../../lib.mk
+
+all: $(TEST_PROGS)
diff --git a/tools/testing/selftests/net/af_unix/test_unix_oob.c b/tools/testing/selftests/net/af_unix/test_unix_oob.c
new file mode 100644 (file)
index 0000000..0f3e376
--- /dev/null
@@ -0,0 +1,437 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/socket.h>
+#include <arpa/inet.h>
+#include <unistd.h>
+#include <string.h>
+#include <fcntl.h>
+#include <sys/ioctl.h>
+#include <errno.h>
+#include <netinet/tcp.h>
+#include <sys/un.h>
+#include <sys/signal.h>
+#include <sys/poll.h>
+
+static int pipefd[2];
+static int signal_recvd;
+static pid_t producer_id;
+static char sock_name[32];
+
+static void sig_hand(int sn, siginfo_t *si, void *p)
+{
+       signal_recvd = sn;
+}
+
+static int set_sig_handler(int signal)
+{
+       struct sigaction sa;
+
+       sa.sa_sigaction = sig_hand;
+       sigemptyset(&sa.sa_mask);
+       sa.sa_flags = SA_SIGINFO | SA_RESTART;
+
+       return sigaction(signal, &sa, NULL);
+}
+
+static void set_filemode(int fd, int set)
+{
+       int flags = fcntl(fd, F_GETFL, 0);
+
+       if (set)
+               flags &= ~O_NONBLOCK;
+       else
+               flags |= O_NONBLOCK;
+       fcntl(fd, F_SETFL, flags);
+}
+
+static void signal_producer(int fd)
+{
+       char cmd;
+
+       cmd = 'S';
+       write(fd, &cmd, sizeof(cmd));
+}
+
+static void wait_for_signal(int fd)
+{
+       char buf[5];
+
+       read(fd, buf, 5);
+}
+
+static void die(int status)
+{
+       fflush(NULL);
+       unlink(sock_name);
+       kill(producer_id, SIGTERM);
+       exit(status);
+}
+
+int is_sioctatmark(int fd)
+{
+       int ans = -1;
+
+       if (ioctl(fd, SIOCATMARK, &ans, sizeof(ans)) < 0) {
+#ifdef DEBUG
+               perror("SIOCATMARK Failed");
+#endif
+       }
+       return ans;
+}
+
+void read_oob(int fd, char *c)
+{
+
+       *c = ' ';
+       if (recv(fd, c, sizeof(*c), MSG_OOB) < 0) {
+#ifdef DEBUG
+               perror("Reading MSG_OOB Failed");
+#endif
+       }
+}
+
+int read_data(int pfd, char *buf, int size)
+{
+       int len = 0;
+
+       memset(buf, size, '0');
+       len = read(pfd, buf, size);
+#ifdef DEBUG
+       if (len < 0)
+               perror("read failed");
+#endif
+       return len;
+}
+
+static void wait_for_data(int pfd, int event)
+{
+       struct pollfd pfds[1];
+
+       pfds[0].fd = pfd;
+       pfds[0].events = event;
+       poll(pfds, 1, -1);
+}
+
+void producer(struct sockaddr_un *consumer_addr)
+{
+       int cfd;
+       char buf[64];
+       int i;
+
+       memset(buf, 'x', sizeof(buf));
+       cfd = socket(AF_UNIX, SOCK_STREAM, 0);
+
+       wait_for_signal(pipefd[0]);
+       if (connect(cfd, (struct sockaddr *)consumer_addr,
+                    sizeof(struct sockaddr)) != 0) {
+               perror("Connect failed");
+               kill(0, SIGTERM);
+               exit(1);
+       }
+
+       for (i = 0; i < 2; i++) {
+               /* Test 1: Test for SIGURG and OOB */
+               wait_for_signal(pipefd[0]);
+               memset(buf, 'x', sizeof(buf));
+               buf[63] = '@';
+               send(cfd, buf, sizeof(buf), MSG_OOB);
+
+               wait_for_signal(pipefd[0]);
+
+               /* Test 2: Test for OOB being overwitten */
+               memset(buf, 'x', sizeof(buf));
+               buf[63] = '%';
+               send(cfd, buf, sizeof(buf), MSG_OOB);
+
+               memset(buf, 'x', sizeof(buf));
+               buf[63] = '#';
+               send(cfd, buf, sizeof(buf), MSG_OOB);
+
+               wait_for_signal(pipefd[0]);
+
+               /* Test 3: Test for SIOCATMARK */
+               memset(buf, 'x', sizeof(buf));
+               buf[63] = '@';
+               send(cfd, buf, sizeof(buf), MSG_OOB);
+
+               memset(buf, 'x', sizeof(buf));
+               buf[63] = '%';
+               send(cfd, buf, sizeof(buf), MSG_OOB);
+
+               memset(buf, 'x', sizeof(buf));
+               send(cfd, buf, sizeof(buf), 0);
+
+               wait_for_signal(pipefd[0]);
+
+               /* Test 4: Test for 1byte OOB msg */
+               memset(buf, 'x', sizeof(buf));
+               buf[0] = '@';
+               send(cfd, buf, 1, MSG_OOB);
+       }
+}
+
+int
+main(int argc, char **argv)
+{
+       int lfd, pfd;
+       struct sockaddr_un consumer_addr, paddr;
+       socklen_t len = sizeof(consumer_addr);
+       char buf[1024];
+       int on = 0;
+       char oob;
+       int flags;
+       int atmark;
+       char *tmp_file;
+
+       lfd = socket(AF_UNIX, SOCK_STREAM, 0);
+       memset(&consumer_addr, 0, sizeof(consumer_addr));
+       consumer_addr.sun_family = AF_UNIX;
+       sprintf(sock_name, "unix_oob_%d", getpid());
+       unlink(sock_name);
+       strcpy(consumer_addr.sun_path, sock_name);
+
+       if ((bind(lfd, (struct sockaddr *)&consumer_addr,
+                 sizeof(consumer_addr))) != 0) {
+               perror("socket bind failed");
+               exit(1);
+       }
+
+       pipe(pipefd);
+
+       listen(lfd, 1);
+
+       producer_id = fork();
+       if (producer_id == 0) {
+               producer(&consumer_addr);
+               exit(0);
+       }
+
+       set_sig_handler(SIGURG);
+       signal_producer(pipefd[1]);
+
+       pfd = accept(lfd, (struct sockaddr *) &paddr, &len);
+       fcntl(pfd, F_SETOWN, getpid());
+
+       signal_recvd = 0;
+       signal_producer(pipefd[1]);
+
+       /* Test 1:
+        * veriyf that SIGURG is
+        * delivered and 63 bytes are
+        * read and oob is '@'
+        */
+       wait_for_data(pfd, POLLIN | POLLPRI);
+       read_oob(pfd, &oob);
+       len = read_data(pfd, buf, 1024);
+       if (!signal_recvd || len != 63 || oob != '@') {
+               fprintf(stderr, "Test 1 failed sigurg %d len %d %c\n",
+                        signal_recvd, len, oob);
+                       die(1);
+       }
+
+       signal_recvd = 0;
+       signal_producer(pipefd[1]);
+
+       /* Test 2:
+        * Verify that the first OOB is over written by
+        * the 2nd one and the first OOB is returned as
+        * part of the read, and sigurg is received.
+        */
+       wait_for_data(pfd, POLLIN | POLLPRI);
+       len = 0;
+       while (len < 70)
+               len = recv(pfd, buf, 1024, MSG_PEEK);
+       len = read_data(pfd, buf, 1024);
+       read_oob(pfd, &oob);
+       if (!signal_recvd || len != 127 || oob != '#') {
+               fprintf(stderr, "Test 2 failed, sigurg %d len %d OOB %c\n",
+               signal_recvd, len, oob);
+               die(1);
+       }
+
+       signal_recvd = 0;
+       signal_producer(pipefd[1]);
+
+       /* Test 3:
+        * verify that 2nd oob over writes
+        * the first one and read breaks at
+        * oob boundary returning 127 bytes
+        * and sigurg is received and atmark
+        * is set.
+        * oob is '%' and second read returns
+        * 64 bytes.
+        */
+       len = 0;
+       wait_for_data(pfd, POLLIN | POLLPRI);
+       while (len < 150)
+               len = recv(pfd, buf, 1024, MSG_PEEK);
+       len = read_data(pfd, buf, 1024);
+       atmark = is_sioctatmark(pfd);
+       read_oob(pfd, &oob);
+
+       if (!signal_recvd || len != 127 || oob != '%' || atmark != 1) {
+               fprintf(stderr, "Test 3 failed, sigurg %d len %d OOB %c ",
+               "atmark %d\n", signal_recvd, len, oob, atmark);
+               die(1);
+       }
+
+       signal_recvd = 0;
+
+       len = read_data(pfd, buf, 1024);
+       if (len != 64) {
+               fprintf(stderr, "Test 3.1 failed, sigurg %d len %d OOB %c\n",
+                       signal_recvd, len, oob);
+               die(1);
+       }
+
+       signal_recvd = 0;
+       signal_producer(pipefd[1]);
+
+       /* Test 4:
+        * verify that a single byte
+        * oob message is delivered.
+        * set non blocking mode and
+        * check proper error is
+        * returned and sigurg is
+        * received and correct
+        * oob is read.
+        */
+
+       set_filemode(pfd, 0);
+
+       wait_for_data(pfd, POLLIN | POLLPRI);
+       len = read_data(pfd, buf, 1024);
+       if ((len == -1) && (errno == 11))
+               len = 0;
+
+       read_oob(pfd, &oob);
+
+       if (!signal_recvd || len != 0 || oob != '@') {
+               fprintf(stderr, "Test 4 failed, sigurg %d len %d OOB %c\n",
+                        signal_recvd, len, oob);
+               die(1);
+       }
+
+       set_filemode(pfd, 1);
+
+       /* Inline Testing */
+
+       on = 1;
+       if (setsockopt(pfd, SOL_SOCKET, SO_OOBINLINE, &on, sizeof(on))) {
+               perror("SO_OOBINLINE");
+               die(1);
+       }
+
+       signal_recvd = 0;
+       signal_producer(pipefd[1]);
+
+       /* Test 1 -- Inline:
+        * Check that SIGURG is
+        * delivered and 63 bytes are
+        * read and oob is '@'
+        */
+
+       wait_for_data(pfd, POLLIN | POLLPRI);
+       len = read_data(pfd, buf, 1024);
+
+       if (!signal_recvd || len != 63) {
+               fprintf(stderr, "Test 1 Inline failed, sigurg %d len %d\n",
+                       signal_recvd, len);
+               die(1);
+       }
+
+       len = read_data(pfd, buf, 1024);
+
+       if (len != 1) {
+               fprintf(stderr,
+                        "Test 1.1 Inline failed, sigurg %d len %d oob %c\n",
+                        signal_recvd, len, oob);
+               die(1);
+       }
+
+       signal_recvd = 0;
+       signal_producer(pipefd[1]);
+
+       /* Test 2 -- Inline:
+        * Verify that the first OOB is over written by
+        * the 2nd one and read breaks correctly on
+        * 2nd OOB boundary with the first OOB returned as
+        * part of the read, and sigurg is delivered and
+        * siocatmark returns true.
+        * next read returns one byte, the oob byte
+        * and siocatmark returns false.
+        */
+       len = 0;
+       wait_for_data(pfd, POLLIN | POLLPRI);
+       while (len < 70)
+               len = recv(pfd, buf, 1024, MSG_PEEK);
+       len = read_data(pfd, buf, 1024);
+       atmark = is_sioctatmark(pfd);
+       if (len != 127 || atmark != 1 || !signal_recvd) {
+               fprintf(stderr, "Test 2 Inline failed, len %d atmark %d\n",
+                        len, atmark);
+               die(1);
+       }
+
+       len = read_data(pfd, buf, 1024);
+       atmark = is_sioctatmark(pfd);
+       if (len != 1 || buf[0] != '#' || atmark == 1) {
+               fprintf(stderr, "Test 2.1 Inline failed, len %d data %c atmark %d\n",
+                       len, buf[0], atmark);
+               die(1);
+       }
+
+       signal_recvd = 0;
+       signal_producer(pipefd[1]);
+
+       /* Test 3 -- Inline:
+        * verify that 2nd oob over writes
+        * the first one and read breaks at
+        * oob boundary returning 127 bytes
+        * and sigurg is received and siocatmark
+        * is true after the read.
+        * subsequent read returns 65 bytes
+        * because of oob which should be '%'.
+        */
+       len = 0;
+       wait_for_data(pfd, POLLIN | POLLPRI);
+       while (len < 126)
+               len = recv(pfd, buf, 1024, MSG_PEEK);
+       len = read_data(pfd, buf, 1024);
+       atmark = is_sioctatmark(pfd);
+       if (!signal_recvd || len != 127 || !atmark) {
+               fprintf(stderr,
+                        "Test 3 Inline failed, sigurg %d len %d data %c\n",
+                        signal_recvd, len, buf[0]);
+               die(1);
+       }
+
+       len = read_data(pfd, buf, 1024);
+       atmark = is_sioctatmark(pfd);
+       if (len != 65 || buf[0] != '%' || atmark != 0) {
+               fprintf(stderr,
+                        "Test 3.1 Inline failed, len %d oob %c atmark %d\n",
+                        len, buf[0], atmark);
+               die(1);
+       }
+
+       signal_recvd = 0;
+       signal_producer(pipefd[1]);
+
+       /* Test 4 -- Inline:
+        * verify that a single
+        * byte oob message is delivered
+        * and read returns one byte, the oob
+        * byte and sigurg is received
+        */
+       wait_for_data(pfd, POLLIN | POLLPRI);
+       len = read_data(pfd, buf, 1024);
+       if (!signal_recvd || len != 1 || buf[0] != '@') {
+               fprintf(stderr,
+                       "Test 4 Inline failed, signal %d len %d data %c\n",
+               signal_recvd, len, buf[0]);
+               die(1);
+       }
+       die(0);
+}