gio/tests/giomodule.c: Use G_MODULE_SUFFIX
[platform/upstream/glib.git] / gio / tests / socket-client.c
index 6068034..62dda89 100644 (file)
@@ -5,6 +5,8 @@
 #include <stdio.h>
 #include <string.h>
 
+#include "gtlsconsoleinteraction.h"
+
 GMainLoop *loop;
 
 gboolean verbose = FALSE;
@@ -38,8 +40,10 @@ static GOptionEntry cmd_entries[] = {
 #include "socket-common.c"
 
 static gboolean
-accept_certificate (GTlsClientConnection *conn, GTlsCertificate *cert,
-                   GTlsCertificateFlags errors, gpointer user_data)
+accept_certificate (GTlsClientConnection *conn,
+                   GTlsCertificate      *cert,
+                   GTlsCertificateFlags  errors,
+                   gpointer              user_data)
 {
   g_print ("Certificate would have been rejected ( ");
   if (errors & G_TLS_CERTIFICATE_UNKNOWN_CA)
@@ -59,59 +63,64 @@ accept_certificate (GTlsClientConnection *conn, GTlsCertificate *cert,
   return TRUE;
 }
 
-int
-main (int argc,
-      char *argv[])
+static GTlsCertificate *
+lookup_client_certificate (GTlsClientConnection  *conn,
+                          GError               **error)
 {
-  GSocket *socket;
-  GSocketAddress *src_address;
-  GSocketAddress *address;
-  GSocketType socket_type;
-  GSocketFamily socket_family;
-  GError *error = NULL;
-  GOptionContext *context;
-  GCancellable *cancellable;
-  GSocketAddressEnumerator *enumerator;
-  GSocketConnectable *connectable;
-  GIOStream *connection;
-  GInputStream *istream;
-  GOutputStream *ostream;
-
-  g_thread_init (NULL);
-
-  g_type_init ();
-
-  context = g_option_context_new (" <hostname>[:port] - Test GSocket client stuff");
-  g_option_context_add_main_entries (context, cmd_entries, NULL);
-  if (!g_option_context_parse (context, &argc, &argv, &error))
+  GList *l, *accepted;
+  GList *c, *certificates;
+  GTlsDatabase *database;
+  GTlsCertificate *certificate = NULL;
+  GTlsConnection *base;
+
+  accepted = g_tls_client_connection_get_accepted_cas (conn);
+  for (l = accepted; l != NULL; l = g_list_next (l))
     {
-      g_printerr ("%s: %s\n", argv[0], error->message);
-      return 1;
+      base = G_TLS_CONNECTION (conn);
+      database = g_tls_connection_get_database (base);
+      certificates = g_tls_database_lookup_certificates_issued_by (database, l->data,
+                                                                   g_tls_connection_get_interaction (base),
+                                                                   G_TLS_DATABASE_LOOKUP_KEYPAIR,
+                                                                   NULL, error);
+      if (error && *error)
+        break;
+
+      if (certificates)
+          certificate = g_object_ref (certificates->data);
+
+      for (c = certificates; c != NULL; c = g_list_next (c))
+        g_object_unref (c->data);
+      g_list_free (certificates);
     }
 
-  if (argc != 2)
-    {
-      g_printerr ("%s: %s\n", argv[0], "Need to specify hostname / unix socket name");
-      return 1;
-    }
-
-  if (use_udp && tls)
-    {
-      g_printerr ("DTLS (TLS over UDP) is not supported");
-      return 1;
-    }
+  for (l = accepted; l != NULL; l = g_list_next (l))
+    g_byte_array_unref (l->data);
+  g_list_free (accepted);
 
-  if (cancel_timeout)
-    {
-      cancellable = g_cancellable_new ();
-      g_thread_create (cancel_thread, cancellable, FALSE, NULL);
-    }
-  else
-    {
-      cancellable = NULL;
-    }
+  if (certificate == NULL && error && !*error)
+    g_set_error_literal (error, G_TLS_ERROR, G_TLS_ERROR_CERTIFICATE_REQUIRED,
+                         "Server requested a certificate, but could not find relevant certificate in database.");
+  return certificate;
+}
 
-  loop = g_main_loop_new (NULL, FALSE);
+static gboolean
+make_connection (const char       *argument,
+                GTlsCertificate  *certificate,
+                GCancellable     *cancellable,
+                GSocket         **socket,
+                GSocketAddress  **address,
+                GIOStream       **connection,
+                GInputStream    **istream,
+                GOutputStream   **ostream,
+                GError          **error)
+{
+  GSocketType socket_type;
+  GSocketFamily socket_family;
+  GSocketAddressEnumerator *enumerator;
+  GSocketConnectable *connectable;
+  GSocketAddress *src_address;
+  GTlsInteraction *interaction;
+  GError *err = NULL;
 
   if (use_udp)
     socket_type = G_SOCKET_TYPE_DATAGRAM;
@@ -123,116 +132,188 @@ main (int argc,
   else
     socket_family = G_SOCKET_FAMILY_IPV4;
 
-  socket = g_socket_new (socket_family, socket_type, 0, &error);
-  if (socket == NULL)
-    {
-      g_printerr ("%s: %s\n", argv[0], error->message);
-      return 1;
-    }
+  *socket = g_socket_new (socket_family, socket_type, 0, error);
+  if (*socket == NULL)
+    return FALSE;
 
   if (read_timeout)
-    g_socket_set_timeout (socket, read_timeout);
+    g_socket_set_timeout (*socket, read_timeout);
 
   if (unix_socket)
     {
       GSocketAddress *addr;
 
-      addr = socket_address_from_string (argv[1]);
+      addr = socket_address_from_string (argument);
       if (addr == NULL)
-       {
-         g_printerr ("%s: Could not parse '%s' as unix socket name\n", argv[0], argv[1]);
-         return 1;
-       }
+        {
+          g_set_error (error, G_IO_ERROR, G_IO_ERROR_FAILED,
+                       "Could not parse '%s' as unix socket name", argument);
+          return FALSE;
+        }
       connectable = G_SOCKET_CONNECTABLE (addr);
     }
   else
     {
-      connectable = g_network_address_parse (argv[1], 7777, &error);
+      connectable = g_network_address_parse (argument, 7777, error);
       if (connectable == NULL)
-       {
-         g_printerr ("%s: %s\n", argv[0], error->message);
-         return 1;
-       }
+        return FALSE;
     }
 
   enumerator = g_socket_connectable_enumerate (connectable);
   while (TRUE)
     {
-      address = g_socket_address_enumerator_next (enumerator, cancellable, &error);
-      if (address == NULL)
-       {
-         if (error == NULL)
-           g_printerr ("%s: No more addresses to try\n", argv[0]);
-         else
-           g_printerr ("%s: %s\n", argv[0], error->message);
-         return 1;
-       }
-
-      if (g_socket_connect (socket, address, cancellable, &error))
-       break;
-      g_printerr ("%s: Connection to %s failed: %s, trying next\n", argv[0], socket_address_to_string (address), error->message);
-      g_error_free (error);
-      error = NULL;
-
-      g_object_unref (address);
+      *address = g_socket_address_enumerator_next (enumerator, cancellable, error);
+      if (*address == NULL)
+        {
+          if (error != NULL && *error == NULL)
+            g_set_error_literal (error, G_IO_ERROR, G_IO_ERROR_FAILED,
+                                 "No more addresses to try");
+          return FALSE;
+        }
+
+      if (g_socket_connect (*socket, *address, cancellable, &err))
+        break;
+      g_message ("Connection to %s failed: %s, trying next\n", socket_address_to_string (*address), err->message);
+      g_clear_error (&err);
+
+      g_object_unref (*address);
     }
   g_object_unref (enumerator);
 
   g_print ("Connected to %s\n",
-          socket_address_to_string (address));
+           socket_address_to_string (*address));
 
-  src_address = g_socket_get_local_address (socket, &error);
+  src_address = g_socket_get_local_address (*socket, error);
   if (!src_address)
     {
-      g_printerr ("Error getting local address: %s\n",
-                 error->message);
-      return 1;
+      g_prefix_error (error, "Error getting local address: ");
+      return FALSE;
     }
+
   g_print ("local address: %s\n",
-          socket_address_to_string (src_address));
+           socket_address_to_string (src_address));
   g_object_unref (src_address);
 
   if (use_udp)
     {
-      connection = NULL;
-      istream = NULL;
-      ostream = NULL;
+      *connection = NULL;
+      *istream = NULL;
+      *ostream = NULL;
     }
   else
-    connection = G_IO_STREAM (g_socket_connection_factory_create_connection (socket));
+    *connection = G_IO_STREAM (g_socket_connection_factory_create_connection (*socket));
 
   if (tls)
     {
       GIOStream *tls_conn;
 
-      tls_conn = g_tls_client_connection_new (connection, connectable, &error);
+      tls_conn = g_tls_client_connection_new (*connection, connectable, error);
       if (!tls_conn)
-       {
-         g_printerr ("Could not create TLS connection: %s\n",
-                     error->message);
-         return 1;
-       }
+        {
+          g_prefix_error (error, "Could not create TLS connection: ");
+          return FALSE;
+        }
 
       g_signal_connect (tls_conn, "accept-certificate",
-                       G_CALLBACK (accept_certificate), NULL);
+                        G_CALLBACK (accept_certificate), NULL);
 
-      if (!g_tls_connection_handshake (G_TLS_CONNECTION (tls_conn),
-                                      cancellable, &error))
-       {
-         g_printerr ("Error during TLS handshake: %s\n",
-                     error->message);
-         return 1;
-       }
+      interaction = g_tls_console_interaction_new ();
+      g_tls_connection_set_interaction (G_TLS_CONNECTION (tls_conn), interaction);
+      g_object_unref (interaction);
 
-      g_object_unref (connection);
-      connection = G_IO_STREAM (tls_conn);
+      if (certificate)
+        g_tls_connection_set_certificate (G_TLS_CONNECTION (tls_conn), certificate);
+
+      g_object_unref (*connection);
+      *connection = G_IO_STREAM (tls_conn);
+
+      if (!g_tls_connection_handshake (G_TLS_CONNECTION (tls_conn),
+                                       cancellable, error))
+        {
+          g_prefix_error (error, "Error during TLS handshake: ");
+          return FALSE;
+        }
     }
   g_object_unref (connectable);
 
-  if (connection)
+  if (*connection)
     {
-      istream = g_io_stream_get_input_stream (connection);
-      ostream = g_io_stream_get_output_stream (connection);
+      *istream = g_io_stream_get_input_stream (*connection);
+      *ostream = g_io_stream_get_output_stream (*connection);
+    }
+
+  return TRUE;
+}
+
+int
+main (int argc,
+      char *argv[])
+{
+  GSocket *socket;
+  GSocketAddress *address;
+  GError *error = NULL;
+  GOptionContext *context;
+  GCancellable *cancellable;
+  GIOStream *connection;
+  GInputStream *istream;
+  GOutputStream *ostream;
+  GSocketAddress *src_address;
+  GTlsCertificate *certificate = NULL;
+  gint i;
+
+  address = NULL;
+  connection = NULL;
+
+  context = g_option_context_new (" <hostname>[:port] - Test GSocket client stuff");
+  g_option_context_add_main_entries (context, cmd_entries, NULL);
+  if (!g_option_context_parse (context, &argc, &argv, &error))
+    {
+      g_printerr ("%s: %s\n", argv[0], error->message);
+      return 1;
+    }
+
+  if (argc != 2)
+    {
+      g_printerr ("%s: %s\n", argv[0], "Need to specify hostname / unix socket name");
+      return 1;
+    }
+
+  if (use_udp && tls)
+    {
+      g_printerr ("DTLS (TLS over UDP) is not supported");
+      return 1;
+    }
+
+  if (cancel_timeout)
+    {
+      GThread *thread;
+      cancellable = g_cancellable_new ();
+      thread = g_thread_new ("cancel", cancel_thread, cancellable);
+      g_thread_unref (thread);
+    }
+  else
+    {
+      cancellable = NULL;
+    }
+
+  loop = g_main_loop_new (NULL, FALSE);
+
+  for (i = 0; i < 2; i++)
+    {
+      if (make_connection (argv[1], certificate, cancellable, &socket, &address,
+                           &connection, &istream, &ostream, &error))
+          break;
+
+      if (g_error_matches (error, G_TLS_ERROR, G_TLS_ERROR_CERTIFICATE_REQUIRED))
+        {
+          g_clear_error (&error);
+          certificate = lookup_client_certificate (G_TLS_CLIENT_CONNECTION (connection), &error);
+          if (certificate != NULL)
+            continue;
+        }
+
+      g_printerr ("%s: %s", argv[0], error->message);
+      return 1;
     }
 
   /* TODO: Test non-blocking connect/handshake */