Imported Upstream version 3.3.5
[platform/upstream/gnutls.git] / tests / mini-dtls-rehandshake.c
1 /*
2  * Copyright (C) 2012 Free Software Foundation, Inc.
3  *
4  * Author: Nikos Mavrogiannopoulos
5  *
6  * This file is part of GnuTLS.
7  *
8  * GnuTLS is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 3 of the License, or
11  * (at your option) any later version.
12  *
13  * GnuTLS is distributed in the hope that it will be useful, but
14  * WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with GnuTLS; if not, write to the Free Software Foundation,
20  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22
23 #ifdef HAVE_CONFIG_H
24 #include <config.h>
25 #endif
26
27 #include <stdio.h>
28 #include <stdlib.h>
29
30 #if defined(_WIN32)
31
32 int main()
33 {
34         exit(77);
35 }
36
37 #else
38
39 #include <string.h>
40 #include <sys/types.h>
41 #include <netinet/in.h>
42 #include <sys/socket.h>
43 #include <sys/wait.h>
44 #include <arpa/inet.h>
45 #include <unistd.h>
46 #include <gnutls/gnutls.h>
47 #include <gnutls/dtls.h>
48
49 #include "utils.h"
50
51 static void terminate(void);
52
53 /* This program tests the rehandshake in DTLS
54  */
55
56 static void server_log_func(int level, const char *str)
57 {
58         fprintf(stderr, "server|<%d>| %s", level, str);
59 }
60
61 static void client_log_func(int level, const char *str)
62 {
63         fprintf(stderr, "client|<%d>| %s", level, str);
64 }
65
66 /* A very basic TLS client, with anonymous authentication.
67  */
68
69 #define MAX_BUF 1024
70 #define MSG "Hello TLS"
71
72 gnutls_session_t session;
73
74 static ssize_t
75 push(gnutls_transport_ptr_t tr, const void *data, size_t len)
76 {
77         int fd = (long int) tr;
78
79         return send(fd, data, len, 0);
80 }
81
82 static void client(int fd, int server_init)
83 {
84         int ret;
85         char buffer[MAX_BUF + 1];
86         gnutls_anon_client_credentials_t anoncred;
87         /* Need to enable anonymous KX specifically. */
88
89         global_init();
90
91         if (debug) {
92                 gnutls_global_set_log_function(client_log_func);
93                 gnutls_global_set_log_level(4711);
94         }
95
96         gnutls_anon_allocate_client_credentials(&anoncred);
97
98         /* Initialize TLS session
99          */
100         gnutls_init(&session, GNUTLS_CLIENT | GNUTLS_DATAGRAM);
101         gnutls_dtls_set_mtu(session, 1500);
102
103         /* Use default priorities */
104         gnutls_priority_set_direct(session,
105                                    "NONE:+VERS-DTLS1.0:+CIPHER-ALL:+MAC-ALL:+SIGN-ALL:+COMP-ALL:+ANON-ECDH:+CURVE-ALL",
106                                    NULL);
107
108         /* put the anonymous credentials to the current session
109          */
110         gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
111
112         gnutls_transport_set_int(session, fd);
113         gnutls_transport_set_push_function(session, push);
114
115         /* Perform the TLS handshake
116          */
117         do {
118                 ret = gnutls_handshake(session);
119         }
120         while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
121
122         if (ret < 0) {
123                 fail("client: Handshake failed\n");
124                 gnutls_perror(ret);
125                 exit(1);
126         } else {
127                 if (debug)
128                         success("client: Handshake was completed\n");
129         }
130
131         if (debug)
132                 success("client: TLS version is: %s\n",
133                         gnutls_protocol_get_name
134                         (gnutls_protocol_get_version(session)));
135
136         if (!server_init) {
137                 sleep(60);
138                 if (debug)
139                         success("Initiating client rehandshake\n");
140                 do {
141                         ret = gnutls_handshake(session);
142                 }
143                 while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
144
145                 if (ret < 0) {
146                         fail("2nd client gnutls_handshake: %s\n",
147                              gnutls_strerror(ret));
148                         terminate();
149                 }
150         } else {
151                 do {
152                         ret = gnutls_record_recv(session, buffer, MAX_BUF);
153                 } while (ret == GNUTLS_E_AGAIN
154                          || ret == GNUTLS_E_INTERRUPTED);
155         }
156
157         if (ret == 0) {
158                 if (debug)
159                         success
160                             ("client: Peer has closed the TLS connection\n");
161                 goto end;
162         } else if (ret < 0) {
163                 if (server_init && ret == GNUTLS_E_REHANDSHAKE) {
164                         if (debug)
165                                 success
166                                     ("Initiating rehandshake due to server request\n");
167                         do {
168                                 ret = gnutls_handshake(session);
169                         }
170                         while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
171                 }
172
173                 if (ret != 0) {
174                         fail("client: Error: %s\n", gnutls_strerror(ret));
175                         exit(1);
176                 }
177         }
178
179         do {
180                 ret = gnutls_record_send(session, MSG, strlen(MSG));
181         } while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
182         gnutls_bye(session, GNUTLS_SHUT_WR);
183
184       end:
185
186         close(fd);
187
188         gnutls_deinit(session);
189
190         gnutls_anon_free_client_credentials(anoncred);
191
192         gnutls_global_deinit();
193 }
194
195
196 /* These are global */
197 gnutls_anon_server_credentials_t anoncred;
198 pid_t child;
199
200 static gnutls_session_t initialize_tls_session(void)
201 {
202         gnutls_session_t session;
203
204         gnutls_init(&session, GNUTLS_SERVER | GNUTLS_DATAGRAM);
205         gnutls_dtls_set_mtu(session, 1500);
206
207         /* avoid calling all the priority functions, since the defaults
208          * are adequate.
209          */
210         gnutls_priority_set_direct(session,
211                                    "NONE:+VERS-DTLS1.0:+CIPHER-ALL:+MAC-ALL:+SIGN-ALL:+COMP-ALL:+ANON-ECDH:+CURVE-ALL",
212                                    NULL);
213
214         gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
215
216         return session;
217 }
218
219 static void terminate(void)
220 {
221         int status;
222
223         kill(child, SIGTERM);
224         wait(&status);
225         exit(1);
226 }
227
228 static void server(int fd, int server_init)
229 {
230         int ret;
231         char buffer[MAX_BUF + 1];
232         /* this must be called once in the program
233          */
234         global_init();
235
236         if (debug) {
237                 gnutls_global_set_log_function(server_log_func);
238                 gnutls_global_set_log_level(4711);
239         }
240
241         gnutls_anon_allocate_server_credentials(&anoncred);
242
243         session = initialize_tls_session();
244
245         gnutls_transport_set_int(session, fd);
246         gnutls_transport_set_push_function(session, push);
247
248         do {
249                 ret = gnutls_handshake(session);
250         }
251         while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
252         if (ret < 0) {
253                 close(fd);
254                 gnutls_deinit(session);
255                 fail("server: Handshake has failed (%s)\n\n",
256                      gnutls_strerror(ret));
257                 terminate();
258         }
259         if (debug)
260                 success("server: Handshake was completed\n");
261
262         if (debug)
263                 success("server: TLS version is: %s\n",
264                         gnutls_protocol_get_name
265                         (gnutls_protocol_get_version(session)));
266
267         /* see the Getting peer's information example */
268         /* print_info(session); */
269
270         if (server_init) {
271                 if (debug)
272                         success("server: Sending dummy packet\n");
273                 ret = gnutls_rehandshake(session);
274                 if (ret < 0) {
275                         fail("gnutls_rehandshake: %s\n",
276                              gnutls_strerror(ret));
277                         terminate();
278                 }
279
280                 if (debug)
281                         success("server: Initiating rehandshake\n");
282                 do {
283                         ret = gnutls_handshake(session);
284                 }
285                 while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
286
287                 if (ret < 0) {
288                         fail("server: 2nd gnutls_handshake: %s\n",
289                              gnutls_strerror(ret));
290                         terminate();
291                 }
292         }
293
294         for (;;) {
295                 memset(buffer, 0, MAX_BUF + 1);
296
297                 do {
298                         ret = gnutls_record_recv(session, buffer, MAX_BUF);
299                 } while (ret == GNUTLS_E_AGAIN
300                          || ret == GNUTLS_E_INTERRUPTED);
301
302                 if (ret == 0) {
303                         if (debug)
304                                 success
305                                     ("server: Peer has closed the GnuTLS connection\n");
306                         break;
307                 } else if (ret < 0) {
308                         if (!server_init && ret == GNUTLS_E_REHANDSHAKE) {
309                                 if (debug)
310                                         success
311                                             ("Initiating rehandshake due to client request\n");
312                                 do {
313                                         ret = gnutls_handshake(session);
314                                 }
315                                 while (ret < 0
316                                        && gnutls_error_is_fatal(ret) == 0);
317                                 if (ret == 0)
318                                         break;
319                         }
320
321                         fail("server: Received corrupted data(%s). Closing...\n", gnutls_strerror(ret));
322                         terminate();
323                 } else if (ret > 0) {
324                         /* echo data back to the client
325                          */
326                         do {
327                                 ret =
328                                     gnutls_record_send(session, buffer,
329                                                        strlen(buffer));
330                         } while (ret == GNUTLS_E_AGAIN
331                                  || ret == GNUTLS_E_INTERRUPTED);
332                 }
333         }
334
335
336         /* do not wait for the peer to close the connection.
337          */
338         gnutls_bye(session, GNUTLS_SHUT_WR);
339
340         close(fd);
341         gnutls_deinit(session);
342
343         gnutls_anon_free_server_credentials(anoncred);
344
345         gnutls_global_deinit();
346
347         if (debug)
348                 success("server: finished\n");
349 }
350
351 static void start(int server_initiated)
352 {
353         int fd[2];
354         int ret;
355
356         ret = socketpair(AF_UNIX, SOCK_STREAM, 0, fd);
357         if (ret < 0) {
358                 perror("socketpair");
359                 exit(1);
360         }
361
362         child = fork();
363         if (child < 0) {
364                 perror("fork");
365                 fail("fork");
366                 exit(1);
367         }
368
369         if (child) {
370                 int status;
371                 /* parent */
372
373                 server(fd[0], server_initiated);
374                 wait(&status);
375                 if (WEXITSTATUS(status) != 0)
376                         fail("Child died with status %d\n",
377                              WEXITSTATUS(status));
378         } else {
379                 close(fd[0]);
380                 client(fd[1], server_initiated);
381                 exit(0);
382         }
383 }
384
385 void doit(void)
386 {
387         start(0);
388         start(1);
389 }
390
391 #endif                          /* _WIN32 */