Imported Upstream version 3.3.5
[platform/upstream/gnutls.git] / tests / mini-dtls-heartbeat.c
1 /*
2  * Copyright (C) 2012 Free Software Foundation, Inc.
3  *
4  * Author: Nikos Mavrogiannopoulos
5  *
6  * This file is part of GnuTLS.
7  *
8  * GnuTLS is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 3 of the License, or
11  * (at your option) any later version.
12  *
13  * GnuTLS is distributed in the hope that it will be useful, but
14  * WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with GnuTLS; if not, write to the Free Software Foundation,
20  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22
23 #ifdef HAVE_CONFIG_H
24 #include <config.h>
25 #endif
26
27 #include <stdio.h>
28 #include <stdlib.h>
29
30 #if defined(_WIN32) || !defined(ENABLE_HEARTBEAT)
31
32 int main()
33 {
34         exit(77);
35 }
36
37 #else
38
39 #include <string.h>
40 #include <sys/types.h>
41 #include <netinet/in.h>
42 #include <sys/socket.h>
43 #include <sys/wait.h>
44 #include <arpa/inet.h>
45 #include <unistd.h>
46 #include <gnutls/gnutls.h>
47 #include <gnutls/dtls.h>
48
49 #include "utils.h"
50
51 static void terminate(void);
52
53 /* This program tests the rehandshake in DTLS
54  */
55
56 static void server_log_func(int level, const char *str)
57 {
58         fprintf(stderr, "server|<%d>| %s", level, str);
59 }
60
61 static void client_log_func(int level, const char *str)
62 {
63         fprintf(stderr, "client|<%d>| %s", level, str);
64 }
65
66 /* These are global */
67 static pid_t child;
68
69 /* A very basic DTLS client, with anonymous authentication, that exchanges heartbeats.
70  */
71
72 #define MAX_BUF 1024
73
74
75 static void client(int fd, int server_init)
76 {
77         gnutls_session_t session;
78         int ret, ret2;
79         char buffer[MAX_BUF + 1];
80         gnutls_anon_client_credentials_t anoncred;
81         /* Need to enable anonymous KX specifically. */
82
83         global_init();
84
85         if (debug) {
86                 gnutls_global_set_log_function(client_log_func);
87                 gnutls_global_set_log_level(4711);
88         }
89
90         gnutls_anon_allocate_client_credentials(&anoncred);
91
92         /* Initialize TLS session
93          */
94         gnutls_init(&session, GNUTLS_CLIENT | GNUTLS_DATAGRAM);
95         gnutls_heartbeat_enable(session, GNUTLS_HB_PEER_ALLOWED_TO_SEND);
96         gnutls_dtls_set_mtu(session, 1500);
97
98         /* Use default priorities */
99         gnutls_priority_set_direct(session,
100                                    "NONE:+VERS-DTLS1.0:+CIPHER-ALL:+MAC-ALL:+SIGN-ALL:+COMP-ALL:+ANON-ECDH:+CURVE-ALL",
101                                    NULL);
102
103         /* put the anonymous credentials to the current session
104          */
105         gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
106
107         gnutls_transport_set_int(session, fd);
108
109         /* Perform the TLS handshake
110          */
111         do {
112                 ret = gnutls_handshake(session);
113         }
114         while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
115
116         if (ret < 0) {
117                 fail("client: Handshake failed\n");
118                 gnutls_perror(ret);
119                 exit(1);
120         } else {
121                 if (debug)
122                         success("client: Handshake was completed\n");
123         }
124
125         if (debug)
126                 success("client: DTLS version is: %s\n",
127                         gnutls_protocol_get_name
128                         (gnutls_protocol_get_version(session)));
129
130         if (!server_init) {
131                 do {
132                         ret =
133                             gnutls_record_recv(session, buffer,
134                                                sizeof(buffer));
135
136                         if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) {
137                                 if (debug)
138                                         success
139                                             ("Ping received. Replying with pong.\n");
140                                 ret2 = gnutls_heartbeat_pong(session, 0);
141                                 if (ret2 < 0) {
142                                         fail("pong: %s\n",
143                                              gnutls_strerror(ret));
144                                         terminate();
145                                 }
146                         }
147                 }
148                 while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED
149                        || ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED);
150
151                 if (ret < 0) {
152                         fail("recv: %s\n", gnutls_strerror(ret));
153                         terminate();
154                 }
155         } else {
156                 do {
157                         ret =
158                             gnutls_heartbeat_ping(session, 256, 5,
159                                                   GNUTLS_HEARTBEAT_WAIT);
160
161                         if (debug)
162                                 success("Ping sent.\n");
163                 }
164                 while (ret == GNUTLS_E_AGAIN
165                        || ret == GNUTLS_E_INTERRUPTED);
166
167                 if (ret < 0) {
168                         fail("ping: %s\n", gnutls_strerror(ret));
169                         terminate();
170                 }
171         }
172
173         gnutls_bye(session, GNUTLS_SHUT_WR);
174
175         close(fd);
176
177         gnutls_deinit(session);
178
179         gnutls_anon_free_client_credentials(anoncred);
180
181         gnutls_global_deinit();
182 }
183
184
185
186 static gnutls_session_t initialize_tls_session(void)
187 {
188         gnutls_session_t session;
189
190         gnutls_init(&session, GNUTLS_SERVER | GNUTLS_DATAGRAM);
191         gnutls_heartbeat_enable(session, GNUTLS_HB_PEER_ALLOWED_TO_SEND);
192         gnutls_dtls_set_mtu(session, 1500);
193
194         /* avoid calling all the priority functions, since the defaults
195          * are adequate.
196          */
197         gnutls_priority_set_direct(session,
198                                    "NONE:+VERS-DTLS1.0:+CIPHER-ALL:+MAC-ALL:+SIGN-ALL:+COMP-ALL:+ANON-ECDH:+CURVE-ALL",
199                                    NULL);
200
201         return session;
202 }
203
204 static void terminate(void)
205 {
206         int status;
207
208         kill(child, SIGTERM);
209         wait(&status);
210         exit(1);
211 }
212
213 static void server(int fd, int server_init)
214 {
215         int ret, ret2;
216         char buffer[MAX_BUF + 1];
217         gnutls_session_t session;
218         gnutls_anon_server_credentials_t anoncred;
219         /* this must be called once in the program
220          */
221         global_init();
222
223         if (debug) {
224                 gnutls_global_set_log_function(server_log_func);
225                 gnutls_global_set_log_level(4711);
226         }
227
228         gnutls_anon_allocate_server_credentials(&anoncred);
229
230         session = initialize_tls_session();
231         gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
232
233         gnutls_transport_set_int(session, fd);
234
235         do {
236                 ret = gnutls_handshake(session);
237         }
238         while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
239         if (ret < 0) {
240                 close(fd);
241                 gnutls_deinit(session);
242                 fail("server: Handshake has failed (%s)\n\n",
243                      gnutls_strerror(ret));
244                 terminate();
245         }
246         if (debug)
247                 success("server: Handshake was completed\n");
248
249         if (debug)
250                 success("server: TLS version is: %s\n",
251                         gnutls_protocol_get_name
252                         (gnutls_protocol_get_version(session)));
253
254         /* see the Getting peer's information example */
255         /* print_info(session); */
256
257         if (server_init) {
258                 do {
259                         ret =
260                             gnutls_record_recv(session, buffer,
261                                                sizeof(buffer));
262
263                         if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) {
264                                 if (debug)
265                                         success
266                                             ("Ping received. Replying with pong.\n");
267                                 ret2 = gnutls_heartbeat_pong(session, 0);
268                                 if (ret2 < 0) {
269                                         fail("pong: %s\n",
270                                              gnutls_strerror(ret));
271                                         terminate();
272                                 }
273                         }
274                 }
275                 while (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED
276                        || ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED);
277         } else {
278                 do {
279                         ret =
280                             gnutls_heartbeat_ping(session, 256, 5,
281                                                   GNUTLS_HEARTBEAT_WAIT);
282
283                         if (debug)
284                                 success("Ping sent.\n");
285                 }
286                 while (ret == GNUTLS_E_AGAIN
287                        || ret == GNUTLS_E_INTERRUPTED);
288
289                 if (ret < 0) {
290                         fail("ping: %s\n", gnutls_strerror(ret));
291                         terminate();
292                 }
293         }
294
295         /* do not wait for the peer to close the connection.
296          */
297         gnutls_bye(session, GNUTLS_SHUT_WR);
298
299         close(fd);
300         gnutls_deinit(session);
301
302         gnutls_anon_free_server_credentials(anoncred);
303
304         gnutls_global_deinit();
305
306         if (debug)
307                 success("server: finished\n");
308 }
309
310 static void start(int server_initiated)
311 {
312         int fd[2];
313         int ret;
314
315         ret = socketpair(AF_UNIX, SOCK_STREAM, 0, fd);
316         if (ret < 0) {
317                 perror("socketpair");
318                 exit(1);
319         }
320
321         child = fork();
322         if (child < 0) {
323                 perror("fork");
324                 fail("fork");
325                 exit(1);
326         }
327
328         if (child) {
329                 int status;
330                 /* parent */
331
332                 server(fd[0], server_initiated);
333                 wait(&status);
334                 if (WEXITSTATUS(status) != 0)
335                         fail("Child died with status %d\n",
336                              WEXITSTATUS(status));
337         } else {
338                 close(fd[0]);
339                 client(fd[1], server_initiated);
340                 exit(0);
341         }
342 }
343
344 void doit(void)
345 {
346         start(0);
347         start(1);
348 }
349
350 #endif                          /* _WIN32 */