2003-03-31 Havoc Pennington <hp@redhat.com>
[platform/upstream/dbus.git] / dbus / dbus-transport.c
index 8c6c7f1..96bc6b6 100644 (file)
@@ -82,16 +82,19 @@ live_messages_size_notify (DBusCounter *counter,
  * @param transport the transport being created.
  * @param vtable the subclass vtable.
  * @param server #TRUE if this transport is on the server side of a connection
+ * @param address the address of the transport
  * @returns #TRUE on success.
  */
 dbus_bool_t
 _dbus_transport_init_base (DBusTransport             *transport,
                            const DBusTransportVTable *vtable,
-                           dbus_bool_t                server)
+                           dbus_bool_t                server,
+                           const DBusString          *address)
 {
   DBusMessageLoader *loader;
   DBusAuth *auth;
   DBusCounter *counter;
+  char *address_copy;
   
   loader = _dbus_message_loader_new ();
   if (loader == NULL)
@@ -113,6 +116,24 @@ _dbus_transport_init_base (DBusTransport             *transport,
       _dbus_auth_unref (auth);
       _dbus_message_loader_unref (loader);
       return FALSE;
+    }  
+  
+  if (server)
+    {
+      _dbus_assert (address == NULL);
+      address_copy = NULL;
+    }
+  else
+    {
+      _dbus_assert (address != NULL);
+
+      if (!_dbus_string_copy_data (address, &address_copy))
+        {
+          _dbus_counter_unref (counter);
+          _dbus_auth_unref (auth);
+          _dbus_message_loader_unref (loader);
+          return FALSE;
+        }
     }
   
   transport->refcount = 1;
@@ -126,7 +147,12 @@ _dbus_transport_init_base (DBusTransport             *transport,
   transport->send_credentials_pending = !server;
   transport->receive_credentials_pending = server;
   transport->is_server = server;
-
+  transport->address = address_copy;
+  
+  transport->unix_user_function = NULL;
+  transport->unix_user_data = NULL;
+  transport->free_unix_user_data = NULL;
+  
   /* Try to default to something that won't totally hose the system,
    * but doesn't impose too much of a limitation.
    */
@@ -140,6 +166,9 @@ _dbus_transport_init_base (DBusTransport             *transport,
                             transport->max_live_messages_size,
                             live_messages_size_notify,
                             transport);
+
+  if (transport->address)
+    _dbus_verbose ("Initialized transport on address %s\n", transport->address);
   
   return TRUE;
 }
@@ -155,12 +184,16 @@ _dbus_transport_finalize_base (DBusTransport *transport)
 {
   if (!transport->disconnected)
     _dbus_transport_disconnect (transport);
+
+  if (transport->free_unix_user_data != NULL)
+    (* transport->free_unix_user_data) (transport->unix_user_data);
   
   _dbus_message_loader_unref (transport->loader);
   _dbus_auth_unref (transport->auth);
   _dbus_counter_set_notify (transport->live_messages_size,
                             0, NULL, NULL);
   _dbus_counter_unref (transport->live_messages_size);
+  dbus_free (transport->address);
 }
 
 /**
@@ -171,21 +204,29 @@ _dbus_transport_finalize_base (DBusTransport *transport)
  * DBusResultCode is a bit limiting here.
  * 
  * @param address the address.
- * @param result location to store reason for failure.
+ * @param error location to store reason for failure.
  * @returns new transport of #NULL on failure.
  */
 DBusTransport*
 _dbus_transport_open (const char     *address,
-                      DBusResultCode *result)
+                      DBusError      *error)
 {
   DBusTransport *transport;
   DBusAddressEntry **entries;
   int len, i;
+  const char *address_problem_type;
+  const char *address_problem_field;
+  const char *address_problem_other;
+
+  _DBUS_ASSERT_ERROR_IS_CLEAR (error);
   
-  if (!dbus_parse_address (address, &entries, &len, result))
+  if (!dbus_parse_address (address, &entries, &len, error))
     return NULL;
 
   transport = NULL;
+  address_problem_type = NULL;
+  address_problem_field = NULL;
+  address_problem_other = NULL;
   
   for (i = 0; i < len; i++)
     {
@@ -196,9 +237,13 @@ _dbus_transport_open (const char     *address,
          const char *path = dbus_address_entry_get_value (entries[i], "path");
 
          if (path == NULL)
-           goto bad_address;
+            {
+              address_problem_type = "unix";
+              address_problem_field = "path";              
+              goto bad_address;
+            }
 
-         transport = _dbus_transport_new_for_domain_socket (path, FALSE, result);
+         transport = _dbus_transport_new_for_domain_socket (path, error);
        }
       else if (strcmp (method, "tcp") == 0)
        {
@@ -208,17 +253,24 @@ _dbus_transport_open (const char     *address,
           long lport;
           dbus_bool_t sresult;
           
-         if (port == NULL)
-           goto bad_address;
+          if (port == NULL)
+            {
+              address_problem_type = "tcp";
+              address_problem_field = "port";
+              goto bad_address;
+            }
 
           _dbus_string_init_const (&str, port);
           sresult = _dbus_string_parse_int (&str, 0, &lport, NULL);
           _dbus_string_free (&str);
           
           if (sresult == FALSE || lport <= 0 || lport > 65535)
-            goto bad_address;
+            {
+              address_problem_other = "Port is not an integer between 0 and 65535";
+              goto bad_address;
+            }
           
-         transport = _dbus_transport_new_for_tcp_socket (host, lport, FALSE, result);
+         transport = _dbus_transport_new_for_tcp_socket (host, lport, error);
        }
 #ifdef DBUS_BUILD_TESTS
       else if (strcmp (method, "debug") == 0)
@@ -226,22 +278,33 @@ _dbus_transport_open (const char     *address,
          const char *name = dbus_address_entry_get_value (entries[i], "name");
 
          if (name == NULL)
-           goto bad_address;
+            {
+              address_problem_type = "debug";
+              address_problem_field = "name";
+              goto bad_address;
+            }
 
-         transport = _dbus_transport_debug_client_new (name, result);
+         transport = _dbus_transport_debug_client_new (name, error);
        }
       else if (strcmp (method, "debug-pipe") == 0)
        {
          const char *name = dbus_address_entry_get_value (entries[i], "name");
 
-         if (name == NULL)
-           goto bad_address;
+          if (name == NULL)
+            {
+              address_problem_type = "debug-pipe";
+              address_problem_field = "name";
+              goto bad_address;
+            }
 
-         transport = _dbus_transport_debug_pipe_new (name, result);
+         transport = _dbus_transport_debug_pipe_new (name, error);
        }
 #endif
       else
-       goto bad_address;
+        {
+          address_problem_other = "Unknown address type (examples of valid types are \"unix\" and \"tcp\")";
+          goto bad_address;
+        }
 
       if (transport)
        break;    
@@ -252,7 +315,15 @@ _dbus_transport_open (const char     *address,
 
  bad_address:
   dbus_address_entries_free (entries);
-  dbus_set_result (result, DBUS_RESULT_BAD_ADDRESS);
+
+  if (address_problem_type != NULL)
+    dbus_set_error (error, DBUS_ERROR_BAD_ADDRESS,
+                    "Address of type %s was missing argument %s",
+                    address_problem_type, address_problem_field);
+  else
+    dbus_set_error (error, DBUS_ERROR_BAD_ADDRESS,
+                    "Could not parse address: %s",
+                    address_problem_other);
 
   return NULL;
 }
@@ -265,6 +336,8 @@ _dbus_transport_open (const char     *address,
 void
 _dbus_transport_ref (DBusTransport *transport)
 {
+  _dbus_assert (transport->refcount > 0);
+  
   transport->refcount += 1;
 }
 
@@ -306,14 +379,12 @@ _dbus_transport_disconnect (DBusTransport *transport)
   if (transport->disconnected)
     return;
 
-  _dbus_transport_ref (transport);
   (* transport->vtable->disconnect) (transport);
   
   transport->disconnected = TRUE;
 
-  _dbus_connection_notify_disconnected (transport->connection);
-  
-  _dbus_transport_unref (transport);
+  if (transport->connection)
+    _dbus_connection_notify_disconnected (transport->connection);
 }
 
 /**
@@ -334,6 +405,8 @@ _dbus_transport_get_is_connected (DBusTransport *transport)
  * Returns #TRUE if we have been authenticated.  Will return #TRUE
  * even if the transport is disconnected.
  *
+ * @todo needs to drop connection->mutex when calling the unix_user_function
+ *
  * @param transport the transport
  * @returns whether we're authenticated
  */
@@ -344,10 +417,12 @@ _dbus_transport_get_is_authenticated (DBusTransport *transport)
     return TRUE;
   else
     {
+      dbus_bool_t maybe_authenticated;
+      
       if (transport->disconnected)
         return FALSE;
       
-      transport->authenticated =
+      maybe_authenticated =
         (!(transport->send_credentials_pending ||
            transport->receive_credentials_pending)) &&
         _dbus_auth_do_work (transport->auth) == DBUS_AUTH_STATE_AUTHENTICATED;
@@ -360,64 +435,106 @@ _dbus_transport_get_is_authenticated (DBusTransport *transport)
        * Or they may give certain identities extra privileges.
        */
       
-      if (transport->authenticated && transport->is_server)
+      if (maybe_authenticated && transport->is_server)
         {
           DBusCredentials auth_identity;
-          DBusCredentials our_identity;
 
-          _dbus_credentials_from_current_process (&our_identity);
           _dbus_auth_get_identity (transport->auth, &auth_identity);
-          
-          if (!_dbus_credentials_match (&our_identity,
-                                        &auth_identity))
+
+          if (transport->unix_user_function != NULL)
             {
-              _dbus_verbose ("Client authorized as UID %d but our UID is %d, disconnecting\n",
-                             auth_identity.uid, our_identity.uid);
-              _dbus_transport_disconnect (transport);
-              return FALSE;
+              /* FIXME we hold the connection lock here and should drop it */
+              if (!(* transport->unix_user_function) (transport->connection,
+                                                      auth_identity.uid,
+                                                      transport->unix_user_data))
+                {
+                  _dbus_verbose ("Client UID %d was rejected, disconnecting\n",
+                                 auth_identity.uid);
+                  _dbus_transport_disconnect (transport);
+                  return FALSE;
+                }
+              else
+                {
+                  _dbus_verbose ("Client UID %d authorized\n", auth_identity.uid);
+                }
             }
           else
             {
-              _dbus_verbose ("Client authorized as UID %d matching our UID %d\n",
-                             auth_identity.uid, our_identity.uid);
+              DBusCredentials our_identity;
+              
+              _dbus_credentials_from_current_process (&our_identity);
+              
+              if (!_dbus_credentials_match (&our_identity,
+                                            &auth_identity))
+                {
+                  _dbus_verbose ("Client authorized as UID %d but our UID is %d, disconnecting\n",
+                                 auth_identity.uid, our_identity.uid);
+                  _dbus_transport_disconnect (transport);
+                  return FALSE;
+                }
+              else
+                {
+                  _dbus_verbose ("Client authorized as UID %d matching our UID %d\n",
+                                 auth_identity.uid, our_identity.uid);
+                }
             }
         }
+
+      transport->authenticated = maybe_authenticated;
       
       return transport->authenticated;
     }
 }
 
 /**
+ * Gets the address of a transport. It will be
+ * #NULL for a server-side transport.
+ *
+ * @param transport the transport
+ * @returns transport's address
+ */
+const char*
+_dbus_transport_get_address (DBusTransport *transport)
+{
+  return transport->address;
+}
+
+/**
  * Handles a watch by reading data, writing data, or disconnecting
  * the transport, as appropriate for the given condition.
  *
  * @param transport the transport.
  * @param watch the watch.
  * @param condition the current state of the watched file descriptor.
+ * @returns #FALSE if not enough memory to fully handle the watch
  */
-void
+dbus_bool_t
 _dbus_transport_handle_watch (DBusTransport           *transport,
                               DBusWatch               *watch,
                               unsigned int             condition)
 {
+  dbus_bool_t retval;
+  
   _dbus_assert (transport->vtable->handle_watch != NULL);
 
   if (transport->disconnected)
-    return;
+    return TRUE;
 
   if (dbus_watch_get_fd (watch) < 0)
     {
       _dbus_warn ("Tried to handle an invalidated watch; this watch should have been removed\n");
-      return;
+      return TRUE;
     }
   
   _dbus_watch_sanitize_condition (watch, &condition);
 
   _dbus_transport_ref (transport);
   _dbus_watch_ref (watch);
-  (* transport->vtable->handle_watch) (transport, watch, condition);
+  retval = (* transport->vtable->handle_watch) (transport, watch, condition);
   _dbus_watch_unref (watch);
   _dbus_transport_unref (transport);
+
+  return retval;
 }
 
 /**
@@ -506,6 +623,177 @@ _dbus_transport_do_iteration (DBusTransport  *transport,
   _dbus_transport_unref (transport);
 }
 
+static dbus_bool_t
+recover_unused_bytes (DBusTransport *transport)
+{
+  if (_dbus_auth_do_work (transport->auth) != DBUS_AUTH_STATE_AUTHENTICATED_WITH_UNUSED_BYTES)
+    return TRUE;
+  
+  if (_dbus_auth_needs_decoding (transport->auth))
+    {
+      DBusString plaintext;
+      const DBusString *encoded;
+      DBusString *buffer;
+      int orig_len;
+      
+      if (!_dbus_string_init (&plaintext))
+        goto nomem;
+      
+      _dbus_auth_get_unused_bytes (transport->auth,
+                                   &encoded);
+
+      if (!_dbus_auth_decode_data (transport->auth,
+                                   encoded, &plaintext))
+        {
+          _dbus_string_free (&plaintext);
+          goto nomem;
+        }
+      
+      _dbus_message_loader_get_buffer (transport->loader,
+                                       &buffer);
+      
+      orig_len = _dbus_string_get_length (buffer);
+      
+      if (!_dbus_string_move (&plaintext, 0, buffer,
+                              orig_len))
+        {
+          _dbus_string_free (&plaintext);
+          goto nomem;
+        }
+      
+      _dbus_verbose (" %d unused bytes sent to message loader\n", 
+                     _dbus_string_get_length (buffer) -
+                     orig_len);
+      
+      _dbus_message_loader_return_buffer (transport->loader,
+                                          buffer,
+                                          _dbus_string_get_length (buffer) -
+                                          orig_len);
+
+      _dbus_auth_delete_unused_bytes (transport->auth);
+      
+      _dbus_string_free (&plaintext);
+    }
+  else
+    {
+      const DBusString *bytes;
+      DBusString *buffer;
+      int orig_len;
+      dbus_bool_t succeeded;
+
+      _dbus_message_loader_get_buffer (transport->loader,
+                                       &buffer);
+                
+      orig_len = _dbus_string_get_length (buffer);
+                
+      _dbus_auth_get_unused_bytes (transport->auth,
+                                   &bytes);
+
+      succeeded = TRUE;
+      if (!_dbus_string_copy (bytes, 0, buffer, _dbus_string_get_length (buffer)))
+        succeeded = FALSE;
+      
+      _dbus_verbose (" %d unused bytes sent to message loader\n", 
+                     _dbus_string_get_length (buffer) -
+                     orig_len);
+      
+      _dbus_message_loader_return_buffer (transport->loader,
+                                          buffer,
+                                          _dbus_string_get_length (buffer) -
+                                          orig_len);
+
+      if (succeeded)
+        _dbus_auth_delete_unused_bytes (transport->auth);
+      else
+        goto nomem;
+    }
+
+  return TRUE;
+
+ nomem:
+  _dbus_verbose ("Not enough memory to transfer unused bytes from auth conversation\n");
+  return FALSE;
+}
+
+/**
+ * Reports our current dispatch status (whether there's buffered
+ * data to be queued as messages, or not, or we need memory).
+ *
+ * @param transport the transport
+ * @returns current status
+ */
+DBusDispatchStatus
+_dbus_transport_get_dispatch_status (DBusTransport *transport)
+{
+  if (_dbus_counter_get_value (transport->live_messages_size) >= transport->max_live_messages_size)
+    return DBUS_DISPATCH_COMPLETE; /* complete for now */
+
+  if (!_dbus_transport_get_is_authenticated (transport))
+    {
+      switch (_dbus_auth_do_work (transport->auth))
+        {
+        case DBUS_AUTH_STATE_WAITING_FOR_MEMORY:
+          return DBUS_DISPATCH_NEED_MEMORY;
+        case DBUS_AUTH_STATE_AUTHENTICATED_WITH_UNUSED_BYTES:
+          if (!recover_unused_bytes (transport))
+            return DBUS_DISPATCH_NEED_MEMORY;
+          break;
+        default:
+          break;
+        }
+    }
+  
+  if (!_dbus_message_loader_queue_messages (transport->loader))
+    return DBUS_DISPATCH_NEED_MEMORY;
+
+  if (_dbus_message_loader_peek_message (transport->loader) != NULL)
+    return DBUS_DISPATCH_DATA_REMAINS;
+  else
+    return DBUS_DISPATCH_COMPLETE;
+}
+
+/**
+ * Processes data we've read while handling a watch, potentially
+ * converting some of it to messages and queueing those messages on
+ * the connection.
+ *
+ * @param transport the transport
+ * @returns #TRUE if we had enough memory to queue all messages
+ */
+dbus_bool_t
+_dbus_transport_queue_messages (DBusTransport *transport)
+{
+  DBusDispatchStatus status;
+  
+  /* Queue any messages */
+  while ((status = _dbus_transport_get_dispatch_status (transport)) == DBUS_DISPATCH_DATA_REMAINS)
+    {
+      DBusMessage *message;
+      DBusList *link;
+
+      link = _dbus_message_loader_pop_message_link (transport->loader);
+      _dbus_assert (link != NULL);
+      
+      message = link->data;
+      
+      _dbus_verbose ("queueing received message %p\n", message);
+
+      _dbus_message_add_size_counter (message, transport->live_messages_size);
+
+      /* pass ownership of link and message ref to connection */
+      _dbus_connection_queue_received_message_link (transport->connection,
+                                                    link);
+    }
+
+  if (_dbus_message_loader_get_is_corrupted (transport->loader))
+    {
+      _dbus_verbose ("Corrupted message stream, disconnecting\n");
+      _dbus_transport_disconnect (transport);
+    }
+
+  return status != DBUS_DISPATCH_NEED_MEMORY;
+}
+
 /**
  * See dbus_connection_set_max_message_size().
  *
@@ -561,4 +849,62 @@ _dbus_transport_get_max_live_messages_size (DBusTransport  *transport)
   return transport->max_live_messages_size;
 }
 
+/**
+ * See dbus_connection_get_unix_user().
+ *
+ * @param transport the transport
+ * @param uid return location for the user ID
+ * @returns #TRUE if uid is filled in with a valid user ID
+ */
+dbus_bool_t
+_dbus_transport_get_unix_user (DBusTransport *transport,
+                               unsigned long *uid)
+{
+  DBusCredentials auth_identity;
+
+  *uid = _DBUS_INT_MAX; /* better than some root or system user in
+                         * case of bugs in the caller. Caller should
+                         * never use this value on purpose, however.
+                         */
+  
+  if (!transport->authenticated)
+    return FALSE;
+  
+  _dbus_auth_get_identity (transport->auth, &auth_identity);
+
+  if (auth_identity.uid >= 0)
+    {
+      *uid = auth_identity.uid;
+      return TRUE;
+    }
+  else
+    return FALSE;
+}
+
+/**
+ * See dbus_connection_set_unix_user_function().
+ *
+ * @param transport the transport
+ * @param function the predicate
+ * @param data data to pass to the predicate
+ * @param free_data_function function to free the data
+ * @param old_data the old user data to be freed
+ * @param old_free_data_function old free data function to free it with
+ */
+void
+_dbus_transport_set_unix_user_function (DBusTransport             *transport,
+                                        DBusAllowUnixUserFunction  function,
+                                        void                      *data,
+                                        DBusFreeFunction           free_data_function,
+                                        void                     **old_data,
+                                        DBusFreeFunction          *old_free_data_function)
+{  
+  *old_data = transport->unix_user_data;
+  *old_free_data_function = transport->free_unix_user_data;
+
+  transport->unix_user_function = function;
+  transport->unix_user_data = data;
+  transport->free_unix_user_data = free_data_function;
+}
+
 /** @} */