Refactor SocketManager's timeout queue
[platform/core/security/key-manager.git] / unit-tests / test_socket-manager.cpp
1 /*
2  *  Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License
15  */
16
17 #include <poll.h>
18
19 #include <cerrno>
20
21 #include <thread>
22 #include <mutex>
23 #include <condition_variable>
24 #include <chrono>
25 #include <unordered_map>
26 #include <string>
27 #include <sstream>
28 #include <vector>
29
30 #include <boost/test/test_tools.hpp>
31 #include <boost_macros_wrapper.h>
32
33 #include <client-common.h>
34 #include <socket-manager.h>
35
36 #include <dpl/log/log.h>
37
38 using namespace CKM;
39
40 namespace {
41
42 constexpr char SERVICE_SOCKET_TEST[] = "/tmp/.central-key-manager-test.sock";
43 constexpr CKM::InterfaceID SOCKET_ID_TEST = 42;
44 constexpr std::chrono::seconds CV_TIMEOUT(10);
45
46 struct TestSocketManager : public SocketManager {
47         size_t TimeoutQueueSize() const { return m_timeoutQueue.size(); }
48 };
49
50 #define THREAD_REQUIRE_MESSAGE(test, message)                                      \
51         do {                                                                           \
52                 if (!(test)) {                                                             \
53                         std::ostringstream os;                                                 \
54                         os << __FILE__ << ":" << __LINE__ << " " << #test << " " << message;   \
55                         throw std::runtime_error(os.str());                                    \
56                 }                                                                          \
57         } while (0)
58
59 #define THREAD_REQUIRE(test) THREAD_REQUIRE_MESSAGE(test, "")
60
61 class TestService : public GenericSocketService {
62 public:
63         ServiceDescriptionVector GetServiceDescription() override {
64                 return ServiceDescriptionVector {
65                         {SERVICE_SOCKET_TEST, "", SOCKET_ID_TEST}
66                 };
67         }
68         void Event(const AcceptEvent &e) override {
69                 std::unique_lock<std::mutex> lock(m_mutex);
70
71                 THREAD_REQUIRE_MESSAGE(m_connections.empty() || m_connections.back().client != -1,
72                                        "Unexpected server entry waiting for client match " <<
73                                        m_connections.back().server);
74
75                 m_connections.push_back({-1 , e.connectionID.sock});
76
77                 LogDebug("AcceptEvent. Added: ? <=>" << e.connectionID.sock);
78
79                 CompareSizes();
80
81                 lock.unlock();
82                 m_cv.notify_one();
83         }
84         void Event(const WriteEvent &) override {}
85         void Event(const ReadEvent &) override {}
86         void Event(const CloseEvent &e) override {
87                 std::unique_lock<std::mutex> lock(m_mutex);
88                 THREAD_REQUIRE(!m_connections.empty());
89
90                 auto serverMatch = [&](const SocketPair& pair){
91                         return pair.server == e.connectionID.sock;
92                 };
93                 auto it = std::find_if(m_connections.begin(), m_connections.end(), serverMatch);
94
95                 THREAD_REQUIRE_MESSAGE(it != m_connections.end(),
96                                        "Can't find connection for server socket = " << e.connectionID.sock);
97
98                 LogDebug("CloseEvent. Removing: " << it->client << "<=>" << it->server);
99                 THREAD_REQUIRE(it->client != -1);
100
101                 m_connections.erase(it);
102
103                 CompareSizes();
104
105                 lock.unlock();
106                 m_cv.notify_one();
107         }
108         void Event(const SecurityEvent &) override {}
109
110         void Start() override {}
111         void Stop() override {}
112
113         void ConnectAndWait(SockRAII& client) {
114                 std::unique_lock<std::mutex> lock(m_mutex);
115
116                 THREAD_REQUIRE_MESSAGE(m_connections.empty() || m_connections.back().client != -1,
117                                        "Unexpected server entry waiting for client match " <<
118                                        m_connections.back().server);
119
120                 int ret = client.connect(GetServiceDescription()[0].serviceHandlerPath.c_str());
121                 BOOST_REQUIRE(ret == CKM_API_SUCCESS);
122
123                 LogDebug("Connected. Waiting for AcceptEvent for: " << client.get() << "<=> ?");
124
125                 BOOST_REQUIRE(m_cv.wait_for(lock, CV_TIMEOUT, [&]{ return AcceptEventArrived(); }));
126
127                 m_connections.back().client = client.get();
128
129                 LogDebug("Accepted. Matched client & server: " << m_connections.back().client << "<=>" <<
130                          m_connections.back().server);
131         }
132
133         void DisconnectAndWait(SockRAII& client) {
134                 int sock = client.get();
135                 client.disconnect();
136
137                 LogDebug("Disconnected. Waiting for CloseEvent for: " << sock << "<=> ?");
138
139                 std::unique_lock<std::mutex> lock(m_mutex);
140                 BOOST_REQUIRE(m_cv.wait_for(lock, CV_TIMEOUT, [&]{ return ClientAbsent(sock); }));
141         }
142
143         void WaitForRemainingClosures() {
144                 std::unique_lock<std::mutex> lock(m_mutex);
145                 if (!m_connections.empty())
146                         BOOST_TEST_MESSAGE("Waiting for remaining " << m_connections.size() << " to close.");
147
148                 BOOST_REQUIRE(m_cv.wait_for(lock, std::chrono::seconds(OVERRIDE_SOCKET_TIMEOUT + 2), [&]{
149                         return m_connections.empty();
150                 }));
151
152                 CompareSizes();
153         }
154
155 private:
156         bool ClientAbsent(int client) const {
157                 auto it = std::find_if(m_connections.begin(),
158                                        m_connections.end(), [&](const SocketPair& pair){
159                         return pair.client == client;
160                 });
161                 return it == m_connections.end();
162         }
163
164         bool AcceptEventArrived() const {
165                 return !m_connections.empty() && m_connections.back().client == -1;
166         }
167
168         void CompareSizes() const {
169                 auto manager = static_cast<TestSocketManager*>(m_serviceManager);
170                 THREAD_REQUIRE(m_connections.size() == manager->TimeoutQueueSize());
171         }
172
173         std::mutex m_mutex;
174         struct SocketPair {
175                 int client;
176                 int server;
177         };
178         std::vector<SocketPair> m_connections;
179         std::condition_variable m_cv;
180 };
181
182 class SocketManagerLoop {
183 public:
184         explicit SocketManagerLoop(SocketManager& manager) :
185                 m_manager(manager),
186                 m_thread([&]{
187                         try {
188                                 manager.MainLoop();
189                         } catch (const std::exception& e) {
190                                 m_exception = true;
191                                 m_what = e.what();
192                         } catch (...) {
193                                 m_exception = true;
194                         }
195                 })
196         {
197         }
198
199         ~SocketManagerLoop() {
200                 m_manager.MainLoopStop();
201                 m_thread.join();
202                 BOOST_CHECK_MESSAGE(!m_exception, m_what);
203         }
204
205 private:
206         bool m_exception = false;
207         std::string m_what = "Unknown exception";
208         SocketManager& m_manager;
209         std::thread m_thread;
210 };
211
212 } // namespace
213
214 BOOST_AUTO_TEST_SUITE(SOCKET_MANAGER_TEST)
215
216 POSITIVE_TEST_CASE(StressTestGrowingTimeoutQueue)
217 {
218         constexpr unsigned INITIAL_CONNECTIONS = 20;
219         constexpr unsigned REPEATS = 100000;
220         constexpr auto INTERVAL = REPEATS/10;
221
222         int ret = unlink(SERVICE_SOCKET_TEST);
223         int err = errno;
224         BOOST_REQUIRE(ret == 0 || (ret == -1 && err == ENOENT));
225
226         TestSocketManager manager;
227         auto service = new TestService();
228         manager.RegisterSocketService(service);
229
230         SocketManagerLoop loop(manager);
231
232         {
233                 SockRAII socket[INITIAL_CONNECTIONS];
234                 for (unsigned i=0;i<INITIAL_CONNECTIONS;i++)
235                         service->ConnectAndWait(socket[i]);
236
237                 BOOST_REQUIRE(manager.TimeoutQueueSize() == INITIAL_CONNECTIONS);
238
239                 SockRAII socket2;
240                 for(unsigned i=0;i<REPEATS;i++) {
241                         service->ConnectAndWait(socket2);
242                         service->DisconnectAndWait(socket2);
243
244                         if ((i + 1) % INTERVAL == 0)
245                                 BOOST_TEST_MESSAGE("Creating connections: " << i + 1 << "/" << REPEATS);
246                 }
247
248                 // wait for remaining connections to close if any
249                 service->WaitForRemainingClosures();
250         }
251 }
252
253 BOOST_AUTO_TEST_SUITE_END()