implement request/response pattern on TCP
authorYoungjae Shin <yj99.shin@samsung.com>
Thu, 19 Jan 2023 08:17:10 +0000 (17:17 +0900)
committerYoungjae Shin <yj99.shin@samsung.com>
Thu, 9 Mar 2023 02:19:50 +0000 (11:19 +0900)
- revise AittMsg also

common/AittTransport.h
modules/tcp/Module.cc
modules/tcp/Module.h
src/AITTImpl.cc
src/NullTransport.cc
src/NullTransport.h
tests/RequestResponse_test.cc
tools/FlexbufPrinter.cc
tools/FlexbufPrinter.h

index 0c03f5d..22fbe62 100644 (file)
@@ -42,13 +42,15 @@ class AittTransport {
     virtual ~AittTransport(void) = default;
 
     virtual void Publish(const std::string &topic, const void *data, const int datalen,
-          const std::string &correlation, AittQoS qos = AITT_QOS_AT_MOST_ONCE,
-          bool retain = false) = 0;
-    virtual void Publish(const std::string &topic, const void *data, const int datalen,
           AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) = 0;
     virtual void *Subscribe(const std::string &topic, const SubscribeCallback &cb,
           void *cbdata = nullptr, AittQoS qos = AITT_QOS_AT_MOST_ONCE) = 0;
     virtual void *Unsubscribe(void *handle) = 0;
+    virtual void PublishWithReply(const std::string &topic, const void *data, const int datalen,
+          AittQoS qos, bool retain, const std::string &reply_topic,
+          const std::string &correlation) = 0;
+    virtual void SendReply(AittMsg *msg, const void *data, const int datalen, AittQoS qos,
+          bool retain) = 0;
 
     AittProtocol GetProtocol() { return protocol; }
 
index 407d92d..70436d0 100644 (file)
@@ -61,8 +61,8 @@ void Module::ThreadMain(void)
     main_loop.Run();
 }
 
-void Module::Publish(const std::string &topic, const void *data, const int datalen,
-      const std::string &correlation, AittQoS qos, bool retain)
+void Module::PublishFull(const AittMsg &msg, const void *data, const int datalen, AittQoS qos,
+      bool retain, bool is_reply)
 {
     RET_IF(datalen < 0);
 
@@ -81,7 +81,7 @@ void Module::Publish(const std::string &topic, const void *data, const int datal
     std::lock_guard<std::mutex> auto_lock_publish(publishTableLock);
     for (PublishMap::iterator it = publishTable.begin(); it != publishTable.end(); ++it) {
         // NOTE: Find entries that have matched with the given topic
-        if (!discovery.CompareTopic(it->first, topic))
+        if (!discovery.CompareTopic(it->first, is_reply ? msg.GetResponseTopic() : msg.GetTopic()))
             continue;
 
         for (HostMap::iterator hostIt = it->second.begin(); hostIt != it->second.end(); ++hostIt) {
@@ -122,10 +122,12 @@ void Module::Publish(const std::string &topic, const void *data, const int datal
                 }
 
                 try {
-                    int32_t length = topic.length();
-                    portIt->second->SendSizedData(topic.c_str(), length);
-                    length = datalen;
-                    portIt->second->SendSizedData(data, length);
+                    flexbuffers::Builder fbb;
+                    PackMsgInfo(fbb, msg, is_reply);
+                    auto buffer = fbb.GetBuffer();
+                    portIt->second->SendSizedData(buffer.data(), buffer.size());
+
+                    portIt->second->SendSizedData(data, datalen);
                 } catch (std::exception &e) {
                     ERR("An exception(%s) occurs during Send().", e.what());
                 }
@@ -134,10 +136,35 @@ void Module::Publish(const std::string &topic, const void *data, const int datal
     }      // publishTable
 }
 
+void Module::PackMsgInfo(flexbuffers::Builder &fbb, const AittMsg &msg, bool is_reply)
+{
+    fbb.Map([&]() {
+        if (is_reply) {
+            if (!msg.GetResponseTopic().empty())
+                fbb.String("topic", msg.GetResponseTopic().c_str());
+        } else {
+            if (!msg.GetTopic().empty())
+                fbb.String("topic", msg.GetTopic().c_str());
+            if (!msg.GetResponseTopic().empty())
+                fbb.String("reply_topic", msg.GetResponseTopic().c_str());
+        }
+        if (!msg.GetCorrelation().empty())
+            fbb.String("correlation", msg.GetCorrelation().c_str());
+        if (msg.GetSequence() != 0)
+            fbb.UInt("sequence", msg.GetSequence());
+        if (msg.IsEndSequence())
+            fbb.Bool("end_sequence", msg.IsEndSequence());
+    });
+
+    fbb.Finish();
+}
+
 void Module::Publish(const std::string &topic, const void *data, const int datalen, AittQoS qos,
       bool retain)
 {
-    Publish(topic, data, datalen, std::string(), qos, retain);
+    AittMsg msg;
+    msg.SetTopic(topic);
+    PublishFull(msg, data, datalen, qos, retain);
 }
 
 void *Module::Subscribe(const std::string &topic, const AittTransport::SubscribeCallback &cb,
@@ -193,6 +220,26 @@ void *Module::Unsubscribe(void *handlePtr)
     return cbdata;
 }
 
+void Module::PublishWithReply(const std::string &topic, const void *data, const int datalen,
+      AittQoS qos, bool retain, const std::string &reply_topic, const std::string &correlation)
+{
+    AittMsg msg;
+    msg.SetTopic(topic);
+    msg.SetResponseTopic(reply_topic);
+    msg.SetCorrelation(correlation);
+    PublishFull(msg, data, datalen, qos, retain);
+}
+
+void Module::SendReply(AittMsg *msg, const void *data, const int datalen, AittQoS qos, bool retain)
+{
+    if (msg == nullptr) {
+        ERR("Invalid message(msg is nullptr)");
+        throw std::runtime_error("Invalid message");
+    }
+
+    PublishFull(*msg, data, datalen, qos, retain, true);
+}
+
 void Module::DiscoveryMessageCallback(const std::string &clientId, const std::string &status,
       const void *msg, const int szmsg)
 {
@@ -323,12 +370,11 @@ int Module::ReceiveData(MainLoopHandler::MainLoopResult result, int handle,
     AittMsg msg_info;
 
     try {
-        topic = impl->GetTopicName(tcp_data);
-        if (topic.empty()) {
+        impl->GetMsgInfo(msg_info, tcp_data);
+        if (msg_info.GetTopic().empty()) {
             ERR("A topic is empty.");
             return AITT_LOOP_EVENT_CONTINUE;
         }
-               msg_info.SetTopic(topic);
 
         szmsg = tcp_data->client->RecvSizedData((void **)&msg);
         if (szmsg < 0) {
@@ -364,26 +410,40 @@ int Module::HandleClientDisconnect(int handle)
     return AITT_LOOP_EVENT_REMOVE;
 }
 
-std::string Module::GetTopicName(Module::TCPData *tcp_data)
+void Module::GetMsgInfo(AittMsg &msg, Module::TCPData *tcp_data)
 {
-    int32_t topic_length = 0;
-    void *topic_data = nullptr;
-    topic_length = tcp_data->client->RecvSizedData(&topic_data);
-    if (topic_length < 0) {
+    int32_t info_length = 0;
+    void *msg_info = nullptr;
+    info_length = tcp_data->client->RecvSizedData(&msg_info);
+    if (info_length < 0) {
         ERR("Got a disconnection message.");
         HandleClientDisconnect(tcp_data->client->GetHandle());
-        return std::string();
+        return;
     }
-    if (nullptr == topic_data) {
+    if (nullptr == msg_info) {
         ERR("Unknown topic");
-        return std::string();
+        return;
     }
 
-    std::string topic = std::string(static_cast<char *>(topic_data), topic_length);
-    INFO("Complete topic = [%s], topic_len = %d", topic.c_str(), topic_length);
-    free(topic_data);
+    UnpackMsgInfo(msg, msg_info, info_length);
 
-    return topic;
+    free(msg_info);
+}
+
+void Module::UnpackMsgInfo(AittMsg &msg, const void *data, const size_t datalen)
+{
+    auto map = flexbuffers::GetRoot(static_cast<const uint8_t *>(data), datalen).AsMap();
+
+    if (map["topic"].IsString())
+        msg.SetTopic(map["topic"].AsString().str());
+    if (map["reply_topic"].IsString())
+        msg.SetResponseTopic(map["reply_topic"].AsString().str());
+    if (map["correlation"].IsString())
+        msg.SetCorrelation(map["correlation"].AsString().str());
+    if (map["sequence"].IsUInt())
+        msg.SetSequence(map["sequence"].AsUInt64());
+    if (map["end_sequence"].IsBool())
+        msg.SetEndSequence(map["end_sequence"].AsBool());
 }
 
 int Module::AcceptConnection(MainLoopHandler::MainLoopResult result, int handle,
index 028d407..a019e6e 100644 (file)
@@ -17,6 +17,7 @@
 
 #include <AittTransport.h>
 #include <MainLoopHandler.h>
+#include <flatbuffers/flexbuffers.h>
 
 #include <map>
 #include <memory>
@@ -40,14 +41,14 @@ class Module : public AittTransport {
     virtual ~Module(void);
 
     void Publish(const std::string &topic, const void *data, const int datalen,
-          const std::string &correlation, AittQoS qos = AITT_QOS_AT_MOST_ONCE,
-          bool retain = false) override;
-    void Publish(const std::string &topic, const void *data, const int datalen,
           AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) override;
 
     void *Subscribe(const std::string &topic, const SubscribeCallback &cb, void *cbdata = nullptr,
           AittQoS qos = AITT_QOS_AT_MOST_ONCE) override;
     void *Unsubscribe(void *handle) override;
+    void PublishWithReply(const std::string &topic, const void *data, const int datalen,
+          AittQoS qos, bool retain, const std::string &reply_topic, const std::string &correlation);
+    void SendReply(AittMsg *msg, const void *data, const int datalen, AittQoS qos, bool retain);
 
   private:
     struct TCPServerData : public MainLoopHandler::MainLoopData {
@@ -105,16 +106,20 @@ class Module : public AittTransport {
 
     static int AcceptConnection(MainLoopHandler::MainLoopResult result, int handle,
           MainLoopHandler::MainLoopData *watchData);
+    void PublishFull(const AittMsg &msg, const void *data, const int datalen,
+          AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false, bool is_reply = false);
     void DiscoveryMessageCallback(const std::string &clientId, const std::string &status,
           const void *msg, const int szmsg);
     void UpdateDiscoveryMsg();
     static int ReceiveData(MainLoopHandler::MainLoopResult result, int handle,
           MainLoopHandler::MainLoopData *watchData);
     int HandleClientDisconnect(int handle);
-    std::string GetTopicName(TCPData *connect_info);
+    void GetMsgInfo(AittMsg &msg, TCPData *connect_info);
     void ThreadMain(void);
     void UpdatePublishTable(const std::string &topic, const std::string &host,
           const TCP::ConnectInfo &info);
+    void PackMsgInfo(flexbuffers::Builder &fbb, const AittMsg &msg, bool is_reply = false);
+    void UnpackMsgInfo(AittMsg &msg, const void *data, const size_t datalen);
 
     const char *const NAME[2] = {"TCP", "SECURE_TCP"};
     MainLoopHandler main_loop;
index 40a9862..630a694 100644 (file)
@@ -270,15 +270,13 @@ void *AITT::Impl::Unsubscribe(AittSubscribeID subscribe_id)
     return user_data;
 }
 
+// It's not supported with multiple protocols
 int AITT::Impl::PublishWithReply(const std::string &topic, const void *data, const int datalen,
       AittProtocol protocol, AittQoS qos, bool retain, const SubscribeCallback &cb, void *user_data,
       const std::string &correlation)
 {
     std::string replyTopic = topic + RESPONSE_POSTFIX + std::to_string(reply_id++);
 
-    if (protocol != AITT_TYPE_MQTT)
-        return -1;  // not yet support
-
     Subscribe(
           replyTopic,
           [this, cb](AittMsg *sub_msg, const void *sub_data, const int sub_datalen,
@@ -294,7 +292,19 @@ int AITT::Impl::PublishWithReply(const std::string &topic, const void *data, con
           },
           user_data, protocol, qos);
 
-    mq->PublishWithReply(topic, data, datalen, qos, false, replyTopic, correlation);
+    switch (protocol) {
+    case AITT_TYPE_MQTT:
+        mq->PublishWithReply(topic, data, datalen, qos, retain, replyTopic, correlation);
+        break;
+    case AITT_TYPE_TCP:
+    case AITT_TYPE_TCP_SECURE:
+        modules.Get(protocol).PublishWithReply(topic, data, datalen, qos, retain, replyTopic,
+              correlation);
+        break;
+    default:
+        ERR("Unknown AittProtocol(%d)", protocol);
+        return -1;
+    }
     return 0;
 }
 
@@ -370,14 +380,22 @@ void AITT::Impl::SendReply(AittMsg *msg, const void *data, const int datalen, bo
 {
     RET_IF(msg == nullptr);
 
-    if ((msg->GetProtocol() & AITT_TYPE_MQTT) != AITT_TYPE_MQTT)
-        return;  // not yet support
-
     if (end == false || msg->GetSequence())
         msg->IncreaseSequence();
     msg->SetEndSequence(end);
 
-    mq->SendReply(msg, data, datalen, AITT_QOS_AT_MOST_ONCE, false);
+    switch (msg->GetProtocol()) {
+    case AITT_TYPE_MQTT:
+        mq->SendReply(msg, data, datalen, AITT_QOS_AT_MOST_ONCE, false);
+        break;
+    case AITT_TYPE_TCP:
+    case AITT_TYPE_TCP_SECURE:
+        modules.Get(msg->GetProtocol()).SendReply(msg, data, datalen, AITT_QOS_AT_MOST_ONCE, false);
+        break;
+    default:
+        ERR("Unknown AittProtocol(%d)", msg->GetProtocol());
+        break;
+    }
 }
 
 void *AITT::Impl::SubscribeTCP(SubscribeInfo *handle, const std::string &topic,
index 5890ce8..76b7c57 100644 (file)
@@ -23,11 +23,6 @@ NullTransport::NullTransport(AittDiscovery& discovery, const std::string& ip)
 }
 
 void NullTransport::Publish(const std::string& topic, const void* data, const int datalen,
-      const std::string& correlation, AittQoS qos, bool retain)
-{
-}
-
-void NullTransport::Publish(const std::string& topic, const void* data, const int datalen,
       AittQoS qos, bool retain)
 {
 }
@@ -42,3 +37,13 @@ void* NullTransport::Unsubscribe(void* handle)
 {
     return nullptr;
 }
+
+void NullTransport::PublishWithReply(const std::string& topic, const void* data, const int datalen,
+      AittQoS qos, bool retain, const std::string& reply_topic, const std::string& correlation)
+{
+}
+
+void NullTransport::SendReply(AittMsg* msg, const void* data, const int datalen, AittQoS qos,
+      bool retain)
+{
+}
index b773ff1..9937bc3 100644 (file)
@@ -26,14 +26,17 @@ class NullTransport : public AittTransport {
     virtual ~NullTransport(void) = default;
 
     void Publish(const std::string &topic, const void *data, const int datalen,
-          const std::string &correlation, AittQoS qos = AITT_QOS_AT_MOST_ONCE,
-          bool retain = false) override;
-
-    void Publish(const std::string &topic, const void *data, const int datalen,
           AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) override;
 
     void *Subscribe(const std::string &topic, const SubscribeCallback &cb, void *cbdata = nullptr,
           AittQoS qos = AITT_QOS_AT_MOST_ONCE) override;
 
     void *Unsubscribe(void *handle) override;
+
+    void PublishWithReply(const std::string &topic, const void *data, const int datalen,
+          AittQoS qos, bool retain, const std::string &reply_topic,
+          const std::string &correlation) override;
+
+    void SendReply(AittMsg *msg, const void *data, const int datalen, AittQoS qos,
+          bool retain) override;
 };
index b777635..a11d108 100644 (file)
@@ -46,6 +46,7 @@ class AITTRRTest : public testing::Test, public AittTests {
     void CheckReplyCallback(bool toggle, bool *reply_ok, AittMsg *msg, const void *data,
           const int datalen, void *cbdata)
     {
+        DBG("CheckReplyCallback: topic(%s)", msg->GetTopic().c_str());
         CheckReply(msg, data, datalen);
         *reply_ok = true;
         if (toggle)
@@ -193,38 +194,51 @@ class AITTRRTest : public testing::Test, public AittTests {
 
 TEST_F(AITTRRTest, RequestResponse_P_Anytime)
 {
-    bool sub_ok, reply_ok;
-    sub_ok = reply_ok = false;
-
-    try {
-        AITT aitt(clientId, LOCAL_IP, AittOption(true, false));
-        aitt.Connect();
+    std::vector<AittProtocol> protocols = {AITT_TYPE_MQTT, AITT_TYPE_TCP, AITT_TYPE_TCP_SECURE};
 
-        aitt.Subscribe(rr_topic.c_str(),
-              [&](AittMsg *msg, const void *data, const int datalen, void *cbdata) {
-                  CheckSubscribe(msg, data, datalen);
-                  aitt.SendReply(msg, reply.c_str(), reply.size());
-                  sub_ok = true;
-              });
+    for (AittProtocol &protocol : protocols) {
+        bool sub_ok, reply_ok;
+        sub_ok = reply_ok = false;
+        ready = false;
 
-        aitt.PublishWithReply(rr_topic.c_str(), message.c_str(), message.size(), AITT_TYPE_MQTT,
-              AITT_QOS_AT_MOST_ONCE, false,
-              std::bind(&AITTRRTest::CheckReplyCallback, GetHandle(), true, &reply_ok,
-                    std::placeholders::_1, std::placeholders::_2, std::placeholders::_3,
-                    std::placeholders::_4),
-              nullptr, correlation);
+        try {
+            AITT aitt(clientId, LOCAL_IP, AittOption(true, false));
+            aitt.Connect();
 
-        mainLoop.AddTimeout(CHECK_INTERVAL,
-              [&](MainLoopHandler::MainLoopResult result, int fd,
-                    MainLoopHandler::MainLoopData *data) -> int {
-                  return ReadyCheck(static_cast<AittTests *>(this));
-              });
-        IterateEventLoop();
+            aitt.Subscribe(
+                  rr_topic.c_str(),
+                  [&](AittMsg *msg, const void *data, const int datalen, void *cbdata) {
+                      DBG("Subscribe Callback");
+                      CheckSubscribe(msg, data, datalen);
+                      usleep(100 * SLEEP_MS);
+                      aitt.SendReply(msg, reply.c_str(), reply.size());
+                      sub_ok = true;
+                  },
+                  nullptr, protocol);
+
+            // Wait a few seconds until the AITT client gets a server list (discover devices)
+            usleep(100 * SLEEP_MS);
+
+            aitt.PublishWithReply(rr_topic.c_str(), message.c_str(), message.size(), protocol,
+                  AITT_QOS_AT_MOST_ONCE, false,
+                  std::bind(&AITTRRTest::CheckReplyCallback, GetHandle(), true, &reply_ok,
+                        std::placeholders::_1, std::placeholders::_2, std::placeholders::_3,
+                        std::placeholders::_4),
+                  nullptr, correlation);
+
+            mainLoop.AddTimeout(CHECK_INTERVAL,
+                  [&](MainLoopHandler::MainLoopResult result, int fd,
+                        MainLoopHandler::MainLoopData *data) -> int {
+                      DBG("Timeout Callback");
+                      return ReadyCheck(static_cast<AittTests *>(this));
+                  });
+            IterateEventLoop();
 
-        EXPECT_TRUE(sub_ok);
-        EXPECT_TRUE(reply_ok);
-    } catch (std::exception &e) {
-        FAIL() << e.what();
+            EXPECT_TRUE(sub_ok);
+            EXPECT_TRUE(reply_ok);
+        } catch (std::exception &e) {
+            FAIL() << e.what();
+        }
     }
 }
 
index 04e9d80..aee5ddc 100644 (file)
 
 #include <iostream>
 
+#include "AittTypes.h"
 #include "aitt_internal.h"
 
-FlexbufPrinter::FlexbufPrinter() : tab(0)
+FlexbufPrinter::FlexbufPrinter() : tab(0), type(AITT_TYPE_UNKNOWN)
 {
 }
 
@@ -54,7 +55,10 @@ void FlexbufPrinter::PrettyMap(const flexbuffers::Reference &data, bool inline_v
     auto keys = map.Keys();
     for (size_t i = 0; i < keys.size(); i++) {
         std::cout << PrettyTab(false) << keys[i].AsKey() << " : ";
+        if (keys[i].AsString().str() == "TCP")
+            type = AITT_TYPE_TCP;
         PrettyParsing(map[keys[i].AsKey()], true);
+        type = AITT_TYPE_UNKNOWN;
     }
 
     tab--;
@@ -77,10 +81,16 @@ void FlexbufPrinter::PrettyVector(const flexbuffers::Reference &data, bool inlin
 void FlexbufPrinter::PrettyBlob(const flexbuffers::Reference &data, bool inline_value)
 {
     auto blob = data.AsBlob();
-    DBG_HEX_DUMP(blob.data(), blob.size());
-    // auto root = flexbuffers::GetRoot(static_cast<const uint8_t *>(blob.data()), blob.size());
-
-    // PrettyParsing(root, true);
+    if (type == AITT_TYPE_TCP) {
+        auto root = flexbuffers::GetRoot(static_cast<const uint8_t *>(blob.data()), blob.size());
+        PrettyParsing(root, true);
+    } else if (type == AITT_TYPE_TCP_SECURE) {
+        auto root = flexbuffers::GetRoot(static_cast<const uint8_t *>(blob.data()), blob.size());
+        PrettyParsing(root, true);
+        type = AITT_TYPE_UNKNOWN;
+    } else {
+        DBG_HEX_DUMP(blob.data(), blob.size());
+    }
 }
 
 void FlexbufPrinter::PrettyParsing(const flexbuffers::Reference &data, bool inline_value)
index 46bfe51..ac77909 100644 (file)
@@ -31,4 +31,5 @@ class FlexbufPrinter {
     void PrettyParsing(const flexbuffers::Reference &data, bool inline_value);
 
     int tab;
+    int type;
 };