Merge remote-tracking branch 'gvdb/master'
[platform/upstream/glib.git] / gio / tests / socket-client.c
1 #include <gio/gio.h>
2 #include <gio/gunixsocketaddress.h>
3 #include <glib.h>
4 #include <stdlib.h>
5 #include <stdio.h>
6 #include <string.h>
7
8 #include "gtlsconsoleinteraction.h"
9
10 GMainLoop *loop;
11
12 gboolean verbose = FALSE;
13 gboolean non_blocking = FALSE;
14 gboolean use_udp = FALSE;
15 int cancel_timeout = 0;
16 int read_timeout = 0;
17 gboolean unix_socket = FALSE;
18 gboolean tls = FALSE;
19
20 static GOptionEntry cmd_entries[] = {
21   {"cancel", 'c', 0, G_OPTION_ARG_INT, &cancel_timeout,
22    "Cancel any op after the specified amount of seconds", NULL},
23   {"udp", 'u', 0, G_OPTION_ARG_NONE, &use_udp,
24    "Use udp instead of tcp", NULL},
25   {"verbose", 'v', 0, G_OPTION_ARG_NONE, &verbose,
26    "Be verbose", NULL},
27   {"non-blocking", 'n', 0, G_OPTION_ARG_NONE, &non_blocking,
28    "Enable non-blocking i/o", NULL},
29 #ifdef G_OS_UNIX
30   {"unix", 'U', 0, G_OPTION_ARG_NONE, &unix_socket,
31    "Use a unix socket instead of IP", NULL},
32 #endif
33   {"timeout", 't', 0, G_OPTION_ARG_INT, &read_timeout,
34    "Time out reads after the specified number of seconds", NULL},
35   {"tls", 'T', 0, G_OPTION_ARG_NONE, &tls,
36    "Use TLS (SSL)", NULL},
37   {NULL}
38 };
39
40 #include "socket-common.c"
41
42 static gboolean
43 accept_certificate (GTlsClientConnection *conn, GTlsCertificate *cert,
44                     GTlsCertificateFlags errors, gpointer user_data)
45 {
46   g_print ("Certificate would have been rejected ( ");
47   if (errors & G_TLS_CERTIFICATE_UNKNOWN_CA)
48     g_print ("unknown-ca ");
49   if (errors & G_TLS_CERTIFICATE_BAD_IDENTITY)
50     g_print ("bad-identity ");
51   if (errors & G_TLS_CERTIFICATE_NOT_ACTIVATED)
52     g_print ("not-activated ");
53   if (errors & G_TLS_CERTIFICATE_EXPIRED)
54     g_print ("expired ");
55   if (errors & G_TLS_CERTIFICATE_REVOKED)
56     g_print ("revoked ");
57   if (errors & G_TLS_CERTIFICATE_INSECURE)
58     g_print ("insecure ");
59   g_print (") but accepting anyway.\n");
60
61   return TRUE;
62 }
63
64 static GTlsCertificate *
65 lookup_client_certificate (GTlsClientConnection *conn, GError **error)
66 {
67   GList *l, *accepted;
68   GList *c, *certificates;
69   GTlsDatabase *database;
70   GTlsCertificate *certificate = NULL;
71   GTlsConnection *base;
72
73   accepted = g_tls_client_connection_get_accepted_cas (conn);
74   for (l = accepted; l != NULL; l = g_list_next (l))
75     {
76       base = G_TLS_CONNECTION (conn);
77       database = g_tls_connection_get_database (base);
78       certificates = g_tls_database_lookup_certificates_issued_by (database, l->data,
79                                                                    g_tls_connection_get_interaction (base),
80                                                                    G_TLS_DATABASE_LOOKUP_KEYPAIR,
81                                                                    NULL, error);
82       if (error && *error)
83         break;
84
85       if (certificates)
86           certificate = g_object_ref (certificates->data);
87
88       for (c = certificates; c != NULL; c = g_list_next (c))
89         g_object_unref (c->data);
90       g_list_free (certificates);
91     }
92
93   for (l = accepted; l != NULL; l = g_list_next (l))
94     g_byte_array_unref (l->data);
95   g_list_free (accepted);
96
97   if (certificate == NULL && error && !*error)
98     g_set_error_literal (error, G_TLS_ERROR, G_TLS_ERROR_CERTIFICATE_REQUIRED,
99                          "Server requested a certificate, but could not find relevant certificate in database.");
100   return certificate;
101 }
102
103 static gboolean
104 make_connection (const char *argument, GTlsCertificate *certificate, GCancellable *cancellable,
105                  GSocket **socket, GSocketAddress **address, GIOStream **connection,
106                  GInputStream **istream, GOutputStream **ostream, GError **error)
107 {
108   GSocketType socket_type;
109   GSocketFamily socket_family;
110   GSocketAddressEnumerator *enumerator;
111   GSocketConnectable *connectable;
112   GSocketAddress *src_address;
113   GTlsInteraction *interaction;
114   GError *err = NULL;
115
116   if (use_udp)
117     socket_type = G_SOCKET_TYPE_DATAGRAM;
118   else
119     socket_type = G_SOCKET_TYPE_STREAM;
120
121   if (unix_socket)
122     socket_family = G_SOCKET_FAMILY_UNIX;
123   else
124     socket_family = G_SOCKET_FAMILY_IPV4;
125
126   *socket = g_socket_new (socket_family, socket_type, 0, error);
127   if (*socket == NULL)
128     return FALSE;
129
130   if (read_timeout)
131     g_socket_set_timeout (*socket, read_timeout);
132
133   if (unix_socket)
134     {
135       GSocketAddress *addr;
136
137       addr = socket_address_from_string (argument);
138       if (addr == NULL)
139         {
140           g_set_error (error, G_IO_ERROR, G_IO_ERROR_FAILED,
141                        "Could not parse '%s' as unix socket name", argument);
142           return FALSE;
143         }
144       connectable = G_SOCKET_CONNECTABLE (addr);
145     }
146   else
147     {
148       connectable = g_network_address_parse (argument, 7777, error);
149       if (connectable == NULL)
150         return FALSE;
151     }
152
153   enumerator = g_socket_connectable_enumerate (connectable);
154   while (TRUE)
155     {
156       *address = g_socket_address_enumerator_next (enumerator, cancellable, error);
157       if (*address == NULL)
158         {
159           if (error == NULL)
160             g_set_error_literal (error, G_IO_ERROR, G_IO_ERROR_FAILED,
161                                  "No more addresses to try");
162           return FALSE;
163         }
164
165       if (g_socket_connect (*socket, *address, cancellable, &err))
166         break;
167       g_message ("Connection to %s failed: %s, trying next\n", socket_address_to_string (*address), err->message);
168       g_clear_error (&err);
169
170       g_object_unref (*address);
171     }
172   g_object_unref (enumerator);
173
174   g_print ("Connected to %s\n",
175            socket_address_to_string (*address));
176
177   src_address = g_socket_get_local_address (*socket, error);
178   if (!src_address)
179     {
180       g_prefix_error (error, "Error getting local address: ");
181       return FALSE;
182     }
183
184   g_print ("local address: %s\n",
185            socket_address_to_string (src_address));
186   g_object_unref (src_address);
187
188   if (use_udp)
189     {
190       *connection = NULL;
191       *istream = NULL;
192       *ostream = NULL;
193     }
194   else
195     *connection = G_IO_STREAM (g_socket_connection_factory_create_connection (*socket));
196
197   if (tls)
198     {
199       GIOStream *tls_conn;
200
201       tls_conn = g_tls_client_connection_new (*connection, connectable, error);
202       if (!tls_conn)
203         {
204           g_prefix_error (error, "Could not create TLS connection: ");
205           return FALSE;
206         }
207
208       g_signal_connect (tls_conn, "accept-certificate",
209                         G_CALLBACK (accept_certificate), NULL);
210
211       interaction = g_tls_console_interaction_new ();
212       g_tls_connection_set_interaction (G_TLS_CONNECTION (tls_conn), interaction);
213       g_object_unref (interaction);
214
215       if (certificate)
216         g_tls_connection_set_certificate (G_TLS_CONNECTION (tls_conn), certificate);
217
218       g_object_unref (*connection);
219       *connection = G_IO_STREAM (tls_conn);
220
221       if (!g_tls_connection_handshake (G_TLS_CONNECTION (tls_conn),
222                                        cancellable, error))
223         {
224           g_prefix_error (error, "Error during TLS handshake: ");
225           return FALSE;
226         }
227     }
228   g_object_unref (connectable);
229
230   if (*connection)
231     {
232       *istream = g_io_stream_get_input_stream (*connection);
233       *ostream = g_io_stream_get_output_stream (*connection);
234     }
235
236   return TRUE;
237 }
238
239 int
240 main (int argc,
241       char *argv[])
242 {
243   GSocket *socket;
244   GSocketAddress *address;
245   GError *error = NULL;
246   GOptionContext *context;
247   GCancellable *cancellable;
248   GIOStream *connection;
249   GInputStream *istream;
250   GOutputStream *ostream;
251   GSocketAddress *src_address;
252   GTlsCertificate *certificate = NULL;
253   gint i;
254
255   g_thread_init (NULL);
256
257   g_type_init ();
258
259   address = NULL;
260   connection = NULL;
261
262   context = g_option_context_new (" <hostname>[:port] - Test GSocket client stuff");
263   g_option_context_add_main_entries (context, cmd_entries, NULL);
264   if (!g_option_context_parse (context, &argc, &argv, &error))
265     {
266       g_printerr ("%s: %s\n", argv[0], error->message);
267       return 1;
268     }
269
270   if (argc != 2)
271     {
272       g_printerr ("%s: %s\n", argv[0], "Need to specify hostname / unix socket name");
273       return 1;
274     }
275
276   if (use_udp && tls)
277     {
278       g_printerr ("DTLS (TLS over UDP) is not supported");
279       return 1;
280     }
281
282   if (cancel_timeout)
283     {
284       cancellable = g_cancellable_new ();
285       g_thread_create (cancel_thread, cancellable, FALSE, NULL);
286     }
287   else
288     {
289       cancellable = NULL;
290     }
291
292   loop = g_main_loop_new (NULL, FALSE);
293
294   for (i = 0; i < 2; i++)
295     {
296       if (make_connection (argv[1], certificate, cancellable, &socket, &address,
297                            &connection, &istream, &ostream, &error))
298           break;
299
300       if (g_error_matches (error, G_TLS_ERROR, G_TLS_ERROR_CERTIFICATE_REQUIRED))
301         {
302           g_clear_error (&error);
303           certificate = lookup_client_certificate (G_TLS_CLIENT_CONNECTION (connection), &error);
304           if (certificate != NULL)
305             continue;
306         }
307
308       g_printerr ("%s: %s", argv[0], error->message);
309       return 1;
310     }
311
312   /* TODO: Test non-blocking connect/handshake */
313   if (non_blocking)
314     g_socket_set_blocking (socket, FALSE);
315
316   while (TRUE)
317     {
318       gchar buffer[4096];
319       gssize size;
320       gsize to_send;
321
322       if (fgets (buffer, sizeof buffer, stdin) == NULL)
323         break;
324
325       to_send = strlen (buffer);
326       while (to_send > 0)
327         {
328           if (use_udp)
329             {
330               ensure_socket_condition (socket, G_IO_OUT, cancellable);
331               size = g_socket_send_to (socket, address,
332                                        buffer, to_send,
333                                        cancellable, &error);
334             }
335           else
336             {
337               ensure_connection_condition (connection, G_IO_OUT, cancellable);
338               size = g_output_stream_write (ostream,
339                                             buffer, to_send,
340                                             cancellable, &error);
341             }
342
343           if (size < 0)
344             {
345               if (g_error_matches (error,
346                                    G_IO_ERROR,
347                                    G_IO_ERROR_WOULD_BLOCK))
348                 {
349                   g_print ("socket send would block, handling\n");
350                   g_error_free (error);
351                   error = NULL;
352                   continue;
353                 }
354               else
355                 {
356                   g_printerr ("Error sending to socket: %s\n",
357                               error->message);
358                   return 1;
359                 }
360             }
361
362           g_print ("sent %" G_GSSIZE_FORMAT " bytes of data\n", size);
363
364           if (size == 0)
365             {
366               g_printerr ("Unexpected short write\n");
367               return 1;
368             }
369
370           to_send -= size;
371         }
372
373       if (use_udp)
374         {
375           ensure_socket_condition (socket, G_IO_IN, cancellable);
376           size = g_socket_receive_from (socket, &src_address,
377                                         buffer, sizeof buffer,
378                                         cancellable, &error);
379         }
380       else
381         {
382           ensure_connection_condition (connection, G_IO_IN, cancellable);
383           size = g_input_stream_read (istream,
384                                       buffer, sizeof buffer,
385                                       cancellable, &error);
386         }
387
388       if (size < 0)
389         {
390           g_printerr ("Error receiving from socket: %s\n",
391                       error->message);
392           return 1;
393         }
394
395       if (size == 0)
396         break;
397
398       g_print ("received %" G_GSSIZE_FORMAT " bytes of data", size);
399       if (use_udp)
400         g_print (" from %s", socket_address_to_string (src_address));
401       g_print ("\n");
402
403       if (verbose)
404         g_print ("-------------------------\n"
405                  "%.*s"
406                  "-------------------------\n",
407                  (int)size, buffer);
408
409     }
410
411   g_print ("closing socket\n");
412
413   if (connection)
414     {
415       if (!g_io_stream_close (connection, cancellable, &error))
416         {
417           g_printerr ("Error closing connection: %s\n",
418                       error->message);
419           return 1;
420         }
421       g_object_unref (connection);
422     }
423   else
424     {
425       if (!g_socket_close (socket, &error))
426         {
427           g_printerr ("Error closing master socket: %s\n",
428                       error->message);
429           return 1;
430         }
431     }
432
433   g_object_unref (socket);
434   g_object_unref (address);
435
436   return 0;
437 }