check negative sized message
authorYoungjae Shin <yj99.shin@samsung.com>
Thu, 13 Oct 2022 05:56:41 +0000 (14:56 +0900)
committerYoungjae Shin <yj99.shin@samsung.com>
Wed, 9 Nov 2022 08:17:08 +0000 (17:17 +0900)
modules/tcp/Module.cc
modules/tcp/TCP.cc
modules/tcp/TCP.h
src/AITT.cc
tests/AITT_test.cc

index 8476a5f..5cc4791 100644 (file)
@@ -64,6 +64,8 @@ void Module::ThreadMain(void)
 void Module::Publish(const std::string &topic, const void *data, const int datalen,
       const std::string &correlation, AittQoS qos, bool retain)
 {
+    RET_IF(datalen < 0);
+
     // NOTE:
     // Iterate discovered service table
     // PublishMap
@@ -329,9 +331,9 @@ void Module::ReceiveData(MainLoopHandler::MainLoopResult result, int handle,
             return;
         }
 
-        int ret = tcp_data->client->RecvSizedData((void **)&msg, szmsg);
-        if (ret < 0) {
-            ERR("Got a disconnection message.");
+        szmsg = tcp_data->client->RecvSizedData((void **)&msg);
+        if (szmsg < 0) {
+            ERR("Got a disconnection message(%d)", szmsg);
             return impl->HandleClientDisconnect(handle);
         }
     } catch (std::exception &e) {
@@ -367,8 +369,8 @@ std::string Module::GetTopicName(Module::TCPData *tcp_data)
 {
     int32_t topic_length = 0;
     void *topic_data = nullptr;
-    int ret = tcp_data->client->RecvSizedData(&topic_data, topic_length);
-    if (ret < 0) {
+    topic_length = tcp_data->client->RecvSizedData(&topic_data);
+    if (topic_length < 0) {
         ERR("Got a disconnection message.");
         HandleClientDisconnect(tcp_data->client->GetHandle());
         return std::string();
index 26d2c32..d387f08 100644 (file)
@@ -114,7 +114,7 @@ void TCP::SetupOptions(const ConnectInfo &connect_info)
     }
 }
 
-void TCP::Send(const void *data, int32_t &data_size)
+int32_t TCP::Send(const void *data, int32_t data_size)
 {
     int32_t sent = 0;
     while (sent < data_size) {
@@ -126,18 +126,20 @@ void TCP::Send(const void *data, int32_t &data_size)
 
         sent += ret;
     }
-    data_size = sent;
+    return sent;
 }
 
-void TCP::SendSizedData(const void *data, int32_t &szData)
+void TCP::SendSizedData(const void *data, int32_t data_size)
 {
+    RET_IF(data_size < 0);
+
     if (secure)
-        SendSizedDataSecure(data, szData);
+        return SendSizedDataSecure(data, data_size);
     else
-        SendSizedDataNormal(data, szData);
+        return SendSizedDataNormal(data, data_size);
 }
 
-int TCP::Recv(void *data, int32_t &data_size)
+int32_t TCP::Recv(void *data, int32_t data_size)
 {
     int32_t received = 0;
     while (received < data_size) {
@@ -154,23 +156,21 @@ int TCP::Recv(void *data, int32_t &data_size)
         received += ret;
     }
 
-    data_size = received;
-    return 0;
+    return received;
 }
 
-int TCP::RecvSizedData(void **data, int32_t &szData)
+int32_t TCP::RecvSizedData(void **data)
 {
     if (secure)
-        return RecvSizedDataSecure(data, szData);
+        return RecvSizedDataSecure(data);
     else
-        return RecvSizedDataNormal(data, szData);
+        return RecvSizedDataNormal(data);
 }
 
-int TCP::HandleZeroMsg(void **data, int32_t &data_size)
+int32_t TCP::HandleZeroMsg(void **data)
 {
     // distinguish between connection problems and zero-size messages
     INFO("Got a zero-size message.");
-    data_size = 0;
     *data = nullptr;
     return 0;
 }
@@ -205,7 +205,7 @@ unsigned short TCP::GetPort(void)
     return ntohs(addr.sin_port);
 }
 
-void TCP::SendSizedDataNormal(const void *data, int32_t &data_size)
+void TCP::SendSizedDataNormal(const void *data, int32_t data_size)
 {
     int32_t fixed_data_size = data_size;
     if (0 == data_size) {
@@ -219,20 +219,20 @@ void TCP::SendSizedDataNormal(const void *data, int32_t &data_size)
     Send(data, data_size);
 }
 
-int TCP::RecvSizedDataNormal(void **data, int32_t &data_size)
+int32_t TCP::RecvSizedDataNormal(void **data)
 {
-    int ret;
+    int32_t result;
 
     int32_t data_len = 0;
     int32_t size_len = sizeof(data_len);
-    ret = Recv(static_cast<void *>(&data_len), size_len);
-    if (ret < 0) {
-        ERR("Recv() Fail(%d)", ret);
-        return ret;
+    result = Recv(static_cast<void *>(&data_len), size_len);
+    if (result < 0) {
+        ERR("Recv() Fail(%d)", result);
+        return result;
     }
 
     if (data_len == INT32_MAX)
-        return HandleZeroMsg(data, data_size);
+        return HandleZeroMsg(data);
 
     if (AITT_MESSAGE_MAX < data_len) {
         ERR("Invalid Size(%d)", data_len);
@@ -240,13 +240,12 @@ int TCP::RecvSizedDataNormal(void **data, int32_t &data_size)
     }
     void *data_buf = malloc(data_len);
     Recv(data_buf, data_len);
-    data_size = data_len;
     *data = data_buf;
 
-    return 0;
+    return data_len;
 }
 
-void TCP::SendSizedDataSecure(const void *data, int32_t &data_size)
+void TCP::SendSizedDataSecure(const void *data, int32_t data_size)
 {
     int32_t fixed_data_size = data_size;
     if (0 == data_size) {
@@ -272,16 +271,16 @@ void TCP::SendSizedDataSecure(const void *data, int32_t &data_size)
     }
 }
 
-int TCP::RecvSizedDataSecure(void **data, int32_t &data_size)
+int32_t TCP::RecvSizedDataSecure(void **data)
 {
-    int ret;
+    int32_t result;
 
     int32_t cipher_size_len = crypto.GetCryptogramSize(sizeof(int32_t));
     unsigned char cipher_size_buf[cipher_size_len];
-    ret = Recv(cipher_size_buf, cipher_size_len);
-    if (ret < 0) {
-        ERR("Recv() Fail(%d)", ret);
-        return ret;
+    result = Recv(cipher_size_buf, cipher_size_len);
+    if (result < 0) {
+        ERR("Recv() Fail(%d)", result);
+        return result;
     }
 
     unsigned char plain_size_buf[cipher_size_len];
@@ -289,7 +288,7 @@ int TCP::RecvSizedDataSecure(void **data, int32_t &data_size)
     crypto.Decrypt(cipher_size_buf, cipher_size_len, plain_size_buf);
     memcpy(&cipher_data_len, plain_size_buf, sizeof(cipher_data_len));
     if (cipher_data_len == INT32_MAX)
-        return HandleZeroMsg(data, data_size);
+        return HandleZeroMsg(data);
 
     if (AITT_MESSAGE_MAX < cipher_data_len) {
         ERR("Invalid Size(%d)", cipher_data_len);
@@ -298,9 +297,9 @@ int TCP::RecvSizedDataSecure(void **data, int32_t &data_size)
     unsigned char cipher_data_buf[cipher_data_len];
     Recv(cipher_data_buf, cipher_data_len);
     unsigned char *data_buf = static_cast<unsigned char *>(malloc(cipher_data_len));
-    data_size = crypto.Decrypt(cipher_data_buf, cipher_data_len, data_buf);
+    result = crypto.Decrypt(cipher_data_buf, cipher_data_len, data_buf);
     *data = data_buf;
-    return 0;
+    return result;
 }
 
 TCP::ConnectInfo::ConnectInfo() : port(0), secure(false), key(), iv()
index 6d73a4a..a626ebf 100644 (file)
@@ -45,22 +45,24 @@ class TCP {
     TCP(const std::string &host, const ConnectInfo &ConnectInfo);
     virtual ~TCP(void);
 
-    void Send(const void *data, int32_t &szData);
-    void SendSizedData(const void *data, int32_t &szData);
-    int Recv(void *data, int32_t &szData);
-    int RecvSizedData(void **data, int32_t &szData);
+    void SendSizedData(const void *data, int32_t data_size);
+    int RecvSizedData(void **data);
     int GetHandle(void);
     unsigned short GetPort(void);
     void GetPeerInfo(std::string &host, unsigned short &port);
 
+    // For unittest, it's public
+    int32_t Send(const void *data, int32_t data_size);
+    int32_t Recv(void *data, int32_t szData);
+
   private:
     TCP(int handle, sockaddr *addr, socklen_t addrlen, const ConnectInfo &connect_info);
     void SetupOptions(const ConnectInfo &connect_info);
-    int HandleZeroMsg(void **data, int32_t &data_size);
-    void SendSizedDataNormal(const void *data, int32_t &data_size);
-    int RecvSizedDataNormal(void **data, int32_t &data_size);
-    void SendSizedDataSecure(const void *data, int32_t &data_size);
-    int RecvSizedDataSecure(void **data, int32_t &data_size);
+    int HandleZeroMsg(void **data);
+    void SendSizedDataNormal(const void *data, int32_t data_size);
+    int RecvSizedDataNormal(void **data);
+    void SendSizedDataSecure(const void *data, int32_t data_size);
+    int RecvSizedDataSecure(void **data);
 
     int handle;
     socklen_t addrlen;
index 83cf273..2b50694 100644 (file)
@@ -74,9 +74,9 @@ void AITT::Disconnect(void)
 void AITT::Publish(const std::string &topic, const void *data, const int datalen,
       AittProtocol protocols, AittQoS qos, bool retain)
 {
-    if (AITT_PAYLOAD_MAX < datalen) {
+    if (datalen < 0 || AITT_PAYLOAD_MAX < datalen) {
         ERR("Invalid Size(%d)", datalen);
-        throw std::runtime_error("Invalid Size");
+        throw AittException(AittException::INVALID_ARG);
     }
 
     return pImpl->Publish(topic, data, datalen, protocols, qos, retain);
@@ -86,9 +86,9 @@ int AITT::PublishWithReply(const std::string &topic, const void *data, const int
       AittProtocol protocol, AittQoS qos, bool retain, const SubscribeCallback &cb, void *cbdata,
       const std::string &correlation)
 {
-    if (AITT_PAYLOAD_MAX < datalen) {
+    if (datalen < 0 || AITT_PAYLOAD_MAX < datalen) {
         ERR("Invalid Size(%d)", datalen);
-        throw std::runtime_error("Invalid Size");
+        throw AittException(AittException::INVALID_ARG);
     }
 
     return pImpl->PublishWithReply(topic, data, datalen, protocol, qos, retain, cb, cbdata,
@@ -99,9 +99,9 @@ int AITT::PublishWithReplySync(const std::string &topic, const void *data, const
       AittProtocol protocol, AittQoS qos, bool retain, const SubscribeCallback &cb, void *cbdata,
       const std::string &correlation, int timeout_ms)
 {
-    if (AITT_PAYLOAD_MAX < datalen) {
+    if (datalen < 0 || AITT_PAYLOAD_MAX < datalen) {
         ERR("Invalid Size(%d)", datalen);
-        throw std::runtime_error("Invalid Size");
+        throw AittException(AittException::INVALID_ARG);
     }
 
     return pImpl->PublishWithReplySync(topic, data, datalen, protocol, qos, retain, cb, cbdata,
@@ -121,9 +121,9 @@ void *AITT::Unsubscribe(AittSubscribeID handle)
 
 void AITT::SendReply(MSG *msg, const void *data, int datalen, bool end)
 {
-    if (AITT_PAYLOAD_MAX < datalen) {
+    if (datalen < 0 || AITT_PAYLOAD_MAX < datalen) {
         ERR("Invalid Size(%d)", datalen);
-        throw std::runtime_error("Invalid Size");
+        throw AittException(AittException::INVALID_ARG);
     }
 
     return pImpl->SendReply(msg, data, datalen, end);
index e0cf13d..efd4e09 100644 (file)
@@ -369,6 +369,18 @@ TEST_F(AITTTest, Publish_SECURE_TCP_P_Anytime)
     }
 }
 
+TEST_F(AITTTest, Publish_minus_size_N_Anytime)
+{
+    try {
+        AITT aitt(clientId, LOCAL_IP, AittOption(true, false));
+        aitt.Connect();
+        EXPECT_THROW(aitt.Publish(testTopic, TEST_MSG, -1, AITT_TYPE_TCP), aitt::AittException);
+        EXPECT_THROW(aitt.Publish(testTopic, TEST_MSG, -1, AITT_TYPE_MQTT), aitt::AittException);
+    } catch (std::exception &e) {
+        FAIL() << "Unexpected exception: " << e.what();
+    }
+}
+
 TEST_F(AITTTest, Publish_Multiple_Protocols_P_Anytime)
 {
     try {