Support marshalling of file descriptors
authorKristian Høgsberg <krh@bitplanet.net>
Fri, 27 Aug 2010 01:49:44 +0000 (21:49 -0400)
committerKristian Høgsberg <krh@bitplanet.net>
Sat, 28 Aug 2010 00:30:22 +0000 (20:30 -0400)
connection.c
scanner.c

index 576becf..71e39e6 100644 (file)
@@ -47,6 +47,8 @@ struct wl_buffer {
 
 struct wl_connection {
        struct wl_buffer in, out;
+       struct wl_buffer fds_in, fds_out;
+       int fds_in_tail;
        int fd;
        void *data;
        wl_connection_update_func_t update;
@@ -169,13 +171,68 @@ void
 wl_connection_consume(struct wl_connection *connection, size_t size)
 {
        connection->in.tail += size;
+       connection->fds_in.tail = connection->fds_in_tail;
 }
 
-int wl_connection_data(struct wl_connection *connection, uint32_t mask)
+static void
+build_cmsg(struct wl_buffer *buffer, char *data, int *clen)
+{
+       struct cmsghdr *cmsg;
+       size_t size;
+
+       size = buffer->head - buffer->tail;
+       if (size > 0) {
+               cmsg = (struct cmsghdr *) data;
+               cmsg->cmsg_level = SOL_SOCKET;
+               cmsg->cmsg_type = SCM_RIGHTS;
+               cmsg->cmsg_len = CMSG_LEN(size);
+               wl_buffer_copy(buffer, CMSG_DATA(cmsg), size);
+               *clen = cmsg->cmsg_len;
+       } else {
+               *clen = 0;
+       }
+}
+
+static void
+close_fds(struct wl_buffer *buffer)
+{
+       int fds[32], i, count;
+       size_t size;
+
+       size = buffer->head - buffer->tail;
+       if (size == 0)
+               return;
+
+       wl_buffer_copy(buffer, fds, size);
+       count = size / 4;
+       for (i = 0; i < count; i++)
+               close(fds[i]);
+       buffer->tail += size;
+}
+
+static void
+decode_cmsg(struct wl_buffer *buffer, struct msghdr *msg)
+{
+       struct cmsghdr *cmsg;
+       size_t size;
+
+       for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL;
+            cmsg = CMSG_NXTHDR(msg, cmsg)) {
+               if (cmsg->cmsg_level == SOL_SOCKET &&
+                   cmsg->cmsg_type == SCM_RIGHTS) {
+                       size = cmsg->cmsg_len - CMSG_LEN(0);
+                       wl_buffer_put(buffer, CMSG_DATA(cmsg), size);
+               }
+       }
+}
+
+int
+wl_connection_data(struct wl_connection *connection, uint32_t mask)
 {
        struct iovec iov[2];
        struct msghdr msg;
-       int len, count;
+       char cmsg[128];
+       int len, count, clen;
 
        if (mask & WL_CONNECTION_READABLE) {
                wl_buffer_put_iov(&connection->in, iov, &count);
@@ -184,8 +241,9 @@ int wl_connection_data(struct wl_connection *connection, uint32_t mask)
                msg.msg_namelen = 0;
                msg.msg_iov = iov;
                msg.msg_iovlen = count;
-               msg.msg_control = NULL;
-               msg.msg_controllen = 0;
+               msg.msg_control = cmsg;
+               msg.msg_controllen = sizeof cmsg;
+               msg.msg_flags = 0;
 
                do {
                        len = recvmsg(connection->fd, &msg, 0);
@@ -201,28 +259,37 @@ int wl_connection_data(struct wl_connection *connection, uint32_t mask)
                        return -1;
                }
 
+               decode_cmsg(&connection->fds_in, &msg);
+
                connection->in.head += len;
        }       
 
        if (mask & WL_CONNECTION_WRITABLE) {
                wl_buffer_get_iov(&connection->out, iov, &count);
 
+               build_cmsg(&connection->fds_out, cmsg, &clen);
+
                msg.msg_name = NULL;
                msg.msg_namelen = 0;
                msg.msg_iov = iov;
                msg.msg_iovlen = count;
-               msg.msg_control = NULL;
-               msg.msg_controllen = 0;
+               msg.msg_control = cmsg;
+               msg.msg_controllen = clen;
+               msg.msg_flags = 0;
 
                do {
                        len = sendmsg(connection->fd, &msg, 0);
                } while (len < 0 && errno == EINTR);
 
                if (len < 0) {
-                       fprintf(stderr, "write error for connection %p: %m\n", connection);
+                       fprintf(stderr,
+                               "write error for connection %p, fd %d: %m\n",
+                               connection, connection->fd);
                        return -1;
                }
 
+               close_fds(&connection->fds_out);
+
                connection->out.tail += len;
                if (connection->out.tail == connection->out.head)
                        connection->update(connection,
@@ -254,9 +321,10 @@ wl_connection_vmarshal(struct wl_connection *connection,
 {
        struct wl_object *object;
        uint32_t args[32], length, *p, size;
+       int32_t dup_fd;
        struct wl_array *array;
        const char *s;
-       int i, count;
+       int i, count, fd;
 
        count = strlen(message->signature);
        assert(count <= ARRAY_LENGTH(args));
@@ -290,6 +358,16 @@ wl_connection_vmarshal(struct wl_connection *connection,
                        memcpy(p, array->data, array->size);
                        p = (void *) p + array->size;
                        break;
+               case 'h':
+                       fd = va_arg(ap, int);
+                       dup_fd = dup(fd);
+                       if (dup_fd < 0) {
+                               fprintf(stderr, "dup failed: %m");
+                               abort();
+                       }
+                       wl_buffer_put(&connection->fds_out,
+                                     &dup_fd, sizeof dup_fd);
+                       break;
                default:
                        assert(0);
                        break;
@@ -313,7 +391,7 @@ wl_connection_demarshal(struct wl_connection *connection,
        ffi_type *types[20];
        ffi_cif cif;
        uint32_t *p, *next, *end, result, length;
-       int i, count, ret = 0;
+       int i, count, fds_tail, ret = 0;
        union {
                uint32_t uint32;
                char *string;
@@ -347,6 +425,7 @@ wl_connection_demarshal(struct wl_connection *connection,
        wl_connection_copy(connection, buffer, size);
        p = &buffer[2];
        end = (uint32_t *) ((char *) (p + size));
+       fds_tail = connection->fds_in.tail;
        for (i = 2; i < count; i++) {
                if (p + 1 > end) {
                        printf("message too short, "
@@ -441,6 +520,13 @@ wl_connection_demarshal(struct wl_connection *connection,
                        memcpy(values[i].array->data, p, length);
                        p = next;
                        break;
+               case 'h':
+                       types[i] = &ffi_type_uint32;
+                       wl_buffer_copy(&connection->fds_in,
+                                      &values[i].uint32,
+                                      sizeof values[i].uint32);
+                       connection->fds_in.tail += sizeof values[i].uint32;
+                       break;
                default:
                        printf("unknown type\n");
                        assert(0);
@@ -452,6 +538,11 @@ wl_connection_demarshal(struct wl_connection *connection,
        ffi_prep_cif(&cif, FFI_DEFAULT_ABI, count, &ffi_type_uint32, types);
        ffi_call(&cif, func, &result, args);
 
+       /* Slight hack here.  We store the tail of fds_in here and
+        * consume will set fds_in.tail to that value */
+       connection->fds_in_tail = connection->fds_in.tail;
+       connection->fds_in.tail = fds_tail;
+
  out:
        count = i;
        for (i = 2; i < count; i++) {
index 0fa5079..b4d7668 100644 (file)
--- a/scanner.c
+++ b/scanner.c
@@ -82,7 +82,8 @@ enum arg_type {
        UNSIGNED,
        STRING,
        OBJECT,
-       ARRAY
+       ARRAY,
+       FD
 };
 
 struct arg {
@@ -189,6 +190,8 @@ start_element(void *data, const char *element_name, const char **atts)
                        arg->type = STRING;
                else if (strcmp(type, "array") == 0)
                        arg->type = ARRAY;
+               else if (strcmp(type, "fd") == 0)
+                       arg->type = FD;
                else if (strcmp(type, "new_id") == 0) {
                        if (interface_name == NULL) {
                                fprintf(stderr, "no interface name given\n");
@@ -236,7 +239,8 @@ emit_type(struct arg *a)
        switch (a->type) {
        default:
        case INT:
-               printf("int32_t ");
+       case FD:
+               printf("int ");
                break;
        case NEW_ID:
        case UNSIGNED:
@@ -536,6 +540,9 @@ emit_messages(struct wl_list *message_list,
                        case ARRAY:
                                printf("a");
                                break;
+                       case FD:
+                               printf("h");
+                               break;
                        }
                }
                printf("\" },\n");