uploaded spice-vdagent
[platform/adaptation/emulator/spice-vdagent.git] / src / udscs.c
1 /*  udscs.c Unix Domain Socket Client Server framework. A framework for quickly
2     creating select() based servers capable of handling multiple clients and
3     matching select() based clients using variable size messages.
4
5     Copyright 2010 Red Hat, Inc.
6
7     Red Hat Authors:
8     Hans de Goede <hdegoede@redhat.com>
9
10     This program is free software: you can redistribute it and/or modify
11     it under the terms of the GNU General Public License as published by
12     the Free Software Foundation, either version 3 of the License, or   
13     (at your option) any later version.
14
15     This program is distributed in the hope that it will be useful,
16     but WITHOUT ANY WARRANTY; without even the implied warranty of 
17     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the  
18     GNU General Public License for more details.
19
20     You should have received a copy of the GNU General Public License
21     along with this program.  If not, see <http://www.gnu.org/licenses/>.
22 */
23 #ifdef HAVE_CONFIG_H
24 #include <config.h>
25 #endif
26
27 #include <stdio.h>
28 #include <stdlib.h>
29 #include <syslog.h>
30 #include <unistd.h>
31 #include <errno.h>
32 #include <sys/socket.h>
33 #include <sys/un.h>
34 #include "udscs.h"
35
36 struct udscs_buf {
37     uint8_t *buf;
38     size_t pos;
39     size_t size;
40     
41     struct udscs_buf *next;
42 };
43
44 struct udscs_connection {
45     int fd;
46     const char * const *type_to_string;
47     int no_types;
48     int debug;
49     struct ucred peer_cred;
50     void *user_data;
51
52     /* Read stuff, single buffer, separate header and data buffer */
53     int header_read;
54     struct udscs_message_header header;
55     struct udscs_buf data;
56
57     /* Writes are stored in a linked list of buffers, with both the header
58        + data for a single message in 1 buffer. */
59     struct udscs_buf *write_buf;
60
61     /* Callbacks */
62     udscs_read_callback read_callback;
63     udscs_disconnect_callback disconnect_callback;
64
65     struct udscs_connection *next;
66     struct udscs_connection *prev;
67 };
68
69 struct udscs_server {
70     int fd;
71     const char * const *type_to_string;
72     int no_types;
73     int debug;
74     struct udscs_connection connections_head;
75     udscs_connect_callback connect_callback;
76     udscs_read_callback read_callback;
77     udscs_disconnect_callback disconnect_callback;
78 };
79
80 static void udscs_do_write(struct udscs_connection **connp);
81 static void udscs_do_read(struct udscs_connection **connp);
82
83
84 struct udscs_server *udscs_create_server(const char *socketname,
85     udscs_connect_callback connect_callback,
86     udscs_read_callback read_callback,
87     udscs_disconnect_callback disconnect_callback,
88     const char * const type_to_string[], int no_types, int debug)
89 {
90     int c;
91     struct sockaddr_un address;
92     struct udscs_server *server;
93
94     server = calloc(1, sizeof(*server));
95     if (!server)
96         return NULL;
97
98     server->type_to_string = type_to_string;
99     server->no_types = no_types;
100     server->debug = debug;
101
102     server->fd = socket(PF_UNIX, SOCK_STREAM, 0);
103     if (server->fd == -1) {
104         syslog(LOG_ERR, "creating unix domain socket: %m");
105         free(server);
106         return NULL;
107     }
108
109     address.sun_family = AF_UNIX;
110     snprintf(address.sun_path, sizeof(address.sun_path), "%s", socketname);
111     c = bind(server->fd, (struct sockaddr *)&address, sizeof(address));
112     if (c != 0) {
113         syslog(LOG_ERR, "bind %s: %m", socketname);
114         free(server);
115         return NULL;
116     }
117
118     c = listen(server->fd, 5);
119     if (c != 0) {
120         syslog(LOG_ERR, "listen: %m");
121         free(server);
122         return NULL;
123     }
124
125     server->connect_callback = connect_callback;
126     server->read_callback = read_callback;
127     server->disconnect_callback = disconnect_callback;
128
129     return server;
130 }
131
132 void udscs_destroy_server(struct udscs_server *server)
133 {
134     struct udscs_connection *conn, *next_conn;
135
136     if (!server)
137         return;
138
139     conn = server->connections_head.next;
140     while (conn) {
141         next_conn = conn->next;
142         udscs_destroy_connection(&conn);
143         conn = next_conn;
144     }
145     close(server->fd);
146     free(server);
147 }
148
149 struct udscs_connection *udscs_connect(const char *socketname,
150     udscs_read_callback read_callback,
151     udscs_disconnect_callback disconnect_callback,
152     const char * const type_to_string[], int no_types, int debug)
153 {
154     int c;
155     struct sockaddr_un address;
156     struct udscs_connection *conn;
157
158     conn = calloc(1, sizeof(*conn));
159     if (!conn)
160         return NULL;
161
162     conn->type_to_string = type_to_string;
163     conn->no_types = no_types;
164     conn->debug = debug;
165
166     conn->fd = socket(PF_UNIX, SOCK_STREAM, 0);
167     if (conn->fd == -1) {
168         syslog(LOG_ERR, "creating unix domain socket: %m");
169         free(conn);
170         return NULL;
171     }
172
173     address.sun_family = AF_UNIX;
174     snprintf(address.sun_path, sizeof(address.sun_path), "%s", socketname);
175     c = connect(conn->fd, (struct sockaddr *)&address, sizeof(address));
176     if (c != 0) {
177         if (conn->debug) {
178             syslog(LOG_DEBUG, "connect %s: %m", socketname);
179         }
180         free(conn);
181         return NULL;
182     }
183
184     conn->read_callback = read_callback;
185     conn->disconnect_callback = disconnect_callback;
186
187     if (conn->debug)
188         syslog(LOG_DEBUG, "%p connected to %s", conn, socketname);
189
190     return conn;
191 }
192
193 void udscs_destroy_connection(struct udscs_connection **connp)
194 {
195     struct udscs_buf *wbuf, *next_wbuf;
196     struct udscs_connection *conn = *connp;
197
198     if (!conn)
199         return;
200
201     if (conn->disconnect_callback)
202         conn->disconnect_callback(conn);
203
204     wbuf = conn->write_buf;
205     while (wbuf) {
206         next_wbuf = wbuf->next;
207         free(wbuf->buf);
208         free(wbuf);
209         wbuf = next_wbuf;
210     }
211
212     free(conn->data.buf);
213
214     if (conn->next)
215         conn->next->prev = conn->prev;
216     if (conn->prev)
217         conn->prev->next = conn->next;
218
219     close(conn->fd);
220
221     if (conn->debug)
222         syslog(LOG_DEBUG, "%p disconnected", conn);
223
224     free(conn);
225     *connp = NULL;
226 }
227
228 struct ucred udscs_get_peer_cred(struct udscs_connection *conn)
229 {
230     return conn->peer_cred;
231 }
232
233 int udscs_server_fill_fds(struct udscs_server *server, fd_set *readfds,
234         fd_set *writefds)
235 {
236     struct udscs_connection *conn;
237     int nfds = server->fd + 1;
238
239     if (!server)
240         return -1;
241
242     FD_SET(server->fd, readfds);
243
244     conn = server->connections_head.next;
245     while (conn) {
246         int conn_nfds = udscs_client_fill_fds(conn, readfds, writefds);
247         if (conn_nfds > nfds)
248             nfds = conn_nfds;
249
250         conn = conn->next;
251     }
252
253     return nfds;
254 }
255
256 int udscs_client_fill_fds(struct udscs_connection *conn, fd_set *readfds,
257         fd_set *writefds)
258 {
259     if (!conn)
260         return -1;
261
262     FD_SET(conn->fd, readfds);
263     if (conn->write_buf)
264         FD_SET(conn->fd, writefds);
265
266     return conn->fd + 1;
267 }
268
269 static void udscs_server_accept(struct udscs_server *server) {
270     struct udscs_connection *new_conn, *conn;
271     struct sockaddr_un address;
272     socklen_t length = sizeof(address);
273     int r, fd;
274
275     fd = accept(server->fd, (struct sockaddr *)&address, &length);
276     if (fd == -1) {
277         if (errno == EINTR)
278             return;
279         syslog(LOG_ERR, "accept: %m");
280         return;
281     }
282
283     new_conn = calloc(1, sizeof(*conn));
284     if (!new_conn) {
285         syslog(LOG_ERR, "out of memory, disconnecting new client");
286         close(fd);
287         return;
288     }
289
290     new_conn->fd = fd;
291     new_conn->type_to_string = server->type_to_string;
292     new_conn->no_types = server->no_types;
293     new_conn->debug = server->debug;
294     new_conn->read_callback = server->read_callback;
295     new_conn->disconnect_callback = server->disconnect_callback;
296
297     length = sizeof(new_conn->peer_cred);
298     r = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &new_conn->peer_cred, &length);
299     if (r != 0) {
300         syslog(LOG_ERR, "Could not get peercred, disconnecting new client");
301         close(fd);
302         free(new_conn);
303         return;
304     }
305
306     conn = &server->connections_head;
307     while (conn->next)
308         conn = conn->next;
309
310     new_conn->prev = conn;
311     conn->next = new_conn;
312
313     if (server->debug)
314         syslog(LOG_DEBUG, "new client accepted: %p, pid: %d",
315                new_conn, (int)new_conn->peer_cred.pid);
316
317     if (server->connect_callback)
318         server->connect_callback(new_conn);
319 }
320
321 void udscs_server_handle_fds(struct udscs_server *server, fd_set *readfds,
322         fd_set *writefds)
323 {
324     struct udscs_connection *conn, *next_conn;
325
326     if (!server)
327         return;
328
329     if (FD_ISSET(server->fd, readfds))
330         udscs_server_accept(server);
331
332     conn = server->connections_head.next;
333     while (conn) {
334         /* conn maybe destroyed by udscs_client_handle_fds (when disconnected),
335            so get the next connection first. */
336         next_conn = conn->next;
337         udscs_client_handle_fds(&conn, readfds, writefds);
338         conn = next_conn;
339     }
340 }
341
342 void udscs_client_handle_fds(struct udscs_connection **connp, fd_set *readfds,
343         fd_set *writefds)
344 {
345     if (!*connp)
346         return;
347
348     if (FD_ISSET((*connp)->fd, readfds))
349         udscs_do_read(connp);
350
351     if (*connp && FD_ISSET((*connp)->fd, writefds))
352         udscs_do_write(connp);
353 }
354
355 int udscs_write(struct udscs_connection *conn, uint32_t type, uint32_t arg1,
356     uint32_t arg2, const uint8_t *data, uint32_t size)
357 {
358     struct udscs_buf *wbuf, *new_wbuf;
359     struct udscs_message_header header;
360
361     new_wbuf = malloc(sizeof(*new_wbuf));
362     if (!new_wbuf)
363         return -1;
364
365     new_wbuf->pos = 0;
366     new_wbuf->size = sizeof(header) + size;
367     new_wbuf->next = NULL;
368     new_wbuf->buf = malloc(new_wbuf->size);
369     if (!new_wbuf->buf) {
370         free(new_wbuf);
371         return -1;
372     }
373
374     header.type = type;
375     header.arg1 = arg1;
376     header.arg2 = arg2;
377     header.size = size;
378
379     memcpy(new_wbuf->buf, &header, sizeof(header));
380     memcpy(new_wbuf->buf + sizeof(header), data, size);
381
382     if (conn->debug) {
383         if (type < conn->no_types)
384             syslog(LOG_DEBUG, "%p sent %s, arg1: %u, arg2: %u, size %u",
385                    conn, conn->type_to_string[type], arg1, arg2, size);
386         else
387             syslog(LOG_DEBUG,
388                    "%p sent invalid message %u, arg1: %u, arg2: %u, size %u",
389                    conn, type, arg1, arg2, size);
390     }
391
392     if (!conn->write_buf) {
393         conn->write_buf = new_wbuf;
394         return 0;
395     }
396
397     /* maybe we should limit the write_buf stack depth ? */
398     wbuf = conn->write_buf;
399     while (wbuf->next)
400         wbuf = wbuf->next;
401
402     wbuf->next = new_wbuf;
403
404     return 0;
405 }
406
407 int udscs_server_write_all(struct udscs_server *server,
408         uint32_t type, uint32_t arg1, uint32_t arg2,
409         const uint8_t *data, uint32_t size)
410 {
411     struct udscs_connection *conn;
412
413     conn = server->connections_head.next;
414     while (conn) {
415         if (udscs_write(conn, type, arg1, arg2, data, size))
416             return -1;
417         conn = conn->next;
418     }
419
420     return 0;
421 }
422
423 int udscs_server_for_all_clients(struct udscs_server *server,
424     udscs_for_all_clients_callback func, void *priv)
425 {
426     int r = 0;
427     struct udscs_connection *conn, *next_conn;
428
429     if (!server)
430         return 0;
431
432     conn = server->connections_head.next;
433     while (conn) {
434         /* Get next conn as func may destroy the current conn */
435         next_conn = conn->next;
436         r += func(&conn, priv);
437         conn = next_conn;
438     }
439     return r;
440 }
441
442 static void udscs_read_complete(struct udscs_connection **connp)
443 {
444     struct udscs_connection *conn = *connp;
445
446     if (conn->debug) {
447         if (conn->header.type < conn->no_types)
448             syslog(LOG_DEBUG,
449                    "%p received %s, arg1: %u, arg2: %u, size %u",
450                    conn, conn->type_to_string[conn->header.type],
451                    conn->header.arg1, conn->header.arg2, conn->header.size);
452         else
453             syslog(LOG_DEBUG,
454                "%p received invalid message %u, arg1: %u, arg2: %u, size %u",
455                conn, conn->header.type, conn->header.arg1, conn->header.arg2,
456                conn->header.size);
457     }
458
459     if (conn->read_callback) {
460         conn->read_callback(connp, &conn->header, conn->data.buf);
461         if (!*connp) /* Was the connection disconnected by the callback ? */
462             return;
463     }
464
465     conn->header_read = 0;
466     memset(&conn->data, 0, sizeof(conn->data));
467 }
468
469 static void udscs_do_read(struct udscs_connection **connp)
470 {
471     ssize_t n;
472     size_t to_read;
473     uint8_t *dest;
474     struct udscs_connection *conn = *connp;
475
476     if (conn->header_read < sizeof(conn->header)) {
477         to_read = sizeof(conn->header) - conn->header_read;
478         dest = (uint8_t *)&conn->header + conn->header_read;
479     } else {
480         to_read = conn->data.size - conn->data.pos;
481         dest = conn->data.buf + conn->data.pos;
482     }
483
484     n = read(conn->fd, dest, to_read);
485     if (n < 0) {
486         if (errno == EINTR)
487             return;
488         syslog(LOG_ERR, "reading unix domain socket: %m, disconnecting %p",
489                conn);
490     }
491     if (n <= 0) {
492         udscs_destroy_connection(connp);
493         return;
494     }
495
496     if (conn->header_read < sizeof(conn->header)) {
497         conn->header_read += n;
498         if (conn->header_read == sizeof(conn->header)) {
499             if (conn->header.size == 0) {
500                 udscs_read_complete(connp);
501                 return;
502             }
503             conn->data.pos = 0;
504             conn->data.size = conn->header.size;
505             conn->data.buf = malloc(conn->data.size);
506             if (!conn->data.buf) {
507                 syslog(LOG_ERR, "out of memory, disconnecting %p", conn);
508                 udscs_destroy_connection(connp);
509                 return;
510             }
511         }
512     } else {
513         conn->data.pos += n;
514         if (conn->data.pos == conn->data.size)
515             udscs_read_complete(connp);
516     }
517 }
518
519 static void udscs_do_write(struct udscs_connection **connp)
520 {
521     ssize_t n;
522     size_t to_write;
523     struct udscs_connection *conn = *connp;
524
525     struct udscs_buf* wbuf = conn->write_buf;
526     if (!wbuf) {
527         syslog(LOG_ERR,
528                "%p do_write called on a connection without a write buf ?!",
529                conn);
530         return;
531     }
532
533     to_write = wbuf->size - wbuf->pos;
534     n = write(conn->fd, wbuf->buf + wbuf->pos, to_write);
535     if (n < 0) {
536         if (errno == EINTR)
537             return;
538         syslog(LOG_ERR, "writing to unix domain socket: %m, disconnecting %p",
539                conn);
540         udscs_destroy_connection(connp);
541         return;
542     }
543
544     wbuf->pos += n;
545     if (wbuf->pos == wbuf->size) {
546         conn->write_buf = wbuf->next;
547         free(wbuf->buf);
548         free(wbuf);
549     }
550 }
551
552 void udscs_set_user_data(struct udscs_connection *conn, void *data)
553 {
554     conn->user_data = data;
555 }
556
557 void *udscs_get_user_data(struct udscs_connection *conn)
558 {
559     if (!conn)
560         return NULL;
561
562     return conn->user_data;
563 }