Merge 'v2.2.0' into tizen_next_qemu_2.2
[sdk/emulator/qemu.git] / nbd.c
1 /*
2  *  Copyright (C) 2005  Anthony Liguori <anthony@codemonkey.ws>
3  *
4  *  Network Block Device
5  *
6  *  This program is free software; you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation; under version 2 of the License.
9  *
10  *  This program is distributed in the hope that it will be useful,
11  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  *  GNU General Public License for more details.
14  *
15  *  You should have received a copy of the GNU General Public License
16  *  along with this program; if not, see <http://www.gnu.org/licenses/>.
17  */
18
19 #include "block/nbd.h"
20 #include "block/block.h"
21 #include "block/block_int.h"
22
23 #include "block/coroutine.h"
24
25 #include <errno.h>
26 #include <string.h>
27 #ifndef _WIN32
28 #include <sys/ioctl.h>
29 #endif
30 #if defined(__sun__) || defined(__HAIKU__)
31 #include <sys/ioccom.h>
32 #endif
33 #include <ctype.h>
34 #include <inttypes.h>
35
36 #ifdef __linux__
37 #include <linux/fs.h>
38 #endif
39
40 #include "qemu/sockets.h"
41 #include "qemu/queue.h"
42 #include "qemu/main-loop.h"
43
44 //#define DEBUG_NBD
45
46 #ifdef DEBUG_NBD
47 #define TRACE(msg, ...) do { \
48     LOG(msg, ## __VA_ARGS__); \
49 } while(0)
50 #else
51 #define TRACE(msg, ...) \
52     do { } while (0)
53 #endif
54
55 #define LOG(msg, ...) do { \
56     fprintf(stderr, "%s:%s():L%d: " msg "\n", \
57             __FILE__, __FUNCTION__, __LINE__, ## __VA_ARGS__); \
58 } while(0)
59
60 /* This is all part of the "official" NBD API.
61  *
62  * The most up-to-date documentation is available at:
63  * https://github.com/yoe/nbd/blob/master/doc/proto.txt
64  */
65
66 #define NBD_REQUEST_SIZE        (4 + 4 + 8 + 8 + 4)
67 #define NBD_REPLY_SIZE          (4 + 4 + 8)
68 #define NBD_REQUEST_MAGIC       0x25609513
69 #define NBD_REPLY_MAGIC         0x67446698
70 #define NBD_OPTS_MAGIC          0x49484156454F5054LL
71 #define NBD_CLIENT_MAGIC        0x0000420281861253LL
72 #define NBD_REP_MAGIC           0x3e889045565a9LL
73
74 #define NBD_SET_SOCK            _IO(0xab, 0)
75 #define NBD_SET_BLKSIZE         _IO(0xab, 1)
76 #define NBD_SET_SIZE            _IO(0xab, 2)
77 #define NBD_DO_IT               _IO(0xab, 3)
78 #define NBD_CLEAR_SOCK          _IO(0xab, 4)
79 #define NBD_CLEAR_QUE           _IO(0xab, 5)
80 #define NBD_PRINT_DEBUG         _IO(0xab, 6)
81 #define NBD_SET_SIZE_BLOCKS     _IO(0xab, 7)
82 #define NBD_DISCONNECT          _IO(0xab, 8)
83 #define NBD_SET_TIMEOUT         _IO(0xab, 9)
84 #define NBD_SET_FLAGS           _IO(0xab, 10)
85
86 #define NBD_OPT_EXPORT_NAME     (1)
87 #define NBD_OPT_ABORT           (2)
88 #define NBD_OPT_LIST            (3)
89
90 /* Definitions for opaque data types */
91
92 typedef struct NBDRequest NBDRequest;
93
94 struct NBDRequest {
95     QSIMPLEQ_ENTRY(NBDRequest) entry;
96     NBDClient *client;
97     uint8_t *data;
98 };
99
100 struct NBDExport {
101     int refcount;
102     void (*close)(NBDExport *exp);
103
104     BlockDriverState *bs;
105     char *name;
106     off_t dev_offset;
107     off_t size;
108     uint32_t nbdflags;
109     QTAILQ_HEAD(, NBDClient) clients;
110     QTAILQ_ENTRY(NBDExport) next;
111
112     AioContext *ctx;
113 };
114
115 static QTAILQ_HEAD(, NBDExport) exports = QTAILQ_HEAD_INITIALIZER(exports);
116
117 struct NBDClient {
118     int refcount;
119     void (*close)(NBDClient *client);
120
121     NBDExport *exp;
122     int sock;
123
124     Coroutine *recv_coroutine;
125
126     CoMutex send_lock;
127     Coroutine *send_coroutine;
128
129     bool can_read;
130
131     QTAILQ_ENTRY(NBDClient) next;
132     int nb_requests;
133     bool closing;
134 };
135
136 /* That's all folks */
137
138 static void nbd_set_handlers(NBDClient *client);
139 static void nbd_unset_handlers(NBDClient *client);
140 static void nbd_update_can_read(NBDClient *client);
141
142 ssize_t nbd_wr_sync(int fd, void *buffer, size_t size, bool do_read)
143 {
144     size_t offset = 0;
145     int err;
146
147     if (qemu_in_coroutine()) {
148         if (do_read) {
149             return qemu_co_recv(fd, buffer, size);
150         } else {
151             return qemu_co_send(fd, buffer, size);
152         }
153     }
154
155     while (offset < size) {
156         ssize_t len;
157
158         if (do_read) {
159             len = qemu_recv(fd, buffer + offset, size - offset, 0);
160         } else {
161             len = send(fd, buffer + offset, size - offset, 0);
162         }
163
164         if (len < 0) {
165             err = socket_error();
166
167             /* recoverable error */
168             if (err == EINTR || (offset > 0 && (err == EAGAIN || err == EWOULDBLOCK))) {
169                 continue;
170             }
171
172             /* unrecoverable error */
173             return -err;
174         }
175
176         /* eof */
177         if (len == 0) {
178             break;
179         }
180
181         offset += len;
182     }
183
184     return offset;
185 }
186
187 static ssize_t read_sync(int fd, void *buffer, size_t size)
188 {
189     /* Sockets are kept in blocking mode in the negotiation phase.  After
190      * that, a non-readable socket simply means that another thread stole
191      * our request/reply.  Synchronization is done with recv_coroutine, so
192      * that this is coroutine-safe.
193      */
194     return nbd_wr_sync(fd, buffer, size, true);
195 }
196
197 static ssize_t write_sync(int fd, void *buffer, size_t size)
198 {
199     int ret;
200     do {
201         /* For writes, we do expect the socket to be writable.  */
202         ret = nbd_wr_sync(fd, buffer, size, false);
203     } while (ret == -EAGAIN);
204     return ret;
205 }
206
207 /* Basic flow for negotiation
208
209    Server         Client
210    Negotiate
211
212    or
213
214    Server         Client
215    Negotiate #1
216                   Option
217    Negotiate #2
218
219    ----
220
221    followed by
222
223    Server         Client
224                   Request
225    Response
226                   Request
227    Response
228                   ...
229    ...
230                   Request (type == 2)
231
232 */
233
234 static int nbd_send_rep(int csock, uint32_t type, uint32_t opt)
235 {
236     uint64_t magic;
237     uint32_t len;
238
239     magic = cpu_to_be64(NBD_REP_MAGIC);
240     if (write_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
241         LOG("write failed (rep magic)");
242         return -EINVAL;
243     }
244     opt = cpu_to_be32(opt);
245     if (write_sync(csock, &opt, sizeof(opt)) != sizeof(opt)) {
246         LOG("write failed (rep opt)");
247         return -EINVAL;
248     }
249     type = cpu_to_be32(type);
250     if (write_sync(csock, &type, sizeof(type)) != sizeof(type)) {
251         LOG("write failed (rep type)");
252         return -EINVAL;
253     }
254     len = cpu_to_be32(0);
255     if (write_sync(csock, &len, sizeof(len)) != sizeof(len)) {
256         LOG("write failed (rep data length)");
257         return -EINVAL;
258     }
259     return 0;
260 }
261
262 static int nbd_send_rep_list(int csock, NBDExport *exp)
263 {
264     uint64_t magic, name_len;
265     uint32_t opt, type, len;
266
267     name_len = strlen(exp->name);
268     magic = cpu_to_be64(NBD_REP_MAGIC);
269     if (write_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
270         LOG("write failed (magic)");
271         return -EINVAL;
272      }
273     opt = cpu_to_be32(NBD_OPT_LIST);
274     if (write_sync(csock, &opt, sizeof(opt)) != sizeof(opt)) {
275         LOG("write failed (opt)");
276         return -EINVAL;
277     }
278     type = cpu_to_be32(NBD_REP_SERVER);
279     if (write_sync(csock, &type, sizeof(type)) != sizeof(type)) {
280         LOG("write failed (reply type)");
281         return -EINVAL;
282     }
283     len = cpu_to_be32(name_len + sizeof(len));
284     if (write_sync(csock, &len, sizeof(len)) != sizeof(len)) {
285         LOG("write failed (length)");
286         return -EINVAL;
287     }
288     len = cpu_to_be32(name_len);
289     if (write_sync(csock, &len, sizeof(len)) != sizeof(len)) {
290         LOG("write failed (length)");
291         return -EINVAL;
292     }
293     if (write_sync(csock, exp->name, name_len) != name_len) {
294         LOG("write failed (buffer)");
295         return -EINVAL;
296     }
297     return 0;
298 }
299
300 static int nbd_handle_list(NBDClient *client, uint32_t length)
301 {
302     int csock;
303     NBDExport *exp;
304
305     csock = client->sock;
306     if (length) {
307         return nbd_send_rep(csock, NBD_REP_ERR_INVALID, NBD_OPT_LIST);
308     }
309
310     /* For each export, send a NBD_REP_SERVER reply. */
311     QTAILQ_FOREACH(exp, &exports, next) {
312         if (nbd_send_rep_list(csock, exp)) {
313             return -EINVAL;
314         }
315     }
316     /* Finish with a NBD_REP_ACK. */
317     return nbd_send_rep(csock, NBD_REP_ACK, NBD_OPT_LIST);
318 }
319
320 static int nbd_handle_export_name(NBDClient *client, uint32_t length)
321 {
322     int rc = -EINVAL, csock = client->sock;
323     char name[256];
324
325     /* Client sends:
326         [20 ..  xx]   export name (length bytes)
327      */
328     TRACE("Checking length");
329     if (length > 255) {
330         LOG("Bad length received");
331         goto fail;
332     }
333     if (read_sync(csock, name, length) != length) {
334         LOG("read failed");
335         goto fail;
336     }
337     name[length] = '\0';
338
339     client->exp = nbd_export_find(name);
340     if (!client->exp) {
341         LOG("export not found");
342         goto fail;
343     }
344
345     QTAILQ_INSERT_TAIL(&client->exp->clients, client, next);
346     nbd_export_get(client->exp);
347     rc = 0;
348 fail:
349     return rc;
350 }
351
352 static int nbd_receive_options(NBDClient *client)
353 {
354     while (1) {
355         int csock = client->sock;
356         uint32_t tmp, length;
357         uint64_t magic;
358
359         /* Client sends:
360             [ 0 ..   3]   client flags
361             [ 4 ..  11]   NBD_OPTS_MAGIC
362             [12 ..  15]   NBD option
363             [16 ..  19]   length
364             ...           Rest of request
365         */
366
367         if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
368             LOG("read failed");
369             return -EINVAL;
370         }
371         TRACE("Checking client flags");
372         tmp = be32_to_cpu(tmp);
373         if (tmp != 0 && tmp != NBD_FLAG_C_FIXED_NEWSTYLE) {
374             LOG("Bad client flags received");
375             return -EINVAL;
376         }
377
378         if (read_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
379             LOG("read failed");
380             return -EINVAL;
381         }
382         TRACE("Checking opts magic");
383         if (magic != be64_to_cpu(NBD_OPTS_MAGIC)) {
384             LOG("Bad magic received");
385             return -EINVAL;
386         }
387
388         if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
389             LOG("read failed");
390             return -EINVAL;
391         }
392
393         if (read_sync(csock, &length, sizeof(length)) != sizeof(length)) {
394             LOG("read failed");
395             return -EINVAL;
396         }
397         length = be32_to_cpu(length);
398
399         TRACE("Checking option");
400         switch (be32_to_cpu(tmp)) {
401         case NBD_OPT_LIST:
402             if (nbd_handle_list(client, length) < 0) {
403                 return 1;
404             }
405             break;
406
407         case NBD_OPT_ABORT:
408             return -EINVAL;
409
410         case NBD_OPT_EXPORT_NAME:
411             return nbd_handle_export_name(client, length);
412
413         default:
414             tmp = be32_to_cpu(tmp);
415             LOG("Unsupported option 0x%x", tmp);
416             nbd_send_rep(client->sock, NBD_REP_ERR_UNSUP, tmp);
417             return -EINVAL;
418         }
419     }
420 }
421
422 static int nbd_send_negotiate(NBDClient *client)
423 {
424     int csock = client->sock;
425     char buf[8 + 8 + 8 + 128];
426     int rc;
427     const int myflags = (NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_TRIM |
428                          NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA);
429
430     /* Negotiation header without options:
431         [ 0 ..   7]   passwd       ("NBDMAGIC")
432         [ 8 ..  15]   magic        (NBD_CLIENT_MAGIC)
433         [16 ..  23]   size
434         [24 ..  25]   server flags (0)
435         [26 ..  27]   export flags
436         [28 .. 151]   reserved     (0)
437
438        Negotiation header with options, part 1:
439         [ 0 ..   7]   passwd       ("NBDMAGIC")
440         [ 8 ..  15]   magic        (NBD_OPTS_MAGIC)
441         [16 ..  17]   server flags (0)
442
443        part 2 (after options are sent):
444         [18 ..  25]   size
445         [26 ..  27]   export flags
446         [28 .. 151]   reserved     (0)
447      */
448
449     qemu_set_block(csock);
450     rc = -EINVAL;
451
452     TRACE("Beginning negotiation.");
453     memset(buf, 0, sizeof(buf));
454     memcpy(buf, "NBDMAGIC", 8);
455     if (client->exp) {
456         assert ((client->exp->nbdflags & ~65535) == 0);
457         cpu_to_be64w((uint64_t*)(buf + 8), NBD_CLIENT_MAGIC);
458         cpu_to_be64w((uint64_t*)(buf + 16), client->exp->size);
459         cpu_to_be16w((uint16_t*)(buf + 26), client->exp->nbdflags | myflags);
460     } else {
461         cpu_to_be64w((uint64_t*)(buf + 8), NBD_OPTS_MAGIC);
462         cpu_to_be16w((uint16_t *)(buf + 16), NBD_FLAG_FIXED_NEWSTYLE);
463     }
464
465     if (client->exp) {
466         if (write_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
467             LOG("write failed");
468             goto fail;
469         }
470     } else {
471         if (write_sync(csock, buf, 18) != 18) {
472             LOG("write failed");
473             goto fail;
474         }
475         rc = nbd_receive_options(client);
476         if (rc != 0) {
477             LOG("option negotiation failed");
478             goto fail;
479         }
480
481         assert ((client->exp->nbdflags & ~65535) == 0);
482         cpu_to_be64w((uint64_t*)(buf + 18), client->exp->size);
483         cpu_to_be16w((uint16_t*)(buf + 26), client->exp->nbdflags | myflags);
484         if (write_sync(csock, buf + 18, sizeof(buf) - 18) != sizeof(buf) - 18) {
485             LOG("write failed");
486             goto fail;
487         }
488     }
489
490     TRACE("Negotiation succeeded.");
491     rc = 0;
492 fail:
493     qemu_set_nonblock(csock);
494     return rc;
495 }
496
497 int nbd_receive_negotiate(int csock, const char *name, uint32_t *flags,
498                           off_t *size, size_t *blocksize)
499 {
500     char buf[256];
501     uint64_t magic, s;
502     uint16_t tmp;
503     int rc;
504
505     TRACE("Receiving negotiation.");
506
507     rc = -EINVAL;
508
509     if (read_sync(csock, buf, 8) != 8) {
510         LOG("read failed");
511         goto fail;
512     }
513
514     buf[8] = '\0';
515     if (strlen(buf) == 0) {
516         LOG("server connection closed");
517         goto fail;
518     }
519
520     TRACE("Magic is %c%c%c%c%c%c%c%c",
521           qemu_isprint(buf[0]) ? buf[0] : '.',
522           qemu_isprint(buf[1]) ? buf[1] : '.',
523           qemu_isprint(buf[2]) ? buf[2] : '.',
524           qemu_isprint(buf[3]) ? buf[3] : '.',
525           qemu_isprint(buf[4]) ? buf[4] : '.',
526           qemu_isprint(buf[5]) ? buf[5] : '.',
527           qemu_isprint(buf[6]) ? buf[6] : '.',
528           qemu_isprint(buf[7]) ? buf[7] : '.');
529
530     if (memcmp(buf, "NBDMAGIC", 8) != 0) {
531         LOG("Invalid magic received");
532         goto fail;
533     }
534
535     if (read_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
536         LOG("read failed");
537         goto fail;
538     }
539     magic = be64_to_cpu(magic);
540     TRACE("Magic is 0x%" PRIx64, magic);
541
542     if (name) {
543         uint32_t reserved = 0;
544         uint32_t opt;
545         uint32_t namesize;
546
547         TRACE("Checking magic (opts_magic)");
548         if (magic != NBD_OPTS_MAGIC) {
549             LOG("Bad magic received");
550             goto fail;
551         }
552         if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
553             LOG("flags read failed");
554             goto fail;
555         }
556         *flags = be16_to_cpu(tmp) << 16;
557         /* reserved for future use */
558         if (write_sync(csock, &reserved, sizeof(reserved)) !=
559             sizeof(reserved)) {
560             LOG("write failed (reserved)");
561             goto fail;
562         }
563         /* write the export name */
564         magic = cpu_to_be64(magic);
565         if (write_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
566             LOG("write failed (magic)");
567             goto fail;
568         }
569         opt = cpu_to_be32(NBD_OPT_EXPORT_NAME);
570         if (write_sync(csock, &opt, sizeof(opt)) != sizeof(opt)) {
571             LOG("write failed (opt)");
572             goto fail;
573         }
574         namesize = cpu_to_be32(strlen(name));
575         if (write_sync(csock, &namesize, sizeof(namesize)) !=
576             sizeof(namesize)) {
577             LOG("write failed (namesize)");
578             goto fail;
579         }
580         if (write_sync(csock, (char*)name, strlen(name)) != strlen(name)) {
581             LOG("write failed (name)");
582             goto fail;
583         }
584     } else {
585         TRACE("Checking magic (cli_magic)");
586
587         if (magic != NBD_CLIENT_MAGIC) {
588             LOG("Bad magic received");
589             goto fail;
590         }
591     }
592
593     if (read_sync(csock, &s, sizeof(s)) != sizeof(s)) {
594         LOG("read failed");
595         goto fail;
596     }
597     *size = be64_to_cpu(s);
598     *blocksize = 1024;
599     TRACE("Size is %" PRIu64, *size);
600
601     if (!name) {
602         if (read_sync(csock, flags, sizeof(*flags)) != sizeof(*flags)) {
603             LOG("read failed (flags)");
604             goto fail;
605         }
606         *flags = be32_to_cpup(flags);
607     } else {
608         if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
609             LOG("read failed (tmp)");
610             goto fail;
611         }
612         *flags |= be32_to_cpu(tmp);
613     }
614     if (read_sync(csock, &buf, 124) != 124) {
615         LOG("read failed (buf)");
616         goto fail;
617     }
618     rc = 0;
619
620 fail:
621     return rc;
622 }
623
624 #ifdef __linux__
625 int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
626 {
627     TRACE("Setting NBD socket");
628
629     if (ioctl(fd, NBD_SET_SOCK, csock) < 0) {
630         int serrno = errno;
631         LOG("Failed to set NBD socket");
632         return -serrno;
633     }
634
635     TRACE("Setting block size to %lu", (unsigned long)blocksize);
636
637     if (ioctl(fd, NBD_SET_BLKSIZE, blocksize) < 0) {
638         int serrno = errno;
639         LOG("Failed setting NBD block size");
640         return -serrno;
641     }
642
643         TRACE("Setting size to %zd block(s)", (size_t)(size / blocksize));
644
645     if (ioctl(fd, NBD_SET_SIZE_BLOCKS, size / blocksize) < 0) {
646         int serrno = errno;
647         LOG("Failed setting size (in blocks)");
648         return -serrno;
649     }
650
651     if (ioctl(fd, NBD_SET_FLAGS, flags) < 0) {
652         if (errno == ENOTTY) {
653             int read_only = (flags & NBD_FLAG_READ_ONLY) != 0;
654             TRACE("Setting readonly attribute");
655
656             if (ioctl(fd, BLKROSET, (unsigned long) &read_only) < 0) {
657                 int serrno = errno;
658                 LOG("Failed setting read-only attribute");
659                 return -serrno;
660             }
661         } else {
662             int serrno = errno;
663             LOG("Failed setting flags");
664             return -serrno;
665         }
666     }
667
668     TRACE("Negotiation ended");
669
670     return 0;
671 }
672
673 int nbd_disconnect(int fd)
674 {
675     ioctl(fd, NBD_CLEAR_QUE);
676     ioctl(fd, NBD_DISCONNECT);
677     ioctl(fd, NBD_CLEAR_SOCK);
678     return 0;
679 }
680
681 int nbd_client(int fd)
682 {
683     int ret;
684     int serrno;
685
686     TRACE("Doing NBD loop");
687
688     ret = ioctl(fd, NBD_DO_IT);
689     if (ret < 0 && errno == EPIPE) {
690         /* NBD_DO_IT normally returns EPIPE when someone has disconnected
691          * the socket via NBD_DISCONNECT.  We do not want to return 1 in
692          * that case.
693          */
694         ret = 0;
695     }
696     serrno = errno;
697
698     TRACE("NBD loop returned %d: %s", ret, strerror(serrno));
699
700     TRACE("Clearing NBD queue");
701     ioctl(fd, NBD_CLEAR_QUE);
702
703     TRACE("Clearing NBD socket");
704     ioctl(fd, NBD_CLEAR_SOCK);
705
706     errno = serrno;
707     return ret;
708 }
709 #else
710 int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
711 {
712     return -ENOTSUP;
713 }
714
715 int nbd_disconnect(int fd)
716 {
717     return -ENOTSUP;
718 }
719
720 int nbd_client(int fd)
721 {
722     return -ENOTSUP;
723 }
724 #endif
725
726 ssize_t nbd_send_request(int csock, struct nbd_request *request)
727 {
728     uint8_t buf[NBD_REQUEST_SIZE];
729     ssize_t ret;
730
731     cpu_to_be32w((uint32_t*)buf, NBD_REQUEST_MAGIC);
732     cpu_to_be32w((uint32_t*)(buf + 4), request->type);
733     cpu_to_be64w((uint64_t*)(buf + 8), request->handle);
734     cpu_to_be64w((uint64_t*)(buf + 16), request->from);
735     cpu_to_be32w((uint32_t*)(buf + 24), request->len);
736
737     TRACE("Sending request to client: "
738           "{ .from = %" PRIu64", .len = %u, .handle = %" PRIu64", .type=%i}",
739           request->from, request->len, request->handle, request->type);
740
741     ret = write_sync(csock, buf, sizeof(buf));
742     if (ret < 0) {
743         return ret;
744     }
745
746     if (ret != sizeof(buf)) {
747         LOG("writing to socket failed");
748         return -EINVAL;
749     }
750     return 0;
751 }
752
753 static ssize_t nbd_receive_request(int csock, struct nbd_request *request)
754 {
755     uint8_t buf[NBD_REQUEST_SIZE];
756     uint32_t magic;
757     ssize_t ret;
758
759     ret = read_sync(csock, buf, sizeof(buf));
760     if (ret < 0) {
761         return ret;
762     }
763
764     if (ret != sizeof(buf)) {
765         LOG("read failed");
766         return -EINVAL;
767     }
768
769     /* Request
770        [ 0 ..  3]   magic   (NBD_REQUEST_MAGIC)
771        [ 4 ..  7]   type    (0 == READ, 1 == WRITE)
772        [ 8 .. 15]   handle
773        [16 .. 23]   from
774        [24 .. 27]   len
775      */
776
777     magic = be32_to_cpup((uint32_t*)buf);
778     request->type  = be32_to_cpup((uint32_t*)(buf + 4));
779     request->handle = be64_to_cpup((uint64_t*)(buf + 8));
780     request->from  = be64_to_cpup((uint64_t*)(buf + 16));
781     request->len   = be32_to_cpup((uint32_t*)(buf + 24));
782
783     TRACE("Got request: "
784           "{ magic = 0x%x, .type = %d, from = %" PRIu64" , len = %u }",
785           magic, request->type, request->from, request->len);
786
787     if (magic != NBD_REQUEST_MAGIC) {
788         LOG("invalid magic (got 0x%x)", magic);
789         return -EINVAL;
790     }
791     return 0;
792 }
793
794 ssize_t nbd_receive_reply(int csock, struct nbd_reply *reply)
795 {
796     uint8_t buf[NBD_REPLY_SIZE];
797     uint32_t magic;
798     ssize_t ret;
799
800     ret = read_sync(csock, buf, sizeof(buf));
801     if (ret < 0) {
802         return ret;
803     }
804
805     if (ret != sizeof(buf)) {
806         LOG("read failed");
807         return -EINVAL;
808     }
809
810     /* Reply
811        [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
812        [ 4 ..  7]    error   (0 == no error)
813        [ 7 .. 15]    handle
814      */
815
816     magic = be32_to_cpup((uint32_t*)buf);
817     reply->error  = be32_to_cpup((uint32_t*)(buf + 4));
818     reply->handle = be64_to_cpup((uint64_t*)(buf + 8));
819
820     TRACE("Got reply: "
821           "{ magic = 0x%x, .error = %d, handle = %" PRIu64" }",
822           magic, reply->error, reply->handle);
823
824     if (magic != NBD_REPLY_MAGIC) {
825         LOG("invalid magic (got 0x%x)", magic);
826         return -EINVAL;
827     }
828     return 0;
829 }
830
831 static ssize_t nbd_send_reply(int csock, struct nbd_reply *reply)
832 {
833     uint8_t buf[NBD_REPLY_SIZE];
834     ssize_t ret;
835
836     /* Reply
837        [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
838        [ 4 ..  7]    error   (0 == no error)
839        [ 7 .. 15]    handle
840      */
841     cpu_to_be32w((uint32_t*)buf, NBD_REPLY_MAGIC);
842     cpu_to_be32w((uint32_t*)(buf + 4), reply->error);
843     cpu_to_be64w((uint64_t*)(buf + 8), reply->handle);
844
845     TRACE("Sending response to client");
846
847     ret = write_sync(csock, buf, sizeof(buf));
848     if (ret < 0) {
849         return ret;
850     }
851
852     if (ret != sizeof(buf)) {
853         LOG("writing to socket failed");
854         return -EINVAL;
855     }
856     return 0;
857 }
858
859 #define MAX_NBD_REQUESTS 16
860
861 void nbd_client_get(NBDClient *client)
862 {
863     client->refcount++;
864 }
865
866 void nbd_client_put(NBDClient *client)
867 {
868     if (--client->refcount == 0) {
869         /* The last reference should be dropped by client->close,
870          * which is called by nbd_client_close.
871          */
872         assert(client->closing);
873
874         nbd_unset_handlers(client);
875         close(client->sock);
876         client->sock = -1;
877         if (client->exp) {
878             QTAILQ_REMOVE(&client->exp->clients, client, next);
879             nbd_export_put(client->exp);
880         }
881         g_free(client);
882     }
883 }
884
885 void nbd_client_close(NBDClient *client)
886 {
887     if (client->closing) {
888         return;
889     }
890
891     client->closing = true;
892
893     /* Force requests to finish.  They will drop their own references,
894      * then we'll close the socket and free the NBDClient.
895      */
896     shutdown(client->sock, 2);
897
898     /* Also tell the client, so that they release their reference.  */
899     if (client->close) {
900         client->close(client);
901     }
902 }
903
904 static NBDRequest *nbd_request_get(NBDClient *client)
905 {
906     NBDRequest *req;
907
908     assert(client->nb_requests <= MAX_NBD_REQUESTS - 1);
909     client->nb_requests++;
910     nbd_update_can_read(client);
911
912     req = g_slice_new0(NBDRequest);
913     nbd_client_get(client);
914     req->client = client;
915     return req;
916 }
917
918 static void nbd_request_put(NBDRequest *req)
919 {
920     NBDClient *client = req->client;
921
922     if (req->data) {
923         qemu_vfree(req->data);
924     }
925     g_slice_free(NBDRequest, req);
926
927     client->nb_requests--;
928     nbd_update_can_read(client);
929     nbd_client_put(client);
930 }
931
932 static void bs_aio_attached(AioContext *ctx, void *opaque)
933 {
934     NBDExport *exp = opaque;
935     NBDClient *client;
936
937     TRACE("Export %s: Attaching clients to AIO context %p\n", exp->name, ctx);
938
939     exp->ctx = ctx;
940
941     QTAILQ_FOREACH(client, &exp->clients, next) {
942         nbd_set_handlers(client);
943     }
944 }
945
946 static void bs_aio_detach(void *opaque)
947 {
948     NBDExport *exp = opaque;
949     NBDClient *client;
950
951     TRACE("Export %s: Detaching clients from AIO context %p\n", exp->name, exp->ctx);
952
953     QTAILQ_FOREACH(client, &exp->clients, next) {
954         nbd_unset_handlers(client);
955     }
956
957     exp->ctx = NULL;
958 }
959
960 NBDExport *nbd_export_new(BlockDriverState *bs, off_t dev_offset,
961                           off_t size, uint32_t nbdflags,
962                           void (*close)(NBDExport *))
963 {
964     NBDExport *exp = g_malloc0(sizeof(NBDExport));
965     exp->refcount = 1;
966     QTAILQ_INIT(&exp->clients);
967     exp->bs = bs;
968     exp->dev_offset = dev_offset;
969     exp->nbdflags = nbdflags;
970     exp->size = size == -1 ? bdrv_getlength(bs) : size;
971     exp->close = close;
972     exp->ctx = bdrv_get_aio_context(bs);
973     bdrv_ref(bs);
974     bdrv_add_aio_context_notifier(bs, bs_aio_attached, bs_aio_detach, exp);
975     /*
976      * NBD exports are used for non-shared storage migration.  Make sure
977      * that BDRV_O_INCOMING is cleared and the image is ready for write
978      * access since the export could be available before migration handover.
979      */
980     bdrv_invalidate_cache(bs, NULL);
981     return exp;
982 }
983
984 NBDExport *nbd_export_find(const char *name)
985 {
986     NBDExport *exp;
987     QTAILQ_FOREACH(exp, &exports, next) {
988         if (strcmp(name, exp->name) == 0) {
989             return exp;
990         }
991     }
992
993     return NULL;
994 }
995
996 void nbd_export_set_name(NBDExport *exp, const char *name)
997 {
998     if (exp->name == name) {
999         return;
1000     }
1001
1002     nbd_export_get(exp);
1003     if (exp->name != NULL) {
1004         g_free(exp->name);
1005         exp->name = NULL;
1006         QTAILQ_REMOVE(&exports, exp, next);
1007         nbd_export_put(exp);
1008     }
1009     if (name != NULL) {
1010         nbd_export_get(exp);
1011         exp->name = g_strdup(name);
1012         QTAILQ_INSERT_TAIL(&exports, exp, next);
1013     }
1014     nbd_export_put(exp);
1015 }
1016
1017 void nbd_export_close(NBDExport *exp)
1018 {
1019     NBDClient *client, *next;
1020
1021     nbd_export_get(exp);
1022     QTAILQ_FOREACH_SAFE(client, &exp->clients, next, next) {
1023         nbd_client_close(client);
1024     }
1025     nbd_export_set_name(exp, NULL);
1026     nbd_export_put(exp);
1027     if (exp->bs) {
1028         bdrv_remove_aio_context_notifier(exp->bs, bs_aio_attached,
1029                                          bs_aio_detach, exp);
1030         bdrv_unref(exp->bs);
1031         exp->bs = NULL;
1032     }
1033 }
1034
1035 void nbd_export_get(NBDExport *exp)
1036 {
1037     assert(exp->refcount > 0);
1038     exp->refcount++;
1039 }
1040
1041 void nbd_export_put(NBDExport *exp)
1042 {
1043     assert(exp->refcount > 0);
1044     if (exp->refcount == 1) {
1045         nbd_export_close(exp);
1046     }
1047
1048     if (--exp->refcount == 0) {
1049         assert(exp->name == NULL);
1050
1051         if (exp->close) {
1052             exp->close(exp);
1053         }
1054
1055         g_free(exp);
1056     }
1057 }
1058
1059 BlockDriverState *nbd_export_get_blockdev(NBDExport *exp)
1060 {
1061     return exp->bs;
1062 }
1063
1064 void nbd_export_close_all(void)
1065 {
1066     NBDExport *exp, *next;
1067
1068     QTAILQ_FOREACH_SAFE(exp, &exports, next, next) {
1069         nbd_export_close(exp);
1070     }
1071 }
1072
1073 static ssize_t nbd_co_send_reply(NBDRequest *req, struct nbd_reply *reply,
1074                                  int len)
1075 {
1076     NBDClient *client = req->client;
1077     int csock = client->sock;
1078     ssize_t rc, ret;
1079
1080     qemu_co_mutex_lock(&client->send_lock);
1081     client->send_coroutine = qemu_coroutine_self();
1082     nbd_set_handlers(client);
1083
1084     if (!len) {
1085         rc = nbd_send_reply(csock, reply);
1086     } else {
1087         socket_set_cork(csock, 1);
1088         rc = nbd_send_reply(csock, reply);
1089         if (rc >= 0) {
1090             ret = qemu_co_send(csock, req->data, len);
1091             if (ret != len) {
1092                 rc = -EIO;
1093             }
1094         }
1095         socket_set_cork(csock, 0);
1096     }
1097
1098     client->send_coroutine = NULL;
1099     nbd_set_handlers(client);
1100     qemu_co_mutex_unlock(&client->send_lock);
1101     return rc;
1102 }
1103
1104 static ssize_t nbd_co_receive_request(NBDRequest *req, struct nbd_request *request)
1105 {
1106     NBDClient *client = req->client;
1107     int csock = client->sock;
1108     uint32_t command;
1109     ssize_t rc;
1110
1111     client->recv_coroutine = qemu_coroutine_self();
1112     nbd_update_can_read(client);
1113
1114     rc = nbd_receive_request(csock, request);
1115     if (rc < 0) {
1116         if (rc != -EAGAIN) {
1117             rc = -EIO;
1118         }
1119         goto out;
1120     }
1121
1122     if (request->len > NBD_MAX_BUFFER_SIZE) {
1123         LOG("len (%u) is larger than max len (%u)",
1124             request->len, NBD_MAX_BUFFER_SIZE);
1125         rc = -EINVAL;
1126         goto out;
1127     }
1128
1129     if ((request->from + request->len) < request->from) {
1130         LOG("integer overflow detected! "
1131             "you're probably being attacked");
1132         rc = -EINVAL;
1133         goto out;
1134     }
1135
1136     TRACE("Decoding type");
1137
1138     command = request->type & NBD_CMD_MASK_COMMAND;
1139     if (command == NBD_CMD_READ || command == NBD_CMD_WRITE) {
1140         req->data = qemu_blockalign(client->exp->bs, request->len);
1141     }
1142     if (command == NBD_CMD_WRITE) {
1143         TRACE("Reading %u byte(s)", request->len);
1144
1145         if (qemu_co_recv(csock, req->data, request->len) != request->len) {
1146             LOG("reading from socket failed");
1147             rc = -EIO;
1148             goto out;
1149         }
1150     }
1151     rc = 0;
1152
1153 out:
1154     client->recv_coroutine = NULL;
1155     nbd_update_can_read(client);
1156
1157     return rc;
1158 }
1159
1160 static void nbd_trip(void *opaque)
1161 {
1162     NBDClient *client = opaque;
1163     NBDExport *exp = client->exp;
1164     NBDRequest *req;
1165     struct nbd_request request;
1166     struct nbd_reply reply;
1167     ssize_t ret;
1168     uint32_t command;
1169
1170     TRACE("Reading request.");
1171     if (client->closing) {
1172         return;
1173     }
1174
1175     req = nbd_request_get(client);
1176     ret = nbd_co_receive_request(req, &request);
1177     if (ret == -EAGAIN) {
1178         goto done;
1179     }
1180     if (ret == -EIO) {
1181         goto out;
1182     }
1183
1184     reply.handle = request.handle;
1185     reply.error = 0;
1186
1187     if (ret < 0) {
1188         reply.error = -ret;
1189         goto error_reply;
1190     }
1191     command = request.type & NBD_CMD_MASK_COMMAND;
1192     if (command != NBD_CMD_DISC && (request.from + request.len) > exp->size) {
1193             LOG("From: %" PRIu64 ", Len: %u, Size: %" PRIu64
1194             ", Offset: %" PRIu64 "\n",
1195                     request.from, request.len,
1196                     (uint64_t)exp->size, (uint64_t)exp->dev_offset);
1197         LOG("requested operation past EOF--bad client?");
1198         goto invalid_request;
1199     }
1200
1201     switch (command) {
1202     case NBD_CMD_READ:
1203         TRACE("Request type is READ");
1204
1205         if (request.type & NBD_CMD_FLAG_FUA) {
1206             ret = bdrv_co_flush(exp->bs);
1207             if (ret < 0) {
1208                 LOG("flush failed");
1209                 reply.error = -ret;
1210                 goto error_reply;
1211             }
1212         }
1213
1214         ret = bdrv_read(exp->bs, (request.from + exp->dev_offset) / 512,
1215                         req->data, request.len / 512);
1216         if (ret < 0) {
1217             LOG("reading from file failed");
1218             reply.error = -ret;
1219             goto error_reply;
1220         }
1221
1222         TRACE("Read %u byte(s)", request.len);
1223         if (nbd_co_send_reply(req, &reply, request.len) < 0)
1224             goto out;
1225         break;
1226     case NBD_CMD_WRITE:
1227         TRACE("Request type is WRITE");
1228
1229         if (exp->nbdflags & NBD_FLAG_READ_ONLY) {
1230             TRACE("Server is read-only, return error");
1231             reply.error = EROFS;
1232             goto error_reply;
1233         }
1234
1235         TRACE("Writing to device");
1236
1237         ret = bdrv_write(exp->bs, (request.from + exp->dev_offset) / 512,
1238                          req->data, request.len / 512);
1239         if (ret < 0) {
1240             LOG("writing to file failed");
1241             reply.error = -ret;
1242             goto error_reply;
1243         }
1244
1245         if (request.type & NBD_CMD_FLAG_FUA) {
1246             ret = bdrv_co_flush(exp->bs);
1247             if (ret < 0) {
1248                 LOG("flush failed");
1249                 reply.error = -ret;
1250                 goto error_reply;
1251             }
1252         }
1253
1254         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1255             goto out;
1256         }
1257         break;
1258     case NBD_CMD_DISC:
1259         TRACE("Request type is DISCONNECT");
1260         errno = 0;
1261         goto out;
1262     case NBD_CMD_FLUSH:
1263         TRACE("Request type is FLUSH");
1264
1265         ret = bdrv_co_flush(exp->bs);
1266         if (ret < 0) {
1267             LOG("flush failed");
1268             reply.error = -ret;
1269         }
1270         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1271             goto out;
1272         }
1273         break;
1274     case NBD_CMD_TRIM:
1275         TRACE("Request type is TRIM");
1276         ret = bdrv_co_discard(exp->bs, (request.from + exp->dev_offset) / 512,
1277                               request.len / 512);
1278         if (ret < 0) {
1279             LOG("discard failed");
1280             reply.error = -ret;
1281         }
1282         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1283             goto out;
1284         }
1285         break;
1286     default:
1287         LOG("invalid request type (%u) received", request.type);
1288     invalid_request:
1289         reply.error = -EINVAL;
1290     error_reply:
1291         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1292             goto out;
1293         }
1294         break;
1295     }
1296
1297     TRACE("Request/Reply complete");
1298
1299 done:
1300     nbd_request_put(req);
1301     return;
1302
1303 out:
1304     nbd_request_put(req);
1305     nbd_client_close(client);
1306 }
1307
1308 static void nbd_read(void *opaque)
1309 {
1310     NBDClient *client = opaque;
1311
1312     if (client->recv_coroutine) {
1313         qemu_coroutine_enter(client->recv_coroutine, NULL);
1314     } else {
1315         qemu_coroutine_enter(qemu_coroutine_create(nbd_trip), client);
1316     }
1317 }
1318
1319 static void nbd_restart_write(void *opaque)
1320 {
1321     NBDClient *client = opaque;
1322
1323     qemu_coroutine_enter(client->send_coroutine, NULL);
1324 }
1325
1326 static void nbd_set_handlers(NBDClient *client)
1327 {
1328     if (client->exp && client->exp->ctx) {
1329         aio_set_fd_handler(client->exp->ctx, client->sock,
1330                            client->can_read ? nbd_read : NULL,
1331                            client->send_coroutine ? nbd_restart_write : NULL,
1332                            client);
1333     }
1334 }
1335
1336 static void nbd_unset_handlers(NBDClient *client)
1337 {
1338     if (client->exp && client->exp->ctx) {
1339         aio_set_fd_handler(client->exp->ctx, client->sock, NULL, NULL, NULL);
1340     }
1341 }
1342
1343 static void nbd_update_can_read(NBDClient *client)
1344 {
1345     bool can_read = client->recv_coroutine ||
1346                     client->nb_requests < MAX_NBD_REQUESTS;
1347
1348     if (can_read != client->can_read) {
1349         client->can_read = can_read;
1350         nbd_set_handlers(client);
1351
1352         /* There is no need to invoke aio_notify(), since aio_set_fd_handler()
1353          * in nbd_set_handlers() will have taken care of that */
1354     }
1355 }
1356
1357 NBDClient *nbd_client_new(NBDExport *exp, int csock,
1358                           void (*close)(NBDClient *))
1359 {
1360     NBDClient *client;
1361     client = g_malloc0(sizeof(NBDClient));
1362     client->refcount = 1;
1363     client->exp = exp;
1364     client->sock = csock;
1365     client->can_read = true;
1366     if (nbd_send_negotiate(client)) {
1367         g_free(client);
1368         return NULL;
1369     }
1370     client->close = close;
1371     qemu_co_mutex_init(&client->send_lock);
1372     nbd_set_handlers(client);
1373
1374     if (exp) {
1375         QTAILQ_INSERT_TAIL(&exp->clients, client, next);
1376         nbd_export_get(exp);
1377     }
1378     return client;
1379 }