From 837b8f4f3fd8c66c19ef935110c64bffcad795e6 Mon Sep 17 00:00:00 2001 From: Youngjae Shin Date: Thu, 4 May 2023 13:14:05 +0900 Subject: [PATCH] support for multiple subscribing to same topic --- modules/tcp/Module.cc | 111 ++++++++++++++++++++++++++++++++----------------- modules/tcp/Module.h | 9 ++-- tests/AITT_TCP_test.cc | 53 +++++++++++++++++++++++ tests/AittTests.h | 12 ++++++ 4 files changed, 144 insertions(+), 41 deletions(-) diff --git a/modules/tcp/Module.cc b/modules/tcp/Module.cc index 2520410..f6bc488 100644 --- a/modules/tcp/Module.cc +++ b/modules/tcp/Module.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include "aitt_internal.h" @@ -170,48 +171,76 @@ void Module::Publish(const std::string &topic, const void *data, const int datal void *Module::Subscribe(const std::string &topic, const AittTransport::SubscribeCallback &cb, void *cbdata, AittQoS qos) { - std::unique_ptr tcpServer; - - unsigned short port = 0; - tcpServer = std::unique_ptr(new TCP::Server("0.0.0.0", port, secure)); - TCPServerData *listen_info = new TCPServerData; - listen_info->impl = this; - listen_info->cb = cb; - listen_info->cbdata = cbdata; - listen_info->topic = topic; - auto handle = tcpServer->GetHandle(); - - main_loop.AddWatch(handle, AcceptConnection, listen_info); - - { - std::lock_guard autoLock(subscribeTableLock); - subscribeTable.insert(SubscribeMap::value_type(topic, std::move(tcpServer))); + TCPServerData *listen_info; + std::unique_ptr cb_info(new Subscribe_CB_Info(cb, cbdata)); + Subscribe_CB_Info *info_ptr = cb_info.get(); + + std::lock_guard lock_from_here(subscribeTableLock); + auto it = std::find_if(subscribeTable.begin(), subscribeTable.end(), + [&](const SubscribeMap::value_type &entry) { return entry.first->topic == topic; }); + if (it != subscribeTable.end()) { + listen_info = it->first; + listen_info->cb_list.push_back(std::move(cb_info)); + } else { + unsigned short port = 0; + std::unique_ptr tcpServer(new TCP::Server("0.0.0.0", port, secure)); + + listen_info = new TCPServerData; + listen_info->impl = this; + listen_info->cb_list.push_back(std::move(cb_info)); + listen_info->topic = topic; + + main_loop.AddWatch(tcpServer->GetHandle(), AcceptConnection, listen_info); + + auto ret = + subscribeTable.insert(SubscribeMap::value_type(listen_info, std::move(tcpServer))); + if (false == ret.second) { + ERR("insert(%s) Fail", topic.c_str()); + throw std::runtime_error("insert() Fail: " + listen_info->topic); + } UpdateDiscoveryMsg(); } + subscribe_handles.insert(SubscribeHandles::value_type(info_ptr, listen_info)); - return reinterpret_cast(handle); + return info_ptr; } void *Module::Unsubscribe(void *handlePtr) { std::lock_guard autoLock(subscribeTableLock); - int handle = static_cast(reinterpret_cast(handlePtr)); - TCPServerData *listen_info = dynamic_cast(main_loop.RemoveWatch(handle)); - if (!listen_info) + auto handle_it = subscribe_handles.find(static_cast(handlePtr)); + if (handle_it == subscribe_handles.end()) { + ERR("Unknown handle(%p)", handlePtr); return nullptr; + } + + TCPServerData *listen_info = handle_it->second; + void *cbdata = handle_it->first->second; + subscribe_handles.erase(handle_it); + + if (1 < listen_info->cb_list.size()) { + auto cb_it = std::find_if(listen_info->cb_list.begin(), listen_info->cb_list.end(), + [&](const std::unique_ptr &cb_info) { + return cb_info.get() == handlePtr; + }); + if (cb_it != listen_info->cb_list.end()) + listen_info->cb_list.erase(cb_it); + else + throw std::runtime_error("Invalid Callback Info"); + return cbdata; + } - void *cbdata = listen_info->cbdata; for (auto fd : listen_info->client_list) { TCPData *tcp_data = dynamic_cast(main_loop.RemoveWatch(fd)); delete tcp_data; } listen_info->client_list.clear(); - auto it = subscribeTable.find(listen_info->topic); + auto it = subscribeTable.find(listen_info); if (it == subscribeTable.end()) - throw std::runtime_error("Service is not registered: " + listen_info->topic); - + throw std::runtime_error("Not subscribed topic: " + listen_info->topic); + main_loop.RemoveWatch(it->second->GetHandle()); subscribeTable.erase(it); delete listen_info; @@ -316,20 +345,19 @@ void Module::DiscoveryMessageCallback(const std::string &clientId, const std::st } } +// map { +// "host": "192.168.1.11", +// "$topic": {port, key, iv} +// } void Module::UpdateDiscoveryMsg() { flexbuffers::Builder fbb; fbb.Map([this, &fbb]() { fbb.String("host", ip); - // SubscribeTable - // map { - // "/customTopic/mytopic": $serverHandle, - // ... - // } for (auto it = subscribeTable.begin(); it != subscribeTable.end(); ++it) { if (it->second) { - fbb.Vector(it->first.c_str(), [&]() { + fbb.Vector(it->first->topic.c_str(), [&]() { fbb.UInt(it->second->GetPort()); if (secure) { fbb.Blob(it->second->GetCryptoKey(), AITT_TCP_ENCRYPTOR_KEY_LEN); @@ -339,7 +367,7 @@ void Module::UpdateDiscoveryMsg() } else { // this is an error case TCP::ConnectInfo info; - fbb.Vector(it->first.c_str(), [&]() { fbb.UInt(it->second->GetPort()); }); + fbb.Vector(it->first->topic.c_str(), [&]() { fbb.UInt(it->second->GetPort()); }); } } }); @@ -383,8 +411,15 @@ int Module::ReceiveData(MainLoopHandler::MainLoopResult result, int handle, return AITT_LOOP_EVENT_CONTINUE; } - auto callback = parent_info->cb; - callback(&msg_info, msg, szmsg, parent_info->cbdata); + std::vector cb_list; + { + std::lock_guard autoLock(impl->subscribeTableLock); + std::transform(parent_info->cb_list.begin(), parent_info->cb_list.end(), + std::back_inserter(cb_list), + [](std::unique_ptr const &it) { return *it; }); + } + for (auto const &it : cb_list) + it.first(&msg_info, msg, szmsg, it.second); free(msg); return AITT_LOOP_EVENT_CONTINUE; @@ -454,7 +489,7 @@ int Module::AcceptConnection(MainLoopHandler::MainLoopResult result, int handle, { std::lock_guard autoLock(impl->subscribeTableLock); - auto clientIt = impl->subscribeTable.find(listen_info->topic); + auto clientIt = impl->subscribeTable.find(listen_info); if (clientIt == impl->subscribeTable.end()) return AITT_LOOP_EVENT_REMOVE; @@ -469,11 +504,11 @@ int Module::AcceptConnection(MainLoopHandler::MainLoopResult result, int handle, int client_handle = client->GetHandle(); listen_info->client_list.push_back(client_handle); - TCPData *ecd = new TCPData; - ecd->parent = listen_info; - ecd->client = std::move(client); + TCPData *tcp_data = new TCPData; + tcp_data->parent = listen_info; + tcp_data->client = std::move(client); - impl->main_loop.AddWatch(client_handle, ReceiveData, ecd); + impl->main_loop.AddWatch(client_handle, ReceiveData, tcp_data); return AITT_LOOP_EVENT_CONTINUE; } diff --git a/modules/tcp/Module.h b/modules/tcp/Module.h index a019e6e..fccdf25 100644 --- a/modules/tcp/Module.h +++ b/modules/tcp/Module.h @@ -51,10 +51,11 @@ class Module : public AittTransport { void SendReply(AittMsg *msg, const void *data, const int datalen, AittQoS qos, bool retain); private: + using Subscribe_CB_Info = std::pair; + struct TCPServerData : public MainLoopHandler::MainLoopData { Module *impl; - SubscribeCallback cb; - void *cbdata; + std::vector> cb_list; std::string topic; std::vector client_list; }; @@ -69,7 +70,8 @@ class Module : public AittTransport { // "/customTopic/mytopic": $serverHandle, // ... // } - using SubscribeMap = std::map>; + using SubscribeMap = std::map>; + using SubscribeHandles = std::map; // ClientTable // map { @@ -129,6 +131,7 @@ class Module : public AittTransport { PublishMap publishTable; std::mutex publishTableLock; SubscribeMap subscribeTable; + SubscribeHandles subscribe_handles; std::mutex subscribeTableLock; ClientMap clientTable; std::mutex clientTableLock; diff --git a/tests/AITT_TCP_test.cc b/tests/AITT_TCP_test.cc index 211f1f9..06f039b 100644 --- a/tests/AITT_TCP_test.cc +++ b/tests/AITT_TCP_test.cc @@ -88,6 +88,49 @@ class AITTTCPTest : public testing::Test, public AittTests { FAIL() << "Unexpected exception: " << e.what(); } } + void TCP_SubscribeSameTopicTwiceTemplate(AittProtocol protocol) + { + try { + AITT aitt(clientId, LOCAL_IP); + aitt.Connect(); + + aitt.Subscribe( + testTopic, + [&](AittMsg *handle, const void *msg, const int szmsg, void *cbdata) -> void { + AITTTCPTest *test = static_cast(cbdata); + test->ToggleReady(); + }, + static_cast(this), protocol); + aitt.Subscribe( + testTopic, + [&](AittMsg *handle, const void *msg, const int szmsg, void *cbdata) -> void { + AITTTCPTest *test = static_cast(cbdata); + test->ToggleReady2(); + }, + static_cast(this), protocol); + + usleep(100 * SLEEP_MS); + /* + while (aitt.CountSubscriber(testTopic, protocol) == 2) { + usleep(SLEEP_10MS); + } + */ + + aitt.Publish(testTopic, TEST_MSG, sizeof(TEST_MSG), protocol); + + mainLoop.AddTimeout(CHECK_INTERVAL, + [&](MainLoopHandler::MainLoopResult result, int fd, + MainLoopHandler::MainLoopData *data) -> int { + return ReadyAllCheck(static_cast(this)); + }); + IterateEventLoop(); + + ASSERT_TRUE(ready); + ASSERT_TRUE(ready2); + } catch (std::exception &e) { + FAIL() << "Unexpected exception: " << e.what(); + } + } }; TEST_F(AITTTCPTest, TCP_Wildcard_single_Anytime) @@ -168,3 +211,13 @@ TEST_F(AITTTCPTest, SECURE_TCP_various_msg_Anytime) FAIL() << "Unexpected exception: " << e.what(); } } + +TEST_F(AITTTCPTest, TCP_Subscribe_Same_Topic_twice_Anytime) +{ + TCP_SubscribeSameTopicTwiceTemplate(AITT_TYPE_TCP); +} + +TEST_F(AITTTCPTest, Secure_TCP_Subscribe_Same_Topic_twice_Anytime) +{ + TCP_SubscribeSameTopicTwiceTemplate(AITT_TYPE_TCP_SECURE); +} diff --git a/tests/AittTests.h b/tests/AittTests.h index 97ee8ce..c867928 100644 --- a/tests/AittTests.h +++ b/tests/AittTests.h @@ -66,6 +66,18 @@ class AittTests { return AITT_LOOP_EVENT_CONTINUE; } + int ReadyAllCheck(void *data) + { + AittTests *test = static_cast(data); + + if (test->ready && test->ready2) { + test->StopEventLoop(); + return AITT_LOOP_EVENT_REMOVE; + } + + return AITT_LOOP_EVENT_CONTINUE; + } + void StopEventLoop(void) { mainLoop.Quit(); } void IterateEventLoop(void) -- 2.7.4