Imported Upstream version 3.0.30
[platform/upstream/gnutls.git] / tests / mini-dtls-rehandshake.c
1 /*
2  * Copyright (C) 2012 Free Software Foundation, Inc.
3  *
4  * Author: Nikos Mavrogiannopoulos
5  *
6  * This file is part of GnuTLS.
7  *
8  * GnuTLS is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 3 of the License, or
11  * (at your option) any later version.
12  *
13  * GnuTLS is distributed in the hope that it will be useful, but
14  * WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with GnuTLS; if not, write to the Free Software Foundation,
20  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22
23 #ifdef HAVE_CONFIG_H
24 #include <config.h>
25 #endif
26
27 #include <stdio.h>
28 #include <stdlib.h>
29
30 #if defined(_WIN32)
31
32 int main()
33 {
34   exit(77);
35 }
36
37 #else
38
39 #include <string.h>
40 #include <sys/types.h>
41 #include <netinet/in.h>
42 #include <sys/socket.h>
43 #include <sys/wait.h>
44 #include <arpa/inet.h>
45 #include <unistd.h>
46 #include <gnutls/gnutls.h>
47 #include <gnutls/dtls.h>
48
49 #include "utils.h"
50
51 static void terminate(void);
52
53 /* This program tests the rehandshake in DTLS
54  */
55
56 static void
57 server_log_func (int level, const char *str)
58 {
59   fprintf (stderr, "server|<%d>| %s", level, str);
60 }
61
62 static void
63 client_log_func (int level, const char *str)
64 {
65   fprintf (stderr, "client|<%d>| %s", level, str);
66 }
67
68 /* A very basic TLS client, with anonymous authentication.
69  */
70
71 #define MAX_BUF 1024
72 #define MSG "Hello TLS"
73
74 gnutls_session_t session;
75
76 static ssize_t
77 push (gnutls_transport_ptr_t tr, const void *data, size_t len)
78 {
79 int fd = (long int)tr;
80
81   return send(fd, data, len, 0);
82 }
83
84 static void
85 client (int fd, int server_init)
86 {
87   int ret;
88   char buffer[MAX_BUF + 1];
89   gnutls_anon_client_credentials_t anoncred;
90   /* Need to enable anonymous KX specifically. */
91
92   gnutls_global_init ();
93
94   if (debug)
95     {
96       gnutls_global_set_log_function (client_log_func);
97       gnutls_global_set_log_level (4711);
98     }
99
100   gnutls_anon_allocate_client_credentials (&anoncred);
101
102   /* Initialize TLS session
103    */
104   gnutls_init (&session, GNUTLS_CLIENT|GNUTLS_DATAGRAM);
105   gnutls_dtls_set_mtu( session, 1500);
106
107   /* Use default priorities */
108   gnutls_priority_set_direct (session, "NONE:+VERS-DTLS1.0:+CIPHER-ALL:+MAC-ALL:+SIGN-ALL:+COMP-ALL:+ANON-ECDH:+CURVE-ALL", NULL);
109
110   /* put the anonymous credentials to the current session
111    */
112   gnutls_credentials_set (session, GNUTLS_CRD_ANON, anoncred);
113
114   gnutls_transport_set_ptr (session, (gnutls_transport_ptr_t) fd);
115   gnutls_transport_set_push_function (session, push);
116
117   /* Perform the TLS handshake
118    */
119   do 
120     {
121       ret = gnutls_handshake (session);
122     }
123   while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
124
125   if (ret < 0)
126     {
127       fail ("client: Handshake failed\n");
128       gnutls_perror (ret);
129       exit(1);
130     }
131   else
132     {
133       if (debug)
134         success ("client: Handshake was completed\n");
135     }
136
137   if (debug)
138     success ("client: TLS version is: %s\n",
139              gnutls_protocol_get_name (gnutls_protocol_get_version
140                                        (session)));
141
142   if (!server_init)
143     {
144       if (debug) success("Initiating client rehandshake\n");
145       do 
146         {
147           ret = gnutls_handshake (session);
148         }
149       while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
150
151       if (ret < 0)
152         {
153           fail ("2nd client gnutls_handshake: %s\n", gnutls_strerror(ret));
154           terminate();
155         }
156     }
157   else
158     {
159       do {
160         ret = gnutls_record_recv (session, buffer, MAX_BUF);
161       } while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
162     }
163
164   if (ret == 0)
165     {
166       if (debug)
167         success ("client: Peer has closed the TLS connection\n");
168       goto end;
169     }
170   else if (ret < 0)
171     {
172       if (server_init && ret == GNUTLS_E_REHANDSHAKE)
173         {
174           if (debug) success("Initiating rehandshake due to server request\n");
175           do 
176             {
177               ret = gnutls_handshake (session);
178             }
179           while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
180         }
181
182       if (ret != 0)
183         {
184           fail ("client: Error: %s\n", gnutls_strerror (ret));
185           exit(1);
186         }
187     }
188
189   do {
190     ret = gnutls_record_send (session, MSG, strlen (MSG));
191   } while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
192   gnutls_bye (session, GNUTLS_SHUT_WR);
193
194 end:
195
196   close (fd);
197
198   gnutls_deinit (session);
199
200   gnutls_anon_free_client_credentials (anoncred);
201
202   gnutls_global_deinit ();
203 }
204
205
206 /* These are global */
207 gnutls_anon_server_credentials_t anoncred;
208 pid_t child;
209
210 static gnutls_session_t
211 initialize_tls_session (void)
212 {
213   gnutls_session_t session;
214
215   gnutls_init (&session, GNUTLS_SERVER|GNUTLS_DATAGRAM);
216   gnutls_dtls_set_mtu( session, 1500);
217
218   /* avoid calling all the priority functions, since the defaults
219    * are adequate.
220    */
221   gnutls_priority_set_direct (session, "NONE:+VERS-DTLS1.0:+CIPHER-ALL:+MAC-ALL:+SIGN-ALL:+COMP-ALL:+ANON-ECDH:+CURVE-ALL", NULL);
222
223   gnutls_credentials_set (session, GNUTLS_CRD_ANON, anoncred);
224
225   return session;
226 }
227
228 static void terminate(void)
229 {
230 int status;
231
232   kill(child, SIGTERM);
233   wait(&status);
234   exit(1);
235 }
236
237 static void
238 server (int fd, int server_init)
239 {
240 int ret;
241 char buffer[MAX_BUF + 1];
242   /* this must be called once in the program
243    */
244   gnutls_global_init ();
245
246   if (debug)
247     {
248       gnutls_global_set_log_function (server_log_func);
249       gnutls_global_set_log_level (4711);
250     }
251
252   gnutls_anon_allocate_server_credentials (&anoncred);
253
254   session = initialize_tls_session ();
255
256   gnutls_transport_set_ptr (session, (gnutls_transport_ptr_t) fd);
257   gnutls_transport_set_push_function (session, push);
258
259   do 
260     {
261       ret = gnutls_handshake (session);
262     }
263   while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
264   if (ret < 0)
265     {
266       close (fd);
267       gnutls_deinit (session);
268       fail ("server: Handshake has failed (%s)\n\n", gnutls_strerror (ret));
269       terminate();
270     }
271   if (debug)
272     success ("server: Handshake was completed\n");
273
274   if (debug)
275     success ("server: TLS version is: %s\n",
276              gnutls_protocol_get_name (gnutls_protocol_get_version
277                                        (session)));
278
279   /* see the Getting peer's information example */
280   /* print_info(session); */
281
282   if (server_init)
283     {
284       if (debug) success("server: Sending dummy packet\n");
285       ret = gnutls_rehandshake(session);
286       if (ret < 0)
287         {
288           fail ("gnutls_rehandshake: %s\n", gnutls_strerror(ret));
289           terminate();
290         }
291
292       if (debug) success("server: Initiating rehandshake\n");
293       do 
294         {
295           ret = gnutls_handshake (session);
296         }
297       while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
298
299       if (ret < 0)
300         {
301           fail ("server: 2nd gnutls_handshake: %s\n", gnutls_strerror(ret));
302           terminate();
303         }
304     }
305
306   for (;;)
307     {
308       memset (buffer, 0, MAX_BUF + 1);
309
310       do {
311         ret = gnutls_record_recv (session, buffer, MAX_BUF);
312       } while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
313
314       if (ret == 0)
315         {
316           if (debug)
317             success ("server: Peer has closed the GnuTLS connection\n");
318           break;
319         }
320       else if (ret < 0)
321         {
322           if (!server_init && ret == GNUTLS_E_REHANDSHAKE)
323             {
324               if (debug) success("Initiating rehandshake due to client request\n");
325               do 
326                 {
327                   ret = gnutls_handshake (session);
328                 }
329               while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
330               if (ret == 0) break;
331             }
332
333           fail ("server: Received corrupted data(%s). Closing...\n", gnutls_strerror(ret));
334           terminate();
335         }
336       else if (ret > 0)
337         {
338           /* echo data back to the client
339            */
340           do {
341             ret = gnutls_record_send (session, buffer, strlen (buffer));
342           } while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED);
343         }
344     }
345   
346
347   /* do not wait for the peer to close the connection.
348    */
349   gnutls_bye (session, GNUTLS_SHUT_WR);
350
351   close (fd);
352   gnutls_deinit (session);
353
354   gnutls_anon_free_server_credentials (anoncred);
355
356   gnutls_global_deinit ();
357
358   if (debug)
359     success ("server: finished\n");
360 }
361
362 static void start (int server_initiated)
363 {
364   int fd[2];
365   int ret;
366   
367   ret = socketpair(AF_UNIX, SOCK_DGRAM, 0, fd);
368   if (ret < 0)
369     {
370       perror("socketpair");
371       exit(1);
372     }
373
374   child = fork ();
375   if (child < 0)
376     {
377       perror ("fork");
378       fail ("fork");
379       exit(1);
380     }
381
382   if (child)
383     {
384       int status;
385       /* parent */
386
387       server (fd[0], server_initiated);
388       wait (&status);
389       if (WEXITSTATUS(status) != 0)
390         fail("Child died with status %d\n", WEXITSTATUS(status));
391     }
392   else 
393     {
394       close(fd[0]);
395       client (fd[1], server_initiated);
396       exit(0);
397     }
398 }
399
400 void
401 doit (void)
402 {
403   start(0);
404   start(1);
405 }
406
407 #endif /* _WIN32 */