support for multiple subscribing to same topic
authorYoungjae Shin <yj99.shin@samsung.com>
Thu, 4 May 2023 04:14:05 +0000 (13:14 +0900)
committerYoungjae Shin <yj99.shin@samsung.com>
Wed, 14 Jun 2023 11:53:03 +0000 (20:53 +0900)
modules/tcp/Module.cc
modules/tcp/Module.h
tests/AITT_TCP_test.cc
tests/AittTests.h

index 2520410..f6bc488 100644 (file)
@@ -18,6 +18,7 @@
 #include <flatbuffers/flexbuffers.h>
 #include <unistd.h>
 
+#include <algorithm>
 #include <random>
 
 #include "aitt_internal.h"
@@ -170,48 +171,76 @@ void Module::Publish(const std::string &topic, const void *data, const int datal
 void *Module::Subscribe(const std::string &topic, const AittTransport::SubscribeCallback &cb,
       void *cbdata, AittQoS qos)
 {
-    std::unique_ptr<TCP::Server> tcpServer;
-
-    unsigned short port = 0;
-    tcpServer = std::unique_ptr<TCP::Server>(new TCP::Server("0.0.0.0", port, secure));
-    TCPServerData *listen_info = new TCPServerData;
-    listen_info->impl = this;
-    listen_info->cb = cb;
-    listen_info->cbdata = cbdata;
-    listen_info->topic = topic;
-    auto handle = tcpServer->GetHandle();
-
-    main_loop.AddWatch(handle, AcceptConnection, listen_info);
-
-    {
-        std::lock_guard<std::mutex> autoLock(subscribeTableLock);
-        subscribeTable.insert(SubscribeMap::value_type(topic, std::move(tcpServer)));
+    TCPServerData *listen_info;
+    std::unique_ptr<Subscribe_CB_Info> cb_info(new Subscribe_CB_Info(cb, cbdata));
+    Subscribe_CB_Info *info_ptr = cb_info.get();
+
+    std::lock_guard<std::mutex> lock_from_here(subscribeTableLock);
+    auto it = std::find_if(subscribeTable.begin(), subscribeTable.end(),
+          [&](const SubscribeMap::value_type &entry) { return entry.first->topic == topic; });
+    if (it != subscribeTable.end()) {
+        listen_info = it->first;
+        listen_info->cb_list.push_back(std::move(cb_info));
+    } else {
+        unsigned short port = 0;
+        std::unique_ptr<TCP::Server> tcpServer(new TCP::Server("0.0.0.0", port, secure));
+
+        listen_info = new TCPServerData;
+        listen_info->impl = this;
+        listen_info->cb_list.push_back(std::move(cb_info));
+        listen_info->topic = topic;
+
+        main_loop.AddWatch(tcpServer->GetHandle(), AcceptConnection, listen_info);
+
+        auto ret =
+              subscribeTable.insert(SubscribeMap::value_type(listen_info, std::move(tcpServer)));
+        if (false == ret.second) {
+            ERR("insert(%s) Fail", topic.c_str());
+            throw std::runtime_error("insert() Fail: " + listen_info->topic);
+        }
         UpdateDiscoveryMsg();
     }
+    subscribe_handles.insert(SubscribeHandles::value_type(info_ptr, listen_info));
 
-    return reinterpret_cast<void *>(handle);
+    return info_ptr;
 }
 
 void *Module::Unsubscribe(void *handlePtr)
 {
     std::lock_guard<std::mutex> autoLock(subscribeTableLock);
 
-    int handle = static_cast<int>(reinterpret_cast<intptr_t>(handlePtr));
-    TCPServerData *listen_info = dynamic_cast<TCPServerData *>(main_loop.RemoveWatch(handle));
-    if (!listen_info)
+    auto handle_it = subscribe_handles.find(static_cast<Subscribe_CB_Info *>(handlePtr));
+    if (handle_it == subscribe_handles.end()) {
+        ERR("Unknown handle(%p)", handlePtr);
         return nullptr;
+    }
+
+    TCPServerData *listen_info = handle_it->second;
+    void *cbdata = handle_it->first->second;
+    subscribe_handles.erase(handle_it);
+
+    if (1 < listen_info->cb_list.size()) {
+        auto cb_it = std::find_if(listen_info->cb_list.begin(), listen_info->cb_list.end(),
+              [&](const std::unique_ptr<Subscribe_CB_Info> &cb_info) {
+                  return cb_info.get() == handlePtr;
+              });
+        if (cb_it != listen_info->cb_list.end())
+            listen_info->cb_list.erase(cb_it);
+        else
+            throw std::runtime_error("Invalid Callback Info");
+        return cbdata;
+    }
 
-    void *cbdata = listen_info->cbdata;
     for (auto fd : listen_info->client_list) {
         TCPData *tcp_data = dynamic_cast<TCPData *>(main_loop.RemoveWatch(fd));
         delete tcp_data;
     }
     listen_info->client_list.clear();
 
-    auto it = subscribeTable.find(listen_info->topic);
+    auto it = subscribeTable.find(listen_info);
     if (it == subscribeTable.end())
-        throw std::runtime_error("Service is not registered: " + listen_info->topic);
-
+        throw std::runtime_error("Not subscribed topic: " + listen_info->topic);
+    main_loop.RemoveWatch(it->second->GetHandle());
     subscribeTable.erase(it);
     delete listen_info;
 
@@ -316,20 +345,19 @@ void Module::DiscoveryMessageCallback(const std::string &clientId, const std::st
     }
 }
 
+// map {
+//   "host": "192.168.1.11",
+//   "$topic": {port, key, iv}
+// }
 void Module::UpdateDiscoveryMsg()
 {
     flexbuffers::Builder fbb;
     fbb.Map([this, &fbb]() {
         fbb.String("host", ip);
 
-        // SubscribeTable
-        // map {
-        //    "/customTopic/mytopic": $serverHandle,
-        //    ...
-        // }
         for (auto it = subscribeTable.begin(); it != subscribeTable.end(); ++it) {
             if (it->second) {
-                fbb.Vector(it->first.c_str(), [&]() {
+                fbb.Vector(it->first->topic.c_str(), [&]() {
                     fbb.UInt(it->second->GetPort());
                     if (secure) {
                         fbb.Blob(it->second->GetCryptoKey(), AITT_TCP_ENCRYPTOR_KEY_LEN);
@@ -339,7 +367,7 @@ void Module::UpdateDiscoveryMsg()
             } else {
                 // this is an error case
                 TCP::ConnectInfo info;
-                fbb.Vector(it->first.c_str(), [&]() { fbb.UInt(it->second->GetPort()); });
+                fbb.Vector(it->first->topic.c_str(), [&]() { fbb.UInt(it->second->GetPort()); });
             }
         }
     });
@@ -383,8 +411,15 @@ int Module::ReceiveData(MainLoopHandler::MainLoopResult result, int handle,
         return AITT_LOOP_EVENT_CONTINUE;
     }
 
-    auto callback = parent_info->cb;
-    callback(&msg_info, msg, szmsg, parent_info->cbdata);
+    std::vector<Subscribe_CB_Info> cb_list;
+    {
+        std::lock_guard<std::mutex> autoLock(impl->subscribeTableLock);
+        std::transform(parent_info->cb_list.begin(), parent_info->cb_list.end(),
+              std::back_inserter(cb_list),
+              [](std::unique_ptr<Subscribe_CB_Info> const &it) { return *it; });
+    }
+    for (auto const &it : cb_list)
+        it.first(&msg_info, msg, szmsg, it.second);
     free(msg);
 
     return AITT_LOOP_EVENT_CONTINUE;
@@ -454,7 +489,7 @@ int Module::AcceptConnection(MainLoopHandler::MainLoopResult result, int handle,
     {
         std::lock_guard<std::mutex> autoLock(impl->subscribeTableLock);
 
-        auto clientIt = impl->subscribeTable.find(listen_info->topic);
+        auto clientIt = impl->subscribeTable.find(listen_info);
         if (clientIt == impl->subscribeTable.end())
             return AITT_LOOP_EVENT_REMOVE;
 
@@ -469,11 +504,11 @@ int Module::AcceptConnection(MainLoopHandler::MainLoopResult result, int handle,
     int client_handle = client->GetHandle();
     listen_info->client_list.push_back(client_handle);
 
-    TCPData *ecd = new TCPData;
-    ecd->parent = listen_info;
-    ecd->client = std::move(client);
+    TCPData *tcp_data = new TCPData;
+    tcp_data->parent = listen_info;
+    tcp_data->client = std::move(client);
 
-    impl->main_loop.AddWatch(client_handle, ReceiveData, ecd);
+    impl->main_loop.AddWatch(client_handle, ReceiveData, tcp_data);
     return AITT_LOOP_EVENT_CONTINUE;
 }
 
index a019e6e..fccdf25 100644 (file)
@@ -51,10 +51,11 @@ class Module : public AittTransport {
     void SendReply(AittMsg *msg, const void *data, const int datalen, AittQoS qos, bool retain);
 
   private:
+    using Subscribe_CB_Info = std::pair<SubscribeCallback, void *>;
+
     struct TCPServerData : public MainLoopHandler::MainLoopData {
         Module *impl;
-        SubscribeCallback cb;
-        void *cbdata;
+        std::vector<std::unique_ptr<Subscribe_CB_Info>> cb_list;
         std::string topic;
         std::vector<int> client_list;
     };
@@ -69,7 +70,8 @@ class Module : public AittTransport {
     //    "/customTopic/mytopic": $serverHandle,
     //    ...
     // }
-    using SubscribeMap = std::map<std::string, std::unique_ptr<TCP::Server>>;
+    using SubscribeMap = std::map<TCPServerData *, std::unique_ptr<TCP::Server>>;
+    using SubscribeHandles = std::map<Subscribe_CB_Info *, TCPServerData *>;
 
     // ClientTable
     // map {
@@ -129,6 +131,7 @@ class Module : public AittTransport {
     PublishMap publishTable;
     std::mutex publishTableLock;
     SubscribeMap subscribeTable;
+    SubscribeHandles subscribe_handles;
     std::mutex subscribeTableLock;
     ClientMap clientTable;
     std::mutex clientTableLock;
index 211f1f9..06f039b 100644 (file)
@@ -88,6 +88,49 @@ class AITTTCPTest : public testing::Test, public AittTests {
             FAIL() << "Unexpected exception: " << e.what();
         }
     }
+    void TCP_SubscribeSameTopicTwiceTemplate(AittProtocol protocol)
+    {
+        try {
+            AITT aitt(clientId, LOCAL_IP);
+            aitt.Connect();
+
+            aitt.Subscribe(
+                  testTopic,
+                  [&](AittMsg *handle, const void *msg, const int szmsg, void *cbdata) -> void {
+                      AITTTCPTest *test = static_cast<AITTTCPTest *>(cbdata);
+                      test->ToggleReady();
+                  },
+                  static_cast<void *>(this), protocol);
+            aitt.Subscribe(
+                  testTopic,
+                  [&](AittMsg *handle, const void *msg, const int szmsg, void *cbdata) -> void {
+                      AITTTCPTest *test = static_cast<AITTTCPTest *>(cbdata);
+                      test->ToggleReady2();
+                  },
+                  static_cast<void *>(this), protocol);
+
+            usleep(100 * SLEEP_MS);
+            /*
+            while (aitt.CountSubscriber(testTopic, protocol) == 2) {
+                usleep(SLEEP_10MS);
+            }
+            */
+
+            aitt.Publish(testTopic, TEST_MSG, sizeof(TEST_MSG), protocol);
+
+            mainLoop.AddTimeout(CHECK_INTERVAL,
+                  [&](MainLoopHandler::MainLoopResult result, int fd,
+                        MainLoopHandler::MainLoopData *data) -> int {
+                      return ReadyAllCheck(static_cast<AittTests *>(this));
+                  });
+            IterateEventLoop();
+
+            ASSERT_TRUE(ready);
+            ASSERT_TRUE(ready2);
+        } catch (std::exception &e) {
+            FAIL() << "Unexpected exception: " << e.what();
+        }
+    }
 };
 
 TEST_F(AITTTCPTest, TCP_Wildcard_single_Anytime)
@@ -168,3 +211,13 @@ TEST_F(AITTTCPTest, SECURE_TCP_various_msg_Anytime)
         FAIL() << "Unexpected exception: " << e.what();
     }
 }
+
+TEST_F(AITTTCPTest, TCP_Subscribe_Same_Topic_twice_Anytime)
+{
+    TCP_SubscribeSameTopicTwiceTemplate(AITT_TYPE_TCP);
+}
+
+TEST_F(AITTTCPTest, Secure_TCP_Subscribe_Same_Topic_twice_Anytime)
+{
+    TCP_SubscribeSameTopicTwiceTemplate(AITT_TYPE_TCP_SECURE);
+}
index 97ee8ce..c867928 100644 (file)
@@ -66,6 +66,18 @@ class AittTests {
         return AITT_LOOP_EVENT_CONTINUE;
     }
 
+    int ReadyAllCheck(void *data)
+    {
+        AittTests *test = static_cast<AittTests *>(data);
+
+        if (test->ready && test->ready2) {
+            test->StopEventLoop();
+            return AITT_LOOP_EVENT_REMOVE;
+        }
+
+        return AITT_LOOP_EVENT_CONTINUE;
+    }
+
     void StopEventLoop(void) { mainLoop.Quit(); }
 
     void IterateEventLoop(void)