gio/tests/giomodule.c: Use G_MODULE_SUFFIX
[platform/upstream/glib.git] / gio / tests / socket-client.c
index 7b14de8..62dda89 100644 (file)
@@ -5,16 +5,17 @@
 #include <stdio.h>
 #include <string.h>
 
-#include "socket-common.c"
+#include "gtlsconsoleinteraction.h"
 
 GMainLoop *loop;
 
 gboolean verbose = FALSE;
 gboolean non_blocking = FALSE;
 gboolean use_udp = FALSE;
-gboolean use_source = FALSE;
 int cancel_timeout = 0;
+int read_timeout = 0;
 gboolean unix_socket = FALSE;
+gboolean tls = FALSE;
 
 static GOptionEntry cmd_entries[] = {
   {"cancel", 'c', 0, G_OPTION_ARG_INT, &cancel_timeout,
@@ -25,66 +26,223 @@ static GOptionEntry cmd_entries[] = {
    "Be verbose", NULL},
   {"non-blocking", 'n', 0, G_OPTION_ARG_NONE, &non_blocking,
    "Enable non-blocking i/o", NULL},
-  {"use-source", 's', 0, G_OPTION_ARG_NONE, &use_source,
-   "Use GSource to wait for non-blocking i/o", NULL},
+#ifdef G_OS_UNIX
   {"unix", 'U', 0, G_OPTION_ARG_NONE, &unix_socket,
    "Use a unix socket instead of IP", NULL},
+#endif
+  {"timeout", 't', 0, G_OPTION_ARG_INT, &read_timeout,
+   "Time out reads after the specified number of seconds", NULL},
+  {"tls", 'T', 0, G_OPTION_ARG_NONE, &tls,
+   "Use TLS (SSL)", NULL},
   {NULL}
 };
 
+#include "socket-common.c"
+
 static gboolean
-source_ready (gpointer data,
-             GIOCondition condition)
+accept_certificate (GTlsClientConnection *conn,
+                   GTlsCertificate      *cert,
+                   GTlsCertificateFlags  errors,
+                   gpointer              user_data)
 {
-  g_main_loop_quit (loop);
-  return FALSE;
+  g_print ("Certificate would have been rejected ( ");
+  if (errors & G_TLS_CERTIFICATE_UNKNOWN_CA)
+    g_print ("unknown-ca ");
+  if (errors & G_TLS_CERTIFICATE_BAD_IDENTITY)
+    g_print ("bad-identity ");
+  if (errors & G_TLS_CERTIFICATE_NOT_ACTIVATED)
+    g_print ("not-activated ");
+  if (errors & G_TLS_CERTIFICATE_EXPIRED)
+    g_print ("expired ");
+  if (errors & G_TLS_CERTIFICATE_REVOKED)
+    g_print ("revoked ");
+  if (errors & G_TLS_CERTIFICATE_INSECURE)
+    g_print ("insecure ");
+  g_print (") but accepting anyway.\n");
+
+  return TRUE;
 }
 
-static void
-ensure_condition (GSocket *socket,
-                 const char *where,
-                 GCancellable *cancellable,
-                 GIOCondition condition)
+static GTlsCertificate *
+lookup_client_certificate (GTlsClientConnection  *conn,
+                          GError               **error)
 {
-  GError *error = NULL;
-  GSource *source;
+  GList *l, *accepted;
+  GList *c, *certificates;
+  GTlsDatabase *database;
+  GTlsCertificate *certificate = NULL;
+  GTlsConnection *base;
+
+  accepted = g_tls_client_connection_get_accepted_cas (conn);
+  for (l = accepted; l != NULL; l = g_list_next (l))
+    {
+      base = G_TLS_CONNECTION (conn);
+      database = g_tls_connection_get_database (base);
+      certificates = g_tls_database_lookup_certificates_issued_by (database, l->data,
+                                                                   g_tls_connection_get_interaction (base),
+                                                                   G_TLS_DATABASE_LOOKUP_KEYPAIR,
+                                                                   NULL, error);
+      if (error && *error)
+        break;
+
+      if (certificates)
+          certificate = g_object_ref (certificates->data);
+
+      for (c = certificates; c != NULL; c = g_list_next (c))
+        g_object_unref (c->data);
+      g_list_free (certificates);
+    }
+
+  for (l = accepted; l != NULL; l = g_list_next (l))
+    g_byte_array_unref (l->data);
+  g_list_free (accepted);
+
+  if (certificate == NULL && error && !*error)
+    g_set_error_literal (error, G_TLS_ERROR, G_TLS_ERROR_CERTIFICATE_REQUIRED,
+                         "Server requested a certificate, but could not find relevant certificate in database.");
+  return certificate;
+}
+
+static gboolean
+make_connection (const char       *argument,
+                GTlsCertificate  *certificate,
+                GCancellable     *cancellable,
+                GSocket         **socket,
+                GSocketAddress  **address,
+                GIOStream       **connection,
+                GInputStream    **istream,
+                GOutputStream   **ostream,
+                GError          **error)
+{
+  GSocketType socket_type;
+  GSocketFamily socket_family;
+  GSocketAddressEnumerator *enumerator;
+  GSocketConnectable *connectable;
+  GSocketAddress *src_address;
+  GTlsInteraction *interaction;
+  GError *err = NULL;
+
+  if (use_udp)
+    socket_type = G_SOCKET_TYPE_DATAGRAM;
+  else
+    socket_type = G_SOCKET_TYPE_STREAM;
+
+  if (unix_socket)
+    socket_family = G_SOCKET_FAMILY_UNIX;
+  else
+    socket_family = G_SOCKET_FAMILY_IPV4;
 
-  if (!non_blocking)
-    return;
+  *socket = g_socket_new (socket_family, socket_type, 0, error);
+  if (*socket == NULL)
+    return FALSE;
 
-  if (use_source)
+  if (read_timeout)
+    g_socket_set_timeout (*socket, read_timeout);
+
+  if (unix_socket)
     {
-      source = g_socket_create_source (socket,
-                                       condition,
-                                       cancellable);
-      g_source_set_callback (source,
-                             (GSourceFunc) source_ready,
-                            NULL, NULL);
-      g_source_attach (source, NULL);
-      g_source_unref (source);
-      g_main_loop_run (loop);
+      GSocketAddress *addr;
+
+      addr = socket_address_from_string (argument);
+      if (addr == NULL)
+        {
+          g_set_error (error, G_IO_ERROR, G_IO_ERROR_FAILED,
+                       "Could not parse '%s' as unix socket name", argument);
+          return FALSE;
+        }
+      connectable = G_SOCKET_CONNECTABLE (addr);
     }
   else
     {
-      if (!g_socket_condition_wait (socket, condition, cancellable, &error))
-       {
-         g_printerr ("condition wait error for %s: %s\n",
-                     where,
-                     error->message);
-         exit (1);
-       }
+      connectable = g_network_address_parse (argument, 7777, error);
+      if (connectable == NULL)
+        return FALSE;
     }
-}
 
-static gpointer
-cancel_thread (gpointer data)
-{
-  GCancellable *cancellable = data;
+  enumerator = g_socket_connectable_enumerate (connectable);
+  while (TRUE)
+    {
+      *address = g_socket_address_enumerator_next (enumerator, cancellable, error);
+      if (*address == NULL)
+        {
+          if (error != NULL && *error == NULL)
+            g_set_error_literal (error, G_IO_ERROR, G_IO_ERROR_FAILED,
+                                 "No more addresses to try");
+          return FALSE;
+        }
+
+      if (g_socket_connect (*socket, *address, cancellable, &err))
+        break;
+      g_message ("Connection to %s failed: %s, trying next\n", socket_address_to_string (*address), err->message);
+      g_clear_error (&err);
+
+      g_object_unref (*address);
+    }
+  g_object_unref (enumerator);
+
+  g_print ("Connected to %s\n",
+           socket_address_to_string (*address));
 
-  g_usleep (1000*1000*cancel_timeout);
-  g_print ("Cancelling\n");
-  g_cancellable_cancel (cancellable);
-  return NULL;
+  src_address = g_socket_get_local_address (*socket, error);
+  if (!src_address)
+    {
+      g_prefix_error (error, "Error getting local address: ");
+      return FALSE;
+    }
+
+  g_print ("local address: %s\n",
+           socket_address_to_string (src_address));
+  g_object_unref (src_address);
+
+  if (use_udp)
+    {
+      *connection = NULL;
+      *istream = NULL;
+      *ostream = NULL;
+    }
+  else
+    *connection = G_IO_STREAM (g_socket_connection_factory_create_connection (*socket));
+
+  if (tls)
+    {
+      GIOStream *tls_conn;
+
+      tls_conn = g_tls_client_connection_new (*connection, connectable, error);
+      if (!tls_conn)
+        {
+          g_prefix_error (error, "Could not create TLS connection: ");
+          return FALSE;
+        }
+
+      g_signal_connect (tls_conn, "accept-certificate",
+                        G_CALLBACK (accept_certificate), NULL);
+
+      interaction = g_tls_console_interaction_new ();
+      g_tls_connection_set_interaction (G_TLS_CONNECTION (tls_conn), interaction);
+      g_object_unref (interaction);
+
+      if (certificate)
+        g_tls_connection_set_certificate (G_TLS_CONNECTION (tls_conn), certificate);
+
+      g_object_unref (*connection);
+      *connection = G_IO_STREAM (tls_conn);
+
+      if (!g_tls_connection_handshake (G_TLS_CONNECTION (tls_conn),
+                                       cancellable, error))
+        {
+          g_prefix_error (error, "Error during TLS handshake: ");
+          return FALSE;
+        }
+    }
+  g_object_unref (connectable);
+
+  if (*connection)
+    {
+      *istream = g_io_stream_get_input_stream (*connection);
+      *ostream = g_io_stream_get_output_stream (*connection);
+    }
+
+  return TRUE;
 }
 
 int
@@ -92,19 +250,19 @@ main (int argc,
       char *argv[])
 {
   GSocket *socket;
-  GSocketAddress *src_address;
   GSocketAddress *address;
-  GSocketType socket_type;
-  GSocketFamily socket_family;
   GError *error = NULL;
   GOptionContext *context;
   GCancellable *cancellable;
-  GSocketAddressEnumerator *enumerator;
-  GSocketConnectable *connectable;
-
-  g_thread_init (NULL);
+  GIOStream *connection;
+  GInputStream *istream;
+  GOutputStream *ostream;
+  GSocketAddress *src_address;
+  GTlsCertificate *certificate = NULL;
+  gint i;
 
-  g_type_init ();
+  address = NULL;
+  connection = NULL;
 
   context = g_option_context_new (" <hostname>[:port] - Test GSocket client stuff");
   g_option_context_add_main_entries (context, cmd_entries, NULL);
@@ -120,10 +278,18 @@ main (int argc,
       return 1;
     }
 
+  if (use_udp && tls)
+    {
+      g_printerr ("DTLS (TLS over UDP) is not supported");
+      return 1;
+    }
+
   if (cancel_timeout)
     {
+      GThread *thread;
       cancellable = g_cancellable_new ();
-      g_thread_create (cancel_thread, cancellable, FALSE, NULL);
+      thread = g_thread_new ("cancel", cancel_thread, cancellable);
+      g_thread_unref (thread);
     }
   else
     {
@@ -132,90 +298,31 @@ main (int argc,
 
   loop = g_main_loop_new (NULL, FALSE);
 
-  if (use_udp)
-    socket_type = G_SOCKET_TYPE_DATAGRAM;
-  else
-    socket_type = G_SOCKET_TYPE_STREAM;
-
-  if (unix_socket)
-    socket_family = G_SOCKET_FAMILY_UNIX;
-  else
-    socket_family = G_SOCKET_FAMILY_IPV4;
-
-  socket = g_socket_new (socket_family, socket_type, 0, &error);
-  if (socket == NULL)
+  for (i = 0; i < 2; i++)
     {
-      g_printerr ("%s: %s\n", argv[0], error->message);
+      if (make_connection (argv[1], certificate, cancellable, &socket, &address,
+                           &connection, &istream, &ostream, &error))
+          break;
+
+      if (g_error_matches (error, G_TLS_ERROR, G_TLS_ERROR_CERTIFICATE_REQUIRED))
+        {
+          g_clear_error (&error);
+          certificate = lookup_client_certificate (G_TLS_CLIENT_CONNECTION (connection), &error);
+          if (certificate != NULL)
+            continue;
+        }
+
+      g_printerr ("%s: %s", argv[0], error->message);
       return 1;
     }
 
-  if (unix_socket)
-    {
-      GSocketAddress *addr;
-
-      addr = socket_address_from_string (argv[1]);
-      if (addr == NULL)
-       {
-         g_printerr ("%s: Could not parse '%s' as unix socket name\n", argv[0], argv[1]);
-         return 1;
-       }
-      connectable = G_SOCKET_CONNECTABLE (addr);
-    }
-  else
-    {
-      connectable = g_network_address_parse (argv[1], 7777, &error);
-      if (connectable == NULL)
-       {
-         g_printerr ("%s: %s\n", argv[0], error->message);
-         return 1;
-       }
-    }
-
-  enumerator = g_socket_connectable_enumerate (connectable);
-  while (TRUE)
-    {
-      address = g_socket_address_enumerator_next (enumerator, cancellable, &error);
-      if (address == NULL)
-       {
-         if (error == NULL)
-           g_printerr ("%s: No more addresses to try\n", argv[0]);
-         else
-           g_printerr ("%s: %s\n", argv[0], error->message);
-         return 1;
-       }
-
-      if (g_socket_connect (socket, address, cancellable, &error))
-       break;
-      g_printerr ("%s: Connection to %s failed: %s, trying next\n", argv[0], socket_address_to_string (address), error->message);
-      g_error_free (error);
-      error = NULL;
-
-      g_object_unref (address);
-    }
-  g_object_unref (enumerator);
-  g_object_unref (connectable);
-
-  g_print ("Connected to %s\n",
-          socket_address_to_string (address));
-
-  /* TODO: Test non-blocking connect */
+  /* TODO: Test non-blocking connect/handshake */
   if (non_blocking)
     g_socket_set_blocking (socket, FALSE);
 
-  src_address = g_socket_get_local_address (socket, &error);
-  if (!src_address)
-    {
-      g_printerr ("Error getting local address: %s\n",
-                 error->message);
-      return 1;
-    }
-  g_print ("local address: %s\n",
-          socket_address_to_string (src_address));
-  g_object_unref (src_address);
-
   while (TRUE)
     {
-      gchar buffer[4096] = { };
+      gchar buffer[4096];
       gssize size;
       gsize to_send;
 
@@ -225,14 +332,20 @@ main (int argc,
       to_send = strlen (buffer);
       while (to_send > 0)
        {
-         ensure_condition (socket, "send", cancellable, G_IO_OUT);
          if (use_udp)
-           size = g_socket_send_to (socket, address,
-                                    buffer, to_send,
-                                    cancellable, &error);
+           {
+             ensure_socket_condition (socket, G_IO_OUT, cancellable);
+             size = g_socket_send_to (socket, address,
+                                      buffer, to_send,
+                                      cancellable, &error);
+           }
          else
-           size = g_socket_send (socket, buffer, to_send,
-                                 cancellable, &error);
+           {
+             ensure_connection_condition (connection, G_IO_OUT, cancellable);
+             size = g_output_stream_write (ostream,
+                                           buffer, to_send,
+                                           cancellable, &error);
+           }
 
          if (size < 0)
            {
@@ -264,14 +377,20 @@ main (int argc,
          to_send -= size;
        }
 
-      ensure_condition (socket, "receive", cancellable, G_IO_IN);
       if (use_udp)
-       size = g_socket_receive_from (socket, &src_address,
+       {
+         ensure_socket_condition (socket, G_IO_IN, cancellable);
+         size = g_socket_receive_from (socket, &src_address,
+                                       buffer, sizeof buffer,
+                                       cancellable, &error);
+       }
+      else
+       {
+         ensure_connection_condition (connection, G_IO_IN, cancellable);
+         size = g_input_stream_read (istream,
                                      buffer, sizeof buffer,
                                      cancellable, &error);
-      else
-       size = g_socket_receive (socket, buffer, sizeof buffer,
-                                cancellable, &error);
+       }
 
       if (size < 0)
        {
@@ -298,15 +417,28 @@ main (int argc,
 
   g_print ("closing socket\n");
 
-  if (!g_socket_close (socket, &error))
+  if (connection)
     {
-      g_printerr ("Error closing master socket: %s\n",
-                 error->message);
-      return 1;
+      if (!g_io_stream_close (connection, cancellable, &error))
+       {
+         g_printerr ("Error closing connection: %s\n",
+                     error->message);
+         return 1;
+       }
+      g_object_unref (connection);
+    }
+  else
+    {
+      if (!g_socket_close (socket, &error))
+       {
+         g_printerr ("Error closing master socket: %s\n",
+                     error->message);
+         return 1;
+       }
     }
 
-  g_object_unref (G_OBJECT (socket));
-  g_object_unref (G_OBJECT (address));
+  g_object_unref (socket);
+  g_object_unref (address);
 
   return 0;
 }