f8695951d4c31b8332a4af5a0be6e6d216b579a5
[platform/core/security/key-manager.git] / unit-tests / test_socket-manager.cpp
1 /*
2  *  Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  *  Licensed under the Apache License, Version 2.0 (the "License");
5  *  you may not use this file except in compliance with the License.
6  *  You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *  Unless required by applicable law or agreed to in writing, software
11  *  distributed under the License is distributed on an "AS IS" BASIS,
12  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *  See the License for the specific language governing permissions and
14  *  limitations under the License
15  */
16
17 #include <poll.h>
18
19 #include <cstdlib>
20 #include <ctime>
21 #include <cerrno>
22
23 #include <thread>
24 #include <mutex>
25 #include <condition_variable>
26 #include <chrono>
27 #include <utility>
28 #include <string>
29 #include <sstream>
30 #include <vector>
31
32 #include <boost/test/test_tools.hpp>
33 #include <boost_macros_wrapper.h>
34
35 #include <test_common.h>
36 #include <dpl/log/log.h>
37
38 #include <client-common.h>
39 #include <socket-manager.h>
40 #include <message-buffer.h>
41
42 using namespace CKM;
43
44 namespace {
45
46 size_t Random(size_t max)
47 {
48         static unsigned int seed = ::time(nullptr);
49         return ::rand_r(&seed) % max;
50 }
51
52 constexpr char SERVICE_SOCKET_TEST[] = "/tmp/.central-key-manager-test.sock";
53 constexpr CKM::InterfaceID SOCKET_ID_TEST = 42;
54 constexpr std::chrono::seconds CV_TIMEOUT(10);
55
56 struct TestSocketManager final : public SocketManager {
57         size_t TimeoutQueueSize() const { return m_timeoutQueue.size(); }
58 };
59
60 #define THREAD_REQUIRE_MESSAGE(test, message)                                      \
61         do {                                                                           \
62                 if (!(test)) {                                                             \
63                         std::ostringstream os;                                                 \
64                         os << __FILE__ << ":" << __LINE__ << " " << #test << " " << message;   \
65                         throw std::runtime_error(os.str());                                    \
66                 }                                                                          \
67         } while (0)
68
69 #define THREAD_REQUIRE(test) THREAD_REQUIRE_MESSAGE(test, "")
70
71 class NoOpService : public GenericSocketService {
72         void Event(const AcceptEvent &) override {}
73         void Event(const WriteEvent &) override {}
74         void Event(const ReadEvent &) override {}
75         void Event(const CloseEvent &) override {}
76         void Event(const SecurityEvent &) override {}
77
78         void Start() override {}
79         void Stop() override {}
80 };
81
82 class TestConnection final : public ServiceConnection {
83 public:
84         explicit TestConnection(const std::string& socketPath) :
85                 ServiceConnection(socketPath.c_str()),
86                 m_id(m_counter++) {
87                 auto ret = prepareConnection();
88                 BOOST_REQUIRE_MESSAGE(ret == CKM_API_SUCCESS, "ret = " << ret);
89         }
90
91         int Send(const CKM::RawBuffer &send_buf) {
92                 m_sent += send_buf.size();
93                 return ServiceConnection::send(SerializeMessage(send_buf));
94         }
95
96         template <typename T>
97         void Receive(const T& logReceived) {
98                 if (m_sent == 0) {
99                         // expect timeout
100                         auto ret = m_socket.waitForSocket(POLLIN, 100);
101                         BOOST_REQUIRE_MESSAGE(ret == 0, "ret = " << ret);
102                         logReceived(0);
103                         return;
104                 }
105
106                 int ret = ServiceConnection::receive(m_recv);
107                 BOOST_REQUIRE_MESSAGE(ret == CKM_API_SUCCESS, "ret = " << ret);
108                 BOOST_REQUIRE(m_recv.Ready());
109                 while (m_recv.Ready()) {
110                         RawBuffer tmp;
111                         m_recv.Deserialize(tmp);
112                         const auto size = tmp.size();
113                         BOOST_REQUIRE_MESSAGE(size <= m_sent, size << ">" << m_sent);
114                         logReceived(size);
115                         m_sent -= size;
116                 }
117         }
118
119         size_t GetId() const { return m_id; }
120
121 private:
122         size_t m_sent = 0;
123         size_t m_id;
124         MessageBuffer m_recv;
125         static inline size_t m_counter = 0;
126 };
127
128 class SocketManagerLoop final {
129 public:
130         explicit SocketManagerLoop(SocketManager& manager) :
131                 m_manager(manager),
132                 m_thread([&]{
133                         try {
134                                 manager.MainLoop();
135                         } catch (const std::exception& e) {
136                                 m_exception = true;
137                                 m_what = e.what();
138                         } catch (...) {
139                                 m_exception = true;
140                         }
141                 })
142         {
143         }
144
145         ~SocketManagerLoop() {
146                 m_manager.MainLoopStop();
147                 m_thread.join();
148
149                 BOOST_CHECK_MESSAGE(!m_exception, m_what);
150         }
151
152 private:
153         bool m_exception = false;
154         std::string m_what = "Unknown exception";
155         SocketManager& m_manager;
156         std::thread m_thread;
157 };
158
159 std::string Id2SockPath(int id) {
160         return std::string(SERVICE_SOCKET_TEST) + std::to_string(id);
161 }
162
163 void unlinkIfExists(const char* path) {
164         int ret = unlink(path);
165         int err = errno;
166         BOOST_REQUIRE(ret == 0 || (ret == -1 && err == ENOENT));
167 }
168
169 } // namespace
170
171 BOOST_AUTO_TEST_SUITE(SOCKET_MANAGER_TEST)
172
173 POSITIVE_TEST_CASE(StressTestGrowingTimeoutQueue)
174 {
175         constexpr unsigned INITIAL_CONNECTIONS = 20;
176         constexpr unsigned REPEATS = 100000;
177         constexpr auto INTERVAL = REPEATS/10;
178
179         class TestService final : public NoOpService {
180         public:
181                 ServiceDescriptionVector GetServiceDescription() override {
182                         return ServiceDescriptionVector {
183                                 {SERVICE_SOCKET_TEST, "", SOCKET_ID_TEST}
184                         };
185                 }
186                 void Event(const AcceptEvent &e) override {
187                         std::unique_lock<std::mutex> lock(m_mutex);
188
189                         THREAD_REQUIRE_MESSAGE(m_connections.empty() || m_connections.back().client != -1,
190                                                "Unexpected server entry waiting for client match " <<
191                                                m_connections.back().server);
192
193                         m_connections.push_back({-1 , e.connectionID.sock});
194
195                         LogDebug("AcceptEvent. Added: ? <=>" << e.connectionID.sock);
196
197                         CompareSizes();
198
199                         lock.unlock();
200                         m_cv.notify_one();
201                 }
202                 void Event(const CloseEvent &e) override {
203                         std::unique_lock<std::mutex> lock(m_mutex);
204                         THREAD_REQUIRE(!m_connections.empty());
205
206                         auto serverMatch = [&](const SocketPair& pair){
207                                 return pair.server == e.connectionID.sock;
208                         };
209                         auto it = std::find_if(m_connections.begin(), m_connections.end(), serverMatch);
210
211                         THREAD_REQUIRE_MESSAGE(it != m_connections.end(),
212                                                "Can't find connection for server socket = " <<
213                                                e.connectionID.sock);
214
215                         LogDebug("CloseEvent. Removing: " << it->client << "<=>" << it->server);
216                         THREAD_REQUIRE(it->client != -1);
217
218                         m_connections.erase(it);
219
220                         CompareSizes();
221
222                         lock.unlock();
223                         m_cv.notify_one();
224                 }
225
226                 void ConnectAndWait(SockRAII& client) {
227                         std::unique_lock<std::mutex> lock(m_mutex);
228
229                         THREAD_REQUIRE_MESSAGE(m_connections.empty() || m_connections.back().client != -1,
230                                                "Unexpected server entry waiting for client match " <<
231                                                m_connections.back().server);
232
233                         int ret = client.connect(GetServiceDescription()[0].serviceHandlerPath.c_str());
234                         BOOST_REQUIRE(ret == CKM_API_SUCCESS);
235
236                         LogDebug("Connected. Waiting for AcceptEvent for: " << client.get() << "<=> ?");
237
238                         BOOST_REQUIRE(m_cv.wait_for(lock, CV_TIMEOUT, [&]{ return AcceptEventArrived(); }));
239
240                         m_connections.back().client = client.get();
241
242                         LogDebug("Accepted. Matched client & server: " << m_connections.back().client <<
243                                  "<=>" << m_connections.back().server);
244                 }
245
246                 void DisconnectAndWait(SockRAII& client) {
247                         int sock = client.get();
248                         client.disconnect();
249
250                         LogDebug("Disconnected. Waiting for CloseEvent for: " << sock << "<=> ?");
251
252                         std::unique_lock<std::mutex> lock(m_mutex);
253                         BOOST_REQUIRE(m_cv.wait_for(lock, CV_TIMEOUT, [&]{ return ClientAbsent(sock); }));
254                 }
255
256                 void WaitForRemainingClosures() {
257                         std::unique_lock<std::mutex> lock(m_mutex);
258                         if (!m_connections.empty())
259                                 BOOST_TEST_MESSAGE("Waiting for remaining " << m_connections.size() <<
260                                                    " to close.");
261
262                         BOOST_REQUIRE(m_cv.wait_for(lock,
263                                                     std::chrono::seconds(OVERRIDE_SOCKET_TIMEOUT + 2),
264                                                     [&]{ return m_connections.empty(); }));
265
266                         CompareSizes();
267                 }
268
269         private:
270                 bool ClientAbsent(int client) const {
271                         auto it = std::find_if(m_connections.begin(),
272                                                m_connections.end(), [&](const SocketPair& pair){
273                                 return pair.client == client;
274                         });
275                         return it == m_connections.end();
276                 }
277
278                 bool AcceptEventArrived() const {
279                         return !m_connections.empty() && m_connections.back().client == -1;
280                 }
281
282                 void CompareSizes() const {
283                         auto manager = static_cast<TestSocketManager*>(m_serviceManager);
284                         THREAD_REQUIRE(m_connections.size() == manager->TimeoutQueueSize());
285                 }
286
287                 std::mutex m_mutex;
288                 struct SocketPair {
289                         int client;
290                         int server;
291                 };
292                 std::vector<SocketPair> m_connections;
293                 std::condition_variable m_cv;
294         };
295
296         unlinkIfExists(SERVICE_SOCKET_TEST);
297
298         TestSocketManager manager;
299         auto service = new TestService();
300         manager.RegisterSocketService(service);
301
302         SocketManagerLoop loop(manager);
303
304         {
305                 SockRAII socket[INITIAL_CONNECTIONS];
306                 for (unsigned i=0;i<INITIAL_CONNECTIONS;i++)
307                         service->ConnectAndWait(socket[i]);
308
309                 BOOST_REQUIRE(manager.TimeoutQueueSize() == INITIAL_CONNECTIONS);
310
311                 SockRAII socket2;
312                 for(unsigned i=0;i<REPEATS;i++) {
313                         service->ConnectAndWait(socket2);
314                         service->DisconnectAndWait(socket2);
315
316                         if ((i + 1) % INTERVAL == 0)
317                                 BOOST_TEST_MESSAGE("Creating connections: " << i + 1 << "/" << REPEATS);
318                 }
319
320                 // wait for remaining connections to close if any
321                 service->WaitForRemainingClosures();
322         }
323 }
324
325 POSITIVE_TEST_CASE(StressTestRandomSocketEvents)
326 {
327         // Too many services or connections may trigger server side timeouts (OVERRIDE_SOCKET_TIMEOUT)
328         constexpr int SERVICES = 4;
329         constexpr int INTERVAL = 1000;
330         constexpr int REPEATS = 10000;
331         constexpr int MAX_CONNECTIONS = 4;
332         // client and server read 2048B and 4096B at once respectively
333         constexpr size_t MAX_BUF_SIZE = 5000;
334
335         enum Event {
336                 CONNECT,
337                 DISCONNECT,
338                 SEND,
339                 RECEIVE,
340
341                 CNT
342         };
343
344         class TestService final : public NoOpService {
345         public:
346                 explicit TestService(int id) :
347                         m_desc({{Id2SockPath(id).c_str(), "", SOCKET_ID_TEST + id}}),
348                         m_id(id) {
349
350                         unlinkIfExists(GetSocketPath().c_str());
351                 }
352
353                 ServiceDescriptionVector GetServiceDescription() override { return m_desc; }
354
355                 void Event(const ReadEvent &e) override {
356                         LogDebug(e.connectionID.sock << ":" << e.connectionID.counter << " Received " <<
357                                  e.rawBuffer.size() << "B");
358                         m_serviceManager->Write(e.connectionID, e.rawBuffer);
359                 }
360
361                 size_t GetConnectionCount() const { return m_connections.size(); }
362
363                 void AddConnection() {
364                         m_connections.emplace_back(new TestConnection(GetSocketPath()));
365                         LogDebug(Prefix(m_connections.back()->GetId()) << "Connected");
366                 }
367
368                 void Disconnect(size_t idx) {
369                         auto it = m_connections.begin() + idx;
370                         auto cid = (*it)->GetId();
371                         if (idx != m_connections.size() - 1)
372                                 *it = std::move(m_connections.back());
373                         m_connections.pop_back();
374                         LogDebug(Prefix(cid) << "Disconnected");
375                 }
376
377                 void Send(size_t idx) {
378                         auto buffer = createRandom(Random(MAX_BUF_SIZE) + 1);
379                         auto& conn = m_connections.at(idx);
380                         auto ret = conn->Send(buffer);
381                         BOOST_REQUIRE_MESSAGE(ret == CKM_API_SUCCESS, "ret = " << ret);
382                         LogDebug(Prefix(conn->GetId())<< "Sent " << buffer.size() << "B");
383                 }
384
385                 void Receive(size_t idx) {
386                         auto& conn = m_connections.at(idx);
387                         conn->Receive([&](const size_t received) {
388                                 LogDebug(Prefix(conn->GetId()) << "Received " << received << "B");
389                         });
390                 }
391
392         private:
393                 const std::string& GetSocketPath() const {
394                         return m_desc.at(0).serviceHandlerPath;
395                 }
396
397                 std::string Prefix(size_t idx) const {
398                         return std::string(" ") + std::to_string(m_id) + ":" + std::to_string(idx) + " ";
399                 }
400
401                 ServiceDescriptionVector m_desc;
402                 int m_id;
403                 std::vector<std::unique_ptr<TestConnection>> m_connections;
404         };
405
406         SocketManager manager;
407         TestService* services[SERVICES];
408
409         for (int i = 0;i<SERVICES;i++) {
410                 services[i] = new TestService(i);
411                 manager.RegisterSocketService(services[i]);
412         }
413
414         SocketManagerLoop loop(manager);
415
416         for (unsigned i = 0;i < REPEATS; i++) {
417                 // random service
418                 auto service = services[Random(SERVICES)];
419
420                 // always connect if there are no active connections
421                 auto eIdx = CONNECT;
422                 auto cIdx = 0;
423                 size_t cCnt = service->GetConnectionCount();
424                 if (cCnt > 0) {
425                         cIdx = Random(cCnt);
426                         eIdx = static_cast<Event>(Random(CNT));
427                         if (eIdx == CONNECT && cCnt == MAX_CONNECTIONS)
428                                 eIdx = SEND; // don't connect if there are too many
429                 }
430
431                 switch (eIdx) {
432                 case CONNECT:
433                         service->AddConnection();
434                         break;
435
436                 case DISCONNECT:
437                         service->Disconnect(cIdx);
438                         break;
439
440                 case SEND:
441                         service->Send(cIdx);
442                         break;
443
444                 case RECEIVE:
445                         service->Receive(cIdx);
446                         break;
447
448                 default:
449                         BOOST_FAIL("Unexpected event");
450                 }
451
452                 if ((i + 1) % INTERVAL == 0)
453                         BOOST_TEST_MESSAGE("Executing random socket actions: " << i + 1 << "/" << REPEATS);
454         }
455 }
456
457 BOOST_AUTO_TEST_SUITE_END()