g_thread_init should be called before any other glib function.
[platform/upstream/libsoup.git] / tests / ssl-test.c
1 #include <gnutls/gnutls.h>
2 #include <glib.h>
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
6 #include <unistd.h>
7 #include <netinet/in.h>
8 #include <sys/socket.h>
9
10 #include "libsoup/soup-address.h"
11 #include "libsoup/soup-socket.h"
12 #include "libsoup/soup-ssl.h"
13
14 #define BUFSIZE 1024
15 #define DH_BITS 1024
16
17 GMainLoop *loop;
18 gnutls_dh_params_t dh_params;
19
20 /* SERVER */
21
22 /* Read @bufsize bytes into @buf from @session. */
23 static void
24 server_read (gnutls_session_t session, char *buf, int bufsize)
25 {
26         int total, nread;
27
28         total = 0;
29         while (total < bufsize) {
30                 nread = gnutls_record_recv (session, buf + total,
31                                             bufsize - total);
32                 if (nread <= 0)
33                         g_error ("server read failed at position %d", total);
34                 total += nread;
35         }
36 }
37
38 /* Write @bufsize bytes from @buf to @session, forcing 3 rehandshakes
39  * along the way. (We do an odd number of rehandshakes to make sure
40  * they occur at weird times relative to the client's read buffer
41  * size.)
42  */
43 static void
44 server_write (gnutls_session_t session, char *buf, int bufsize)
45 {
46         int total, nwrote;
47         int next_rehandshake = bufsize / 3;
48
49         total = 0;
50         while (total < bufsize) {
51                 if (total >= next_rehandshake) {
52                         if (gnutls_rehandshake (session) < 0)
53                                 g_error ("client refused rehandshake at position %d", total);
54                         if (gnutls_handshake (session) < 0)
55                                 g_error ("server rehandshake failed at position %d", total);
56                         next_rehandshake = MIN (bufsize, next_rehandshake + bufsize / 3);
57                 }
58
59                 nwrote = gnutls_record_send (session, buf + total,
60                                              next_rehandshake - total);
61                 if (nwrote <= 0)
62                         g_error ("server write failed at position %d: %d", total, nwrote);
63                 total += nwrote;
64         }
65 }
66
67 const char *ssl_cert_file = SRCDIR "/test-cert.pem";
68 const char *ssl_key_file = SRCDIR "/test-key.pem";
69
70 static gpointer
71 server_thread (gpointer user_data)
72 {
73         int listener = GPOINTER_TO_INT (user_data), client;
74         gnutls_certificate_credentials creds;
75         gnutls_session_t session;
76         struct sockaddr_in sin;
77         int len;
78         char buf[BUFSIZE];
79         int status;
80
81         gnutls_certificate_allocate_credentials (&creds);
82         if (gnutls_certificate_set_x509_key_file (creds,
83                                                   ssl_cert_file, ssl_key_file,
84                                                   GNUTLS_X509_FMT_PEM) != 0) {
85                 g_error ("Failed to set SSL certificate and key files "
86                          "(%s, %s).", ssl_cert_file, ssl_key_file);
87         }
88         gnutls_certificate_set_dh_params (creds, dh_params);
89
90         /* Create a new session */
91         gnutls_init (&session, GNUTLS_SERVER);
92         gnutls_set_default_priority (session);
93         gnutls_credentials_set (session, GNUTLS_CRD_CERTIFICATE, creds);
94         gnutls_dh_set_prime_bits (session, DH_BITS);
95
96         /* Wait for client thread to connect */
97         len = sizeof (sin);
98         client = accept (listener, (struct sockaddr *) &sin, (void *)&len);
99         gnutls_transport_set_ptr (session, GINT_TO_POINTER (client));
100
101         /* Initial handshake */
102         status = gnutls_handshake (session);
103         if (status < 0)
104                 g_error ("initial handshake failed: %d", status);
105
106         /* Synchronous client test. */
107         server_read (session, buf, BUFSIZE);
108         server_write (session, buf, BUFSIZE);
109
110         /* Async client test. */
111         server_read (session, buf, BUFSIZE);
112         server_write (session, buf, BUFSIZE);
113
114         /* That's all, folks. */
115         gnutls_bye (session, GNUTLS_SHUT_WR);
116         gnutls_deinit (session);
117         close (client);
118         gnutls_certificate_free_credentials (creds);
119
120         return NULL;
121 }
122
123 /* async client code */
124
125 typedef struct {
126         char writebuf[BUFSIZE], readbuf[BUFSIZE];
127         int total;
128 } AsyncData;
129
130 static void
131 async_read (SoupSocket *sock, gpointer user_data)
132 {
133         AsyncData *data = user_data;
134         SoupSocketIOStatus status;
135         gsize n;
136         GError *error = NULL;
137
138         do {
139                 status = soup_socket_read (sock, data->readbuf + data->total,
140                                            BUFSIZE - data->total, &n,
141                                            NULL, &error);
142                 if (status == SOUP_SOCKET_OK)
143                         data->total += n;
144         } while (status == SOUP_SOCKET_OK && data->total < BUFSIZE);
145
146         if (status == SOUP_SOCKET_ERROR || status == SOUP_SOCKET_EOF) {
147                 g_error ("Async read got status %d: %s", status,
148                          error ? error->message : "(unknown)");
149         } else if (status == SOUP_SOCKET_WOULD_BLOCK)
150                 return;
151
152         if (memcmp (data->writebuf, data->readbuf, BUFSIZE) != 0)
153                 g_error ("Sync read didn't match write");
154
155         g_free (data);
156         g_main_loop_quit (loop);
157 }
158
159 static void
160 async_write (SoupSocket *sock, gpointer user_data)
161 {
162         AsyncData *data = user_data;
163         SoupSocketIOStatus status;
164         gsize n;
165         GError *error = NULL;
166
167         do {
168                 status = soup_socket_write (sock, data->writebuf + data->total,
169                                             BUFSIZE - data->total, &n,
170                                             NULL, &error);
171                 if (status == SOUP_SOCKET_OK)
172                         data->total += n;
173         } while (status == SOUP_SOCKET_OK && data->total < BUFSIZE);
174
175         if (status == SOUP_SOCKET_ERROR || status == SOUP_SOCKET_EOF) {
176                 g_error ("Async write got status %d: %s", status,
177                          error ? error->message : "(unknown)");
178         } else if (status == SOUP_SOCKET_WOULD_BLOCK)
179                 return;
180
181         data->total = 0;
182         async_read (sock, user_data);
183 }
184
185 static gboolean
186 start_writing (gpointer user_data)
187 {
188         SoupSocket *sock = user_data;
189         AsyncData *data;
190         int i;
191
192         data = g_new (AsyncData, 1);
193         for (i = 0; i < BUFSIZE; i++)
194                 data->writebuf[i] = i & 0xFF;
195         data->total = 0;
196
197         g_signal_connect (sock, "writable",
198                           G_CALLBACK (async_write), data);
199         g_signal_connect (sock, "readable",
200                           G_CALLBACK (async_read), data);
201
202         async_write (sock, data);
203         return FALSE;
204 }
205
206 int debug;
207
208 static void
209 debug_log (int level, const char *str)
210 {
211   fputs (str, stderr);
212 }
213
214 int
215 main (int argc, char **argv)
216 {
217         int opt, listener, sin_len, port, i;
218         struct sockaddr_in sin;
219         GThread *server;
220         char writebuf[BUFSIZE], readbuf[BUFSIZE];
221         SoupAddress *addr;
222         SoupSSLCredentials *creds;
223         SoupSocket *sock;
224         gsize n, total;
225         SoupSocketIOStatus status;
226         GError *error = NULL;
227
228         g_thread_init (NULL);
229         g_type_init ();
230
231         while ((opt = getopt (argc, argv, "c:d:k:")) != -1) {
232                 switch (opt) {
233                 case 'c':
234                         ssl_cert_file = optarg;
235                         break;
236                 case 'd':
237                         debug = atoi (optarg);
238                         break;
239                 case 'k':
240                         ssl_key_file = optarg;
241                         break;
242
243                 case '?':
244                         fprintf (stderr, "Usage: %s [-d debuglevel] [-c ssl-cert-file] [-k ssl-key-file]\n",
245                                  argv[0]);
246                         break;
247                 }
248         }
249
250         if (debug) {
251                 gnutls_global_set_log_function (debug_log);
252                 gnutls_global_set_log_level (debug);
253         }
254
255         /* Create server socket */
256         listener = socket (AF_INET, SOCK_STREAM, 0);
257         if (listener == -1) {
258                 perror ("creating listening socket");
259                 exit (1);
260         }
261
262         memset (&sin, 0, sizeof (sin));
263         sin.sin_family = AF_INET;
264         sin.sin_addr.s_addr = INADDR_ANY;
265
266         if (bind (listener, (struct sockaddr *) &sin, sizeof (sin))  == -1) {
267                 perror ("binding listening socket");
268                 exit (1);
269         }
270
271         if (listen (listener, 1) == -1) {
272                 perror ("listening on socket");
273                 exit (1);
274         }
275
276         sin_len = sizeof (sin);
277         getsockname (listener, (struct sockaddr *)&sin, (void *)&sin_len);
278         port = ntohs (sin.sin_port);
279
280         /* Create the client */
281         addr = soup_address_new ("127.0.0.1", port);
282         creds = soup_ssl_get_client_credentials (NULL);
283         sock = soup_socket_new (SOUP_SOCKET_REMOTE_ADDRESS, addr,
284                                 SOUP_SOCKET_FLAG_NONBLOCKING, FALSE,
285                                 SOUP_SOCKET_SSL_CREDENTIALS, creds,
286                                 NULL);
287         g_object_unref (addr);
288         status = soup_socket_connect_sync (sock, NULL);
289         if (status != SOUP_STATUS_OK) {
290                 g_error ("Could not create client socket: %s",
291                          soup_status_get_phrase (status));
292         }
293
294         soup_socket_start_ssl (sock, NULL);
295
296         /* Now spawn server thread */
297         server = g_thread_create (server_thread, GINT_TO_POINTER (listener),
298                                   TRUE, NULL);
299
300         /* Synchronous client test */
301         for (i = 0; i < BUFSIZE; i++)
302                 writebuf[i] = i & 0xFF;
303
304         total = 0;
305         while (total < BUFSIZE) {
306                 status = soup_socket_write (sock, writebuf + total,
307                                             BUFSIZE - total, &n,
308                                             NULL, &error);
309                 if (status != SOUP_SOCKET_OK)
310                         g_error ("Sync write got status %d: %s", status,
311                                  error ? error->message : "(unknown)");
312                 total += n;
313         }
314
315         total = 0;
316         while (total < BUFSIZE) {
317                 status = soup_socket_read (sock, readbuf + total,
318                                            BUFSIZE - total, &n,
319                                            NULL, &error);
320                 if (status != SOUP_SOCKET_OK)
321                         g_error ("Sync read got status %d: %s", status,
322                                  error ? error->message : "(unknown)");
323                 total += n;
324         }
325
326         if (memcmp (writebuf, readbuf, BUFSIZE) != 0)
327                 g_error ("Sync read didn't match write");
328
329         printf ("SYNCHRONOUS SSL TEST PASSED\n");
330
331         /* Switch socket to async and do it again */
332
333         g_object_set (sock,
334                       SOUP_SOCKET_FLAG_NONBLOCKING, TRUE,
335                       NULL);
336
337         g_idle_add (start_writing, sock);
338         loop = g_main_loop_new (NULL, TRUE);
339         g_main_loop_run (loop);
340         g_main_loop_unref (loop);
341         g_main_context_unref (g_main_context_default ());
342
343         printf ("ASYNCHRONOUS SSL TEST PASSED\n");
344
345         g_object_unref (sock);
346         soup_ssl_free_client_credentials (creds);
347         g_thread_join (server);
348
349         /* Success */
350         return 0;
351 }