nbd: track clients into NBDExport
[sdk/emulator/qemu.git] / nbd.c
1 /*
2  *  Copyright (C) 2005  Anthony Liguori <anthony@codemonkey.ws>
3  *
4  *  Network Block Device
5  *
6  *  This program is free software; you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation; under version 2 of the License.
9  *
10  *  This program is distributed in the hope that it will be useful,
11  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  *  GNU General Public License for more details.
14  *
15  *  You should have received a copy of the GNU General Public License
16  *  along with this program; if not, see <http://www.gnu.org/licenses/>.
17  */
18
19 #include "nbd.h"
20 #include "block.h"
21
22 #include "qemu-coroutine.h"
23
24 #include <errno.h>
25 #include <string.h>
26 #ifndef _WIN32
27 #include <sys/ioctl.h>
28 #endif
29 #if defined(__sun__) || defined(__HAIKU__)
30 #include <sys/ioccom.h>
31 #endif
32 #include <ctype.h>
33 #include <inttypes.h>
34
35 #ifdef __linux__
36 #include <linux/fs.h>
37 #endif
38
39 #include "qemu_socket.h"
40 #include "qemu-queue.h"
41
42 //#define DEBUG_NBD
43
44 #ifdef DEBUG_NBD
45 #define TRACE(msg, ...) do { \
46     LOG(msg, ## __VA_ARGS__); \
47 } while(0)
48 #else
49 #define TRACE(msg, ...) \
50     do { } while (0)
51 #endif
52
53 #define LOG(msg, ...) do { \
54     fprintf(stderr, "%s:%s():L%d: " msg "\n", \
55             __FILE__, __FUNCTION__, __LINE__, ## __VA_ARGS__); \
56 } while(0)
57
58 /* This is all part of the "official" NBD API */
59
60 #define NBD_REQUEST_SIZE        (4 + 4 + 8 + 8 + 4)
61 #define NBD_REPLY_SIZE          (4 + 4 + 8)
62 #define NBD_REQUEST_MAGIC       0x25609513
63 #define NBD_REPLY_MAGIC         0x67446698
64 #define NBD_OPTS_MAGIC          0x49484156454F5054LL
65 #define NBD_CLIENT_MAGIC        0x0000420281861253LL
66
67 #define NBD_SET_SOCK            _IO(0xab, 0)
68 #define NBD_SET_BLKSIZE         _IO(0xab, 1)
69 #define NBD_SET_SIZE            _IO(0xab, 2)
70 #define NBD_DO_IT               _IO(0xab, 3)
71 #define NBD_CLEAR_SOCK          _IO(0xab, 4)
72 #define NBD_CLEAR_QUE           _IO(0xab, 5)
73 #define NBD_PRINT_DEBUG         _IO(0xab, 6)
74 #define NBD_SET_SIZE_BLOCKS     _IO(0xab, 7)
75 #define NBD_DISCONNECT          _IO(0xab, 8)
76 #define NBD_SET_TIMEOUT         _IO(0xab, 9)
77 #define NBD_SET_FLAGS           _IO(0xab, 10)
78
79 #define NBD_OPT_EXPORT_NAME     (1 << 0)
80
81 /* Definitions for opaque data types */
82
83 typedef struct NBDRequest NBDRequest;
84
85 struct NBDRequest {
86     QSIMPLEQ_ENTRY(NBDRequest) entry;
87     NBDClient *client;
88     uint8_t *data;
89 };
90
91 struct NBDExport {
92     int refcount;
93     BlockDriverState *bs;
94     off_t dev_offset;
95     off_t size;
96     uint32_t nbdflags;
97     QTAILQ_HEAD(, NBDClient) clients;
98     QSIMPLEQ_HEAD(, NBDRequest) requests;
99 };
100
101 struct NBDClient {
102     int refcount;
103     void (*close)(NBDClient *client);
104
105     NBDExport *exp;
106     int sock;
107
108     Coroutine *recv_coroutine;
109
110     CoMutex send_lock;
111     Coroutine *send_coroutine;
112
113     QTAILQ_ENTRY(NBDClient) next;
114     int nb_requests;
115     bool closing;
116 };
117
118 /* That's all folks */
119
120 ssize_t nbd_wr_sync(int fd, void *buffer, size_t size, bool do_read)
121 {
122     size_t offset = 0;
123     int err;
124
125     if (qemu_in_coroutine()) {
126         if (do_read) {
127             return qemu_co_recv(fd, buffer, size);
128         } else {
129             return qemu_co_send(fd, buffer, size);
130         }
131     }
132
133     while (offset < size) {
134         ssize_t len;
135
136         if (do_read) {
137             len = qemu_recv(fd, buffer + offset, size - offset, 0);
138         } else {
139             len = send(fd, buffer + offset, size - offset, 0);
140         }
141
142         if (len < 0) {
143             err = socket_error();
144
145             /* recoverable error */
146             if (err == EINTR || (offset > 0 && err == EAGAIN)) {
147                 continue;
148             }
149
150             /* unrecoverable error */
151             return -err;
152         }
153
154         /* eof */
155         if (len == 0) {
156             break;
157         }
158
159         offset += len;
160     }
161
162     return offset;
163 }
164
165 static ssize_t read_sync(int fd, void *buffer, size_t size)
166 {
167     /* Sockets are kept in blocking mode in the negotiation phase.  After
168      * that, a non-readable socket simply means that another thread stole
169      * our request/reply.  Synchronization is done with recv_coroutine, so
170      * that this is coroutine-safe.
171      */
172     return nbd_wr_sync(fd, buffer, size, true);
173 }
174
175 static ssize_t write_sync(int fd, void *buffer, size_t size)
176 {
177     int ret;
178     do {
179         /* For writes, we do expect the socket to be writable.  */
180         ret = nbd_wr_sync(fd, buffer, size, false);
181     } while (ret == -EAGAIN);
182     return ret;
183 }
184
185 static void combine_addr(char *buf, size_t len, const char* address,
186                          uint16_t port)
187 {
188     /* If the address-part contains a colon, it's an IPv6 IP so needs [] */
189     if (strstr(address, ":")) {
190         snprintf(buf, len, "[%s]:%u", address, port);
191     } else {
192         snprintf(buf, len, "%s:%u", address, port);
193     }
194 }
195
196 int tcp_socket_outgoing(const char *address, uint16_t port)
197 {
198     char address_and_port[128];
199     combine_addr(address_and_port, 128, address, port);
200     return tcp_socket_outgoing_spec(address_and_port);
201 }
202
203 int tcp_socket_outgoing_spec(const char *address_and_port)
204 {
205     return inet_connect(address_and_port, true, NULL, NULL);
206 }
207
208 int tcp_socket_incoming(const char *address, uint16_t port)
209 {
210     char address_and_port[128];
211     combine_addr(address_and_port, 128, address, port);
212     return tcp_socket_incoming_spec(address_and_port);
213 }
214
215 int tcp_socket_incoming_spec(const char *address_and_port)
216 {
217     char *ostr  = NULL;
218     int olen = 0;
219     return inet_listen(address_and_port, ostr, olen, SOCK_STREAM, 0, NULL);
220 }
221
222 int unix_socket_incoming(const char *path)
223 {
224     char *ostr = NULL;
225     int olen = 0;
226
227     return unix_listen(path, ostr, olen);
228 }
229
230 int unix_socket_outgoing(const char *path)
231 {
232     return unix_connect(path);
233 }
234
235 /* Basic flow
236
237    Server         Client
238
239    Negotiate
240                   Request
241    Response
242                   Request
243    Response
244                   ...
245    ...
246                   Request (type == 2)
247 */
248
249 static int nbd_send_negotiate(NBDClient *client)
250 {
251     int csock = client->sock;
252     char buf[8 + 8 + 8 + 128];
253     int rc;
254
255     /* Negotiate
256         [ 0 ..   7]   passwd   ("NBDMAGIC")
257         [ 8 ..  15]   magic    (NBD_CLIENT_MAGIC)
258         [16 ..  23]   size
259         [24 ..  27]   flags
260         [28 .. 151]   reserved (0)
261      */
262
263     socket_set_block(csock);
264     rc = -EINVAL;
265
266     TRACE("Beginning negotiation.");
267     memcpy(buf, "NBDMAGIC", 8);
268     cpu_to_be64w((uint64_t*)(buf + 8), NBD_CLIENT_MAGIC);
269     cpu_to_be64w((uint64_t*)(buf + 16), client->exp->size);
270     cpu_to_be32w((uint32_t*)(buf + 24),
271                  client->exp->nbdflags | NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_TRIM |
272                  NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA);
273     memset(buf + 28, 0, 124);
274
275     if (write_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
276         LOG("write failed");
277         goto fail;
278     }
279
280     TRACE("Negotiation succeeded.");
281     rc = 0;
282 fail:
283     socket_set_nonblock(csock);
284     return rc;
285 }
286
287 int nbd_receive_negotiate(int csock, const char *name, uint32_t *flags,
288                           off_t *size, size_t *blocksize)
289 {
290     char buf[256];
291     uint64_t magic, s;
292     uint16_t tmp;
293     int rc;
294
295     TRACE("Receiving negotiation.");
296
297     socket_set_block(csock);
298     rc = -EINVAL;
299
300     if (read_sync(csock, buf, 8) != 8) {
301         LOG("read failed");
302         goto fail;
303     }
304
305     buf[8] = '\0';
306     if (strlen(buf) == 0) {
307         LOG("server connection closed");
308         goto fail;
309     }
310
311     TRACE("Magic is %c%c%c%c%c%c%c%c",
312           qemu_isprint(buf[0]) ? buf[0] : '.',
313           qemu_isprint(buf[1]) ? buf[1] : '.',
314           qemu_isprint(buf[2]) ? buf[2] : '.',
315           qemu_isprint(buf[3]) ? buf[3] : '.',
316           qemu_isprint(buf[4]) ? buf[4] : '.',
317           qemu_isprint(buf[5]) ? buf[5] : '.',
318           qemu_isprint(buf[6]) ? buf[6] : '.',
319           qemu_isprint(buf[7]) ? buf[7] : '.');
320
321     if (memcmp(buf, "NBDMAGIC", 8) != 0) {
322         LOG("Invalid magic received");
323         goto fail;
324     }
325
326     if (read_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
327         LOG("read failed");
328         goto fail;
329     }
330     magic = be64_to_cpu(magic);
331     TRACE("Magic is 0x%" PRIx64, magic);
332
333     if (name) {
334         uint32_t reserved = 0;
335         uint32_t opt;
336         uint32_t namesize;
337
338         TRACE("Checking magic (opts_magic)");
339         if (magic != NBD_OPTS_MAGIC) {
340             LOG("Bad magic received");
341             goto fail;
342         }
343         if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
344             LOG("flags read failed");
345             goto fail;
346         }
347         *flags = be16_to_cpu(tmp) << 16;
348         /* reserved for future use */
349         if (write_sync(csock, &reserved, sizeof(reserved)) !=
350             sizeof(reserved)) {
351             LOG("write failed (reserved)");
352             goto fail;
353         }
354         /* write the export name */
355         magic = cpu_to_be64(magic);
356         if (write_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
357             LOG("write failed (magic)");
358             goto fail;
359         }
360         opt = cpu_to_be32(NBD_OPT_EXPORT_NAME);
361         if (write_sync(csock, &opt, sizeof(opt)) != sizeof(opt)) {
362             LOG("write failed (opt)");
363             goto fail;
364         }
365         namesize = cpu_to_be32(strlen(name));
366         if (write_sync(csock, &namesize, sizeof(namesize)) !=
367             sizeof(namesize)) {
368             LOG("write failed (namesize)");
369             goto fail;
370         }
371         if (write_sync(csock, (char*)name, strlen(name)) != strlen(name)) {
372             LOG("write failed (name)");
373             goto fail;
374         }
375     } else {
376         TRACE("Checking magic (cli_magic)");
377
378         if (magic != NBD_CLIENT_MAGIC) {
379             LOG("Bad magic received");
380             goto fail;
381         }
382     }
383
384     if (read_sync(csock, &s, sizeof(s)) != sizeof(s)) {
385         LOG("read failed");
386         goto fail;
387     }
388     *size = be64_to_cpu(s);
389     *blocksize = 1024;
390     TRACE("Size is %" PRIu64, *size);
391
392     if (!name) {
393         if (read_sync(csock, flags, sizeof(*flags)) != sizeof(*flags)) {
394             LOG("read failed (flags)");
395             goto fail;
396         }
397         *flags = be32_to_cpup(flags);
398     } else {
399         if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
400             LOG("read failed (tmp)");
401             goto fail;
402         }
403         *flags |= be32_to_cpu(tmp);
404     }
405     if (read_sync(csock, &buf, 124) != 124) {
406         LOG("read failed (buf)");
407         goto fail;
408     }
409     rc = 0;
410
411 fail:
412     socket_set_nonblock(csock);
413     return rc;
414 }
415
416 #ifdef __linux__
417 int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
418 {
419     TRACE("Setting NBD socket");
420
421     if (ioctl(fd, NBD_SET_SOCK, csock) < 0) {
422         int serrno = errno;
423         LOG("Failed to set NBD socket");
424         return -serrno;
425     }
426
427     TRACE("Setting block size to %lu", (unsigned long)blocksize);
428
429     if (ioctl(fd, NBD_SET_BLKSIZE, blocksize) < 0) {
430         int serrno = errno;
431         LOG("Failed setting NBD block size");
432         return -serrno;
433     }
434
435         TRACE("Setting size to %zd block(s)", (size_t)(size / blocksize));
436
437     if (ioctl(fd, NBD_SET_SIZE_BLOCKS, size / blocksize) < 0) {
438         int serrno = errno;
439         LOG("Failed setting size (in blocks)");
440         return -serrno;
441     }
442
443     if (flags & NBD_FLAG_READ_ONLY) {
444         int read_only = 1;
445         TRACE("Setting readonly attribute");
446
447         if (ioctl(fd, BLKROSET, (unsigned long) &read_only) < 0) {
448             int serrno = errno;
449             LOG("Failed setting read-only attribute");
450             return -serrno;
451         }
452     }
453
454     if (ioctl(fd, NBD_SET_FLAGS, flags) < 0
455         && errno != ENOTTY) {
456         int serrno = errno;
457         LOG("Failed setting flags");
458         return -serrno;
459     }
460
461     TRACE("Negotiation ended");
462
463     return 0;
464 }
465
466 int nbd_disconnect(int fd)
467 {
468     ioctl(fd, NBD_CLEAR_QUE);
469     ioctl(fd, NBD_DISCONNECT);
470     ioctl(fd, NBD_CLEAR_SOCK);
471     return 0;
472 }
473
474 int nbd_client(int fd)
475 {
476     int ret;
477     int serrno;
478
479     TRACE("Doing NBD loop");
480
481     ret = ioctl(fd, NBD_DO_IT);
482     if (ret < 0 && errno == EPIPE) {
483         /* NBD_DO_IT normally returns EPIPE when someone has disconnected
484          * the socket via NBD_DISCONNECT.  We do not want to return 1 in
485          * that case.
486          */
487         ret = 0;
488     }
489     serrno = errno;
490
491     TRACE("NBD loop returned %d: %s", ret, strerror(serrno));
492
493     TRACE("Clearing NBD queue");
494     ioctl(fd, NBD_CLEAR_QUE);
495
496     TRACE("Clearing NBD socket");
497     ioctl(fd, NBD_CLEAR_SOCK);
498
499     errno = serrno;
500     return ret;
501 }
502 #else
503 int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
504 {
505     return -ENOTSUP;
506 }
507
508 int nbd_disconnect(int fd)
509 {
510     return -ENOTSUP;
511 }
512
513 int nbd_client(int fd)
514 {
515     return -ENOTSUP;
516 }
517 #endif
518
519 ssize_t nbd_send_request(int csock, struct nbd_request *request)
520 {
521     uint8_t buf[NBD_REQUEST_SIZE];
522     ssize_t ret;
523
524     cpu_to_be32w((uint32_t*)buf, NBD_REQUEST_MAGIC);
525     cpu_to_be32w((uint32_t*)(buf + 4), request->type);
526     cpu_to_be64w((uint64_t*)(buf + 8), request->handle);
527     cpu_to_be64w((uint64_t*)(buf + 16), request->from);
528     cpu_to_be32w((uint32_t*)(buf + 24), request->len);
529
530     TRACE("Sending request to client: "
531           "{ .from = %" PRIu64", .len = %u, .handle = %" PRIu64", .type=%i}",
532           request->from, request->len, request->handle, request->type);
533
534     ret = write_sync(csock, buf, sizeof(buf));
535     if (ret < 0) {
536         return ret;
537     }
538
539     if (ret != sizeof(buf)) {
540         LOG("writing to socket failed");
541         return -EINVAL;
542     }
543     return 0;
544 }
545
546 static ssize_t nbd_receive_request(int csock, struct nbd_request *request)
547 {
548     uint8_t buf[NBD_REQUEST_SIZE];
549     uint32_t magic;
550     ssize_t ret;
551
552     ret = read_sync(csock, buf, sizeof(buf));
553     if (ret < 0) {
554         return ret;
555     }
556
557     if (ret != sizeof(buf)) {
558         LOG("read failed");
559         return -EINVAL;
560     }
561
562     /* Request
563        [ 0 ..  3]   magic   (NBD_REQUEST_MAGIC)
564        [ 4 ..  7]   type    (0 == READ, 1 == WRITE)
565        [ 8 .. 15]   handle
566        [16 .. 23]   from
567        [24 .. 27]   len
568      */
569
570     magic = be32_to_cpup((uint32_t*)buf);
571     request->type  = be32_to_cpup((uint32_t*)(buf + 4));
572     request->handle = be64_to_cpup((uint64_t*)(buf + 8));
573     request->from  = be64_to_cpup((uint64_t*)(buf + 16));
574     request->len   = be32_to_cpup((uint32_t*)(buf + 24));
575
576     TRACE("Got request: "
577           "{ magic = 0x%x, .type = %d, from = %" PRIu64" , len = %u }",
578           magic, request->type, request->from, request->len);
579
580     if (magic != NBD_REQUEST_MAGIC) {
581         LOG("invalid magic (got 0x%x)", magic);
582         return -EINVAL;
583     }
584     return 0;
585 }
586
587 ssize_t nbd_receive_reply(int csock, struct nbd_reply *reply)
588 {
589     uint8_t buf[NBD_REPLY_SIZE];
590     uint32_t magic;
591     ssize_t ret;
592
593     ret = read_sync(csock, buf, sizeof(buf));
594     if (ret < 0) {
595         return ret;
596     }
597
598     if (ret != sizeof(buf)) {
599         LOG("read failed");
600         return -EINVAL;
601     }
602
603     /* Reply
604        [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
605        [ 4 ..  7]    error   (0 == no error)
606        [ 7 .. 15]    handle
607      */
608
609     magic = be32_to_cpup((uint32_t*)buf);
610     reply->error  = be32_to_cpup((uint32_t*)(buf + 4));
611     reply->handle = be64_to_cpup((uint64_t*)(buf + 8));
612
613     TRACE("Got reply: "
614           "{ magic = 0x%x, .error = %d, handle = %" PRIu64" }",
615           magic, reply->error, reply->handle);
616
617     if (magic != NBD_REPLY_MAGIC) {
618         LOG("invalid magic (got 0x%x)", magic);
619         return -EINVAL;
620     }
621     return 0;
622 }
623
624 static ssize_t nbd_send_reply(int csock, struct nbd_reply *reply)
625 {
626     uint8_t buf[NBD_REPLY_SIZE];
627     ssize_t ret;
628
629     /* Reply
630        [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
631        [ 4 ..  7]    error   (0 == no error)
632        [ 7 .. 15]    handle
633      */
634     cpu_to_be32w((uint32_t*)buf, NBD_REPLY_MAGIC);
635     cpu_to_be32w((uint32_t*)(buf + 4), reply->error);
636     cpu_to_be64w((uint64_t*)(buf + 8), reply->handle);
637
638     TRACE("Sending response to client");
639
640     ret = write_sync(csock, buf, sizeof(buf));
641     if (ret < 0) {
642         return ret;
643     }
644
645     if (ret != sizeof(buf)) {
646         LOG("writing to socket failed");
647         return -EINVAL;
648     }
649     return 0;
650 }
651
652 #define MAX_NBD_REQUESTS 16
653
654 void nbd_client_get(NBDClient *client)
655 {
656     client->refcount++;
657 }
658
659 void nbd_client_put(NBDClient *client)
660 {
661     if (--client->refcount == 0) {
662         /* The last reference should be dropped by client->close,
663          * which is called by nbd_client_close.
664          */
665         assert(client->closing);
666
667         qemu_set_fd_handler2(client->sock, NULL, NULL, NULL, NULL);
668         close(client->sock);
669         client->sock = -1;
670         QTAILQ_REMOVE(&client->exp->clients, client, next);
671         nbd_export_put(client->exp);
672         g_free(client);
673     }
674 }
675
676 void nbd_client_close(NBDClient *client)
677 {
678     if (client->closing) {
679         return;
680     }
681
682     client->closing = true;
683
684     /* Force requests to finish.  They will drop their own references,
685      * then we'll close the socket and free the NBDClient.
686      */
687     shutdown(client->sock, 2);
688
689     /* Also tell the client, so that they release their reference.  */
690     if (client->close) {
691         client->close(client);
692     }
693 }
694
695 static NBDRequest *nbd_request_get(NBDClient *client)
696 {
697     NBDRequest *req;
698     NBDExport *exp = client->exp;
699
700     assert(client->nb_requests <= MAX_NBD_REQUESTS - 1);
701     client->nb_requests++;
702
703     if (QSIMPLEQ_EMPTY(&exp->requests)) {
704         req = g_malloc0(sizeof(NBDRequest));
705         req->data = qemu_blockalign(exp->bs, NBD_BUFFER_SIZE);
706     } else {
707         req = QSIMPLEQ_FIRST(&exp->requests);
708         QSIMPLEQ_REMOVE_HEAD(&exp->requests, entry);
709     }
710     nbd_client_get(client);
711     req->client = client;
712     return req;
713 }
714
715 static void nbd_request_put(NBDRequest *req)
716 {
717     NBDClient *client = req->client;
718     QSIMPLEQ_INSERT_HEAD(&client->exp->requests, req, entry);
719     if (client->nb_requests-- == MAX_NBD_REQUESTS) {
720         qemu_notify_event();
721     }
722     nbd_client_put(client);
723 }
724
725 NBDExport *nbd_export_new(BlockDriverState *bs, off_t dev_offset,
726                           off_t size, uint32_t nbdflags)
727 {
728     NBDExport *exp = g_malloc0(sizeof(NBDExport));
729     QSIMPLEQ_INIT(&exp->requests);
730     exp->refcount = 1;
731     QTAILQ_INIT(&exp->clients);
732     exp->bs = bs;
733     exp->dev_offset = dev_offset;
734     exp->nbdflags = nbdflags;
735     exp->size = size == -1 ? bdrv_getlength(bs) : size;
736     return exp;
737 }
738
739 void nbd_export_close(NBDExport *exp)
740 {
741     NBDClient *client, *next;
742
743     nbd_export_get(exp);
744     QTAILQ_FOREACH_SAFE(client, &exp->clients, next, next) {
745         nbd_client_close(client);
746     }
747     nbd_export_put(exp);
748 }
749
750 void nbd_export_get(NBDExport *exp)
751 {
752     assert(exp->refcount > 0);
753     exp->refcount++;
754 }
755
756 void nbd_export_put(NBDExport *exp)
757 {
758     assert(exp->refcount > 0);
759     if (exp->refcount == 1) {
760         nbd_export_close(exp);
761     }
762
763     if (--exp->refcount == 0) {
764         while (!QSIMPLEQ_EMPTY(&exp->requests)) {
765             NBDRequest *first = QSIMPLEQ_FIRST(&exp->requests);
766             QSIMPLEQ_REMOVE_HEAD(&exp->requests, entry);
767             qemu_vfree(first->data);
768             g_free(first);
769         }
770
771         g_free(exp);
772     }
773 }
774
775 static int nbd_can_read(void *opaque);
776 static void nbd_read(void *opaque);
777 static void nbd_restart_write(void *opaque);
778
779 static ssize_t nbd_co_send_reply(NBDRequest *req, struct nbd_reply *reply,
780                                  int len)
781 {
782     NBDClient *client = req->client;
783     int csock = client->sock;
784     ssize_t rc, ret;
785
786     qemu_co_mutex_lock(&client->send_lock);
787     qemu_set_fd_handler2(csock, nbd_can_read, nbd_read,
788                          nbd_restart_write, client);
789     client->send_coroutine = qemu_coroutine_self();
790
791     if (!len) {
792         rc = nbd_send_reply(csock, reply);
793     } else {
794         socket_set_cork(csock, 1);
795         rc = nbd_send_reply(csock, reply);
796         if (rc >= 0) {
797             ret = qemu_co_send(csock, req->data, len);
798             if (ret != len) {
799                 rc = -EIO;
800             }
801         }
802         socket_set_cork(csock, 0);
803     }
804
805     client->send_coroutine = NULL;
806     qemu_set_fd_handler2(csock, nbd_can_read, nbd_read, NULL, client);
807     qemu_co_mutex_unlock(&client->send_lock);
808     return rc;
809 }
810
811 static ssize_t nbd_co_receive_request(NBDRequest *req, struct nbd_request *request)
812 {
813     NBDClient *client = req->client;
814     int csock = client->sock;
815     ssize_t rc;
816
817     client->recv_coroutine = qemu_coroutine_self();
818     rc = nbd_receive_request(csock, request);
819     if (rc < 0) {
820         if (rc != -EAGAIN) {
821             rc = -EIO;
822         }
823         goto out;
824     }
825
826     if (request->len > NBD_BUFFER_SIZE) {
827         LOG("len (%u) is larger than max len (%u)",
828             request->len, NBD_BUFFER_SIZE);
829         rc = -EINVAL;
830         goto out;
831     }
832
833     if ((request->from + request->len) < request->from) {
834         LOG("integer overflow detected! "
835             "you're probably being attacked");
836         rc = -EINVAL;
837         goto out;
838     }
839
840     TRACE("Decoding type");
841
842     if ((request->type & NBD_CMD_MASK_COMMAND) == NBD_CMD_WRITE) {
843         TRACE("Reading %u byte(s)", request->len);
844
845         if (qemu_co_recv(csock, req->data, request->len) != request->len) {
846             LOG("reading from socket failed");
847             rc = -EIO;
848             goto out;
849         }
850     }
851     rc = 0;
852
853 out:
854     client->recv_coroutine = NULL;
855     return rc;
856 }
857
858 static void nbd_trip(void *opaque)
859 {
860     NBDClient *client = opaque;
861     NBDExport *exp = client->exp;
862     NBDRequest *req;
863     struct nbd_request request;
864     struct nbd_reply reply;
865     ssize_t ret;
866
867     TRACE("Reading request.");
868     if (client->closing) {
869         return;
870     }
871
872     req = nbd_request_get(client);
873     ret = nbd_co_receive_request(req, &request);
874     if (ret == -EAGAIN) {
875         goto done;
876     }
877     if (ret == -EIO) {
878         goto out;
879     }
880
881     reply.handle = request.handle;
882     reply.error = 0;
883
884     if (ret < 0) {
885         reply.error = -ret;
886         goto error_reply;
887     }
888
889     if ((request.from + request.len) > exp->size) {
890             LOG("From: %" PRIu64 ", Len: %u, Size: %" PRIu64
891             ", Offset: %" PRIu64 "\n",
892                     request.from, request.len,
893                     (uint64_t)exp->size, (uint64_t)exp->dev_offset);
894         LOG("requested operation past EOF--bad client?");
895         goto invalid_request;
896     }
897
898     switch (request.type & NBD_CMD_MASK_COMMAND) {
899     case NBD_CMD_READ:
900         TRACE("Request type is READ");
901
902         if (request.type & NBD_CMD_FLAG_FUA) {
903             ret = bdrv_co_flush(exp->bs);
904             if (ret < 0) {
905                 LOG("flush failed");
906                 reply.error = -ret;
907                 goto error_reply;
908             }
909         }
910
911         ret = bdrv_read(exp->bs, (request.from + exp->dev_offset) / 512,
912                         req->data, request.len / 512);
913         if (ret < 0) {
914             LOG("reading from file failed");
915             reply.error = -ret;
916             goto error_reply;
917         }
918
919         TRACE("Read %u byte(s)", request.len);
920         if (nbd_co_send_reply(req, &reply, request.len) < 0)
921             goto out;
922         break;
923     case NBD_CMD_WRITE:
924         TRACE("Request type is WRITE");
925
926         if (exp->nbdflags & NBD_FLAG_READ_ONLY) {
927             TRACE("Server is read-only, return error");
928             reply.error = EROFS;
929             goto error_reply;
930         }
931
932         TRACE("Writing to device");
933
934         ret = bdrv_write(exp->bs, (request.from + exp->dev_offset) / 512,
935                          req->data, request.len / 512);
936         if (ret < 0) {
937             LOG("writing to file failed");
938             reply.error = -ret;
939             goto error_reply;
940         }
941
942         if (request.type & NBD_CMD_FLAG_FUA) {
943             ret = bdrv_co_flush(exp->bs);
944             if (ret < 0) {
945                 LOG("flush failed");
946                 reply.error = -ret;
947                 goto error_reply;
948             }
949         }
950
951         if (nbd_co_send_reply(req, &reply, 0) < 0) {
952             goto out;
953         }
954         break;
955     case NBD_CMD_DISC:
956         TRACE("Request type is DISCONNECT");
957         errno = 0;
958         goto out;
959     case NBD_CMD_FLUSH:
960         TRACE("Request type is FLUSH");
961
962         ret = bdrv_co_flush(exp->bs);
963         if (ret < 0) {
964             LOG("flush failed");
965             reply.error = -ret;
966         }
967         if (nbd_co_send_reply(req, &reply, 0) < 0) {
968             goto out;
969         }
970         break;
971     case NBD_CMD_TRIM:
972         TRACE("Request type is TRIM");
973         ret = bdrv_co_discard(exp->bs, (request.from + exp->dev_offset) / 512,
974                               request.len / 512);
975         if (ret < 0) {
976             LOG("discard failed");
977             reply.error = -ret;
978         }
979         if (nbd_co_send_reply(req, &reply, 0) < 0) {
980             goto out;
981         }
982         break;
983     default:
984         LOG("invalid request type (%u) received", request.type);
985     invalid_request:
986         reply.error = -EINVAL;
987     error_reply:
988         if (nbd_co_send_reply(req, &reply, 0) < 0) {
989             goto out;
990         }
991         break;
992     }
993
994     TRACE("Request/Reply complete");
995
996 done:
997     nbd_request_put(req);
998     return;
999
1000 out:
1001     nbd_request_put(req);
1002     nbd_client_close(client);
1003 }
1004
1005 static int nbd_can_read(void *opaque)
1006 {
1007     NBDClient *client = opaque;
1008
1009     return client->recv_coroutine || client->nb_requests < MAX_NBD_REQUESTS;
1010 }
1011
1012 static void nbd_read(void *opaque)
1013 {
1014     NBDClient *client = opaque;
1015
1016     if (client->recv_coroutine) {
1017         qemu_coroutine_enter(client->recv_coroutine, NULL);
1018     } else {
1019         qemu_coroutine_enter(qemu_coroutine_create(nbd_trip), client);
1020     }
1021 }
1022
1023 static void nbd_restart_write(void *opaque)
1024 {
1025     NBDClient *client = opaque;
1026
1027     qemu_coroutine_enter(client->send_coroutine, NULL);
1028 }
1029
1030 NBDClient *nbd_client_new(NBDExport *exp, int csock,
1031                           void (*close)(NBDClient *))
1032 {
1033     NBDClient *client;
1034     client = g_malloc0(sizeof(NBDClient));
1035     client->refcount = 1;
1036     client->exp = exp;
1037     client->sock = csock;
1038     if (nbd_send_negotiate(client) < 0) {
1039         g_free(client);
1040         return NULL;
1041     }
1042     client->close = close;
1043     qemu_co_mutex_init(&client->send_lock);
1044     qemu_set_fd_handler2(csock, nbd_can_read, nbd_read, NULL, client);
1045
1046     QTAILQ_INSERT_TAIL(&exp->clients, client, next);
1047     nbd_export_get(exp);
1048     return client;
1049 }