Imported Upstream version 3.3.5
[platform/upstream/gnutls.git] / tests / mini-dtls-record.c
1 /*
2  * Copyright (C) 2012-2013 Free Software Foundation, Inc.
3  * Copyright (C) 2013 Nikos Mavrogiannopoulos
4  *
5  * Author: Nikos Mavrogiannopoulos
6  *
7  * This file is part of GnuTLS.
8  *
9  * GnuTLS is free software; you can redistribute it and/or modify it
10  * under the terms of the GNU General Public License as published by
11  * the Free Software Foundation; either version 3 of the License, or
12  * (at your option) any later version.
13  *
14  * GnuTLS is distributed in the hope that it will be useful, but
15  * WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * General Public License for more details.
18  *
19  * You should have received a copy of the GNU General Public License
20  * along with GnuTLS; if not, write to the Free Software Foundation,
21  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
22  */
23
24 #ifdef HAVE_CONFIG_H
25 #include <config.h>
26 #endif
27
28 #include <stdio.h>
29 #include <stdlib.h>
30
31 #if defined(_WIN32)
32
33 int main()
34 {
35         exit(77);
36 }
37
38 #else
39
40 #include <string.h>
41 #include <errno.h>
42 #include <sys/types.h>
43 #include <netinet/in.h>
44 #include <sys/socket.h>
45 #include <sys/wait.h>
46 #include <arpa/inet.h>
47 #include <unistd.h>
48 #include <gnutls/gnutls.h>
49 #include <gnutls/dtls.h>
50
51 #include "utils.h"
52
53 static int test_finished = 0;
54 static void terminate(void);
55
56 /* This program tests whether messages in DTLS are received
57  * with the expected sequence. That is whether the message
58  * sequence numbers returned correspond to the received messages.
59  */
60
61 /*
62 static void
63 tls_audit_log_func (gnutls_session_t session, const char *str)
64 {
65   fprintf (stderr, "|<%p>| %s", session, str);
66 }
67 */
68
69 static void server_log_func(int level, const char *str)
70 {
71         fprintf(stderr, "server|<%d>| %s", level, str);
72 }
73
74 static void client_log_func(int level, const char *str)
75 {
76         fprintf(stderr, "client|<%d>| %s", level, str);
77 }
78
79 /* These are global */
80 static pid_t child;
81
82 /* A test client/server app for DTLS duplicate packet detection.
83  */
84
85 #define MAX_BUF 1024
86
87 #define MAX_SEQ 128
88
89 static int msg_seq[] =
90     { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 16, 5, 32, 11, 11, 11, 11, 12,
91         10, 13, 14, 15, 16, 17, 19, 20, 18, 22, 24, 23, 25, 26, 27, 29, 28,
92         29, 29, 30, 31, 32, 33, 34, 35, 37, 36, 38, 39, 42, 37, 40, 41, 41,
93         -1
94 };
95
96 static unsigned int current = 0;
97 static unsigned int pos = 0;
98
99 unsigned char *stored_messages[MAX_SEQ];
100 unsigned int stored_sizes[MAX_SEQ];
101
102 static ssize_t odd_push(gnutls_transport_ptr_t tr, const void *data, size_t len)
103 {
104         ssize_t ret;
105         unsigned i;
106
107         if (msg_seq[current] == -1 || test_finished != 0) {
108                 test_finished = 1;
109                 return len;
110         }
111
112         stored_messages[current] = malloc(len);
113         memcpy(stored_messages[current], data, len);
114         stored_sizes[current] = len;
115
116         if (pos != current) {
117                 for (i = pos; i <= current; i++) {
118                         if (stored_messages[msg_seq[i]] != NULL) {
119                                 do {
120
121                                         ret =
122                                             send((long int)tr,
123                                                  stored_messages[msg_seq
124                                                                  [i]],
125                                                  stored_sizes[msg_seq[i]], 0);
126                                 }
127                                 while (ret == -1 && errno == EAGAIN);
128                                 pos++;
129                         } else
130                                 break;
131                 }
132         } else if (msg_seq[current] == (int)current) {
133                 do {
134                         ret = send((long int)tr, data, len, 0);
135                 }
136                 while (ret == -1 && errno == EAGAIN);
137
138                 current++;
139                 pos++;
140
141                 return ret;
142         } else if (stored_messages[msg_seq[current]] != NULL) {
143                 do {
144                         ret =
145                             send((long int)tr,
146                                  stored_messages[msg_seq[current]],
147                                  stored_sizes[msg_seq[current]], 0);
148                 }
149                 while (ret == -1 && errno == EAGAIN);
150                 current++;
151                 pos++;
152                 return ret;
153         }
154
155         current++;
156
157         return len;
158 }
159
160 static ssize_t n_push(gnutls_transport_ptr_t tr, const void *data, size_t len)
161 {
162         return send((unsigned long)tr, data, len, 0);
163 }
164
165 /* The first five messages are handshake. Thus corresponds to msg_seq+5 */
166 static int recv_msg_seq[] =
167     { 1, 2, 3, 4, 5, 6, 12, 28, 7, 8, 9, 10, 11, 13, 15, 16, 14, 18, 20,
168         19, 21, 22, 23, 25, 24, 26, 27, 29, 30, 31, 33, 32, 34, 35, 38, 36, 37,
169             -1
170 };
171
172 static void client(int fd)
173 {
174         gnutls_session_t session;
175         int ret;
176         char buffer[MAX_BUF + 1];
177         gnutls_anon_client_credentials_t anoncred;
178         unsigned char seq[8];
179         uint64_t useq;
180         unsigned current = 0;
181
182         memset(buffer, 0, sizeof(buffer));
183
184         /* Need to enable anonymous KX specifically. */
185
186 /*    gnutls_global_set_audit_log_function (tls_audit_log_func); */
187         global_init();
188
189         if (debug) {
190                 gnutls_global_set_log_function(client_log_func);
191                 gnutls_global_set_log_level(2);
192         }
193
194         gnutls_anon_allocate_client_credentials(&anoncred);
195
196         /* Initialize TLS session
197          */
198         gnutls_init(&session, GNUTLS_CLIENT | GNUTLS_DATAGRAM);
199         gnutls_heartbeat_enable(session, GNUTLS_HB_PEER_ALLOWED_TO_SEND);
200         gnutls_dtls_set_mtu(session, 1500);
201
202         /* Use default priorities */
203         gnutls_priority_set_direct(session,
204                                    "NONE:+VERS-DTLS1.0:+CIPHER-ALL:+MAC-ALL:+SIGN-ALL:+COMP-ALL:+ANON-ECDH:+CURVE-ALL",
205                                    NULL);
206
207         /* put the anonymous credentials to the current session
208          */
209         gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
210
211         gnutls_transport_set_int(session, fd);
212
213         /* Perform the TLS handshake
214          */
215         do {
216                 ret = gnutls_handshake(session);
217         }
218         while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
219
220         if (ret < 0) {
221                 fail("client: Handshake failed\n");
222                 gnutls_perror(ret);
223                 exit(1);
224         } else {
225                 if (debug)
226                         success("client: Handshake was completed\n");
227         }
228
229         gnutls_record_send(session, buffer, 1);
230
231         if (debug)
232                 success("client: DTLS version is: %s\n",
233                         gnutls_protocol_get_name
234                         (gnutls_protocol_get_version(session)));
235         do {
236                 ret =
237                     gnutls_record_recv_seq(session, buffer, sizeof(buffer),
238                                            seq);
239
240                 if (ret > 0) {
241                         useq =
242                             seq[3] | (seq[2] << 8) | (seq[1] << 16) |
243                             (seq[0] << 24);
244                         useq <<= 32;
245                         useq |=
246                             seq[7] | (seq[6] << 8) | (seq[5] << 16) |
247                             (seq[4] << 24);
248
249                         if (recv_msg_seq[current] == -1) {
250                                 fail("received message sequence differs\n");
251                                 terminate();
252                         }
253
254                         if ((uint32_t) recv_msg_seq[current] != (uint32_t) useq) {
255                                 fail("received message sequence differs (got: %u, expected: %u)\n", (unsigned)useq, (unsigned)recv_msg_seq[current]);
256                                 terminate();
257                         }
258
259                         current++;
260                 }
261         }
262         while ((ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED
263                 || ret > 0));
264
265         gnutls_bye(session, GNUTLS_SHUT_WR);
266
267         close(fd);
268
269         gnutls_deinit(session);
270
271         gnutls_anon_free_client_credentials(anoncred);
272
273         gnutls_global_deinit();
274 }
275
276 static void terminate(void)
277 {
278         int status;
279
280         kill(child, SIGTERM);
281         wait(&status);
282         exit(1);
283 }
284
285 static void server(int fd)
286 {
287         int ret;
288         gnutls_session_t session;
289         gnutls_anon_server_credentials_t anoncred;
290         char c;
291
292         global_init();
293
294         if (debug) {
295                 gnutls_global_set_log_function(server_log_func);
296                 gnutls_global_set_log_level(2);
297         }
298
299         gnutls_anon_allocate_server_credentials(&anoncred);
300
301         gnutls_init(&session, GNUTLS_SERVER | GNUTLS_DATAGRAM);
302         gnutls_transport_set_push_function(session, odd_push);
303         gnutls_heartbeat_enable(session, GNUTLS_HB_PEER_ALLOWED_TO_SEND);
304         gnutls_dtls_set_mtu(session, 1500);
305
306         /* avoid calling all the priority functions, since the defaults
307          * are adequate.
308          */
309         gnutls_priority_set_direct(session,
310                                    "NONE:+VERS-DTLS1.0:+CIPHER-ALL:+MAC-ALL:+SIGN-ALL:+COMP-ALL:+ANON-ECDH:+CURVE-ALL",
311                                    NULL);
312         gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
313
314         gnutls_transport_set_int(session, fd);
315
316         do {
317                 ret = gnutls_handshake(session);
318         }
319         while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
320         if (ret < 0) {
321                 close(fd);
322                 gnutls_deinit(session);
323                 fail("server: Handshake has failed (%s)\n\n",
324                      gnutls_strerror(ret));
325                 terminate();
326         }
327         if (debug)
328                 success("server: Handshake was completed\n");
329
330         if (debug)
331                 success("server: TLS version is: %s\n",
332                         gnutls_protocol_get_name
333                         (gnutls_protocol_get_version(session)));
334
335         gnutls_record_recv(session, &c, 1);
336         do {
337                 do {
338                         ret = gnutls_record_send(session, &c, 1);
339                 }
340                 while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
341
342                 if (ret < 0) {
343                         fail("send: %s\n", gnutls_strerror(ret));
344                         terminate();
345                 }
346         }
347         while (test_finished == 0);
348
349         gnutls_transport_set_push_function(session, n_push);
350         do {
351                 ret = gnutls_bye(session, GNUTLS_SHUT_WR);
352         }
353         while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
354
355         close(fd);
356         gnutls_deinit(session);
357
358         gnutls_anon_free_server_credentials(anoncred);
359
360         gnutls_global_deinit();
361
362         if (debug)
363                 success("server: finished\n");
364 }
365
366 static void start(void)
367 {
368         int fd[2];
369         int ret;
370
371         ret = socketpair(AF_UNIX, SOCK_STREAM, 0, fd);
372         if (ret < 0) {
373                 perror("socketpair");
374                 exit(1);
375         }
376
377         child = fork();
378         if (child < 0) {
379                 perror("fork");
380                 fail("fork");
381                 exit(1);
382         }
383
384         if (child) {
385                 int status;
386                 /* parent */
387                 close(fd[1]);
388                 server(fd[0]);
389                 wait(&status);
390                 if (WEXITSTATUS(status) != 0)
391                         fail("Child died with status %d\n",
392                              WEXITSTATUS(status));
393         } else {
394                 close(fd[0]);
395                 client(fd[1]);
396                 exit(0);
397         }
398 }
399
400 void doit(void)
401 {
402         start();
403 }
404
405 #endif                          /* _WIN32 */