Fix race condition that allows disconnecting all clients too early 59/315959/2
authorKrzysztof Malysa <k.malysa@samsung.com>
Fri, 6 Dec 2024 12:16:28 +0000 (13:16 +0100)
committerKrzysztof Malysa <k.malysa@samsung.com>
Fri, 6 Dec 2024 12:32:27 +0000 (13:32 +0100)
Change-Id: Id0cfd9596c20f4fb09f64172fcc4b92aacdf4e6c

src/service/sockets/SocketManager.cpp
src/service/sockets/SocketManager.h

index 43816bf4a5f9bfbff68f80df50ed1e3c8a5d31bf..317430d55eaa2d68770133a6ed1c416262ce1206 100644 (file)
@@ -134,9 +134,11 @@ void SocketManager::mainLoop() {
 
     m_nonReadOnlyWorkerThread = std::thread([this] {
         auto markAllFdsAsWriteReady = [&] {
-            auto guard = std::lock_guard{m_fdsLock};
+            auto guard = std::lock_guard{m_fdsAndDisconnectAllClientsLock};
             for (size_t fd = 0; fd < m_fds.size(); ++fd)
                 m_fds[fd].m_writeReady->store(true);
+            m_disconnectAllClients =
+                m_disconnectAllClients || std::exchange(m_needToDisconnectAllClients, false);
         };
         for (;;) {
             auto reqV = m_nonReadOnlyRequests.recv();
@@ -190,11 +192,12 @@ void SocketManager::mainLoop() {
                 //   check them in a non-racy way
                 // Please note that disconnecting all clients has to happen before actual
                 // writing of the response in the main thread. This code is safe because marking
-                // all clients to disconnect happens above in the req.rquest->execute() call and
-                // if the main thread observes m_writeReady == true set in the loop below, it
-                // will check for disconnecting all clients before actual writing of the
+                // all clients to disconnect happens under the same mutex as marking sokects as
+                // writable and if the main thread observes m_writeReady == true set in the loop
+                // below, it will check for disconnecting all clients before actual writing of the
                 // responses.
                 markAllFdsAsWriteReady();
+
                 LOGD("non-read-only logic worker thread: sending response to request with socket fd"
                      " [%i] with generation [%" PRIu64 "] and sequence number [%i] of size [%i]",
                      req.socketFd, req.socketFdGeneration,
@@ -289,6 +292,9 @@ void SocketManager::mainLoop() {
         }
         m_fdsToCheckForReadButNotProcessedRequests.clear();
 
+        // Need to do below checks under mutex to see both hasDataToWrite() and
+        // m_disconnectAllClients changes.
+        auto guard = std::lock_guard{m_fdsAndDisconnectAllClientsLock};
         // TODO: do it more optimally by some marking which fds have data to write
         LOGD("checking sockets < %zu for data to write", m_fds.size());
         for (size_t fd = 0; fd < m_fds.size(); ++fd) {
@@ -300,11 +306,11 @@ void SocketManager::mainLoop() {
         }
 
         // If we noticed that some socket has data to write. We have to check
-        // m_needToDisconnectAllClients and disconnect all clients if true BEFORE writing
-        // the response. All this to retain soundness of the client and admin API.
+        // m_disconnectAllClients and disconnect all clients if true BEFORE writing the response.
+        // All this to retain soundness of the client and admin API.
         // Now we are safe to disconnect all clients without invalidating any file descriptor that
         // will be used somewhere else in this function.
-        if (m_needToDisconnectAllClients.exchange(false)) {
+        if (std::exchange(m_disconnectAllClients, false)) {
             LOGD("SocketManager disconnecting all clients");
             // m_fds.size() may change during iteration of the loop (closeSocket() calls
             // shrinkFds())
@@ -719,7 +725,7 @@ void SocketManager::createNonReadOnlyRequestResultsNumEventFd() {
 Descriptor &SocketManager::createDescriptorWatchedForRead(int fd, bool client) {
     assert(fd >= 0);
     if (static_cast<size_t>(fd) >= m_fds.size()) {
-        auto guard = std::lock_guard{m_fdsLock};
+        auto guard = std::lock_guard{m_fdsAndDisconnectAllClientsLock};
         m_fds.resize(fd + 1);
     }
     auto &desc = m_fds[fd];
@@ -737,7 +743,7 @@ void SocketManager::shrinkFds() {
     while (newSize > 0 && !m_fds[newSize - 1].isUsed()) {
         --newSize;
     }
-    auto guard = std::lock_guard{m_fdsLock};
+    auto guard = std::lock_guard{m_fdsAndDisconnectAllClientsLock};
     m_fds.resize(newSize);
 }
 
index 2adb4ee3e3371d8c91342b88aae6020a44afec75..754792147158b14e3a809ee4659244505ae12b97 100644 (file)
@@ -72,7 +72,7 @@ public:
     void signalDisconnectAllClients() {
         if (std::this_thread::get_id() != m_nonReadOnlyWorkerThread.get_id())
             throw UnexpectedErrorException{"signalDisconnectAllClients() call in the wrong thread"};
-        m_needToDisconnectAllClients.store(true);
+        m_needToDisconnectAllClients = true;
     }
 
     // Only safe to call from m_nonReadOnlyWorkerThread
@@ -114,19 +114,20 @@ private:
     int m_nonReadOnlyRequestResultsNumEventFd = -1;
     // Only accessed from m_nonReadOnlyWorkerThread
     bool m_needToStopMainLoop = false;
+    bool m_needToDisconnectAllClients = false;
     // Set in m_nonReadOnlyWorkerThread, read and reset in the main thread.
     // The problem is that handling requests in the m_nonReadOnlyWorkerThread may write to buffers
     // in m_fds, so that the main thread will write responses before the main thread will receive
     // notification through m_nonReadOnlyRequestResultsNumEventFd about the request being completed.
     // But to avoid situation where clients are still connected after the response was received we
-    // have to disconnect them before writing of the response happens, so we need "faster"
+    // have to disconnect them before sending of the response happens, so we need "faster"
     // notification method. That is why this is used instead of signaling the need to disconnect all
     // clients through NonReadOnlyRequestResult.
-    std::atomic_bool m_needToDisconnectAllClients = false;
+    bool m_disconnectAllClients = false;
 
     size_t m_openFdsLimit;
 
-    std::mutex m_fdsLock;
+    std::recursive_mutex m_fdsAndDisconnectAllClientsLock;
     std::vector<Descriptor> m_fds;
     Epoll m_epoll;
     std::vector<int> m_fdsToCheckForReadButNotProcessedRequests;