From: Youngjae Shin Date: Thu, 13 Oct 2022 05:56:41 +0000 (+0900) Subject: check negative sized message X-Git-Tag: accepted/tizen/unified/20221115.172906~16 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=981a7c67bf90a090af9d6d3d4799dba4745a228e;p=platform%2Fcore%2Fml%2Faitt.git check negative sized message --- diff --git a/modules/tcp/Module.cc b/modules/tcp/Module.cc index 8476a5f..5cc4791 100644 --- a/modules/tcp/Module.cc +++ b/modules/tcp/Module.cc @@ -64,6 +64,8 @@ void Module::ThreadMain(void) void Module::Publish(const std::string &topic, const void *data, const int datalen, const std::string &correlation, AittQoS qos, bool retain) { + RET_IF(datalen < 0); + // NOTE: // Iterate discovered service table // PublishMap @@ -329,9 +331,9 @@ void Module::ReceiveData(MainLoopHandler::MainLoopResult result, int handle, return; } - int ret = tcp_data->client->RecvSizedData((void **)&msg, szmsg); - if (ret < 0) { - ERR("Got a disconnection message."); + szmsg = tcp_data->client->RecvSizedData((void **)&msg); + if (szmsg < 0) { + ERR("Got a disconnection message(%d)", szmsg); return impl->HandleClientDisconnect(handle); } } catch (std::exception &e) { @@ -367,8 +369,8 @@ std::string Module::GetTopicName(Module::TCPData *tcp_data) { int32_t topic_length = 0; void *topic_data = nullptr; - int ret = tcp_data->client->RecvSizedData(&topic_data, topic_length); - if (ret < 0) { + topic_length = tcp_data->client->RecvSizedData(&topic_data); + if (topic_length < 0) { ERR("Got a disconnection message."); HandleClientDisconnect(tcp_data->client->GetHandle()); return std::string(); diff --git a/modules/tcp/TCP.cc b/modules/tcp/TCP.cc index 26d2c32..d387f08 100644 --- a/modules/tcp/TCP.cc +++ b/modules/tcp/TCP.cc @@ -114,7 +114,7 @@ void TCP::SetupOptions(const ConnectInfo &connect_info) } } -void TCP::Send(const void *data, int32_t &data_size) +int32_t TCP::Send(const void *data, int32_t data_size) { int32_t sent = 0; while (sent < data_size) { @@ -126,18 +126,20 @@ void TCP::Send(const void *data, int32_t &data_size) sent += ret; } - data_size = sent; + return sent; } -void TCP::SendSizedData(const void *data, int32_t &szData) +void TCP::SendSizedData(const void *data, int32_t data_size) { + RET_IF(data_size < 0); + if (secure) - SendSizedDataSecure(data, szData); + return SendSizedDataSecure(data, data_size); else - SendSizedDataNormal(data, szData); + return SendSizedDataNormal(data, data_size); } -int TCP::Recv(void *data, int32_t &data_size) +int32_t TCP::Recv(void *data, int32_t data_size) { int32_t received = 0; while (received < data_size) { @@ -154,23 +156,21 @@ int TCP::Recv(void *data, int32_t &data_size) received += ret; } - data_size = received; - return 0; + return received; } -int TCP::RecvSizedData(void **data, int32_t &szData) +int32_t TCP::RecvSizedData(void **data) { if (secure) - return RecvSizedDataSecure(data, szData); + return RecvSizedDataSecure(data); else - return RecvSizedDataNormal(data, szData); + return RecvSizedDataNormal(data); } -int TCP::HandleZeroMsg(void **data, int32_t &data_size) +int32_t TCP::HandleZeroMsg(void **data) { // distinguish between connection problems and zero-size messages INFO("Got a zero-size message."); - data_size = 0; *data = nullptr; return 0; } @@ -205,7 +205,7 @@ unsigned short TCP::GetPort(void) return ntohs(addr.sin_port); } -void TCP::SendSizedDataNormal(const void *data, int32_t &data_size) +void TCP::SendSizedDataNormal(const void *data, int32_t data_size) { int32_t fixed_data_size = data_size; if (0 == data_size) { @@ -219,20 +219,20 @@ void TCP::SendSizedDataNormal(const void *data, int32_t &data_size) Send(data, data_size); } -int TCP::RecvSizedDataNormal(void **data, int32_t &data_size) +int32_t TCP::RecvSizedDataNormal(void **data) { - int ret; + int32_t result; int32_t data_len = 0; int32_t size_len = sizeof(data_len); - ret = Recv(static_cast(&data_len), size_len); - if (ret < 0) { - ERR("Recv() Fail(%d)", ret); - return ret; + result = Recv(static_cast(&data_len), size_len); + if (result < 0) { + ERR("Recv() Fail(%d)", result); + return result; } if (data_len == INT32_MAX) - return HandleZeroMsg(data, data_size); + return HandleZeroMsg(data); if (AITT_MESSAGE_MAX < data_len) { ERR("Invalid Size(%d)", data_len); @@ -240,13 +240,12 @@ int TCP::RecvSizedDataNormal(void **data, int32_t &data_size) } void *data_buf = malloc(data_len); Recv(data_buf, data_len); - data_size = data_len; *data = data_buf; - return 0; + return data_len; } -void TCP::SendSizedDataSecure(const void *data, int32_t &data_size) +void TCP::SendSizedDataSecure(const void *data, int32_t data_size) { int32_t fixed_data_size = data_size; if (0 == data_size) { @@ -272,16 +271,16 @@ void TCP::SendSizedDataSecure(const void *data, int32_t &data_size) } } -int TCP::RecvSizedDataSecure(void **data, int32_t &data_size) +int32_t TCP::RecvSizedDataSecure(void **data) { - int ret; + int32_t result; int32_t cipher_size_len = crypto.GetCryptogramSize(sizeof(int32_t)); unsigned char cipher_size_buf[cipher_size_len]; - ret = Recv(cipher_size_buf, cipher_size_len); - if (ret < 0) { - ERR("Recv() Fail(%d)", ret); - return ret; + result = Recv(cipher_size_buf, cipher_size_len); + if (result < 0) { + ERR("Recv() Fail(%d)", result); + return result; } unsigned char plain_size_buf[cipher_size_len]; @@ -289,7 +288,7 @@ int TCP::RecvSizedDataSecure(void **data, int32_t &data_size) crypto.Decrypt(cipher_size_buf, cipher_size_len, plain_size_buf); memcpy(&cipher_data_len, plain_size_buf, sizeof(cipher_data_len)); if (cipher_data_len == INT32_MAX) - return HandleZeroMsg(data, data_size); + return HandleZeroMsg(data); if (AITT_MESSAGE_MAX < cipher_data_len) { ERR("Invalid Size(%d)", cipher_data_len); @@ -298,9 +297,9 @@ int TCP::RecvSizedDataSecure(void **data, int32_t &data_size) unsigned char cipher_data_buf[cipher_data_len]; Recv(cipher_data_buf, cipher_data_len); unsigned char *data_buf = static_cast(malloc(cipher_data_len)); - data_size = crypto.Decrypt(cipher_data_buf, cipher_data_len, data_buf); + result = crypto.Decrypt(cipher_data_buf, cipher_data_len, data_buf); *data = data_buf; - return 0; + return result; } TCP::ConnectInfo::ConnectInfo() : port(0), secure(false), key(), iv() diff --git a/modules/tcp/TCP.h b/modules/tcp/TCP.h index 6d73a4a..a626ebf 100644 --- a/modules/tcp/TCP.h +++ b/modules/tcp/TCP.h @@ -45,22 +45,24 @@ class TCP { TCP(const std::string &host, const ConnectInfo &ConnectInfo); virtual ~TCP(void); - void Send(const void *data, int32_t &szData); - void SendSizedData(const void *data, int32_t &szData); - int Recv(void *data, int32_t &szData); - int RecvSizedData(void **data, int32_t &szData); + void SendSizedData(const void *data, int32_t data_size); + int RecvSizedData(void **data); int GetHandle(void); unsigned short GetPort(void); void GetPeerInfo(std::string &host, unsigned short &port); + // For unittest, it's public + int32_t Send(const void *data, int32_t data_size); + int32_t Recv(void *data, int32_t szData); + private: TCP(int handle, sockaddr *addr, socklen_t addrlen, const ConnectInfo &connect_info); void SetupOptions(const ConnectInfo &connect_info); - int HandleZeroMsg(void **data, int32_t &data_size); - void SendSizedDataNormal(const void *data, int32_t &data_size); - int RecvSizedDataNormal(void **data, int32_t &data_size); - void SendSizedDataSecure(const void *data, int32_t &data_size); - int RecvSizedDataSecure(void **data, int32_t &data_size); + int HandleZeroMsg(void **data); + void SendSizedDataNormal(const void *data, int32_t data_size); + int RecvSizedDataNormal(void **data); + void SendSizedDataSecure(const void *data, int32_t data_size); + int RecvSizedDataSecure(void **data); int handle; socklen_t addrlen; diff --git a/src/AITT.cc b/src/AITT.cc index 83cf273..2b50694 100644 --- a/src/AITT.cc +++ b/src/AITT.cc @@ -74,9 +74,9 @@ void AITT::Disconnect(void) void AITT::Publish(const std::string &topic, const void *data, const int datalen, AittProtocol protocols, AittQoS qos, bool retain) { - if (AITT_PAYLOAD_MAX < datalen) { + if (datalen < 0 || AITT_PAYLOAD_MAX < datalen) { ERR("Invalid Size(%d)", datalen); - throw std::runtime_error("Invalid Size"); + throw AittException(AittException::INVALID_ARG); } return pImpl->Publish(topic, data, datalen, protocols, qos, retain); @@ -86,9 +86,9 @@ int AITT::PublishWithReply(const std::string &topic, const void *data, const int AittProtocol protocol, AittQoS qos, bool retain, const SubscribeCallback &cb, void *cbdata, const std::string &correlation) { - if (AITT_PAYLOAD_MAX < datalen) { + if (datalen < 0 || AITT_PAYLOAD_MAX < datalen) { ERR("Invalid Size(%d)", datalen); - throw std::runtime_error("Invalid Size"); + throw AittException(AittException::INVALID_ARG); } return pImpl->PublishWithReply(topic, data, datalen, protocol, qos, retain, cb, cbdata, @@ -99,9 +99,9 @@ int AITT::PublishWithReplySync(const std::string &topic, const void *data, const AittProtocol protocol, AittQoS qos, bool retain, const SubscribeCallback &cb, void *cbdata, const std::string &correlation, int timeout_ms) { - if (AITT_PAYLOAD_MAX < datalen) { + if (datalen < 0 || AITT_PAYLOAD_MAX < datalen) { ERR("Invalid Size(%d)", datalen); - throw std::runtime_error("Invalid Size"); + throw AittException(AittException::INVALID_ARG); } return pImpl->PublishWithReplySync(topic, data, datalen, protocol, qos, retain, cb, cbdata, @@ -121,9 +121,9 @@ void *AITT::Unsubscribe(AittSubscribeID handle) void AITT::SendReply(MSG *msg, const void *data, int datalen, bool end) { - if (AITT_PAYLOAD_MAX < datalen) { + if (datalen < 0 || AITT_PAYLOAD_MAX < datalen) { ERR("Invalid Size(%d)", datalen); - throw std::runtime_error("Invalid Size"); + throw AittException(AittException::INVALID_ARG); } return pImpl->SendReply(msg, data, datalen, end); diff --git a/tests/AITT_test.cc b/tests/AITT_test.cc index e0cf13d..efd4e09 100644 --- a/tests/AITT_test.cc +++ b/tests/AITT_test.cc @@ -369,6 +369,18 @@ TEST_F(AITTTest, Publish_SECURE_TCP_P_Anytime) } } +TEST_F(AITTTest, Publish_minus_size_N_Anytime) +{ + try { + AITT aitt(clientId, LOCAL_IP, AittOption(true, false)); + aitt.Connect(); + EXPECT_THROW(aitt.Publish(testTopic, TEST_MSG, -1, AITT_TYPE_TCP), aitt::AittException); + EXPECT_THROW(aitt.Publish(testTopic, TEST_MSG, -1, AITT_TYPE_MQTT), aitt::AittException); + } catch (std::exception &e) { + FAIL() << "Unexpected exception: " << e.what(); + } +} + TEST_F(AITTTest, Publish_Multiple_Protocols_P_Anytime) { try {