maint: Add .mailmap
[platform/upstream/nbd.git] / tests / run / nbd-tester-client.c
1 /*
2  * Test client to test the NBD server. Doesn't do anything useful, except
3  * checking that the server does, actually, work.
4  *
5  * Note that the only 'real' test is to check the client against a kernel. If
6  * it works here but does not work in the kernel, then that's most likely a bug
7  * in this program and/or in nbd-server.
8  *
9  * Copyright(c) 2006  Wouter Verhelst
10  *
11  * This program is Free Software; you can redistribute it and/or modify it
12  * under the terms of the GNU General Public License as published by the Free
13  * Software Foundation, in version 2.
14  *
15  * This program is distributed in the hope that it will be useful, but WITHOUT
16  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
17  * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
18  * more details.
19  *
20  * You should have received a copy of the GNU General Public License along with
21  * this program; if not, write to the Free Software Foundation, Inc., 51
22  * Franklin St, Fifth Floor, Boston, MA  02110-1301 USA
23  */
24 #include <stdlib.h>
25 #include <stdio.h>
26 #include <stdbool.h>
27 #include <string.h>
28 #include <sys/time.h>
29 #include <sys/types.h>
30 #include <sys/wait.h>
31 #include <sys/un.h>
32 #include <sys/socket.h>
33 #include <sys/stat.h>
34 #include <sys/mman.h>
35 #include <fcntl.h>
36 #include <syslog.h>
37 #include <unistd.h>
38 #include "config.h"
39 #include "lfs.h"
40 #include <netinet/in.h>
41 #include <glib.h>
42
43 #define MY_NAME "nbd-tester-client"
44 #include "cliserv.h"
45
46 #if HAVE_GNUTLS
47 #include "crypto-gnutls.h"
48 #endif
49
50 static gchar errstr[1024];
51 const static int errstr_len = 1023;
52
53 static uint64_t size;
54
55 static int looseordering = 0;
56
57 static gchar *transactionlog = "nbd-tester-client.tr";
58 static gchar *certfile = NULL;
59 static gchar *keyfile = NULL;
60 static gchar *cacertfile = NULL;
61 static gchar *tlshostname = NULL;
62
63 typedef enum {
64         CONNECTION_TYPE_INIT_PASSWD,
65         CONNECTION_TYPE_CLISERV,
66         CONNECTION_TYPE_FULL,
67 } CONNECTION_TYPE;
68
69 typedef enum {
70         CONNECTION_CLOSE_PROPERLY,
71         CONNECTION_CLOSE_FAST,
72 } CLOSE_TYPE;
73
74 struct reqcontext {
75         uint64_t seq;
76         char orighandle[8];
77         struct nbd_request req;
78         struct reqcontext *next;
79         struct reqcontext *prev;
80 };
81
82 struct rclist {
83         struct reqcontext *head;
84         struct reqcontext *tail;
85         int numitems;
86 };
87
88 struct chunk {
89         char *buffer;
90         char *readptr;
91         char *writeptr;
92         uint64_t space;
93         uint64_t length;
94         struct chunk *next;
95         struct chunk *prev;
96 };
97
98 struct chunklist {
99         struct chunk *head;
100         struct chunk *tail;
101         int numitems;
102 };
103
104 struct blkitem {
105         uint32_t seq;
106         int32_t inflightr;
107         int32_t inflightw;
108 };
109
110 void rclist_unlink(struct rclist *l, struct reqcontext *p)
111 {
112         if (p && l) {
113                 struct reqcontext *prev = p->prev;
114                 struct reqcontext *next = p->next;
115
116                 /* Fix link to previous */
117                 if (prev)
118                         prev->next = next;
119                 else
120                         l->head = next;
121
122                 if (next)
123                         next->prev = prev;
124                 else
125                         l->tail = prev;
126
127                 p->prev = NULL;
128                 p->next = NULL;
129                 l->numitems--;
130         }
131 }
132
133 /* Add a new list item to the tail */
134 void rclist_addtail(struct rclist *l, struct reqcontext *p)
135 {
136         if (!p || !l)
137                 return;
138         if (l->tail) {
139                 if (l->tail->next)
140                         g_warning("addtail found list tail has a next pointer");
141                 l->tail->next = p;
142                 p->next = NULL;
143                 p->prev = l->tail;
144                 l->tail = p;
145         } else {
146                 if (l->head)
147                         g_warning("addtail found no list tail but a list head");
148                 l->head = p;
149                 l->tail = p;
150                 p->prev = NULL;
151                 p->next = NULL;
152         }
153         l->numitems++;
154 }
155
156 void chunklist_unlink(struct chunklist *l, struct chunk *p)
157 {
158         if (p && l) {
159                 struct chunk *prev = p->prev;
160                 struct chunk *next = p->next;
161
162                 /* Fix link to previous */
163                 if (prev)
164                         prev->next = next;
165                 else
166                         l->head = next;
167
168                 if (next)
169                         next->prev = prev;
170                 else
171                         l->tail = prev;
172
173                 p->prev = NULL;
174                 p->next = NULL;
175                 l->numitems--;
176         }
177 }
178
179 /* Add a new list item to the tail */
180 void chunklist_addtail(struct chunklist *l, struct chunk *p)
181 {
182         if (!p || !l)
183                 return;
184         if (l->tail) {
185                 if (l->tail->next)
186                         g_warning("addtail found list tail has a next pointer");
187                 l->tail->next = p;
188                 p->next = NULL;
189                 p->prev = l->tail;
190                 l->tail = p;
191         } else {
192                 if (l->head)
193                         g_warning("addtail found no list tail but a list head");
194                 l->head = p;
195                 l->tail = p;
196                 p->prev = NULL;
197                 p->next = NULL;
198         }
199         l->numitems++;
200 }
201
202 /* Add some new bytes to a chunklist */
203 void addbuffer(struct chunklist *l, void *data, uint64_t len)
204 {
205         void *buf;
206         uint64_t size = 64 * 1024;
207         struct chunk *pchunk;
208
209         while (len > 0) {
210                 /* First see if there is a current chunk, and if it has space */
211                 if (l->tail && l->tail->space) {
212                         uint64_t towrite = len;
213                         if (towrite > l->tail->space)
214                                 towrite = l->tail->space;
215                         memcpy(l->tail->writeptr, data, towrite);
216                         l->tail->length += towrite;
217                         l->tail->space -= towrite;
218                         l->tail->writeptr += towrite;
219                         len -= towrite;
220                         data += towrite;
221                 }
222
223                 if (len > 0) {
224                         /* We still need to write more, so prepare a new chunk */
225                         if ((NULL == (buf = malloc(size)))
226                             || (NULL ==
227                                 (pchunk = calloc(1, sizeof(struct chunk))))) {
228                                 g_critical("Out of memory");
229                                 exit(1);
230                         }
231
232                         pchunk->buffer = buf;
233                         pchunk->readptr = buf;
234                         pchunk->writeptr = buf;
235                         pchunk->space = size;
236                         chunklist_addtail(l, pchunk);
237                 }
238         }
239
240 }
241
242 /* returns 0 on success, -1 on failure */
243 int writebuffer(int fd, struct chunklist *l)
244 {
245
246         struct chunk *pchunk = NULL;
247         int res;
248         if (!l)
249                 return 0;
250
251         while (!pchunk) {
252                 pchunk = l->head;
253                 if (!pchunk)
254                         return 0;
255                 if (!(pchunk->length) || !(pchunk->readptr)) {
256                         chunklist_unlink(l, pchunk);
257                         free(pchunk->buffer);
258                         free(pchunk);
259                         pchunk = NULL;
260                 }
261         }
262
263         /* OK we have a chunk with some data in */
264         res = write(fd, pchunk->readptr, pchunk->length);
265         if (res == 0)
266                 errno = EAGAIN;
267         if (res <= 0)
268                 return -1;
269         pchunk->length -= res;
270         pchunk->readptr += res;
271         if (!pchunk->length) {
272                 chunklist_unlink(l, pchunk);
273                 free(pchunk->buffer);
274                 free(pchunk);
275         }
276         return 0;
277 }
278
279 #define TEST_WRITE (1<<0)
280 #define TEST_FLUSH (1<<1)
281 #define TEST_EXPECT_ERROR (1<<2)
282 #define TEST_HANDSHAKE (1<<3)
283
284 int timeval_subtract(struct timeval *result, struct timeval *x,
285                      struct timeval *y)
286 {
287         if (x->tv_usec < y->tv_usec) {
288                 int nsec = (y->tv_usec - x->tv_usec) / 1000000 + 1;
289                 y->tv_usec -= 1000000 * nsec;
290                 y->tv_sec += nsec;
291         }
292
293         if (x->tv_usec - y->tv_usec > 1000000) {
294                 int nsec = (x->tv_usec - y->tv_usec) / 1000000;
295                 y->tv_usec += 1000000 * nsec;
296                 y->tv_sec -= nsec;
297         }
298
299         result->tv_sec = x->tv_sec - y->tv_sec;
300         result->tv_usec = x->tv_usec - y->tv_usec;
301
302         return x->tv_sec < y->tv_sec;
303 }
304
305 double timeval_diff_to_double(struct timeval *x, struct timeval *y)
306 {
307         struct timeval r;
308         timeval_subtract(&r, x, y);
309         return r.tv_sec * 1.0 + r.tv_usec / 1000000.0;
310 }
311
312 static inline int read_all(int f, void *buf, size_t len)
313 {
314         ssize_t res;
315         size_t retval = 0;
316
317         while (len > 0) {
318                 if ((res = read(f, buf, len)) <= 0) {
319                         if (!res)
320                                 errno = EAGAIN;
321                         snprintf(errstr, errstr_len, "Read failed: %s",
322                                  strerror(errno));
323                         return -1;
324                 }
325                 len -= res;
326                 buf += res;
327                 retval += res;
328         }
329         return retval;
330 }
331
332 static inline int write_all(int f, void *buf, size_t len)
333 {
334         ssize_t res;
335         size_t retval = 0;
336
337         while (len > 0) {
338                 if ((res = write(f, buf, len)) <= 0) {
339                         if (!res)
340                                 errno = EAGAIN;
341                         snprintf(errstr, errstr_len, "Write failed: %s",
342                                  strerror(errno));
343                         return -1;
344                 }
345                 len -= res;
346                 buf += res;
347                 retval += res;
348         }
349         return retval;
350 }
351
352 static int tlserrout (void *opaque, const char *format, va_list ap) {
353         return vfprintf(stderr, format, ap);
354 }
355
356 #define READ_ALL_ERRCHK(f, buf, len, whereto, errmsg...) if((read_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); goto whereto; }
357 #define READ_ALL_ERR_RT(f, buf, len, whereto, rval, errmsg...) if((read_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); retval = rval; goto whereto; }
358
359 #define WRITE_ALL_ERRCHK(f, buf, len, whereto, errmsg...) if((write_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); goto whereto; }
360 #define WRITE_ALL_ERR_RT(f, buf, len, whereto, rval, errmsg...) if((write_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); retval = rval; goto whereto; }
361
362 int setup_connection_common(int sock, char *name, CONNECTION_TYPE ctype,
363                             int *serverflags, int testflags)
364 {
365         char buf[256];
366         u64 tmp64;
367         uint64_t mymagic = (name ? opts_magic : cliserv_magic);
368         uint32_t tmp32 = 0;
369         uint16_t handshakeflags = 0;
370         uint32_t negotiationflags = 0;
371
372         if (ctype < CONNECTION_TYPE_INIT_PASSWD)
373                 goto end;
374         READ_ALL_ERRCHK(sock, buf, strlen(INIT_PASSWD), err,
375                         "Could not read INIT_PASSWD: %s", strerror(errno));
376         if (strlen(buf) == 0) {
377                 snprintf(errstr, errstr_len, "Server closed connection");
378                 goto err;
379         }
380         if (strncmp(buf, INIT_PASSWD, strlen(INIT_PASSWD))) {
381                 snprintf(errstr, errstr_len, "INIT_PASSWD does not match");
382                 goto err;
383         }
384         if (ctype < CONNECTION_TYPE_CLISERV)
385                 goto end;
386         READ_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
387                         "Could not read cliserv_magic: %s", strerror(errno));
388         tmp64 = ntohll(tmp64);
389         if (tmp64 != mymagic) {
390                 strncpy(errstr, "mymagic does not match", errstr_len);
391                 goto err;
392         }
393         if (ctype < CONNECTION_TYPE_FULL)
394                 goto end;
395         if (!name) {
396                 READ_ALL_ERRCHK(sock, &size, sizeof(size), err,
397                                 "Could not read size: %s", strerror(errno));
398                 size = ntohll(size);
399                 uint32_t flags;
400                 READ_ALL_ERRCHK(sock, &flags, sizeof(uint32_t), err,
401                                 "Could not read flags: %s", strerror(errno));
402                 flags = ntohl(flags);
403                 *serverflags = flags;
404                 READ_ALL_ERRCHK(sock, buf, 124, err, "Could not read data: %s",
405                                 strerror(errno));
406                 goto end;
407         }
408         /* handshake flags */
409         READ_ALL_ERRCHK(sock, &handshakeflags, sizeof(handshakeflags), err,
410                         "Could not read reserved field: %s", strerror(errno));
411         handshakeflags = ntohs(handshakeflags);
412         /* negotiation flags */
413         if (handshakeflags & NBD_FLAG_FIXED_NEWSTYLE)
414                 negotiationflags |= NBD_FLAG_C_FIXED_NEWSTYLE;
415         else if (keyfile) {
416                 snprintf(errstr, errstr_len, "Cannot negotiate TLS without NBD_FLAG_FIXED_NEWSTYLE");
417                 goto err;
418         }
419         negotiationflags = htonl(negotiationflags);
420         WRITE_ALL_ERRCHK(sock, &negotiationflags, sizeof(negotiationflags), err,
421                          "Could not write reserved field: %s", strerror(errno));
422         if (testflags & TEST_HANDSHAKE) {
423                 /* Server must support newstyle for this test */
424                 if (!(handshakeflags & NBD_FLAG_FIXED_NEWSTYLE)) {
425                         strncpy(errstr, "server does not support handshake", errstr_len);
426                         goto err;
427                 }
428                 goto end;
429         }
430 #if HAVE_GNUTLS
431         /* TLS */
432         if (keyfile) {
433                 int plainfd[2]; // [0] is used by the proxy, [1] is used by NBD
434                 tlssession_t *s = NULL;
435                 int ret;
436
437                 /* magic */
438                 tmp64 = htonll(opts_magic);
439                 WRITE_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
440                                  "Could not write magic: %s", strerror(errno));
441                 /* starttls */
442                 tmp32 = htonl(NBD_OPT_STARTTLS);
443                 WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
444                          "Could not write option: %s", strerror(errno));
445                 /* length of data */
446                 tmp32 = htonl(0);
447                 WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
448                          "Could not write option length: %s", strerror(errno));
449
450                 READ_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
451                                 "Could not read cliserv_magic: %s", strerror(errno));
452                 tmp64 = ntohll(tmp64);
453                 if (tmp64 != NBD_OPT_REPLY_MAGIC) {
454                         strncpy(errstr, "reply magic does not match", errstr_len);
455                         goto err;
456                 }
457                 READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
458                                 "Could not read option type: %s", strerror(errno));
459                 tmp32 = ntohl(tmp32);
460                 if (tmp32 != NBD_OPT_STARTTLS) {
461                         strncpy(errstr, "Reply to wrong option", errstr_len);
462                         goto err;
463                 }
464                 READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
465                                 "Could not read option reply type: %s", strerror(errno));
466                 tmp32 = ntohl(tmp32);
467                 if (tmp32 != NBD_REP_ACK) {
468                         if(tmp32 & NBD_REP_FLAG_ERROR) {
469                                 snprintf(errstr, errstr_len, "Received error %d", tmp32 & ~NBD_REP_FLAG_ERROR);
470                         } else {
471                                 snprintf(errstr, errstr_len, "Option reply type %d != NBD_REP_ACK", tmp32);
472                         }
473                         goto err;
474                 }
475                 READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
476                                 "Could not read option data length: %s", strerror(errno));
477                 tmp32 = ntohl(tmp32);
478                 if (tmp32 != 0) {
479                         strncpy(errstr, "Option reply data length != 0", errstr_len);
480                         goto err;
481                 }
482
483                 s = tlssession_new(FALSE,
484                                    keyfile,
485                                    certfile,
486                                    cacertfile,
487                                    tlshostname,
488                                    !cacertfile || !tlshostname, // insecure flag
489 #ifdef DODBG
490                                    1, // debug
491 #else
492                                    0, // debug
493 #endif
494                                    NULL, // quitfn
495                                    tlserrout, // erroutfn
496                                    NULL // opaque
497                         );
498                 if (!s) {
499                         strncpy(errstr, "Cannot establish TLS session", errstr_len);
500                         goto err;
501                 }
502
503                 if (socketpair(AF_UNIX, SOCK_STREAM, 0, plainfd) < 0) {
504                         strncpy(errstr, "Cannot get socket pair", errstr_len);
505                         goto err;
506                 }
507
508                 if (set_nonblocking(plainfd[0], 0) <0 ||
509                     set_nonblocking(plainfd[1], 0) <0 ||
510                     set_nonblocking(sock, 0) <0) {
511                         close(plainfd[0]);
512                         close(plainfd[1]);
513                         strncpy(errstr, "Cannot set socket options", errstr_len);
514                         goto err;
515                 }
516
517                 ret = fork();
518                 if (ret < 0)
519                         err("Could not fork");
520                 else if (ret == 0) {
521                         // we are the child
522                         signal (SIGPIPE, SIG_IGN);
523                         close(plainfd[1]);
524                         tlssession_mainloop(sock, plainfd[0], s);
525                         close(sock);
526                         close(plainfd[0]);
527                         exit(0);
528                 }
529                 close(plainfd[0]);
530                 close(sock);
531                 sock = plainfd[1]; /* use the decrypted FD from now on */
532         }
533 #else
534         if (keyfile) {
535                 strncpy(errstr, "TLS requested but support not compiled in", errstr_len);
536                 goto err;
537         }
538 #endif
539         if(testflags & TEST_EXPECT_ERROR) {
540                 struct sigaction act;
541                 memset(&act, '0', sizeof act);
542                 act.sa_handler = SIG_IGN;
543                 sigaction(SIGPIPE, &act, NULL);
544         }
545         /* magic */
546         tmp64 = htonll(opts_magic);
547         WRITE_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
548                          "Could not write magic: %s", strerror(errno));
549         /* name */
550         tmp32 = htonl(NBD_OPT_EXPORT_NAME);
551         WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
552                          "Could not write option: %s", strerror(errno));
553         tmp32 = htonl((uint32_t) strlen(name));
554         WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
555                          "Could not write name length: %s", strerror(errno));
556         WRITE_ALL_ERRCHK(sock, name, strlen(name), err,
557                          "Could not write name:: %s", strerror(errno));
558         READ_ALL_ERRCHK(sock, &size, sizeof(size), err,
559                         "Could not read size: %s", strerror(errno));
560         size = ntohll(size);
561         uint16_t flags;
562         READ_ALL_ERRCHK(sock, &flags, sizeof(uint16_t), err,
563                         "Could not read flags: %s", strerror(errno));
564         flags = ntohs(flags);
565         *serverflags = flags;
566         READ_ALL_ERRCHK(sock, buf, 124, err,
567                         "Could not read reserved zeroes: %s", strerror(errno));
568         goto end;
569 err:
570         close(sock);
571         sock = -1;
572 end:
573         return sock;
574 }
575
576 int setup_unix_connection(gchar * unixsock)
577 {
578         struct sockaddr_un addr;
579         int sock;
580
581         sock = 0;
582         if ((sock = socket(AF_UNIX, SOCK_STREAM, 0)) < 0) {
583                 strncpy(errstr, strerror(errno), errstr_len);
584                 goto err;
585         }
586
587         setmysockopt(sock);
588         memset(&addr, 0, sizeof(struct sockaddr_un));
589         addr.sun_family = AF_UNIX;
590         strncpy(addr.sun_path, unixsock, sizeof addr.sun_path);
591         addr.sun_path[sizeof(addr.sun_path)-1] = '\0';
592         if (connect(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
593                 strncpy(errstr, strerror(errno), errstr_len);
594                 goto err_open;
595         }
596         goto end;
597 err_open:
598         close(sock);
599 err:
600         sock = -1;
601 end:
602         return sock;
603 }
604
605 int setup_inet_connection(gchar * hostname, int port)
606 {
607         int sock;
608         struct hostent *host;
609         struct sockaddr_in addr;
610
611         sock = 0;
612         if ((sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0) {
613                 strncpy(errstr, strerror(errno), errstr_len);
614                 goto err;
615         }
616         setmysockopt(sock);
617         if (!(host = gethostbyname(hostname))) {
618                 strncpy(errstr, hstrerror(h_errno), errstr_len);
619                 goto err_open;
620         }
621         addr.sin_family = AF_INET;
622         addr.sin_port = htons(port);
623         addr.sin_addr.s_addr = *((int *)host->h_addr);
624         if ((connect(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0)) {
625                 strncpy(errstr, strerror(errno), errstr_len);
626                 goto err_open;
627         }
628         goto end;
629 err_open:
630         close(sock);
631 err:
632         sock = -1;
633 end:
634         return sock;
635 }
636
637 int setup_inetd_connection(gchar **argv)
638 {
639         int sv[2], status;
640         pid_t child;
641
642         if (socketpair(AF_UNIX, SOCK_STREAM, 0, sv) == -1) {
643                 strncpy(errstr, strerror(errno), errstr_len);
644                 return -1;
645         }
646
647         child = vfork();
648         if (child == 0) {
649                 dup2(sv[0], 0);
650                 close(sv[0]);
651                 close(sv[1]);
652                 execvp(argv[0], argv);
653                 perror("execvp");
654                 _exit(-1);
655         } else if (child == -1) {
656                 close(sv[0]);
657                 close(sv[1]);
658                 strncpy(errstr, strerror(errno), errstr_len);
659                 return -1;
660         }
661
662         close(sv[0]);
663         if (waitpid(child, &status, WNOHANG)) {
664                 close(sv[1]);
665                 return -1;
666         }
667
668         setmysockopt(sv[1]);
669         return sv[1];
670 }
671
672 int close_connection(int sock, CLOSE_TYPE type)
673 {
674         struct nbd_request req;
675         u64 counter = 0;
676
677         switch (type) {
678         case CONNECTION_CLOSE_PROPERLY:
679                 req.magic = htonl(NBD_REQUEST_MAGIC);
680                 req.type = htonl(NBD_CMD_DISC);
681                 memcpy(&(req.handle), &(counter), sizeof(counter));
682                 counter++;
683                 req.from = 0;
684                 req.len = 0;
685                 if (write(sock, &req, sizeof(req)) < 0) {
686                         snprintf(errstr, errstr_len,
687                                  "Could not write to socket: %s",
688                                  strerror(errno));
689                         return -1;
690                 }
691         case CONNECTION_CLOSE_FAST:
692                 if (close(sock) < 0) {
693                         snprintf(errstr, errstr_len,
694                                  "Could not close socket: %s", strerror(errno));
695                         return -1;
696                 }
697                 break;
698         default:
699                 g_critical("Your compiler is on crack!");       /* or I am buggy */
700                 return -1;
701         }
702         return 0;
703 }
704
705 int read_packet_check_header(int sock, size_t datasize, long long int curhandle)
706 {
707         struct nbd_reply rep;
708         int retval = 0;
709         char buf[datasize];
710
711         READ_ALL_ERR_RT(sock, &rep, sizeof(rep), end, -1,
712                         "Could not read reply header: %s", strerror(errno));
713         rep.magic = ntohl(rep.magic);
714         rep.error = ntohl(rep.error);
715         if (rep.magic != NBD_REPLY_MAGIC) {
716                 snprintf(errstr, errstr_len,
717                          "Received package with incorrect reply_magic. Index of sent packages is %lld (0x%llX), received handle is %lld (0x%llX). Received magic 0x%lX, expected 0x%lX",
718                          (long long int)curhandle,
719                          (long long unsigned int)curhandle,
720                          (long long int)*((u64 *) rep.handle),
721                          (long long unsigned int)*((u64 *) rep.handle),
722                          (long unsigned int)rep.magic,
723                          (long unsigned int)NBD_REPLY_MAGIC);
724                 retval = -1;
725                 goto end;
726         }
727         if (rep.error) {
728                 snprintf(errstr, errstr_len,
729                          "Received error from server: %ld (0x%lX). Handle is %lld (0x%llX).",
730                          (long int)rep.error, (long unsigned int)rep.error,
731                          (long long int)(*((u64 *) rep.handle)),
732                          (long long unsigned int)*((u64 *) rep.handle));
733                 retval = -2;
734                 goto end;
735         }
736         if (datasize)
737                 READ_ALL_ERR_RT(sock, &buf, datasize, end, -1,
738                                 "Could not read data: %s", strerror(errno));
739
740 end:
741         return retval;
742 }
743
744 int oversize_test(char *name, int sock, char close_sock, int testflags)
745 {
746         int retval = 0;
747         struct nbd_request req;
748         struct nbd_reply rep;
749         int i = 0;
750         int serverflags = 0;
751         pid_t G_GNUC_UNUSED mypid = getpid();
752         char buf[((1024 * 1024) + sizeof(struct nbd_request) / 2) << 1];
753         bool got_err;
754
755         /* This should work */
756         if ((sock =
757                  setup_connection_common(sock, name,
758                                   CONNECTION_TYPE_FULL,
759                                   &serverflags, testflags)) < 0) {
760                 g_warning("Could not open socket: %s", errstr);
761                 retval = -1;
762                 goto err;
763         }
764         req.magic = htonl(NBD_REQUEST_MAGIC);
765         req.type = htonl(NBD_CMD_READ);
766         req.len = htonl(1024 * 1024);
767         memcpy(&(req.handle), &i, sizeof(i));
768         req.from = htonll(i);
769         WRITE_ALL_ERR_RT(sock, &req, sizeof(req), err, -1,
770                          "Could not write request: %s", strerror(errno));
771         printf("%d: testing oversized request: %d: ", getpid(), ntohl(req.len));
772         READ_ALL_ERR_RT(sock, &rep, sizeof(struct nbd_reply), err, -1,
773                         "Could not read reply header: %s", strerror(errno));
774         READ_ALL_ERR_RT(sock, &buf, ntohl(req.len), err, -1,
775                         "Could not read data: %s", strerror(errno));
776         if (rep.error) {
777                 snprintf(errstr, errstr_len, "Received unexpected error: %d",
778                          rep.error);
779                 retval = -1;
780                 goto err;
781         } else {
782                 printf("OK\n");
783         }
784         /* This probably should not work */
785         i++;
786         req.from = htonll(i);
787         req.len = htonl(ntohl(req.len) + sizeof(struct nbd_request) / 2);
788         WRITE_ALL_ERR_RT(sock, &req, sizeof(req), err, -1,
789                          "Could not write request: %s", strerror(errno));
790         printf("%d: testing oversized request: %d: ", getpid(), ntohl(req.len));
791         READ_ALL_ERR_RT(sock, &rep, sizeof(struct nbd_reply), err, -1,
792                         "Could not read reply header: %s", strerror(errno));
793         READ_ALL_ERR_RT(sock, &buf, ntohl(req.len), err, -1,
794                         "Could not read data: %s", strerror(errno));
795         if (rep.error) {
796                 printf("Received expected error\n");
797                 got_err = true;
798         } else {
799                 printf("OK\n");
800                 got_err = false;
801         }
802         /* ... unless this works, too */
803         i++;
804         req.from = htonll(i);
805         req.len = htonl(ntohl(req.len) << 1);
806         WRITE_ALL_ERR_RT(sock, &req, sizeof(req), err, -1,
807                          "Could not write request: %s", strerror(errno));
808         printf("%d: testing oversized request: %d: ", getpid(), ntohl(req.len));
809         READ_ALL_ERR_RT(sock, &rep, sizeof(struct nbd_reply), err, -1,
810                         "Could not read reply header: %s", strerror(errno));
811         READ_ALL_ERR_RT(sock, &buf, ntohl(req.len), err, -1,
812                         "Could not read data: %s", strerror(errno));
813         if (rep.error) {
814                 printf("error\n");
815         } else {
816                 printf("OK\n");
817         }
818         if ((rep.error && !got_err) || (!rep.error && got_err)) {
819                 printf("Received unexpected error\n");
820                 retval = -1;
821         }
822 err:
823         return retval;
824 }
825
826 int handshake_test(char *name, int sock, char close_sock, int testflags)
827 {
828         int retval = -1;
829         int serverflags = 0;
830         u64 tmp64;
831         uint32_t tmp32 = 0;
832
833         /* This should work */
834         if ((sock =
835                  setup_connection_common(sock, name,
836                                   CONNECTION_TYPE_FULL,
837                                   &serverflags, testflags)) < 0) {
838                 g_warning("Could not open socket: %s", errstr);
839                 goto err;
840         }
841
842         /* Intentionally throw an unknown option at the server */
843         tmp64 = htonll(opts_magic);
844         WRITE_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
845                          "Could not write magic: %s", strerror(errno));
846         tmp32 = htonl(0x7654321);
847         WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
848                          "Could not write option: %s", strerror(errno));
849         tmp32 = htonl((uint32_t) sizeof(tmp32));
850         WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
851                          "Could not write option length: %s", strerror(errno));
852         WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
853                          "Could not write option payload: %s", strerror(errno));
854         /* Expect proper error from server */
855         READ_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
856                         "Could not read magic: %s", strerror(errno));
857         tmp64 = ntohll(tmp64);
858         if (tmp64 != 0x3e889045565a9LL) {
859                 strncpy(errstr, "magic does not match", errstr_len);
860                 goto err;
861         }
862         READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
863                         "Could not read option: %s", strerror(errno));
864         tmp32 = ntohl(tmp32);
865         if (tmp32 != 0x7654321) {
866                 strncpy(errstr, "option does not match", errstr_len);
867                 goto err;
868         }
869         READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
870                         "Could not read status: %s", strerror(errno));
871         tmp32 = ntohl(tmp32);
872         if (tmp32 != NBD_REP_ERR_UNSUP) {
873                 strncpy(errstr, "status does not match", errstr_len);
874                 goto err;
875         }
876         READ_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
877                         "Could not read length: %s", strerror(errno));
878         tmp32 = ntohl(tmp32);
879         while (tmp32) {
880                 char buf[1024];
881                 size_t len = tmp32 < sizeof(buf) ? tmp32 : sizeof(buf);
882                 READ_ALL_ERRCHK(sock, buf, len, err,
883                                 "Could not read payload: %s", strerror(errno));
884                 tmp32 -= len;
885         }
886
887
888         /* Send NBD_OPT_ABORT to close the connection */
889         tmp64 = htonll(opts_magic);
890         WRITE_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err,
891                          "Could not write magic: %s", strerror(errno));
892         tmp32 = htonl(NBD_OPT_ABORT);
893         WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
894                          "Could not write option: %s", strerror(errno));
895         tmp32 = htonl((uint32_t) 0);
896         WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err,
897                          "Could not write option length: %s", strerror(errno));
898
899         retval = 0;
900
901         g_message("Handshake test completed. No errors encountered.");
902 err:
903         return retval;
904 }
905
906 int throughput_test(char *name, int sock, char close_sock, int testflags)
907 {
908         long long int i;
909         char writebuf[1024];
910         struct nbd_request req;
911         int requests = 0;
912         fd_set set;
913         struct timeval tv;
914         struct timeval start;
915         struct timeval stop;
916         double timespan;
917         double speed;
918         char speedchar[2] = { '\0', '\0' };
919         int retval = 0;
920         int serverflags = 0;
921         signed int do_write = TRUE;
922         pid_t mypid = getpid();
923         char *print = getenv("NBD_TEST_SILENT");
924
925         if (!(testflags & TEST_WRITE))
926                 testflags &= ~TEST_FLUSH;
927
928         memset(writebuf, 'X', 1024);
929         size = 0;
930         if ((sock =
931                  setup_connection_common(sock, name,
932                                   CONNECTION_TYPE_FULL,
933                                   &serverflags, testflags)) < 0) {
934                 g_warning("Could not open socket: %s", errstr);
935                 if(testflags & TEST_EXPECT_ERROR) {
936                         g_message("Test failed, as expected");
937                         retval = 0;
938                 } else {
939                         retval = -1;
940                 }
941                 goto err;
942         }
943         if ((testflags & TEST_FLUSH)
944             && ((serverflags & (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))
945                 != (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))) {
946                 snprintf(errstr, errstr_len,
947                          "Server did not supply flush capability flags");
948                 retval = -1;
949                 goto err_open;
950         }
951         req.magic = htonl(NBD_REQUEST_MAGIC);
952         req.len = htonl(1024);
953         if (gettimeofday(&start, NULL) < 0) {
954                 retval = -1;
955                 snprintf(errstr, errstr_len, "Could not measure start time: %s",
956                          strerror(errno));
957                 goto err_open;
958         }
959         for (i = 0; i + 1024 <= size; i += 1024) {
960                 if (do_write) {
961                         int sendfua = (testflags & TEST_FLUSH)
962                             && (((i >> 10) & 15) == 3);
963                         int sendflush = (testflags & TEST_FLUSH)
964                             && (((i >> 10) & 15) == 11);
965                         req.type =
966                             htonl((testflags & TEST_WRITE) ? NBD_CMD_WRITE :
967                                   NBD_CMD_READ);
968                         if (sendfua)
969                                 req.type =
970                                     htonl(NBD_CMD_WRITE | NBD_CMD_FLAG_FUA);
971                         memcpy(&(req.handle), &i, sizeof(i));
972                         req.from = htonll(i);
973                         if (write_all(sock, &req, sizeof(req)) < 0) {
974                                 retval = -1;
975                                 goto err_open;
976                         }
977                         if (testflags & TEST_WRITE) {
978                                 if (write_all(sock, writebuf, 1024) < 0) {
979                                         retval = -1;
980                                         goto err_open;
981                                 }
982                         }
983                         ++requests;
984                         if (sendflush) {
985                                 long long int j = i ^ (1LL << 63);
986                                 req.type = htonl(NBD_CMD_FLUSH);
987                                 memcpy(&(req.handle), &j, sizeof(j));
988                                 req.from = 0;
989                                 req.len = 0;
990                                 if (write_all(sock, &req, sizeof(req)) < 0) {
991                                         retval = -1;
992                                         goto err_open;
993                                 }
994                                 req.len = htonl(1024);
995                                 ++requests;
996                         }
997                 }
998                 do {
999                         FD_ZERO(&set);
1000                         FD_SET(sock, &set);
1001                         tv.tv_sec = 0;
1002                         tv.tv_usec = 0;
1003                         select(sock + 1, &set, NULL, NULL, &tv);
1004                         if (FD_ISSET(sock, &set)) {
1005                                 /* Okay, there's something ready for
1006                                  * reading here */
1007                                 int rv;
1008                                 if ((rv =
1009                                      read_packet_check_header(sock,
1010                                                               (testflags &
1011                                                                TEST_WRITE) ? 0 :
1012                                                               1024, i)) < 0) {
1013                                         if (!(testflags & TEST_EXPECT_ERROR)
1014                                             || rv != -2) {
1015                                                 retval = -1;
1016                                         } else {
1017                                                 printf("\n");
1018                                         }
1019                                         goto err_open;
1020                                 } else {
1021                                         if (testflags & TEST_EXPECT_ERROR) {
1022                                                 retval = -1;
1023                                                 goto err_open;
1024                                         }
1025                                 }
1026                                 --requests;
1027                         }
1028                 } while (FD_ISSET(sock, &set));
1029                 /* Now wait until we can write again or until a second have
1030                  * passed, whichever comes first*/
1031                 FD_ZERO(&set);
1032                 FD_SET(sock, &set);
1033                 tv.tv_sec = 1;
1034                 tv.tv_usec = 0;
1035                 do_write = select(sock + 1, NULL, &set, NULL, &tv);
1036                 if (!do_write)
1037                         printf("Select finished\n");
1038                 if (do_write < 0) {
1039                         snprintf(errstr, errstr_len, "select: %s",
1040                                  strerror(errno));
1041                         retval = -1;
1042                         goto err_open;
1043                 }
1044                 if(print == NULL) {
1045                         printf("%d: Requests: %d  \r", (int)mypid, requests);
1046                 }
1047         }
1048         /* Now empty the read buffer */
1049         do {
1050                 FD_ZERO(&set);
1051                 FD_SET(sock, &set);
1052                 tv.tv_sec = 0;
1053                 tv.tv_usec = 0;
1054                 select(sock + 1, &set, NULL, NULL, &tv);
1055                 if (FD_ISSET(sock, &set)) {
1056                         /* Okay, there's something ready for
1057                          * reading here */
1058                         read_packet_check_header(sock,
1059                                                  (testflags & TEST_WRITE) ? 0 :
1060                                                  1024, i);
1061                         --requests;
1062                 }
1063                 if(print == NULL) {
1064                         printf("%d: Requests: %d  \r", (int)mypid, requests);
1065                 }
1066         } while (requests);
1067         printf("%d: Requests: %d  \n", (int)mypid, requests);
1068         if (gettimeofday(&stop, NULL) < 0) {
1069                 retval = -1;
1070                 snprintf(errstr, errstr_len, "Could not measure end time: %s",
1071                          strerror(errno));
1072                 goto err_open;
1073         }
1074         timespan = timeval_diff_to_double(&stop, &start);
1075         speed = size / timespan;
1076         if (speed > 1024) {
1077                 speed = speed / 1024.0;
1078                 speedchar[0] = 'K';
1079         }
1080         if (speed > 1024) {
1081                 speed = speed / 1024.0;
1082                 speedchar[0] = 'M';
1083         }
1084         if (speed > 1024) {
1085                 speed = speed / 1024.0;
1086                 speedchar[0] = 'G';
1087         }
1088         g_message
1089             ("%d: Throughput %s test (%s flushes) complete. Took %.3f seconds to complete, %.3f%sib/s",
1090              (int)getpid(), (testflags & TEST_WRITE) ? "write" : "read",
1091              (testflags & TEST_FLUSH) ? "with" : "without", timespan, speed,
1092              speedchar);
1093
1094 err_open:
1095         if (close_sock) {
1096                 close_connection(sock, CONNECTION_CLOSE_PROPERLY);
1097         }
1098 err:
1099         return retval;
1100 }
1101
1102 /*
1103  * fill 512 byte buffer 'buf' with a hashed selection of interesting data based
1104  * only on handle and blknum. The first word is blknum, and the second handle, for ease
1105  * of understanding. Things with handle 0 are blank.
1106  */
1107 static inline void makebuf(char *buf, uint64_t seq, uint64_t blknum)
1108 {
1109         uint64_t x = ((uint64_t) blknum) ^ (seq << 32) ^ (seq >> 32);
1110         uint64_t *p = (uint64_t *) buf;
1111         int i;
1112         if (!seq) {
1113                 bzero(buf, 512);
1114                 return;
1115         }
1116         for (i = 0; i < 512 / sizeof(uint64_t); i++) {
1117                 int s;
1118                 *(p++) = x;
1119                 x += 0xFEEDA1ECDEADBEEFULL + i + (((uint64_t) i) << 56);
1120                 s = x & 63;
1121                 x = x ^ (x << s) ^ (x >> (64 - s)) ^ 0xAA55AA55AA55AA55ULL ^
1122                     seq;
1123         }
1124 }
1125
1126 static inline int checkbuf(char *buf, uint64_t seq, uint64_t blknum)
1127 {
1128         uint64_t cmp[64];       // 512/8 = 64
1129         makebuf((char *)cmp, seq, blknum);
1130         return memcmp(cmp, buf, 512) ? -1 : 0;
1131 }
1132
1133 static inline void dumpcommand(char *text, uint32_t command)
1134 {
1135 #ifdef DEBUG_COMMANDS
1136         command = ntohl(command);
1137         char *ctext;
1138         switch (command & NBD_CMD_MASK_COMMAND) {
1139         case NBD_CMD_READ:
1140                 ctext = "NBD_CMD_READ";
1141                 break;
1142         case NBD_CMD_WRITE:
1143                 ctext = "NBD_CMD_WRITE";
1144                 break;
1145         case NBD_CMD_DISC:
1146                 ctext = "NBD_CMD_DISC";
1147                 break;
1148         case NBD_CMD_FLUSH:
1149                 ctext = "NBD_CMD_FLUSH";
1150                 break;
1151         default:
1152                 ctext = "UNKNOWN";
1153                 break;
1154         }
1155         printf("%s: %s [%s] (0x%08x)\n",
1156                text,
1157                ctext, (command & NBD_CMD_FLAG_FUA) ? "FUA" : "NONE", command);
1158 #endif
1159 }
1160
1161 /* return an unused handle */
1162 uint64_t getrandomhandle(GHashTable * phash)
1163 {
1164         uint64_t handle = 0;
1165         int i;
1166         do {
1167                 /* RAND_MAX may be as low as 2^15 */
1168                 for (i = 1; i <= 5; i++)
1169                         handle ^= random() ^ (handle << 15);
1170         } while (g_hash_table_lookup(phash, &handle));
1171         return handle;
1172 }
1173
1174 int integrity_test(char *name, int sock, char close_sock, int testflags)
1175 {
1176         struct nbd_reply rep;
1177         fd_set rset;
1178         fd_set wset;
1179         struct timeval tv;
1180         struct timeval start;
1181         struct timeval stop;
1182         double timespan;
1183         double speed;
1184         char speedchar[2] = { '\0', '\0' };
1185         int retval = -1;
1186         int serverflags = 0;
1187         pid_t G_GNUC_UNUSED mypid = getpid();
1188         int blkhashfd = -1;
1189         char *blkhashname = NULL;
1190         struct blkitem *blkhash = NULL;
1191         int logfd = -1;
1192         uint64_t seq = 1;
1193         uint64_t processed = 0;
1194         uint64_t printer = 0;
1195         char *do_print = getenv("NBD_TEST_SILENT");
1196         uint64_t xfer = 0;
1197         int readtransactionfile = 1;
1198         int blocked = 0;
1199         struct rclist txqueue = { NULL, NULL, 0 };
1200         struct rclist inflight = { NULL, NULL, 0 };
1201         struct chunklist txbuf = { NULL, NULL, 0 };
1202
1203         GHashTable *handlehash = g_hash_table_new(g_int64_hash, g_int64_equal);
1204
1205         size = 0;
1206         if ((sock =
1207                  setup_connection_common(sock, name,
1208                                   CONNECTION_TYPE_FULL,
1209                                   &serverflags, testflags)) < 0) {
1210                 g_warning("Could not open socket: %s", errstr);
1211                 goto err;
1212         }
1213
1214         if ((serverflags & (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))
1215             != (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))
1216                 g_warning
1217                     ("Server flags do not support FLUSH and FUA - these may error");
1218
1219 #ifdef HAVE_MKSTEMP
1220         blkhashname = strdup("/tmp/blkarray-XXXXXX");
1221         if (!blkhashname || (-1 == (blkhashfd = mkstemp(blkhashname)))) {
1222                 g_warning("Could not open temp file: %s", strerror(errno));
1223                 goto err;
1224         }
1225 #else
1226         /* use tmpnam here to avoid further feature test nightmare */
1227         if (-1 == (blkhashfd = open(blkhashname = strdup(tmpnam(NULL)),
1228                                     O_CREAT | O_RDWR,
1229                                     S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH))) {
1230                 g_warning("Could not open temp file: %s", strerror(errno));
1231                 goto err;
1232         }
1233 #endif
1234         /* Ensure space freed if we die */
1235         if (-1 == unlink(blkhashname)) {
1236                 g_warning("Could not unlink temp file: %s", strerror(errno));
1237                 goto err;
1238         }
1239
1240         if (-1 ==
1241             lseek(blkhashfd, (off_t) ((size >> 9) * sizeof(struct blkitem)),
1242                   SEEK_SET)) {
1243                 g_warning("Could not llseek temp file: %s", strerror(errno));
1244                 goto err;
1245         }
1246
1247         if (-1 == write(blkhashfd, "\0", 1)) {
1248                 g_warning("Could not write temp file: %s", strerror(errno));
1249                 goto err;
1250         }
1251
1252         if (NULL == (blkhash = mmap(NULL,
1253                                     (size >> 9) * sizeof(struct blkitem),
1254                                     PROT_READ | PROT_WRITE,
1255                                     MAP_SHARED, blkhashfd, 0))) {
1256                 g_warning("Could not mmap temp file: %s", strerror(errno));
1257                 goto err;
1258         }
1259
1260         if (-1 == (logfd = open(transactionlog, O_RDONLY))) {
1261                 g_warning("Could open log file: %s", strerror(errno));
1262                 goto err;
1263         }
1264
1265         if (gettimeofday(&start, NULL) < 0) {
1266                 snprintf(errstr, errstr_len, "Could not measure start time: %s",
1267                          strerror(errno));
1268                 goto err_open;
1269         }
1270
1271         while (readtransactionfile || txqueue.numitems || txbuf.numitems
1272                || inflight.numitems) {
1273                 int ret;
1274
1275                 uint32_t magic;
1276                 uint32_t command;
1277                 uint64_t from;
1278                 uint32_t len;
1279                 struct reqcontext *prc;
1280
1281                 *errstr = 0;
1282
1283                 FD_ZERO(&wset);
1284                 FD_ZERO(&rset);
1285                 if (readtransactionfile)
1286                         FD_SET(logfd, &rset);
1287                 if ((!blocked && txqueue.numitems) || txbuf.numitems)
1288                         FD_SET(sock, &wset);
1289                 if (inflight.numitems)
1290                         FD_SET(sock, &rset);
1291                 tv.tv_sec = 5;
1292                 tv.tv_usec = 0;
1293                 ret =
1294                     select(1 + ((sock > logfd) ? sock : logfd), &rset, &wset,
1295                            NULL, &tv);
1296                 if (ret == 0) {
1297                         snprintf(errstr, errstr_len,
1298                                  "Timeout reading from socket");
1299                         goto err_open;
1300                 } else if (ret < 0) {
1301                         g_warning("Could not mmap temp file: %s", errstr);
1302                         goto err;
1303                 }
1304                 /* We know we've got at least one thing to do here then */
1305
1306                 /* Get a command from the transaction log */
1307                 if (FD_ISSET(logfd, &rset)) {
1308
1309                         /* Read a request or reply from the transaction file */
1310                         READ_ALL_ERRCHK(logfd,
1311                                         &magic,
1312                                         sizeof(magic),
1313                                         err_open,
1314                                         "Could not read transaction log: %s",
1315                                         strerror(errno));
1316                         magic = ntohl(magic);
1317                         switch (magic) {
1318                         case NBD_REQUEST_MAGIC:
1319                                 if (NULL ==
1320                                     (prc =
1321                                      calloc(1, sizeof(struct reqcontext)))) {
1322                                         snprintf(errstr, errstr_len,
1323                                                  "Could not allocate request");
1324                                         goto err_open;
1325                                 }
1326                                 READ_ALL_ERRCHK(logfd,
1327                                                 sizeof(magic) +
1328                                                 (char *)&(prc->req),
1329                                                 sizeof(struct nbd_request) -
1330                                                 sizeof(magic), err_open,
1331                                                 "Could not read transaction log: %s",
1332                                                 strerror(errno));
1333                                 prc->req.magic = htonl(NBD_REQUEST_MAGIC);
1334                                 memcpy(prc->orighandle, prc->req.handle, 8);
1335                                 prc->seq = seq++;
1336                                 if ((ntohl(prc->req.type) &
1337                                      NBD_CMD_MASK_COMMAND) == NBD_CMD_DISC) {
1338                                         /* no more to read; don't enqueue as no reply
1339                                          * we will disconnect manually at the end
1340                                          */
1341                                         readtransactionfile = 0;
1342                                         free(prc);
1343                                 } else {
1344                                         dumpcommand("Enqueuing command",
1345                                                     prc->req.type);
1346                                         rclist_addtail(&txqueue, prc);
1347                                 }
1348                                 prc = NULL;
1349                                 break;
1350                         case NBD_REPLY_MAGIC:
1351                                 READ_ALL_ERRCHK(logfd,
1352                                                 sizeof(magic) + (char *)(&rep),
1353                                                 sizeof(struct nbd_reply) -
1354                                                 sizeof(magic), err_open,
1355                                                 "Could not read transaction log: %s",
1356                                                 strerror(errno));
1357
1358                                 if (rep.error) {
1359                                         snprintf(errstr, errstr_len,
1360                                                  "Transaction log file contained errored transaction");
1361                                         goto err_open;
1362                                 }
1363
1364                                 /* We do not need to consume data on a read reply as there is
1365                                  * none in the log */
1366                                 break;
1367                         default:
1368                                 snprintf(errstr, errstr_len,
1369                                          "Could not measure start time: %08x",
1370                                          magic);
1371                                 goto err_open;
1372                         }
1373                 }
1374
1375                 /* See if we have a write we can do */
1376                 if (FD_ISSET(sock, &wset)) {
1377                         if ((!(txqueue.head) && !(txbuf.head)) || blocked)
1378                                 g_warning
1379                                     ("Socket write FD set but we shouldn't have been interested");
1380
1381                         /* If there is no buffered data, generate some */
1382                         if (!blocked && !(txbuf.head)
1383                             && (NULL != (prc = txqueue.head))) {
1384                                 if (ntohl(prc->req.magic) != NBD_REQUEST_MAGIC) {
1385                                         g_warning
1386                                             ("Asked to write a request without a magic number");
1387                                         goto err_open;
1388                                 }
1389
1390                                 command = ntohl(prc->req.type);
1391                                 from = ntohll(prc->req.from);
1392                                 len = ntohl(prc->req.len);
1393
1394                                 /* First check whether we can touch this command at all. If this
1395                                  * command is a read, and there is an inflight write, OR if this
1396                                  * command is a write, and there is an inflight read or write, then
1397                                  * we need to leave the command alone and signal that we are blocked
1398                                  */
1399
1400                                 if (!looseordering) {
1401                                         uint64_t cfrom;
1402                                         uint32_t clen;
1403                                         cfrom = from;
1404                                         clen = len;
1405                                         while (clen > 0) {
1406                                                 uint64_t blknum = cfrom >> 9;
1407                                                 if (cfrom >= size) {
1408                                                         snprintf(errstr,
1409                                                                  errstr_len,
1410                                                                  "offset %llx beyond size %llx",
1411                                                                  (long long int)
1412                                                                  cfrom,
1413                                                                  (long long int)
1414                                                                  size);
1415                                                         goto err_open;
1416                                                 }
1417                                                 if (blkhash[blknum].inflightw ||
1418                                                     (blkhash[blknum].inflightr
1419                                                      &&
1420                                                      ((command &
1421                                                        NBD_CMD_MASK_COMMAND) ==
1422                                                       NBD_CMD_WRITE))) {
1423                                                         blocked = 1;
1424                                                         break;
1425                                                 }
1426                                                 cfrom += 512;
1427                                                 clen -= 512;
1428                                         }
1429                                 }
1430
1431                                 if (blocked)
1432                                         goto skipdequeue;
1433
1434                                 rclist_unlink(&txqueue, prc);
1435                                 rclist_addtail(&inflight, prc);
1436
1437                                 dumpcommand("Sending command", prc->req.type);
1438                                 /* we rewrite the handle as they otherwise may not be unique */
1439                                 *((uint64_t *) (prc->req.handle)) =
1440                                     getrandomhandle(handlehash);
1441                                 g_hash_table_insert(handlehash, prc->req.handle,
1442                                                     prc);
1443                                 addbuffer(&txbuf, &(prc->req),
1444                                           sizeof(struct nbd_request));
1445                                 switch (command & NBD_CMD_MASK_COMMAND) {
1446                                 case NBD_CMD_WRITE:
1447                                         xfer += len;
1448                                         while (len > 0) {
1449                                                 uint64_t blknum = from >> 9;
1450                                                 char dbuf[512];
1451                                                 if (from >= size) {
1452                                                         snprintf(errstr,
1453                                                                  errstr_len,
1454                                                                  "offset %llx beyond size %llx",
1455                                                                  (long long int)
1456                                                                  from,
1457                                                                  (long long int)
1458                                                                  size);
1459                                                         goto err_open;
1460                                                 }
1461                                                 (blkhash[blknum].inflightw)++;
1462                                                 /* work out what we should be writing */
1463                                                 makebuf(dbuf, prc->seq, blknum);
1464                                                 addbuffer(&txbuf, dbuf, 512);
1465                                                 from += 512;
1466                                                 len -= 512;
1467                                         }
1468                                         break;
1469                                 case NBD_CMD_READ:
1470                                         xfer += len;
1471                                         while (len > 0) {
1472                                                 uint64_t blknum = from >> 9;
1473                                                 if (from >= size) {
1474                                                         snprintf(errstr,
1475                                                                  errstr_len,
1476                                                                  "offset %llx beyond size %llx",
1477                                                                  (long long int)
1478                                                                  from,
1479                                                                  (long long int)
1480                                                                  size);
1481                                                         goto err_open;
1482                                                 }
1483                                                 (blkhash[blknum].inflightr)++;
1484                                                 from += 512;
1485                                                 len -= 512;
1486                                         }
1487                                         break;
1488                                 case NBD_CMD_DISC:
1489                                 case NBD_CMD_FLUSH:
1490                                         break;
1491                                 default:
1492                                         snprintf(errstr, errstr_len,
1493                                                  "Incomprehensible command: %08x",
1494                                                  command);
1495                                         goto err_open;
1496                                         break;
1497                                 }
1498
1499                                 prc = NULL;
1500                         }
1501 skipdequeue:
1502
1503                         /* there should be some now */
1504                         if (writebuffer(sock, &txbuf) < 0) {
1505                                 snprintf(errstr, errstr_len,
1506                                          "Failed to write to socket buffer: %s",
1507                                          strerror(errno));
1508                                 goto err_open;
1509                         }
1510
1511                 }
1512
1513                 /* See if there is a reply to be processed from the socket */
1514                 if (FD_ISSET(sock, &rset)) {
1515                         /* Okay, there's something ready for
1516                          * reading here */
1517
1518                         READ_ALL_ERRCHK(sock,
1519                                         &rep,
1520                                         sizeof(struct nbd_reply),
1521                                         err_open,
1522                                         "Could not read from server socket: %s",
1523                                         strerror(errno));
1524
1525                         if (rep.magic != htonl(NBD_REPLY_MAGIC)) {
1526                                 snprintf(errstr, errstr_len,
1527                                          "Bad magic from server");
1528                                 goto err_open;
1529                         }
1530
1531                         if (rep.error) {
1532                                 snprintf(errstr, errstr_len,
1533                                          "Server errored a transaction");
1534                                 goto err_open;
1535                         }
1536
1537                         uint64_t handle;
1538                         memcpy(&handle, rep.handle, 8);
1539                         prc = g_hash_table_lookup(handlehash, &handle);
1540                         if (!prc) {
1541                                 snprintf(errstr, errstr_len,
1542                                          "Unrecognised handle in reply: 0x%llX",
1543                                          *(long long unsigned int *)(rep.
1544                                                                      handle));
1545                                 goto err_open;
1546                         }
1547                         if (!g_hash_table_remove(handlehash, &handle)) {
1548                                 snprintf(errstr, errstr_len,
1549                                          "Could not remove handle from hash: 0x%llX",
1550                                          *(long long unsigned int *)(rep.
1551                                                                      handle));
1552                                 goto err_open;
1553                         }
1554
1555                         if (prc->req.magic != htonl(NBD_REQUEST_MAGIC)) {
1556                                 snprintf(errstr, errstr_len,
1557                                          "Bad magic in inflight data: %08x",
1558                                          prc->req.magic);
1559                                 goto err_open;
1560                         }
1561
1562                         dumpcommand("Processing reply to command",
1563                                     prc->req.type);
1564                         command = ntohl(prc->req.type);
1565                         from = ntohll(prc->req.from);
1566                         len = ntohl(prc->req.len);
1567
1568                         switch (command & NBD_CMD_MASK_COMMAND) {
1569                         case NBD_CMD_READ:
1570                                 while (len > 0) {
1571                                         uint64_t blknum = from >> 9;
1572                                         char dbuf[512];
1573                                         if (from >= size) {
1574                                                 snprintf(errstr, errstr_len,
1575                                                          "offset %llx beyond size %llx",
1576                                                          (long long int)from,
1577                                                          (long long int)size);
1578                                                 goto err_open;
1579                                         }
1580                                         READ_ALL_ERRCHK(sock,
1581                                                         dbuf,
1582                                                         512,
1583                                                         err_open,
1584                                                         "Could not read data: %s",
1585                                                         strerror(errno));
1586                                         if (--(blkhash[blknum].inflightr) < 0) {
1587                                                 snprintf(errstr, errstr_len,
1588                                                          "Received a read reply for offset %llx when not in flight",
1589                                                          (long long int)from);
1590                                                 goto err_open;
1591                                         }
1592                                         /* work out what we was written */
1593                                         if (checkbuf
1594                                             (dbuf, blkhash[blknum].seq,
1595                                              blknum)) {
1596                                                 snprintf(errstr, errstr_len,
1597                                                          "Bad reply data: I wanted blk %08x, seq %08x but I got (at a guess) blk %08x, seq %08x",
1598                                                          (unsigned int)blknum,
1599                                                          blkhash[blknum].seq,
1600                                                          ((uint32_t
1601                                                            *) (dbuf))[0],
1602                                                          ((uint32_t
1603                                                            *) (dbuf))[1]
1604                                                     );
1605                                                 goto err_open;
1606
1607                                         }
1608                                         from += 512;
1609                                         len -= 512;
1610                                 }
1611                                 break;
1612                         case NBD_CMD_WRITE:
1613                                 /* subsequent reads should get data with this seq */
1614                                 while (len > 0) {
1615                                         uint64_t blknum = from >> 9;
1616                                         if (--(blkhash[blknum].inflightw) < 0) {
1617                                                 snprintf(errstr, errstr_len,
1618                                                          "Received a write reply for offset %llx when not in flight",
1619                                                          (long long int)from);
1620                                                 goto err_open;
1621                                         }
1622                                         blkhash[blknum].seq =
1623                                             (uint32_t) (prc->seq);
1624                                         from += 512;
1625                                         len -= 512;
1626                                 }
1627                                 break;
1628                         default:
1629                                 break;
1630                         }
1631                         blocked = 0;
1632                         processed++;
1633                         rclist_unlink(&inflight, prc);
1634                         prc->req.magic = 0;     /* so a duplicate reply is detected */
1635                         free(prc);
1636                 }
1637
1638                 if ((do_print == NULL && !(printer++ % 5000))
1639                     || !(readtransactionfile || txqueue.numitems
1640                          || inflight.numitems))
1641                         printf
1642                             ("%d: Seq %08lld Queued: %08d Inflight: %08d Done: %08lld\r",
1643                              (int)mypid, (long long int)seq, txqueue.numitems,
1644                              inflight.numitems, (long long int)processed);
1645
1646         }
1647
1648         printf("\n");
1649
1650         if (gettimeofday(&stop, NULL) < 0) {
1651                 snprintf(errstr, errstr_len, "Could not measure end time: %s",
1652                          strerror(errno));
1653                 goto err_open;
1654         }
1655         timespan = timeval_diff_to_double(&stop, &start);
1656         speed = xfer / timespan;
1657         if (speed > 1024) {
1658                 speed = speed / 1024.0;
1659                 speedchar[0] = 'K';
1660         }
1661         if (speed > 1024) {
1662                 speed = speed / 1024.0;
1663                 speedchar[0] = 'M';
1664         }
1665         if (speed > 1024) {
1666                 speed = speed / 1024.0;
1667                 speedchar[0] = 'G';
1668         }
1669         g_message
1670             ("%d: Integrity %s test complete. Took %.3f seconds to complete, %.3f%sib/s",
1671              (int)getpid(), (testflags & TEST_WRITE) ? "write" : "read",
1672              timespan, speed, speedchar);
1673
1674         retval = 0;
1675
1676 err_open:
1677         if (close_sock) {
1678                 close_connection(sock, CONNECTION_CLOSE_PROPERLY);
1679         }
1680 err:
1681         if (size && blkhash)
1682                 munmap(blkhash, (size >> 9) * sizeof(struct blkitem));
1683
1684         if (blkhashfd != -1)
1685                 close(blkhashfd);
1686
1687         if (logfd != -1)
1688                 close(logfd);
1689
1690         if (blkhashname)
1691                 free(blkhashname);
1692
1693         if (*errstr)
1694                 g_warning("%s", errstr);
1695
1696         g_hash_table_destroy(handlehash);
1697
1698         return retval;
1699 }
1700
1701 void handle_nonopt(char *opt, gchar ** hostname, long int *p)
1702 {
1703         static int nonopt = 0;
1704
1705         switch (nonopt) {
1706         case 0:
1707                 *hostname = g_strdup(opt);
1708                 nonopt++;
1709                 break;
1710         case 1:
1711                 *p = (strtol(opt, NULL, 0));
1712                 if (*p == LONG_MIN || *p == LONG_MAX) {
1713                         g_critical("Could not parse port number: %s",
1714                                    strerror(errno));
1715                         exit(EXIT_FAILURE);
1716                 }
1717                 break;
1718         }
1719 }
1720
1721 typedef int (*testfunc) (char *, int, char, int);
1722
1723 int main(int argc, char **argv)
1724 {
1725         gchar *hostname = NULL, *unixsock = NULL;
1726         long int p = 10809;
1727         char *name = NULL;
1728         int sock = -1;
1729         int c;
1730         int testflags = 0;
1731         testfunc test = throughput_test;
1732
1733 #if HAVE_GNUTLS
1734         tlssession_init();
1735 #endif
1736
1737         /* Ignore SIGPIPE as we want to pick up the error from write() */
1738         signal(SIGPIPE, SIG_IGN);
1739
1740         errstr[errstr_len] = '\0';
1741
1742         if (argc < 3) {
1743                 g_message("%d: Not enough arguments", (int)getpid());
1744                 g_message("%d: Usage: %s <hostname> <port>", (int)getpid(),
1745                           argv[0]);
1746                 g_message("%d: Or: %s <hostname> -N <exportname> [<port>]",
1747                           (int)getpid(), argv[0]);
1748                 g_message("%d: Or: %s -u <unix socket> -N <exportname>",
1749                           (int)getpid(), argv[0]);
1750                 exit(EXIT_FAILURE);
1751         }
1752         logging(MY_NAME);
1753         while ((c = getopt(argc, argv, "FN:t:owfilu:hC:K:A:H:I")) >= 0) {
1754                 switch (c) {
1755                 case 1:
1756                         handle_nonopt(optarg, &hostname, &p);
1757                         break;
1758                 case 'N':
1759                         name = g_strdup(optarg);
1760                         break;
1761                 case 'F':
1762                         testflags |= TEST_EXPECT_ERROR;
1763                         break;
1764                 case 't':
1765                         transactionlog = g_strdup(optarg);
1766                         break;
1767                 case 'o':
1768                         test = oversize_test;
1769                         break;
1770                 case 'l':
1771                         looseordering = 1;
1772                         break;
1773                 case 'w':
1774                         testflags |= TEST_WRITE;
1775                         break;
1776                 case 'f':
1777                         testflags |= TEST_FLUSH;
1778                         break;
1779                 case 'I':
1780 #ifndef ISSERVER
1781                         err_nonfatal("inetd mode not supported without syslog support");
1782                         return 77;
1783 #endif
1784                         p = -1;
1785                         break;
1786                 case 'i':
1787                         test = integrity_test;
1788                         break;
1789                 case 'u':
1790                         unixsock = g_strdup(optarg);
1791                         break;
1792                 case 'h':
1793                         test = handshake_test;
1794                         testflags |= TEST_HANDSHAKE;
1795                         break;
1796 #if HAVE_GNUTLS
1797                 case 'C':
1798                         certfile=g_strdup(optarg);
1799                         break;
1800                 case 'K':
1801                         keyfile=g_strdup(optarg);
1802                         break;
1803                 case 'A':
1804                         cacertfile=g_strdup(optarg);
1805                         break;
1806                 case 'H':
1807                         tlshostname=g_strdup(optarg);
1808                         break;
1809 #else
1810                 case 'C':
1811                 case 'K':
1812                 case 'H':
1813                 case 'A':
1814                         g_warning("TLS support not compiled in");
1815                         /* Do not change this - looked for by test suite */
1816                         exit(77);
1817 #endif
1818                 }
1819         }
1820
1821         if (p != -1) {
1822                 while (optind < argc) {
1823                         handle_nonopt(argv[optind++], &hostname, &p);
1824                 }
1825         }
1826
1827         if (keyfile && !certfile)
1828                 certfile = g_strdup(keyfile);
1829
1830         if (!tlshostname && hostname)
1831                 tlshostname = g_strdup(hostname);
1832
1833         if (hostname != NULL) {
1834                 sock = setup_inet_connection(hostname, p);
1835         } else if (unixsock != NULL) {
1836                 sock = setup_unix_connection(unixsock);
1837         } else if (p == -1) {
1838                 sock = setup_inetd_connection(argv + optind);
1839         } else {
1840                 g_error("need a hostname, a unix domain socket or inetd-mode command line!");
1841                 return -1;
1842         }
1843
1844         if (sock == -1) {
1845                 g_warning("Could not establish a connection: %s", errstr);
1846                 exit(EXIT_FAILURE);
1847         }
1848
1849         if (test(name, sock, TRUE, testflags)
1850             < 0) {
1851                 g_warning("Could not run test: %s", errstr);
1852                 exit(EXIT_FAILURE);
1853         }
1854
1855         return 0;
1856 }