Merge remote-tracking branch 'afaerber/qom-cpu' into staging
[sdk/emulator/qemu.git] / nbd.c
1 /*
2  *  Copyright (C) 2005  Anthony Liguori <anthony@codemonkey.ws>
3  *
4  *  Network Block Device
5  *
6  *  This program is free software; you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation; under version 2 of the License.
9  *
10  *  This program is distributed in the hope that it will be useful,
11  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  *  GNU General Public License for more details.
14  *
15  *  You should have received a copy of the GNU General Public License
16  *  along with this program; if not, see <http://www.gnu.org/licenses/>.
17  */
18
19 #include "nbd.h"
20 #include "block.h"
21
22 #include "qemu-coroutine.h"
23
24 #include <errno.h>
25 #include <string.h>
26 #ifndef _WIN32
27 #include <sys/ioctl.h>
28 #endif
29 #if defined(__sun__) || defined(__HAIKU__)
30 #include <sys/ioccom.h>
31 #endif
32 #include <ctype.h>
33 #include <inttypes.h>
34
35 #ifdef __linux__
36 #include <linux/fs.h>
37 #endif
38
39 #include "qemu_socket.h"
40 #include "qemu-queue.h"
41
42 //#define DEBUG_NBD
43
44 #ifdef DEBUG_NBD
45 #define TRACE(msg, ...) do { \
46     LOG(msg, ## __VA_ARGS__); \
47 } while(0)
48 #else
49 #define TRACE(msg, ...) \
50     do { } while (0)
51 #endif
52
53 #define LOG(msg, ...) do { \
54     fprintf(stderr, "%s:%s():L%d: " msg "\n", \
55             __FILE__, __FUNCTION__, __LINE__, ## __VA_ARGS__); \
56 } while(0)
57
58 /* This is all part of the "official" NBD API */
59
60 #define NBD_REQUEST_SIZE        (4 + 4 + 8 + 8 + 4)
61 #define NBD_REPLY_SIZE          (4 + 4 + 8)
62 #define NBD_REQUEST_MAGIC       0x25609513
63 #define NBD_REPLY_MAGIC         0x67446698
64 #define NBD_OPTS_MAGIC          0x49484156454F5054LL
65 #define NBD_CLIENT_MAGIC        0x0000420281861253LL
66
67 #define NBD_SET_SOCK            _IO(0xab, 0)
68 #define NBD_SET_BLKSIZE         _IO(0xab, 1)
69 #define NBD_SET_SIZE            _IO(0xab, 2)
70 #define NBD_DO_IT               _IO(0xab, 3)
71 #define NBD_CLEAR_SOCK          _IO(0xab, 4)
72 #define NBD_CLEAR_QUE           _IO(0xab, 5)
73 #define NBD_PRINT_DEBUG         _IO(0xab, 6)
74 #define NBD_SET_SIZE_BLOCKS     _IO(0xab, 7)
75 #define NBD_DISCONNECT          _IO(0xab, 8)
76 #define NBD_SET_TIMEOUT         _IO(0xab, 9)
77 #define NBD_SET_FLAGS           _IO(0xab, 10)
78
79 #define NBD_OPT_EXPORT_NAME     (1 << 0)
80
81 /* Definitions for opaque data types */
82
83 typedef struct NBDRequest NBDRequest;
84
85 struct NBDRequest {
86     QSIMPLEQ_ENTRY(NBDRequest) entry;
87     NBDClient *client;
88     uint8_t *data;
89 };
90
91 struct NBDExport {
92     int refcount;
93     void (*close)(NBDExport *exp);
94
95     BlockDriverState *bs;
96     char *name;
97     off_t dev_offset;
98     off_t size;
99     uint32_t nbdflags;
100     QTAILQ_HEAD(, NBDClient) clients;
101     QSIMPLEQ_HEAD(, NBDRequest) requests;
102     QTAILQ_ENTRY(NBDExport) next;
103 };
104
105 static QTAILQ_HEAD(, NBDExport) exports = QTAILQ_HEAD_INITIALIZER(exports);
106
107 struct NBDClient {
108     int refcount;
109     void (*close)(NBDClient *client);
110
111     NBDExport *exp;
112     int sock;
113
114     Coroutine *recv_coroutine;
115
116     CoMutex send_lock;
117     Coroutine *send_coroutine;
118
119     QTAILQ_ENTRY(NBDClient) next;
120     int nb_requests;
121     bool closing;
122 };
123
124 /* That's all folks */
125
126 ssize_t nbd_wr_sync(int fd, void *buffer, size_t size, bool do_read)
127 {
128     size_t offset = 0;
129     int err;
130
131     if (qemu_in_coroutine()) {
132         if (do_read) {
133             return qemu_co_recv(fd, buffer, size);
134         } else {
135             return qemu_co_send(fd, buffer, size);
136         }
137     }
138
139     while (offset < size) {
140         ssize_t len;
141
142         if (do_read) {
143             len = qemu_recv(fd, buffer + offset, size - offset, 0);
144         } else {
145             len = send(fd, buffer + offset, size - offset, 0);
146         }
147
148         if (len < 0) {
149             err = socket_error();
150
151             /* recoverable error */
152             if (err == EINTR || (offset > 0 && err == EAGAIN)) {
153                 continue;
154             }
155
156             /* unrecoverable error */
157             return -err;
158         }
159
160         /* eof */
161         if (len == 0) {
162             break;
163         }
164
165         offset += len;
166     }
167
168     return offset;
169 }
170
171 static ssize_t read_sync(int fd, void *buffer, size_t size)
172 {
173     /* Sockets are kept in blocking mode in the negotiation phase.  After
174      * that, a non-readable socket simply means that another thread stole
175      * our request/reply.  Synchronization is done with recv_coroutine, so
176      * that this is coroutine-safe.
177      */
178     return nbd_wr_sync(fd, buffer, size, true);
179 }
180
181 static ssize_t write_sync(int fd, void *buffer, size_t size)
182 {
183     int ret;
184     do {
185         /* For writes, we do expect the socket to be writable.  */
186         ret = nbd_wr_sync(fd, buffer, size, false);
187     } while (ret == -EAGAIN);
188     return ret;
189 }
190
191 static void combine_addr(char *buf, size_t len, const char* address,
192                          uint16_t port)
193 {
194     /* If the address-part contains a colon, it's an IPv6 IP so needs [] */
195     if (strstr(address, ":")) {
196         snprintf(buf, len, "[%s]:%u", address, port);
197     } else {
198         snprintf(buf, len, "%s:%u", address, port);
199     }
200 }
201
202 int tcp_socket_outgoing(const char *address, uint16_t port)
203 {
204     char address_and_port[128];
205     combine_addr(address_and_port, 128, address, port);
206     return tcp_socket_outgoing_spec(address_and_port);
207 }
208
209 int tcp_socket_outgoing_spec(const char *address_and_port)
210 {
211     return inet_connect(address_and_port, true, NULL, NULL);
212 }
213
214 int tcp_socket_incoming(const char *address, uint16_t port)
215 {
216     char address_and_port[128];
217     combine_addr(address_and_port, 128, address, port);
218     return tcp_socket_incoming_spec(address_and_port);
219 }
220
221 int tcp_socket_incoming_spec(const char *address_and_port)
222 {
223     char *ostr  = NULL;
224     int olen = 0;
225     return inet_listen(address_and_port, ostr, olen, SOCK_STREAM, 0, NULL);
226 }
227
228 int unix_socket_incoming(const char *path)
229 {
230     char *ostr = NULL;
231     int olen = 0;
232
233     return unix_listen(path, ostr, olen);
234 }
235
236 int unix_socket_outgoing(const char *path)
237 {
238     return unix_connect(path);
239 }
240
241 /* Basic flow for negotiation
242
243    Server         Client
244    Negotiate
245
246    or
247
248    Server         Client
249    Negotiate #1
250                   Option
251    Negotiate #2
252
253    ----
254
255    followed by
256
257    Server         Client
258                   Request
259    Response
260                   Request
261    Response
262                   ...
263    ...
264                   Request (type == 2)
265
266 */
267
268 static int nbd_receive_options(NBDClient *client)
269 {
270     int csock = client->sock;
271     char name[256];
272     uint32_t tmp, length;
273     uint64_t magic;
274     int rc;
275
276     /* Client sends:
277         [ 0 ..   3]   reserved (0)
278         [ 4 ..  11]   NBD_OPTS_MAGIC
279         [12 ..  15]   NBD_OPT_EXPORT_NAME
280         [16 ..  19]   length
281         [20 ..  xx]   export name (length bytes)
282      */
283
284     rc = -EINVAL;
285     if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
286         LOG("read failed");
287         goto fail;
288     }
289     TRACE("Checking reserved");
290     if (tmp != 0) {
291         LOG("Bad reserved received");
292         goto fail;
293     }
294
295     if (read_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
296         LOG("read failed");
297         goto fail;
298     }
299     TRACE("Checking reserved");
300     if (magic != be64_to_cpu(NBD_OPTS_MAGIC)) {
301         LOG("Bad magic received");
302         goto fail;
303     }
304
305     if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
306         LOG("read failed");
307         goto fail;
308     }
309     TRACE("Checking option");
310     if (tmp != be32_to_cpu(NBD_OPT_EXPORT_NAME)) {
311         LOG("Bad option received");
312         goto fail;
313     }
314
315     if (read_sync(csock, &length, sizeof(length)) != sizeof(length)) {
316         LOG("read failed");
317         goto fail;
318     }
319     TRACE("Checking length");
320     length = be32_to_cpu(length);
321     if (length > 255) {
322         LOG("Bad length received");
323         goto fail;
324     }
325     if (read_sync(csock, name, length) != length) {
326         LOG("read failed");
327         goto fail;
328     }
329     name[length] = '\0';
330
331     client->exp = nbd_export_find(name);
332     if (!client->exp) {
333         LOG("export not found");
334         goto fail;
335     }
336
337     QTAILQ_INSERT_TAIL(&client->exp->clients, client, next);
338     nbd_export_get(client->exp);
339
340     TRACE("Option negotiation succeeded.");
341     rc = 0;
342 fail:
343     return rc;
344 }
345
346 static int nbd_send_negotiate(NBDClient *client)
347 {
348     int csock = client->sock;
349     char buf[8 + 8 + 8 + 128];
350     int rc;
351     const int myflags = (NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_TRIM |
352                          NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA);
353
354     /* Negotiation header without options:
355         [ 0 ..   7]   passwd       ("NBDMAGIC")
356         [ 8 ..  15]   magic        (NBD_CLIENT_MAGIC)
357         [16 ..  23]   size
358         [24 ..  25]   server flags (0)
359         [24 ..  27]   export flags
360         [28 .. 151]   reserved     (0)
361
362        Negotiation header with options, part 1:
363         [ 0 ..   7]   passwd       ("NBDMAGIC")
364         [ 8 ..  15]   magic        (NBD_OPTS_MAGIC)
365         [16 ..  17]   server flags (0)
366
367        part 2 (after options are sent):
368         [18 ..  25]   size
369         [26 ..  27]   export flags
370         [28 .. 151]   reserved     (0)
371      */
372
373     socket_set_block(csock);
374     rc = -EINVAL;
375
376     TRACE("Beginning negotiation.");
377     memcpy(buf, "NBDMAGIC", 8);
378     if (client->exp) {
379         assert ((client->exp->nbdflags & ~65535) == 0);
380         cpu_to_be64w((uint64_t*)(buf + 8), NBD_CLIENT_MAGIC);
381         cpu_to_be64w((uint64_t*)(buf + 16), client->exp->size);
382         cpu_to_be16w((uint16_t*)(buf + 26), client->exp->nbdflags | myflags);
383     } else {
384         cpu_to_be64w((uint64_t*)(buf + 8), NBD_OPTS_MAGIC);
385     }
386     memset(buf + 28, 0, 124);
387
388     if (client->exp) {
389         if (write_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
390             LOG("write failed");
391             goto fail;
392         }
393     } else {
394         if (write_sync(csock, buf, 18) != 18) {
395             LOG("write failed");
396             goto fail;
397         }
398         rc = nbd_receive_options(client);
399         if (rc < 0) {
400             LOG("option negotiation failed");
401             goto fail;
402         }
403
404         assert ((client->exp->nbdflags & ~65535) == 0);
405         cpu_to_be64w((uint64_t*)(buf + 18), client->exp->size);
406         cpu_to_be16w((uint16_t*)(buf + 26), client->exp->nbdflags | myflags);
407         if (write_sync(csock, buf + 18, sizeof(buf) - 18) != sizeof(buf) - 18) {
408             LOG("write failed");
409             goto fail;
410         }
411     }
412
413     TRACE("Negotiation succeeded.");
414     rc = 0;
415 fail:
416     socket_set_nonblock(csock);
417     return rc;
418 }
419
420 int nbd_receive_negotiate(int csock, const char *name, uint32_t *flags,
421                           off_t *size, size_t *blocksize)
422 {
423     char buf[256];
424     uint64_t magic, s;
425     uint16_t tmp;
426     int rc;
427
428     TRACE("Receiving negotiation.");
429
430     socket_set_block(csock);
431     rc = -EINVAL;
432
433     if (read_sync(csock, buf, 8) != 8) {
434         LOG("read failed");
435         goto fail;
436     }
437
438     buf[8] = '\0';
439     if (strlen(buf) == 0) {
440         LOG("server connection closed");
441         goto fail;
442     }
443
444     TRACE("Magic is %c%c%c%c%c%c%c%c",
445           qemu_isprint(buf[0]) ? buf[0] : '.',
446           qemu_isprint(buf[1]) ? buf[1] : '.',
447           qemu_isprint(buf[2]) ? buf[2] : '.',
448           qemu_isprint(buf[3]) ? buf[3] : '.',
449           qemu_isprint(buf[4]) ? buf[4] : '.',
450           qemu_isprint(buf[5]) ? buf[5] : '.',
451           qemu_isprint(buf[6]) ? buf[6] : '.',
452           qemu_isprint(buf[7]) ? buf[7] : '.');
453
454     if (memcmp(buf, "NBDMAGIC", 8) != 0) {
455         LOG("Invalid magic received");
456         goto fail;
457     }
458
459     if (read_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
460         LOG("read failed");
461         goto fail;
462     }
463     magic = be64_to_cpu(magic);
464     TRACE("Magic is 0x%" PRIx64, magic);
465
466     if (name) {
467         uint32_t reserved = 0;
468         uint32_t opt;
469         uint32_t namesize;
470
471         TRACE("Checking magic (opts_magic)");
472         if (magic != NBD_OPTS_MAGIC) {
473             LOG("Bad magic received");
474             goto fail;
475         }
476         if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
477             LOG("flags read failed");
478             goto fail;
479         }
480         *flags = be16_to_cpu(tmp) << 16;
481         /* reserved for future use */
482         if (write_sync(csock, &reserved, sizeof(reserved)) !=
483             sizeof(reserved)) {
484             LOG("write failed (reserved)");
485             goto fail;
486         }
487         /* write the export name */
488         magic = cpu_to_be64(magic);
489         if (write_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
490             LOG("write failed (magic)");
491             goto fail;
492         }
493         opt = cpu_to_be32(NBD_OPT_EXPORT_NAME);
494         if (write_sync(csock, &opt, sizeof(opt)) != sizeof(opt)) {
495             LOG("write failed (opt)");
496             goto fail;
497         }
498         namesize = cpu_to_be32(strlen(name));
499         if (write_sync(csock, &namesize, sizeof(namesize)) !=
500             sizeof(namesize)) {
501             LOG("write failed (namesize)");
502             goto fail;
503         }
504         if (write_sync(csock, (char*)name, strlen(name)) != strlen(name)) {
505             LOG("write failed (name)");
506             goto fail;
507         }
508     } else {
509         TRACE("Checking magic (cli_magic)");
510
511         if (magic != NBD_CLIENT_MAGIC) {
512             LOG("Bad magic received");
513             goto fail;
514         }
515     }
516
517     if (read_sync(csock, &s, sizeof(s)) != sizeof(s)) {
518         LOG("read failed");
519         goto fail;
520     }
521     *size = be64_to_cpu(s);
522     *blocksize = 1024;
523     TRACE("Size is %" PRIu64, *size);
524
525     if (!name) {
526         if (read_sync(csock, flags, sizeof(*flags)) != sizeof(*flags)) {
527             LOG("read failed (flags)");
528             goto fail;
529         }
530         *flags = be32_to_cpup(flags);
531     } else {
532         if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
533             LOG("read failed (tmp)");
534             goto fail;
535         }
536         *flags |= be32_to_cpu(tmp);
537     }
538     if (read_sync(csock, &buf, 124) != 124) {
539         LOG("read failed (buf)");
540         goto fail;
541     }
542     rc = 0;
543
544 fail:
545     socket_set_nonblock(csock);
546     return rc;
547 }
548
549 #ifdef __linux__
550 int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
551 {
552     TRACE("Setting NBD socket");
553
554     if (ioctl(fd, NBD_SET_SOCK, csock) < 0) {
555         int serrno = errno;
556         LOG("Failed to set NBD socket");
557         return -serrno;
558     }
559
560     TRACE("Setting block size to %lu", (unsigned long)blocksize);
561
562     if (ioctl(fd, NBD_SET_BLKSIZE, blocksize) < 0) {
563         int serrno = errno;
564         LOG("Failed setting NBD block size");
565         return -serrno;
566     }
567
568         TRACE("Setting size to %zd block(s)", (size_t)(size / blocksize));
569
570     if (ioctl(fd, NBD_SET_SIZE_BLOCKS, size / blocksize) < 0) {
571         int serrno = errno;
572         LOG("Failed setting size (in blocks)");
573         return -serrno;
574     }
575
576     if (flags & NBD_FLAG_READ_ONLY) {
577         int read_only = 1;
578         TRACE("Setting readonly attribute");
579
580         if (ioctl(fd, BLKROSET, (unsigned long) &read_only) < 0) {
581             int serrno = errno;
582             LOG("Failed setting read-only attribute");
583             return -serrno;
584         }
585     }
586
587     if (ioctl(fd, NBD_SET_FLAGS, flags) < 0
588         && errno != ENOTTY) {
589         int serrno = errno;
590         LOG("Failed setting flags");
591         return -serrno;
592     }
593
594     TRACE("Negotiation ended");
595
596     return 0;
597 }
598
599 int nbd_disconnect(int fd)
600 {
601     ioctl(fd, NBD_CLEAR_QUE);
602     ioctl(fd, NBD_DISCONNECT);
603     ioctl(fd, NBD_CLEAR_SOCK);
604     return 0;
605 }
606
607 int nbd_client(int fd)
608 {
609     int ret;
610     int serrno;
611
612     TRACE("Doing NBD loop");
613
614     ret = ioctl(fd, NBD_DO_IT);
615     if (ret < 0 && errno == EPIPE) {
616         /* NBD_DO_IT normally returns EPIPE when someone has disconnected
617          * the socket via NBD_DISCONNECT.  We do not want to return 1 in
618          * that case.
619          */
620         ret = 0;
621     }
622     serrno = errno;
623
624     TRACE("NBD loop returned %d: %s", ret, strerror(serrno));
625
626     TRACE("Clearing NBD queue");
627     ioctl(fd, NBD_CLEAR_QUE);
628
629     TRACE("Clearing NBD socket");
630     ioctl(fd, NBD_CLEAR_SOCK);
631
632     errno = serrno;
633     return ret;
634 }
635 #else
636 int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
637 {
638     return -ENOTSUP;
639 }
640
641 int nbd_disconnect(int fd)
642 {
643     return -ENOTSUP;
644 }
645
646 int nbd_client(int fd)
647 {
648     return -ENOTSUP;
649 }
650 #endif
651
652 ssize_t nbd_send_request(int csock, struct nbd_request *request)
653 {
654     uint8_t buf[NBD_REQUEST_SIZE];
655     ssize_t ret;
656
657     cpu_to_be32w((uint32_t*)buf, NBD_REQUEST_MAGIC);
658     cpu_to_be32w((uint32_t*)(buf + 4), request->type);
659     cpu_to_be64w((uint64_t*)(buf + 8), request->handle);
660     cpu_to_be64w((uint64_t*)(buf + 16), request->from);
661     cpu_to_be32w((uint32_t*)(buf + 24), request->len);
662
663     TRACE("Sending request to client: "
664           "{ .from = %" PRIu64", .len = %u, .handle = %" PRIu64", .type=%i}",
665           request->from, request->len, request->handle, request->type);
666
667     ret = write_sync(csock, buf, sizeof(buf));
668     if (ret < 0) {
669         return ret;
670     }
671
672     if (ret != sizeof(buf)) {
673         LOG("writing to socket failed");
674         return -EINVAL;
675     }
676     return 0;
677 }
678
679 static ssize_t nbd_receive_request(int csock, struct nbd_request *request)
680 {
681     uint8_t buf[NBD_REQUEST_SIZE];
682     uint32_t magic;
683     ssize_t ret;
684
685     ret = read_sync(csock, buf, sizeof(buf));
686     if (ret < 0) {
687         return ret;
688     }
689
690     if (ret != sizeof(buf)) {
691         LOG("read failed");
692         return -EINVAL;
693     }
694
695     /* Request
696        [ 0 ..  3]   magic   (NBD_REQUEST_MAGIC)
697        [ 4 ..  7]   type    (0 == READ, 1 == WRITE)
698        [ 8 .. 15]   handle
699        [16 .. 23]   from
700        [24 .. 27]   len
701      */
702
703     magic = be32_to_cpup((uint32_t*)buf);
704     request->type  = be32_to_cpup((uint32_t*)(buf + 4));
705     request->handle = be64_to_cpup((uint64_t*)(buf + 8));
706     request->from  = be64_to_cpup((uint64_t*)(buf + 16));
707     request->len   = be32_to_cpup((uint32_t*)(buf + 24));
708
709     TRACE("Got request: "
710           "{ magic = 0x%x, .type = %d, from = %" PRIu64" , len = %u }",
711           magic, request->type, request->from, request->len);
712
713     if (magic != NBD_REQUEST_MAGIC) {
714         LOG("invalid magic (got 0x%x)", magic);
715         return -EINVAL;
716     }
717     return 0;
718 }
719
720 ssize_t nbd_receive_reply(int csock, struct nbd_reply *reply)
721 {
722     uint8_t buf[NBD_REPLY_SIZE];
723     uint32_t magic;
724     ssize_t ret;
725
726     ret = read_sync(csock, buf, sizeof(buf));
727     if (ret < 0) {
728         return ret;
729     }
730
731     if (ret != sizeof(buf)) {
732         LOG("read failed");
733         return -EINVAL;
734     }
735
736     /* Reply
737        [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
738        [ 4 ..  7]    error   (0 == no error)
739        [ 7 .. 15]    handle
740      */
741
742     magic = be32_to_cpup((uint32_t*)buf);
743     reply->error  = be32_to_cpup((uint32_t*)(buf + 4));
744     reply->handle = be64_to_cpup((uint64_t*)(buf + 8));
745
746     TRACE("Got reply: "
747           "{ magic = 0x%x, .error = %d, handle = %" PRIu64" }",
748           magic, reply->error, reply->handle);
749
750     if (magic != NBD_REPLY_MAGIC) {
751         LOG("invalid magic (got 0x%x)", magic);
752         return -EINVAL;
753     }
754     return 0;
755 }
756
757 static ssize_t nbd_send_reply(int csock, struct nbd_reply *reply)
758 {
759     uint8_t buf[NBD_REPLY_SIZE];
760     ssize_t ret;
761
762     /* Reply
763        [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
764        [ 4 ..  7]    error   (0 == no error)
765        [ 7 .. 15]    handle
766      */
767     cpu_to_be32w((uint32_t*)buf, NBD_REPLY_MAGIC);
768     cpu_to_be32w((uint32_t*)(buf + 4), reply->error);
769     cpu_to_be64w((uint64_t*)(buf + 8), reply->handle);
770
771     TRACE("Sending response to client");
772
773     ret = write_sync(csock, buf, sizeof(buf));
774     if (ret < 0) {
775         return ret;
776     }
777
778     if (ret != sizeof(buf)) {
779         LOG("writing to socket failed");
780         return -EINVAL;
781     }
782     return 0;
783 }
784
785 #define MAX_NBD_REQUESTS 16
786
787 void nbd_client_get(NBDClient *client)
788 {
789     client->refcount++;
790 }
791
792 void nbd_client_put(NBDClient *client)
793 {
794     if (--client->refcount == 0) {
795         /* The last reference should be dropped by client->close,
796          * which is called by nbd_client_close.
797          */
798         assert(client->closing);
799
800         qemu_set_fd_handler2(client->sock, NULL, NULL, NULL, NULL);
801         close(client->sock);
802         client->sock = -1;
803         if (client->exp) {
804             QTAILQ_REMOVE(&client->exp->clients, client, next);
805             nbd_export_put(client->exp);
806         }
807         g_free(client);
808     }
809 }
810
811 void nbd_client_close(NBDClient *client)
812 {
813     if (client->closing) {
814         return;
815     }
816
817     client->closing = true;
818
819     /* Force requests to finish.  They will drop their own references,
820      * then we'll close the socket and free the NBDClient.
821      */
822     shutdown(client->sock, 2);
823
824     /* Also tell the client, so that they release their reference.  */
825     if (client->close) {
826         client->close(client);
827     }
828 }
829
830 static NBDRequest *nbd_request_get(NBDClient *client)
831 {
832     NBDRequest *req;
833     NBDExport *exp = client->exp;
834
835     assert(client->nb_requests <= MAX_NBD_REQUESTS - 1);
836     client->nb_requests++;
837
838     if (QSIMPLEQ_EMPTY(&exp->requests)) {
839         req = g_malloc0(sizeof(NBDRequest));
840         req->data = qemu_blockalign(exp->bs, NBD_BUFFER_SIZE);
841     } else {
842         req = QSIMPLEQ_FIRST(&exp->requests);
843         QSIMPLEQ_REMOVE_HEAD(&exp->requests, entry);
844     }
845     nbd_client_get(client);
846     req->client = client;
847     return req;
848 }
849
850 static void nbd_request_put(NBDRequest *req)
851 {
852     NBDClient *client = req->client;
853     QSIMPLEQ_INSERT_HEAD(&client->exp->requests, req, entry);
854     if (client->nb_requests-- == MAX_NBD_REQUESTS) {
855         qemu_notify_event();
856     }
857     nbd_client_put(client);
858 }
859
860 NBDExport *nbd_export_new(BlockDriverState *bs, off_t dev_offset,
861                           off_t size, uint32_t nbdflags,
862                           void (*close)(NBDExport *))
863 {
864     NBDExport *exp = g_malloc0(sizeof(NBDExport));
865     QSIMPLEQ_INIT(&exp->requests);
866     exp->refcount = 1;
867     QTAILQ_INIT(&exp->clients);
868     exp->bs = bs;
869     exp->dev_offset = dev_offset;
870     exp->nbdflags = nbdflags;
871     exp->size = size == -1 ? bdrv_getlength(bs) : size;
872     exp->close = close;
873     return exp;
874 }
875
876 NBDExport *nbd_export_find(const char *name)
877 {
878     NBDExport *exp;
879     QTAILQ_FOREACH(exp, &exports, next) {
880         if (strcmp(name, exp->name) == 0) {
881             return exp;
882         }
883     }
884
885     return NULL;
886 }
887
888 void nbd_export_set_name(NBDExport *exp, const char *name)
889 {
890     if (exp->name == name) {
891         return;
892     }
893
894     nbd_export_get(exp);
895     if (exp->name != NULL) {
896         g_free(exp->name);
897         exp->name = NULL;
898         QTAILQ_REMOVE(&exports, exp, next);
899         nbd_export_put(exp);
900     }
901     if (name != NULL) {
902         nbd_export_get(exp);
903         exp->name = g_strdup(name);
904         QTAILQ_INSERT_TAIL(&exports, exp, next);
905     }
906     nbd_export_put(exp);
907 }
908
909 void nbd_export_close(NBDExport *exp)
910 {
911     NBDClient *client, *next;
912
913     nbd_export_get(exp);
914     QTAILQ_FOREACH_SAFE(client, &exp->clients, next, next) {
915         nbd_client_close(client);
916     }
917     nbd_export_set_name(exp, NULL);
918     nbd_export_put(exp);
919 }
920
921 void nbd_export_get(NBDExport *exp)
922 {
923     assert(exp->refcount > 0);
924     exp->refcount++;
925 }
926
927 void nbd_export_put(NBDExport *exp)
928 {
929     assert(exp->refcount > 0);
930     if (exp->refcount == 1) {
931         nbd_export_close(exp);
932     }
933
934     if (--exp->refcount == 0) {
935         assert(exp->name == NULL);
936
937         if (exp->close) {
938             exp->close(exp);
939         }
940
941         while (!QSIMPLEQ_EMPTY(&exp->requests)) {
942             NBDRequest *first = QSIMPLEQ_FIRST(&exp->requests);
943             QSIMPLEQ_REMOVE_HEAD(&exp->requests, entry);
944             qemu_vfree(first->data);
945             g_free(first);
946         }
947
948         g_free(exp);
949     }
950 }
951
952 BlockDriverState *nbd_export_get_blockdev(NBDExport *exp)
953 {
954     return exp->bs;
955 }
956
957 void nbd_export_close_all(void)
958 {
959     NBDExport *exp, *next;
960
961     QTAILQ_FOREACH_SAFE(exp, &exports, next, next) {
962         nbd_export_close(exp);
963     }
964 }
965
966 static int nbd_can_read(void *opaque);
967 static void nbd_read(void *opaque);
968 static void nbd_restart_write(void *opaque);
969
970 static ssize_t nbd_co_send_reply(NBDRequest *req, struct nbd_reply *reply,
971                                  int len)
972 {
973     NBDClient *client = req->client;
974     int csock = client->sock;
975     ssize_t rc, ret;
976
977     qemu_co_mutex_lock(&client->send_lock);
978     qemu_set_fd_handler2(csock, nbd_can_read, nbd_read,
979                          nbd_restart_write, client);
980     client->send_coroutine = qemu_coroutine_self();
981
982     if (!len) {
983         rc = nbd_send_reply(csock, reply);
984     } else {
985         socket_set_cork(csock, 1);
986         rc = nbd_send_reply(csock, reply);
987         if (rc >= 0) {
988             ret = qemu_co_send(csock, req->data, len);
989             if (ret != len) {
990                 rc = -EIO;
991             }
992         }
993         socket_set_cork(csock, 0);
994     }
995
996     client->send_coroutine = NULL;
997     qemu_set_fd_handler2(csock, nbd_can_read, nbd_read, NULL, client);
998     qemu_co_mutex_unlock(&client->send_lock);
999     return rc;
1000 }
1001
1002 static ssize_t nbd_co_receive_request(NBDRequest *req, struct nbd_request *request)
1003 {
1004     NBDClient *client = req->client;
1005     int csock = client->sock;
1006     ssize_t rc;
1007
1008     client->recv_coroutine = qemu_coroutine_self();
1009     rc = nbd_receive_request(csock, request);
1010     if (rc < 0) {
1011         if (rc != -EAGAIN) {
1012             rc = -EIO;
1013         }
1014         goto out;
1015     }
1016
1017     if (request->len > NBD_BUFFER_SIZE) {
1018         LOG("len (%u) is larger than max len (%u)",
1019             request->len, NBD_BUFFER_SIZE);
1020         rc = -EINVAL;
1021         goto out;
1022     }
1023
1024     if ((request->from + request->len) < request->from) {
1025         LOG("integer overflow detected! "
1026             "you're probably being attacked");
1027         rc = -EINVAL;
1028         goto out;
1029     }
1030
1031     TRACE("Decoding type");
1032
1033     if ((request->type & NBD_CMD_MASK_COMMAND) == NBD_CMD_WRITE) {
1034         TRACE("Reading %u byte(s)", request->len);
1035
1036         if (qemu_co_recv(csock, req->data, request->len) != request->len) {
1037             LOG("reading from socket failed");
1038             rc = -EIO;
1039             goto out;
1040         }
1041     }
1042     rc = 0;
1043
1044 out:
1045     client->recv_coroutine = NULL;
1046     return rc;
1047 }
1048
1049 static void nbd_trip(void *opaque)
1050 {
1051     NBDClient *client = opaque;
1052     NBDExport *exp = client->exp;
1053     NBDRequest *req;
1054     struct nbd_request request;
1055     struct nbd_reply reply;
1056     ssize_t ret;
1057
1058     TRACE("Reading request.");
1059     if (client->closing) {
1060         return;
1061     }
1062
1063     req = nbd_request_get(client);
1064     ret = nbd_co_receive_request(req, &request);
1065     if (ret == -EAGAIN) {
1066         goto done;
1067     }
1068     if (ret == -EIO) {
1069         goto out;
1070     }
1071
1072     reply.handle = request.handle;
1073     reply.error = 0;
1074
1075     if (ret < 0) {
1076         reply.error = -ret;
1077         goto error_reply;
1078     }
1079
1080     if ((request.from + request.len) > exp->size) {
1081             LOG("From: %" PRIu64 ", Len: %u, Size: %" PRIu64
1082             ", Offset: %" PRIu64 "\n",
1083                     request.from, request.len,
1084                     (uint64_t)exp->size, (uint64_t)exp->dev_offset);
1085         LOG("requested operation past EOF--bad client?");
1086         goto invalid_request;
1087     }
1088
1089     switch (request.type & NBD_CMD_MASK_COMMAND) {
1090     case NBD_CMD_READ:
1091         TRACE("Request type is READ");
1092
1093         if (request.type & NBD_CMD_FLAG_FUA) {
1094             ret = bdrv_co_flush(exp->bs);
1095             if (ret < 0) {
1096                 LOG("flush failed");
1097                 reply.error = -ret;
1098                 goto error_reply;
1099             }
1100         }
1101
1102         ret = bdrv_read(exp->bs, (request.from + exp->dev_offset) / 512,
1103                         req->data, request.len / 512);
1104         if (ret < 0) {
1105             LOG("reading from file failed");
1106             reply.error = -ret;
1107             goto error_reply;
1108         }
1109
1110         TRACE("Read %u byte(s)", request.len);
1111         if (nbd_co_send_reply(req, &reply, request.len) < 0)
1112             goto out;
1113         break;
1114     case NBD_CMD_WRITE:
1115         TRACE("Request type is WRITE");
1116
1117         if (exp->nbdflags & NBD_FLAG_READ_ONLY) {
1118             TRACE("Server is read-only, return error");
1119             reply.error = EROFS;
1120             goto error_reply;
1121         }
1122
1123         TRACE("Writing to device");
1124
1125         ret = bdrv_write(exp->bs, (request.from + exp->dev_offset) / 512,
1126                          req->data, request.len / 512);
1127         if (ret < 0) {
1128             LOG("writing to file failed");
1129             reply.error = -ret;
1130             goto error_reply;
1131         }
1132
1133         if (request.type & NBD_CMD_FLAG_FUA) {
1134             ret = bdrv_co_flush(exp->bs);
1135             if (ret < 0) {
1136                 LOG("flush failed");
1137                 reply.error = -ret;
1138                 goto error_reply;
1139             }
1140         }
1141
1142         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1143             goto out;
1144         }
1145         break;
1146     case NBD_CMD_DISC:
1147         TRACE("Request type is DISCONNECT");
1148         errno = 0;
1149         goto out;
1150     case NBD_CMD_FLUSH:
1151         TRACE("Request type is FLUSH");
1152
1153         ret = bdrv_co_flush(exp->bs);
1154         if (ret < 0) {
1155             LOG("flush failed");
1156             reply.error = -ret;
1157         }
1158         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1159             goto out;
1160         }
1161         break;
1162     case NBD_CMD_TRIM:
1163         TRACE("Request type is TRIM");
1164         ret = bdrv_co_discard(exp->bs, (request.from + exp->dev_offset) / 512,
1165                               request.len / 512);
1166         if (ret < 0) {
1167             LOG("discard failed");
1168             reply.error = -ret;
1169         }
1170         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1171             goto out;
1172         }
1173         break;
1174     default:
1175         LOG("invalid request type (%u) received", request.type);
1176     invalid_request:
1177         reply.error = -EINVAL;
1178     error_reply:
1179         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1180             goto out;
1181         }
1182         break;
1183     }
1184
1185     TRACE("Request/Reply complete");
1186
1187 done:
1188     nbd_request_put(req);
1189     return;
1190
1191 out:
1192     nbd_request_put(req);
1193     nbd_client_close(client);
1194 }
1195
1196 static int nbd_can_read(void *opaque)
1197 {
1198     NBDClient *client = opaque;
1199
1200     return client->recv_coroutine || client->nb_requests < MAX_NBD_REQUESTS;
1201 }
1202
1203 static void nbd_read(void *opaque)
1204 {
1205     NBDClient *client = opaque;
1206
1207     if (client->recv_coroutine) {
1208         qemu_coroutine_enter(client->recv_coroutine, NULL);
1209     } else {
1210         qemu_coroutine_enter(qemu_coroutine_create(nbd_trip), client);
1211     }
1212 }
1213
1214 static void nbd_restart_write(void *opaque)
1215 {
1216     NBDClient *client = opaque;
1217
1218     qemu_coroutine_enter(client->send_coroutine, NULL);
1219 }
1220
1221 NBDClient *nbd_client_new(NBDExport *exp, int csock,
1222                           void (*close)(NBDClient *))
1223 {
1224     NBDClient *client;
1225     client = g_malloc0(sizeof(NBDClient));
1226     client->refcount = 1;
1227     client->exp = exp;
1228     client->sock = csock;
1229     if (nbd_send_negotiate(client) < 0) {
1230         g_free(client);
1231         return NULL;
1232     }
1233     client->close = close;
1234     qemu_co_mutex_init(&client->send_lock);
1235     qemu_set_fd_handler2(csock, nbd_can_read, nbd_read, NULL, client);
1236
1237     if (exp) {
1238         QTAILQ_INSERT_TAIL(&exp->clients, client, next);
1239         nbd_export_get(exp);
1240     }
1241     return client;
1242 }