From: Youngjae Shin Date: Tue, 20 Sep 2022 01:13:30 +0000 (+0900) Subject: [refactoring] revise architecture of TCP module X-Git-Tag: accepted/tizen/unified/20220921.091818~1 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=56637ea5cf120c8f84d7002229e7bb8f7fc44ba8;p=platform%2Fcore%2Fml%2Faitt.git [refactoring] revise architecture of TCP module - revise module manager - revise encryption module --- diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b4146b..dab5e51 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ CMAKE_MINIMUM_REQUIRED(VERSION 3.4.1) SET(CMAKE_SKIP_BUILD_RPATH true) -PROJECT(aitt VERSION 0.0.1 LANGUAGES CXX) +PROJECT(aitt VERSION 0.0.1 LANGUAGES CXX C) SET_PROPERTY(GLOBAL PROPERTY GLOBAL_DEPENDS_DEBUG_MODE 0) SET(CMAKE_POSITION_INDEPENDENT_CODE TRUE) SET(CMAKE_CXX_STANDARD 11) @@ -25,10 +25,10 @@ ELSE(PLATFORM STREQUAL "android") ADD_DEFINITIONS(-DTIZEN) ADD_DEFINITIONS(-DPLATFORM=${PLATFORM}) SET(ADDITIONAL_OPT "-DTIZEN") - SET(PKGS dlog) + SET(TIZEN_LOG_PKG dlog) ENDIF(PLATFORM STREQUAL "tizen") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-psabi -fdiagnostics-color -fvisibility=hidden") - PKG_CHECK_MODULES(AITT_NEEDS REQUIRED ${PKGS} libmosquitto flatbuffers glib-2.0) + PKG_CHECK_MODULES(AITT_NEEDS REQUIRED ${TIZEN_LOG_PKG} libmosquitto flatbuffers glib-2.0) INCLUDE_DIRECTORIES(${AITT_NEEDS_INCLUDE_DIRS}) LINK_DIRECTORIES(${AITT_NEEDS_LIBRARY_DIRS}) ENDIF(PLATFORM STREQUAL "android") @@ -52,7 +52,11 @@ INCLUDE_DIRECTORIES(include common) AUX_SOURCE_DIRECTORY(src AITT_SRC) -ADD_LIBRARY(${PROJECT_NAME} SHARED ${AITT_SRC}) +SET(MODULE_MAN_SRC src/ModuleManager.cc src/NullTransport.cc) +ADD_LIBRARY(MODULE_MANAGER OBJECT ${MODULE_MAN_SRC}) +list(REMOVE_ITEM AITT_SRC ${MODULE_MAN_SRC}) + +ADD_LIBRARY(${PROJECT_NAME} SHARED ${AITT_SRC} $) TARGET_LINK_LIBRARIES(${PROJECT_NAME} Threads::Threads ${CMAKE_DL_LIBS} ${AITT_COMMON}) TARGET_LINK_LIBRARIES(${PROJECT_NAME} ${AITT_NEEDS_LIBRARIES}) diff --git a/android/aitt/build.gradle b/android/aitt/build.gradle index dfa7d7d..1f56b8b 100644 --- a/android/aitt/build.gradle +++ b/android/aitt/build.gradle @@ -3,7 +3,7 @@ plugins { id "de.undercouch.download" version "5.0.1" } -def thirdPartyDir = new File ("${rootProject.projectDir}/third_party") +def thirdPartyDir = new File("${rootProject.projectDir}/third_party") def flatbuffersDir = new File("${thirdPartyDir}/flatbuffers-2.0.0") def mosquittoDir = new File("${thirdPartyDir}/mosquitto-2.0.14") @@ -43,7 +43,9 @@ android { path file('./CMakeLists.txt') } } - + buildFeatures { + prefab true + } buildTypes { debug { debuggable true @@ -87,6 +89,8 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.4.1' implementation 'com.google.flatbuffers:flatbuffers-java:2.0.0' + implementation 'com.android.ndk.thirdparty:openssl:1.1.1g-alpha-1' + implementation project(path: ':android:modules:tcp') implementation project(path: ':android:modules:webrtc') @@ -151,7 +155,7 @@ task jacocoTestReport(type: JacocoReport, dependsOn: ['testDebugUnitTest']) { } def fileFilter = ['**/R.class', '**/R$*.class', '**/BuildConfig.*', '**/Manifest*.*', '**/*Test*.*'] - def debugTree = fileTree(dir : "${buildDir}/intermediates/javac/debug", excludes: fileFilter) + def debugTree = fileTree(dir: "${buildDir}/intermediates/javac/debug", excludes: fileFilter) def mainSrc = "${project.projectDir}/src/main/java" diff --git a/common/AittDiscovery.cc b/common/AittDiscovery.cc index 348671c..efc54fe 100644 --- a/common/AittDiscovery.cc +++ b/common/AittDiscovery.cc @@ -20,16 +20,19 @@ #include #include "AittException.h" -#include "MQProxy.h" #include "aitt_internal.h" namespace aitt { -AittDiscovery::AittDiscovery(const std::string &id, const AittOption &option) - : id_(id), discovery_mq(new MQProxy(id + "d", option)), callback_handle(nullptr) +AittDiscovery::AittDiscovery(const std::string &id) : id_(id), callback_handle(nullptr) { } +void AittDiscovery::SetMQ(std::unique_ptr mq) +{ + discovery_mq = std::move(mq); +} + void AittDiscovery::Start(const std::string &host, int port, const std::string &username, const std::string &password) { @@ -151,6 +154,8 @@ const char *AittDiscovery::GetProtocolStr(AittProtocol protocol) return "mqtt"; case AITT_TYPE_TCP: return "tcp"; + case AITT_TYPE_TCP_SECURE: + return "tcp_secure"; case AITT_TYPE_WEBRTC: return "webrtc"; default: @@ -168,6 +173,9 @@ AittProtocol AittDiscovery::GetProtocol(const std::string &protocol_str) if (STR_EQ == protocol_str.compare(GetProtocolStr(AITT_TYPE_TCP))) return AITT_TYPE_TCP; + if (STR_EQ == protocol_str.compare(GetProtocolStr(AITT_TYPE_TCP_SECURE))) + return AITT_TYPE_TCP_SECURE; + if (STR_EQ == protocol_str.compare(GetProtocolStr(AITT_TYPE_WEBRTC))) return AITT_TYPE_WEBRTC; diff --git a/common/AittDiscovery.h b/common/AittDiscovery.h index 46f3080..d26892e 100644 --- a/common/AittDiscovery.h +++ b/common/AittDiscovery.h @@ -30,7 +30,9 @@ class AittDiscovery { using DiscoveryCallback = std::function; - explicit AittDiscovery(const std::string &id, const AittOption &option); + // AittDiscovery() = default; + explicit AittDiscovery(const std::string &id); + void SetMQ(std::unique_ptr mq); void Start(const std::string &host, int port, const std::string &username, const std::string &password); void Stop(); diff --git a/common/AittTransport.h b/common/AittTransport.h index 7aa6730..048beeb 100644 --- a/common/AittTransport.h +++ b/common/AittTransport.h @@ -29,31 +29,36 @@ namespace aitt { class AittTransport { public: - typedef void *(*ModuleEntry)(AittProtocol protocol, const char *ip, AittDiscovery &discovery); + typedef void *( + *ModuleEntry)(AittProtocol type, AittDiscovery &discovery, const std::string &my_ip); using SubscribeCallback = std::function; static constexpr const char *const MODULE_ENTRY_NAME = DEFINE_TO_STR(AITT_TRANSPORT_NEW); - explicit AittTransport(AittDiscovery &discovery) : discovery(discovery) {} + explicit AittTransport(AittProtocol type, AittDiscovery &discovery) + : protocol(type), discovery(discovery) + { + } virtual ~AittTransport(void) = default; virtual void Publish(const std::string &topic, const void *data, const size_t datalen, - AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) = 0; - - virtual void Publish(const std::string &topic, const void *data, const size_t datalen, const std::string &correlation, AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) = 0; + virtual void Publish(const std::string &topic, const void *data, const size_t datalen, + AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) = 0; + virtual void *Subscribe(const std::string &topic, const SubscribeCallback &cb, void *cbdata = nullptr, AittQoS qos = AITT_QOS_AT_MOST_ONCE) = 0; - virtual void *Subscribe(const std::string &topic, const SubscribeCallback &cb, const void *data, const size_t datalen, void *cbdata = nullptr, AittQoS qos = AITT_QOS_AT_MOST_ONCE) = 0; virtual void *Unsubscribe(void *handle) = 0; + AittProtocol GetProtocol() { return protocol; } protected: + AittProtocol protocol; AittDiscovery &discovery; }; diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 53eadd9..a908292 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -2,7 +2,7 @@ FILE(GLOB COMMON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) ADD_LIBRARY(${AITT_COMMON} SHARED ${COMMON_SRCS}) TARGET_LINK_LIBRARIES(${AITT_COMMON} ${AITT_NEEDS_LIBRARIES} Threads::Threads) -TARGET_COMPILE_OPTIONS(${AITT_COMMON} PUBLIC ${AITT_NEEDS_CFLAGS_OTHER} "-fvisibility=default") +TARGET_COMPILE_OPTIONS(${AITT_COMMON} PRIVATE ${AITT_NEEDS_CFLAGS_OTHER} "-fvisibility=default") IF(VERSIONING) SET_TARGET_PROPERTIES(${AITT_COMMON} PROPERTIES VERSION ${PROJECT_VERSION} diff --git a/common/MQProxy.cc b/common/MQProxy.cc deleted file mode 100644 index 628b3f5..0000000 --- a/common/MQProxy.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "MQProxy.h" - -#include "ModuleLoader.h" -#include "MosquittoMQ.h" -#include "aitt_internal.h" - -namespace aitt { - -MQProxy::MQProxy(const std::string &id, const AittOption &option) : handle(nullptr, nullptr) -{ - if (option.GetUseCustomMqttBroker()) { - ModuleLoader loader; - handle = loader.OpenModule(ModuleLoader::TYPE_CUSTOM_MQTT); - - mq = loader.LoadMqttClient(handle.get(), id, option); - INFO("Custom MQ(%p)", mq.get()); - } else { - mq = std::unique_ptr(new MosquittoMQ(id, option.GetClearSession())); - INFO("Mosquitto MQ"); - } -} - -void MQProxy::SetConnectionCallback(const MQConnectionCallback &cb) -{ - mq->SetConnectionCallback(cb); -} - -void MQProxy::Connect(const std::string &host, int port, const std::string &username, - const std::string &password) -{ - mq->Connect(host, port, username, password); -} - -void MQProxy::SetWillInfo(const std::string &topic, const void *msg, size_t szmsg, int qos, - bool retain) -{ - mq->SetWillInfo(topic, msg, szmsg, qos, retain); -} - -void MQProxy::Disconnect(void) -{ - mq->Disconnect(); -} - -void MQProxy::Publish(const std::string &topic, const void *data, const size_t datalen, int qos, - bool retain) -{ - mq->Publish(topic, data, datalen, qos, retain); -} - -void MQProxy::PublishWithReply(const std::string &topic, const void *data, const size_t datalen, - int qos, bool retain, const std::string &reply_topic, const std::string &correlation) -{ - mq->PublishWithReply(topic, data, datalen, qos, retain, reply_topic, correlation); -} - -void MQProxy::SendReply(MSG *msg, const void *data, const size_t datalen, int qos, bool retain) -{ - mq->SendReply(msg, data, datalen, qos, retain); -} - -void *MQProxy::Subscribe(const std::string &topic, const SubscribeCallback &cb, void *user_data, - int qos) -{ - return mq->Subscribe(topic, cb, user_data, qos); -} - -void *MQProxy::Unsubscribe(void *handle) -{ - return mq->Unsubscribe(handle); -} - -} // namespace aitt diff --git a/common/MQProxy.h b/common/MQProxy.h deleted file mode 100644 index b6edb91..0000000 --- a/common/MQProxy.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "MQ.h" - -namespace aitt { - -class MQProxy : public MQ { - public: - explicit MQProxy(const std::string &id, const AittOption &option); - virtual ~MQProxy() = default; - - void SetConnectionCallback(const MQConnectionCallback &cb); - void Connect(const std::string &host, int port, const std::string &username, - const std::string &password); - void SetWillInfo(const std::string &topic, const void *msg, size_t szmsg, int qos, bool retain); - void Disconnect(void); - void Publish(const std::string &topic, const void *data, const size_t datalen, int qos = 0, - bool retain = false); - void PublishWithReply(const std::string &topic, const void *data, const size_t datalen, int qos, - bool retain, const std::string &reply_topic, const std::string &correlation); - void SendReply(MSG *msg, const void *data, const size_t datalen, int qos, bool retain); - void *Subscribe(const std::string &topic, const SubscribeCallback &cb, - void *user_data = nullptr, int qos = 0); - void *Unsubscribe(void *handle); - - private: - std::unique_ptr handle; - std::unique_ptr mq; -}; - -} // namespace aitt diff --git a/common/ModuleLoader.cc b/common/ModuleLoader.cc deleted file mode 100644 index c00a5b0..0000000 --- a/common/ModuleLoader.cc +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ModuleLoader.h" - -#include - -#include "AittException.h" -#include "MQ.h" -#include "NullTransport.h" -#include "aitt_internal.h" - -namespace aitt { - -std::string ModuleLoader::GetModuleFilename(Type type) -{ - if (type == TYPE_TCP || type == TYPE_SECURE_TCP) - return "libaitt-transport-tcp.so"; - if (type == TYPE_WEBRTC) - return "libaitt-transport-webrtc.so"; - if (type == TYPE_CUSTOM_MQTT) - return "libaitt-st-broker.so"; - - return std::string("Unknown"); -} - -ModuleLoader::ModuleHandle ModuleLoader::OpenModule(Type type) -{ - std::string filename = GetModuleFilename(type); - - ModuleHandle handle(dlopen(filename.c_str(), RTLD_LAZY | RTLD_LOCAL), - [](const void *handle) -> void { - if (dlclose(const_cast(handle))) - ERR("dlclose: %s", dlerror()); - }); - if (handle == nullptr) - ERR("dlopen(%s): %s", filename.c_str(), dlerror()); - - return handle; -} - -std::unique_ptr ModuleLoader::LoadTransport( - void *handle, AittProtocol protocol, const std::string &ip, AittDiscovery &discovery) -{ - if (handle == nullptr) { - ERR("handle is NULL"); - return std::unique_ptr( - new NullTransport(ip.c_str(), discovery)); - } - - AittTransport::ModuleEntry get_instance_fn = reinterpret_cast( - dlsym(handle, AittTransport::MODULE_ENTRY_NAME)); - if (get_instance_fn == nullptr) { - ERR("dlsym: %s", dlerror()); - return std::unique_ptr( - new NullTransport(ip.c_str(), discovery)); - } - - std::unique_ptr instance( - static_cast(get_instance_fn(protocol, ip.c_str(), discovery))); - if (instance == nullptr) { - ERR("get_instance_fn(AittTransport) Fail"); - return std::unique_ptr( - new NullTransport(ip.c_str(), discovery)); - } - - return instance; -} - -std::unique_ptr ModuleLoader::LoadMqttClient(void *handle, const std::string &id, - const AittOption &option) -{ - MQ::ModuleEntry get_instance_fn = - reinterpret_cast(dlsym(handle, MQ::MODULE_ENTRY_NAME)); - if (get_instance_fn == nullptr) { - ERR("dlsym: %s", dlerror()); - throw AittException(AittException::SYSTEM_ERR); - } - - std::unique_ptr instance(static_cast(get_instance_fn(id.c_str(), option))); - if (instance == nullptr) { - ERR("get_instance_fn(MQ) Fail"); - throw AittException(AittException::SYSTEM_ERR); - } - - return instance; -} - -AittProtocol ModuleLoader::GetProtocol(Type type) -{ - switch (type) { - case TYPE_TCP: - return AITT_TYPE_TCP; - case TYPE_SECURE_TCP: - return AITT_TYPE_SECURE_TCP; - case TYPE_WEBRTC: - return AITT_TYPE_WEBRTC; - case TYPE_RTSP: - default: - return AITT_TYPE_UNKNOWN; - } -} - -} // namespace aitt diff --git a/common/ModuleLoader.h b/common/ModuleLoader.h deleted file mode 100644 index d564ab9..0000000 --- a/common/ModuleLoader.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include - -#include "AittTransport.h" -#include "AittTypes.h" -#include "MQ.h" - -namespace aitt { - -class ModuleLoader { - public: - enum Type { - TYPE_TCP, - TYPE_SECURE_TCP, - TYPE_WEBRTC, - TYPE_RTSP, - TYPE_TRANSPORT_MAX, - TYPE_CUSTOM_MQTT, - }; - - using ModuleHandle = std::unique_ptr; - - ModuleLoader() = default; - virtual ~ModuleLoader() = default; - - ModuleHandle OpenModule(Type type); - std::unique_ptr LoadTransport( - void *handle, AittProtocol protocol, const std::string &ip, AittDiscovery &discovery); - std::unique_ptr LoadMqttClient(void *handle, const std::string &id, - const AittOption &option); - AittProtocol GetProtocol(Type type); - - private: - std::string GetModuleFilename(Type type); - - std::string ip; -}; - -} // namespace aitt diff --git a/common/MosquittoMQ.cc b/common/MosquittoMQ.cc index 1eb0d17..b25f4c3 100644 --- a/common/MosquittoMQ.cc +++ b/common/MosquittoMQ.cc @@ -376,7 +376,7 @@ void *MosquittoMQ::Unsubscribe(void *sub_handle) int mid = -1; int ret = mosquitto_unsubscribe(handle, &mid, topic.c_str()); if (ret != MOSQ_ERR_SUCCESS) { - ERR("mosquitto_unsubscribe(%s) Fail(%d)", topic.c_str(), ret); + ERR("mosquitto_unsubscribe(%s) Fail(%s)", topic.c_str(), mosquitto_strerror(ret)); throw AittException(AittException::MQTT_ERR); } diff --git a/common/aitt_internal.h b/common/aitt_internal.h index feef42b..7e782b2 100644 --- a/common/aitt_internal.h +++ b/common/aitt_internal.h @@ -67,6 +67,16 @@ PLATFORM_LOGE("[%lu] (%d:%s) \033[31m" fmt "\033[0m", GETTID(), _errno, errMsg, \ ##__VA_ARGS__); \ } while (0) + +#define DBG_HEX_DUMP(data, len) \ + do { \ + size_t i; \ + char dump[len * 3]; \ + for (i = 0; i < len; i++) { \ + snprintf(dump + i * 3, (len * 3) - (i * 3), "%02X ", data[i]); \ + } \ + DBG("%s", dump); \ + } while (0) #endif #define RET_IF(expr) \ diff --git a/include/AITT.h b/include/AITT.h index 19a7239..42ee6e2 100644 --- a/include/AITT.h +++ b/include/AITT.h @@ -49,7 +49,6 @@ class API AITT { void Publish(const std::string &topic, const void *data, const size_t datalen, AittProtocol protocols = AITT_TYPE_MQTT, AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false); - int PublishWithReply(const std::string &topic, const void *data, const size_t datalen, AittProtocol protocol, AittQoS qos, bool retain, const SubscribeCallback &cb, void *cbdata, const std::string &correlation); @@ -61,7 +60,6 @@ class API AITT { AittSubscribeID Subscribe(const std::string &topic, const SubscribeCallback &cb, void *cbdata = nullptr, AittProtocol protocol = AITT_TYPE_MQTT, AittQoS qos = AITT_QOS_AT_MOST_ONCE); - void *Unsubscribe(AittSubscribeID handle); void SendReply(MSG *msg, const void *data, const size_t datalen, bool end = true); diff --git a/include/AittTypes.h b/include/AittTypes.h index 23fc51e..b9d22c8 100644 --- a/include/AittTypes.h +++ b/include/AittTypes.h @@ -23,7 +23,7 @@ enum AittProtocol { AITT_TYPE_UNKNOWN = 0, AITT_TYPE_MQTT = (0x1 << 0), // Publish message through the MQTT AITT_TYPE_TCP = (0x1 << 1), // Publish message to peers using the TCP - AITT_TYPE_SECURE_TCP = (0x1 << 2), // Publish message to peers using the TCP with AES + AITT_TYPE_TCP_SECURE = (0x1 << 2), // Publish message to peers using the Secure TCP AITT_TYPE_WEBRTC = (0x1 << 3), // Publish message to peers using the WEBRTC }; @@ -40,6 +40,9 @@ enum AittConnectionState { AITT_CONNECT_FAILED = 2, // Failed to connect to the mqtt broker. }; +// The maximum size in bytes of a message. It follows MQTT +#define AITT_MESSAGE_MAX 268435455 + #ifdef TIZEN #include #define TIZEN_ERROR_AITT -0x04020000 diff --git a/mock/mosquitto.cc b/mock/mosquitto.cc index c2f1581..5f39dd4 100644 --- a/mock/mosquitto.cc +++ b/mock/mosquitto.cc @@ -17,103 +17,105 @@ #include "MQMockTest.h" #include "MQTTMock.h" +#include "aitt_internal.h" MQTTMock *MQMockTest::mqttMock = nullptr; extern "C" { -int mosquitto_lib_init(void) +API int mosquitto_lib_init(void) { return MQMockTest::GetMock().mosquitto_lib_init(); } -int mosquitto_lib_cleanup(void) +API int mosquitto_lib_cleanup(void) { return MQMockTest::GetMock().mosquitto_lib_cleanup(); } -struct mosquitto *mosquitto_new(const char *id, bool clean_session, void *obj) +API struct mosquitto *mosquitto_new(const char *id, bool clean_session, void *obj) { return MQMockTest::GetMock().mosquitto_new(id, clean_session, obj); } -int mosquitto_int_option(struct mosquitto *mosq, enum mosq_opt_t option, int value) +API int mosquitto_int_option(struct mosquitto *mosq, enum mosq_opt_t option, int value) { return MQMockTest::GetMock().mosquitto_int_option(mosq, option, value); } -void mosquitto_destroy(struct mosquitto *mosq) +API void mosquitto_destroy(struct mosquitto *mosq) { return MQMockTest::GetMock().mosquitto_destroy(mosq); } -int mosquitto_username_pw_set(struct mosquitto *mosq, const char *username, const char *password) +API int mosquitto_username_pw_set(struct mosquitto *mosq, const char *username, + const char *password) { return MQMockTest::GetMock().mosquitto_username_pw_set(mosq, username, password); } -int mosquitto_will_set(struct mosquitto *mosq, const char *topic, int payloadlen, +API int mosquitto_will_set(struct mosquitto *mosq, const char *topic, int payloadlen, const void *payload, int qos, bool retain) { return MQMockTest::GetMock().mosquitto_will_set(mosq, topic, payloadlen, payload, qos, retain); } -int mosquitto_will_clear(struct mosquitto *mosq) +API int mosquitto_will_clear(struct mosquitto *mosq) { return MQMockTest::GetMock().mosquitto_will_clear(mosq); } -int mosquitto_connect(struct mosquitto *mosq, const char *host, int port, int keepalive) +API int mosquitto_connect(struct mosquitto *mosq, const char *host, int port, int keepalive) { return MQMockTest::GetMock().mosquitto_connect(mosq, host, port, keepalive); } -int mosquitto_disconnect(struct mosquitto *mosq) +API int mosquitto_disconnect(struct mosquitto *mosq) { return MQMockTest::GetMock().mosquitto_disconnect(mosq); } -int mosquitto_publish(struct mosquitto *mosq, int *mid, const char *topic, int payloadlen, +API int mosquitto_publish(struct mosquitto *mosq, int *mid, const char *topic, int payloadlen, const void *payload, int qos, bool retain) { return MQMockTest::GetMock().mosquitto_publish(mosq, mid, topic, payloadlen, payload, qos, retain); } -int mosquitto_subscribe(struct mosquitto *mosq, int *mid, const char *sub, int qos) +API int mosquitto_subscribe(struct mosquitto *mosq, int *mid, const char *sub, int qos) { return MQMockTest::GetMock().mosquitto_subscribe(mosq, mid, sub, qos); } -int mosquitto_unsubscribe(struct mosquitto *mosq, int *mid, const char *sub) +API int mosquitto_unsubscribe(struct mosquitto *mosq, int *mid, const char *sub) { return MQMockTest::GetMock().mosquitto_unsubscribe(mosq, mid, sub); } -int mosquitto_loop_start(struct mosquitto *mosq) +API int mosquitto_loop_start(struct mosquitto *mosq) { return MQMockTest::GetMock().mosquitto_loop_start(mosq); } -int mosquitto_loop_stop(struct mosquitto *mosq, bool force) +API int mosquitto_loop_stop(struct mosquitto *mosq, bool force) { return MQMockTest::GetMock().mosquitto_loop_stop(mosq, force); } -void mosquitto_message_v5_callback_set(struct mosquitto *mosq, +API void mosquitto_message_v5_callback_set(struct mosquitto *mosq, void (*on_message)(struct mosquitto *, void *, const struct mosquitto_message *, const struct mqtt5__property *)) { return MQMockTest::GetMock().mosquitto_message_v5_callback_set(mosq, on_message); } -void mosquitto_connect_v5_callback_set(struct mosquitto *mosq, +API void mosquitto_connect_v5_callback_set(struct mosquitto *mosq, void (*on_connect)(struct mosquitto *, void *, int, int, const mosquitto_property *)) { return MQMockTest::GetMock().mosquitto_connect_v5_callback_set(mosq, on_connect); } -void mosquitto_disconnect_v5_callback_set(struct mosquitto *mosq, +API void mosquitto_disconnect_v5_callback_set(struct mosquitto *mosq, void (*on_disconnect)(struct mosquitto *, void *, int, const mosquitto_property *)) { return MQMockTest::GetMock().mosquitto_disconnect_v5_callback_set(mosq, on_disconnect); diff --git a/modules/tcp/AESEncryptor.cc b/modules/tcp/AESEncryptor.cc old mode 100755 new mode 100644 index afafc32..4c1aa6a --- a/modules/tcp/AESEncryptor.cc +++ b/modules/tcp/AESEncryptor.cc @@ -1,126 +1,137 @@ -/* - * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "AESEncryptor.h" -#ifndef ANDROID -#include -#endif -#include -#include -#include -#include -#include - -#include "aitt_internal.h" - -using random_bytes_generator = - std::independent_bits_engine; - -AESEncryptor::AESEncryptor(void) -{ - GenerateCipherKey(); -} - -AESEncryptor::AESEncryptor(const unsigned char key[AES_KEY_BYTE_SIZE]) -{ - memcpy(cipher_key, key, AES_KEY_BYTE_SIZE); -} - -AESEncryptor::~AESEncryptor(void) -{ -} - -void AESEncryptor::GenerateCipherKey(void) -{ - std::random_device rd; - random_bytes_generator rbg(rd()); - std::vector key_vector(AES_KEY_BYTE_SIZE); - std::generate(begin(key_vector), end(key_vector), std::ref(rbg)); - std::copy(key_vector.begin(), key_vector.end(), cipher_key); -} - -unsigned char *AESEncryptor::GetEncryptedData( - const void *data, size_t data_length, size_t &encrypted_data_length) -{ - size_t padding_buffer_size = GetPaddingBufferSize(data_length); - DBG("data_length = %zu, padding_buffer_size = %zu", data_length, padding_buffer_size); - - unsigned char padding_buffer[padding_buffer_size]; - memcpy(padding_buffer, data, data_length); - - unsigned char *encrypted_data = (unsigned char *)malloc(padding_buffer_size); - for (int i = 0; i < static_cast(padding_buffer_size) / AESEncryptor::AES_KEY_BYTE_SIZE; - i++) { - Encrypt(padding_buffer + AESEncryptor::AES_KEY_BYTE_SIZE * i, - encrypted_data + AESEncryptor::AES_KEY_BYTE_SIZE * i); - } - encrypted_data_length = padding_buffer_size; - - return encrypted_data; -} - -void AESEncryptor::Encrypt(const unsigned char *target_data, unsigned char *encrypted_data) -{ -#ifndef ANDROID - AES_KEY encryption_key; - if (AES_set_encrypt_key(cipher_key, AES_KEY_BIT_SIZE, &encryption_key) < 0) { - ERR("Fail to AES_set_encrypt_key()"); - throw std::runtime_error(strerror(errno)); - } - - AES_ecb_encrypt(target_data, encrypted_data, &encryption_key, AES_ENCRYPT); -#endif -} - -void AESEncryptor::GetDecryptedData( - unsigned char *padding_buffer, size_t padding_buffer_size, size_t data_length, void *data) -{ - unsigned char decrypted_data[padding_buffer_size]; - for (int i = 0; i < (int)padding_buffer_size / AESEncryptor::AES_KEY_BYTE_SIZE; i++) { - Decrypt(padding_buffer + AESEncryptor::AES_KEY_BYTE_SIZE * i, - decrypted_data + AESEncryptor::AES_KEY_BYTE_SIZE * i); - } - memcpy(data, decrypted_data, data_length); -} - -void AESEncryptor::Decrypt(const unsigned char *target_data, unsigned char *decrypted_data) -{ -#ifndef ANDROID - AES_KEY decryption_key; - if (AES_set_decrypt_key(cipher_key, AES_KEY_BIT_SIZE, &decryption_key) < 0) { - ERR("Fail to AES_set_decrypt_key()"); - throw std::runtime_error(strerror(errno)); - } - - AES_ecb_encrypt(target_data, decrypted_data, &decryption_key, AES_DECRYPT); -#endif -} - -size_t AESEncryptor::GetPaddingBufferSize(size_t data_length) -{ - size_t padding_buffer_size = (data_length + AESEncryptor::AES_KEY_BYTE_SIZE) - / AESEncryptor::AES_KEY_BYTE_SIZE * AESEncryptor::AES_KEY_BYTE_SIZE; - if (padding_buffer_size % AESEncryptor::AES_KEY_BYTE_SIZE != 0) { - ERR("data_length is not a multiple of AES_KEY_BYTE_SIZE."); - return 0; - } - - return padding_buffer_size; -} - -const unsigned char *AESEncryptor::GetCipherKey(void) -{ - return cipher_key; -} +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "AESEncryptor.h" + +#include +#include + +#include +#include +#include + +#include "aitt_internal.h" + +namespace AittTCPNamespace { + +AESEncryptor::AESEncryptor() +{ +} + +AESEncryptor::~AESEncryptor(void) +{ +} + +void AESEncryptor::Init(const unsigned char *key, const unsigned char *iv) +{ + key_.insert(key_.begin(), key, key + AITT_TCP_ENCRYPTOR_KEY_LEN); + iv_.insert(iv_.begin(), iv, iv + AITT_TCP_ENCRYPTOR_IV_LEN); + + DBG_HEX_DUMP(key_.data(), key_.size()); + DBG_HEX_DUMP(iv_.data(), iv_.size()); +} + +size_t AESEncryptor::GetCryptogramSize(size_t plain_size) +{ + const int BLOCKSIZE = 16; + return (plain_size / BLOCKSIZE + 1) * BLOCKSIZE; +} + +void AESEncryptor::GenerateKey(unsigned char (&key)[AITT_TCP_ENCRYPTOR_KEY_LEN], + unsigned char (&iv)[AITT_TCP_ENCRYPTOR_IV_LEN]) +{ + std::mt19937 random_gen{std::random_device{}()}; + std::uniform_int_distribution<> gen(0, 255); + + size_t i; + for (i = 0; i < sizeof(iv); i++) { + key[i] = gen(random_gen); + iv[i] = gen(random_gen); + } + for (size_t j = i; j < sizeof(key); j++) { + key[j] = gen(random_gen); + } +} + +size_t AESEncryptor::Encrypt(const unsigned char *plaintext, int plaintext_len, + unsigned char *ciphertext) +{ + int len; + int ciphertext_len; + + if (key_.size() == 0) + return 0; + + std::unique_ptr ctx(EVP_CIPHER_CTX_new(), + [](EVP_CIPHER_CTX *c) { EVP_CIPHER_CTX_free(c); }); + if (ctx.get() == nullptr) { + ERR("EVP_CIPHER_CTX_new() Fail(%s)", strerror(errno)); + throw std::runtime_error(strerror(errno)); + } + + if (1 != EVP_EncryptInit_ex(ctx.get(), EVP_aes_256_cbc(), NULL, key_.data(), iv_.data())) { + ERR("EVP_EncryptInit_ex() Fail(%s)", strerror(errno)); + throw std::runtime_error(strerror(errno)); + } + + if (1 != EVP_EncryptUpdate(ctx.get(), ciphertext, &ciphertext_len, plaintext, plaintext_len)) { + ERR("EVP_EncryptUpdate() Fail(%s)", strerror(errno)); + throw std::runtime_error(strerror(errno)); + } + + if (1 != EVP_EncryptFinal_ex(ctx.get(), ciphertext + ciphertext_len, &len)) { + ERR("EVP_EncryptFinal_ex() Fail(%s)", strerror(errno)); + throw std::runtime_error(strerror(errno)); + } + + return ciphertext_len + len; +} + +size_t AESEncryptor::Decrypt(const unsigned char *ciphertext, int ciphertext_len, + unsigned char *plaintext) +{ + int len; + int plaintext_len; + + if (key_.size() == 0) + return 0; + + std::unique_ptr ctx(EVP_CIPHER_CTX_new(), + [](EVP_CIPHER_CTX *c) { EVP_CIPHER_CTX_free(c); }); + if (ctx.get() == nullptr) { + ERR("EVP_CIPHER_CTX_new() Fail(%s)", strerror(errno)); + throw std::runtime_error(strerror(errno)); + } + + if (1 != EVP_DecryptInit_ex(ctx.get(), EVP_aes_256_cbc(), NULL, key_.data(), iv_.data())) { + ERR("EVP_DecryptInit_ex() Fail(%s)", strerror(errno)); + throw std::runtime_error(strerror(errno)); + } + + if (1 != EVP_DecryptUpdate(ctx.get(), plaintext, &plaintext_len, ciphertext, ciphertext_len)) { + ERR("EVP_DecryptUpdate() Fail(%s)", strerror(errno)); + throw std::runtime_error(strerror(errno)); + } + + if (1 != EVP_DecryptFinal_ex(ctx.get(), plaintext + plaintext_len, &len)) { + ERR("EVP_DecryptFinal_ex() Fail(%s)", strerror(errno)); + throw std::runtime_error(strerror(errno)); + } + plaintext_len += len; + + return plaintext_len; +} + +} // namespace AittTCPNamespace diff --git a/modules/tcp/AESEncryptor.h b/modules/tcp/AESEncryptor.h old mode 100755 new mode 100644 index 59fc851..fc05123 --- a/modules/tcp/AESEncryptor.h +++ b/modules/tcp/AESEncryptor.h @@ -15,30 +15,30 @@ */ #pragma once -#include +#include +#include -class AESEncryptor { - public: - constexpr static int AES_KEY_BYTE_SIZE = 16; +// AES-256 CBC +#define AITT_TCP_ENCRYPTOR_KEY_LEN 32 +#define AITT_TCP_ENCRYPTOR_IV_LEN 16 +namespace AittTCPNamespace { + +class AESEncryptor { public: - AESEncryptor(void); - explicit AESEncryptor(const unsigned char key[AES_KEY_BYTE_SIZE]); - ~AESEncryptor(void); + AESEncryptor(); + virtual ~AESEncryptor(void); - unsigned char *GetEncryptedData( - const void *data, size_t data_length, size_t &encrypted_data_length); - void Encrypt(const unsigned char *target_data, unsigned char *encrypted_data); - void GetDecryptedData(unsigned char *padding_buffer, size_t padding_buffer_size, - size_t data_length, void *data); - void Decrypt(const unsigned char *target_data, unsigned char *decrypted_data); - size_t GetPaddingBufferSize(size_t data_length); - const unsigned char *GetCipherKey(void); + static void GenerateKey(unsigned char (&key)[AITT_TCP_ENCRYPTOR_KEY_LEN], + unsigned char (&iv)[AITT_TCP_ENCRYPTOR_IV_LEN]); + void Init(const unsigned char *key, const unsigned char *iv); + size_t GetCryptogramSize(size_t plain_size); + size_t Encrypt(const unsigned char *plaintext, int plaintext_len, unsigned char *ciphertext); + size_t Decrypt(const unsigned char *ciphertext, int ciphertext_len, unsigned char *plaintext); private: - void GenerateCipherKey(void); - - unsigned char cipher_key[AES_KEY_BYTE_SIZE]; - - constexpr static int AES_KEY_BIT_SIZE = AES_KEY_BYTE_SIZE << 3; + std::vector key_; + std::vector iv_; }; + +} // namespace AittTCPNamespace diff --git a/modules/tcp/CMakeLists.txt b/modules/tcp/CMakeLists.txt index ae7defc..96bd2d7 100644 --- a/modules/tcp/CMakeLists.txt +++ b/modules/tcp/CMakeLists.txt @@ -1,20 +1,27 @@ SET(AITT_TCP aitt-transport-tcp) +SET(AITT_SECURE_TCP aitt-transport-tcp-secure) INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) IF(PLATFORM STREQUAL "tizen") - PKG_CHECK_MODULES(AITT_TCP_NEEDS REQUIRED openssl1.1) -ELSEIF( NOT PLATFORM STREQUAL "android") - PKG_CHECK_MODULES(AITT_TCP_NEEDS REQUIRED openssl) + SET(OPENSSL openssl1.1) +ELSE(PLATFORM STREQUAL "tizen") + SET(OPENSSL openssl) ENDIF(PLATFORM STREQUAL "tizen") +PKG_CHECK_MODULES(AITT_TCP_NEEDS REQUIRED ${OPENSSL}) INCLUDE_DIRECTORIES(${AITT_TCP_NEEDS_INCLUDE_DIRS}) LINK_DIRECTORIES(${AITT_TCP_NEEDS_LIBRARY_DIRS}) -ADD_LIBRARY(TCP_OBJ OBJECT TCP.cc TCPServer.cc AESEncryptor.cc) -ADD_LIBRARY(${AITT_TCP} SHARED $ ../transport_entry.cc Module.cc) +ADD_LIBRARY(TCP_OBJ STATIC TCP.cc TCPServer.cc AESEncryptor.cc) +ADD_LIBRARY(${AITT_TCP} SHARED ../transport_entry.cc Module.cc) +TARGET_LINK_LIBRARIES(${AITT_TCP} Threads::Threads TCP_OBJ ${AITT_COMMON} ${AITT_TCP_NEEDS_LIBRARIES}) -TARGET_LINK_LIBRARIES(${AITT_TCP} Threads::Threads ${AITT_COMMON} ${AITT_TCP_NEEDS_LIBRARIES}) +IF(PLATFORM STREQUAL "android") + FIND_PACKAGE(openssl REQUIRED CONFIG) + TARGET_LINK_LIBRARIES(TCP_OBJ openssl::crypto) + TARGET_LINK_LIBRARIES(${AITT_TCP} openssl::crypto) +ENDIF(PLATFORM STREQUAL "android") INSTALL(TARGETS ${AITT_TCP} DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/modules/tcp/Module.cc b/modules/tcp/Module.cc index 409df96..a2df94e 100644 --- a/modules/tcp/Module.cc +++ b/modules/tcp/Module.cc @@ -19,14 +19,18 @@ #include #include +#include + #include "aitt_internal.h" -Module::Module(AittProtocol protocol, const std::string &ip, AittDiscovery &discovery) - : AittTransport(discovery), protocol(protocol), ip(ip) +namespace AittTCPNamespace { + +Module::Module(AittProtocol type, AittDiscovery &discovery, const std::string &my_ip) + : AittTransport(type, discovery), ip(my_ip), secure(type == AITT_TYPE_TCP_SECURE) { aittThread = std::thread(&Module::ThreadMain, this); - discovery_cb = discovery.AddDiscoveryCB(AITT_TYPE_TCP, + discovery_cb = discovery.AddDiscoveryCB(type, std::bind(&Module::DiscoveryMessageCallback, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)); DBG("Discovery Callback : %p, %d", this, discovery_cb); @@ -37,7 +41,7 @@ Module::~Module(void) discovery.RemoveDiscoveryCB(discovery_cb); while (main_loop.Quit() == false) { - // wait when called before the thread has completely created + // wait when called before the thread has completely created. usleep(1000); } @@ -47,16 +51,13 @@ Module::~Module(void) void Module::ThreadMain(void) { - pthread_setname_np(pthread_self(), "TCPWorkerLoop"); + if (secure) + pthread_setname_np(pthread_self(), "SecureTCPLoop"); + else + pthread_setname_np(pthread_self(), "NormalTCPLoop"); main_loop.Run(); } -void Module::Publish( - const std::string &topic, const void *data, const size_t datalen, AittQoS qos, bool retain) -{ - Publish(topic, data, datalen, std::string(), qos, retain); -} - void Module::Publish(const std::string &topic, const void *data, const size_t datalen, const std::string &correlation, AittQoS qos, bool retain) { @@ -78,202 +79,76 @@ void Module::Publish(const std::string &topic, const void *data, const size_t da if (!aitt::AittUtil::CompareTopic(it->first, topic)) continue; - INFO("[Topic] it->first (%s)", it->first.c_str()); for (HostMap::iterator hostIt = it->second.begin(); hostIt != it->second.end(); ++hostIt) { - INFO("[ClientID] hostIt->first (%s)", hostIt->first.c_str()); // Iterate all ports, // the current implementation only be able to have the ZERO or a SINGLE entry for (PortMap::iterator portIt = hostIt->second.begin(); portIt != hostIt->second.end(); ++portIt) { - // portIt->second // handle - INFO("[Port] portIt->first = (%d)", portIt->first); - if (!portIt->second) { // AITT_TYPE_TCP - std::string host = FindHost(hostIt); - if (host.empty() == true) { - ERR("clientTable or subscribeTable is broken."); - continue; + if (!portIt->second) { + std::string host; + { + ClientMap::iterator clientIt; + std::lock_guard auto_lock_client(clientTableLock); + + clientIt = clientTable.find(hostIt->first); + if (clientIt != clientTable.end()) + host = clientIt->second; + + // NOTE: + // otherwise, it is a critical error + // The broken clientTable or subscribeTable } - std::unique_ptr client(new TCP(host, portIt->first)); - // TODO: - // If the client gets disconnected, this channel entry must be cleared - // In order to do that, there should be an observer to monitor - // each connections and manipulate the discovered service table - INFO("A new TCP client for topic(%s) is created!!", topic.c_str()); - std::unique_ptr clientInfo(new TCPPublishInfo()); - clientInfo->client_handle = std::move(client); - portIt->second = std::move(clientInfo); - } - - if (protocol == AITT_TYPE_SECURE_TCP && !portIt->second->client_handle) { - std::string host = FindHost(hostIt); - if (host.empty() == true) { - ERR("clientTable or subscribeTable is broken."); - continue; - } std::unique_ptr client(new TCP(host, portIt->first)); - INFO("[SECURE_TCP] A new TCP client for topic(%s) is created!!", - topic.c_str()); - portIt->second->client_handle = std::move(client); + // TODO: + // If the client gets disconnected, + // This channel entry must be cleared + // In order to do that, + // There should be an observer to monitor + // each connections and manipulate + // the discovered service table + portIt->second = std::move(client); } - if (!portIt->second->client_handle) { + if (!portIt->second) { ERR("Failed to create a new client instance"); continue; } - if (protocol == AITT_TYPE_SECURE_TCP) { - if (SendEncryptedTopic(topic, portIt) == true) - SendEncryptedPayload(datalen, portIt, data); - } else { - if (SendTopic(topic, portIt) == true) - SendPayload(datalen, portIt, data); + try { + size_t length = topic.length(); + portIt->second->SendSizedData(topic.c_str(), length); + length = datalen; + portIt->second->SendSizedData(data, length); + } catch (std::exception &e) { + ERR("An exception(%s) occurs during Send().", e.what()); } } - } - } + } // connectionEntries + } // publishTable } -std::string Module::FindHost(HostMap::iterator &host_iterator) +void Module::Publish(const std::string &topic, const void *data, const size_t datalen, AittQoS qos, + bool retain) { - std::lock_guard auto_lock_client(clientTableLock); - ClientMap::iterator client_iterator = clientTable.find(host_iterator->first); - if (client_iterator != clientTable.end()) - return client_iterator->second; - - return std::string(); -} - -bool Module::SendEncryptedTopic(const std::string &topic, Module::PortMap::iterator &portIt) -{ - size_t topic_length = topic.length(); - unsigned char *encrypted_data = nullptr; - - try { - SendEncryptedData(portIt, static_cast(&topic_length), sizeof(topic_length)); - - SendEncryptedData(portIt, static_cast(topic.c_str()), topic_length); - } catch (std::exception &e) { - ERR("An exception(%s) occurs during SendExactSize().", e.what()); - free(encrypted_data); - return false; - } - - return true; -} - -void Module::SendEncryptedData( - Module::PortMap::iterator &port_iterator, const void *data, size_t data_length) -{ - size_t encrypted_data_size = 0; - unsigned char *encrypted_data = port_iterator->second->aes_encryptor->GetEncryptedData( - data, data_length, encrypted_data_size); - if (encrypted_data != nullptr && encrypted_data_size > 0) - SendExactSize(port_iterator, encrypted_data, encrypted_data_size); - - free(encrypted_data); -} - -void Module::SendExactSize( - Module::PortMap::iterator &port_iterator, const void *data, size_t data_length) -{ - size_t remaining_size = data_length; - while (0 < remaining_size) { - const char *data_index = static_cast(data) + (data_length - remaining_size); - size_t size_sent = remaining_size; - port_iterator->second->client_handle->Send(data_index, size_sent); - if (size_sent > 0) { - remaining_size -= size_sent; - } else if (size_sent == 0) { - DBG("size_sent == 0"); - remaining_size = 0; - } - } -} - -void Module::SendEncryptedPayload( - const size_t &datalen, Module::PortMap::iterator &portIt, const void *data) -{ - size_t payload_size = datalen; - if (0 == datalen) { - // Distinguish between connection problems and zero-size messages - INFO("Send a zero-size message."); - payload_size = UINT32_MAX; - } - - try { - SendEncryptedData(portIt, static_cast(&payload_size), sizeof(payload_size)); - if (payload_size == UINT32_MAX) { - INFO("An actual data size is 0. Skip this payload transmission."); - return; - } - - SendEncryptedData(portIt, data, datalen); - } catch (std::exception &e) { - ERR("An exception(%s) occurs during SendExactSize().", e.what()); - } -} - -bool Module::SendTopic(const std::string &topic, Module::PortMap::iterator &portIt) -{ - size_t topic_length = topic.length(); - - try { - SendExactSize(portIt, static_cast(&topic_length), sizeof(topic_length)); - - SendExactSize(portIt, static_cast(topic.c_str()), topic_length); - } catch (std::exception &e) { - ERR("An exception(%s) occurs during SendExactSize().", e.what()); - return false; - } - - return true; + Publish(topic, data, datalen, std::string(), qos, retain); } -void Module::SendPayload(const size_t &datalen, Module::PortMap::iterator &portIt, const void *data) -{ - size_t payload_size = datalen; - if (0 == datalen) { - // Distinguish between connection problems and zero-size messages - INFO("Send a zero-size message."); - payload_size = UINT32_MAX; - } - - try { - DBG("sizeof(payload_size) = %zu", sizeof(payload_size)); - SendExactSize(portIt, static_cast(&payload_size), sizeof(payload_size)); - - if (payload_size == UINT32_MAX) { - INFO("An actual data size is 0. Skip this payload transmission."); - return; - } - - DBG("datalen = %zu", datalen); - SendExactSize(portIt, data, datalen); - } catch (std::exception &e) { - ERR("An exception(%s) occurs during SendExactSize().", e.what()); - } -} void *Module::Subscribe(const std::string &topic, const AittTransport::SubscribeCallback &cb, void *cbdata, AittQoS qos) { std::unique_ptr tcpServer; unsigned short port = 0; - tcpServer = std::unique_ptr(new TCP::Server("0.0.0.0", port)); + tcpServer = std::unique_ptr(new TCP::Server("0.0.0.0", port, secure)); TCPServerData *listen_info = new TCPServerData; listen_info->impl = this; listen_info->cb = cb; listen_info->cbdata = cbdata; listen_info->topic = topic; - listen_info->is_secure = (protocol == AITT_TYPE_SECURE_TCP ? true : false); - if (listen_info->is_secure == true) { - tcpServer->CreateAESEncryptor(); - listen_info->aes_encryptor = tcpServer->GetAESEncryptor(); - } - auto handle = tcpServer->GetHandle(); + main_loop.AddWatch(handle, AcceptConnection, listen_info); { @@ -312,8 +187,8 @@ void *Module::Unsubscribe(void *handlePtr) void *cbdata = listen_info->cbdata; listen_info->client_lock.lock(); for (auto fd : listen_info->client_list) { - TCPData *connect_info = dynamic_cast(main_loop.RemoveWatch(fd)); - delete connect_info; + TCPData *tcp_data = dynamic_cast(main_loop.RemoveWatch(fd)); + delete tcp_data; } listen_info->client_list.clear(); listen_info->client_lock.unlock(); @@ -347,10 +222,7 @@ void Module::DiscoveryMessageCallback(const std::string &clientId, const std::st // serviceMessage (flexbuffers) // map { // "host": "192.168.1.11", - // "$topic": port, - // "$topic/port" : protocol - // // if protocol == AES_TYPE_SECURE_TCP, the below exists. - // "$topic/port/protocol" : cipher_key + // "$topic": {port, key, iv} // } auto map = flexbuffers::GetRoot(static_cast(msg), szmsg).AsMap(); std::string host = map["host"].AsString().c_str(); @@ -372,23 +244,31 @@ void Module::DiscoveryMessageCallback(const std::string &clientId, const std::st if (!topic.compare("host")) continue; - auto port = map[topic].AsUInt16(); - std::string protocol_topic = std::string(topic).append("/").append(std::to_string(port)); - auto protocol = map[protocol_topic].AsUInt16(); - const unsigned char *key = nullptr; - if (protocol == AITT_TYPE_SECURE_TCP) { - std::string key_topic = - std::string(protocol_topic).append("/").append(std::to_string(protocol)); - const char *transmitted_key = map[key_topic].AsString().c_str(); - key = reinterpret_cast(transmitted_key); - { - std::lock_guard autoLock(publishTableLock); - UpdatePublishTable(topic, clientId, port, key); + TCP::ConnectInfo info; + auto connectInfo = map[topic].AsVector(); + size_t vec_size = connectInfo.size(); + info.port = connectInfo[0].AsUInt16(); + if (secure) { + if (vec_size != 3) { + ERR("Unknown Message"); + return; } + info.secure = true; + auto key_blob = connectInfo[1].AsBlob(); + if (key_blob.size() == sizeof(info.key)) + memcpy(info.key, key_blob.data(), key_blob.size()); + else + ERR("Invalid key blob(%zu) != %zu", key_blob.size(), sizeof(info.key)); + + auto iv_blob = connectInfo[2].AsBlob(); + if (iv_blob.size() == sizeof(info.iv)) + memcpy(info.iv, iv_blob.data(), iv_blob.size()); + else + ERR("Invalid iv blob(%zu) != %zu", iv_blob.size(), sizeof(info.iv)); } { std::lock_guard autoLock(publishTableLock); - UpdatePublishTable(topic, clientId, port, key); + UpdatePublishTable(topic, clientId, info); } } } @@ -396,15 +276,6 @@ void Module::DiscoveryMessageCallback(const std::string &clientId, const std::st void Module::UpdateDiscoveryMsg() { flexbuffers::Builder fbb; - // flexbuffers - // { - // "host": "127.0.0.1", - // "$topic": $port, - // ... - // "$topic/port" : protocol - // // if protocol == AITT_TYPE_SECURE_TCP, then the below exists. - // "$topic/port/protocol" : key - // } fbb.Map([this, &fbb]() { fbb.String("host", ip); @@ -415,34 +286,33 @@ void Module::UpdateDiscoveryMsg() // } for (auto it = subscribeTable.begin(); it != subscribeTable.end(); ++it) { if (it->second) { - auto port = it->second->GetPort(); - fbb.UInt(it->first.c_str(), port); - if (protocol == AITT_TYPE_SECURE_TCP) { - std::string protocol_topic = - std::string(it->first.c_str()).append("/").append(std::to_string(port)); - fbb.UInt(protocol_topic.c_str(), static_cast(protocol)); - const unsigned char *key = it->second->GetKey(); - std::string key_topic = - protocol_topic.append("/").append(std::to_string(protocol)); - fbb.String(key_topic.c_str(), std::string(reinterpret_cast(key))); - } + fbb.Vector(it->first.c_str(), [&]() { + fbb.UInt(it->second->GetPort()); + if (secure) { + fbb.Blob(it->second->GetCryptoKey(), AITT_TCP_ENCRYPTOR_KEY_LEN); + fbb.Blob(it->second->GetCryptoIv(), AITT_TCP_ENCRYPTOR_IV_LEN); + } + }); } else { - fbb.UInt(it->first.c_str(), 0); // this is an error case + // this is an error case + TCP::ConnectInfo info; + fbb.Vector(it->first.c_str(), [&]() { fbb.UInt(it->second->GetPort()); }); } } }); fbb.Finish(); auto buf = fbb.GetBuffer(); - discovery.UpdateDiscoveryMsg(AITT_TYPE_TCP, buf.data(), buf.size()); + discovery.UpdateDiscoveryMsg(secure ? AITT_TYPE_TCP_SECURE : AITT_TYPE_TCP, buf.data(), + buf.size()); } void Module::ReceiveData(MainLoopHandler::MainLoopResult result, int handle, MainLoopHandler::MainLoopData *user_data) { - TCPData *connect_info = dynamic_cast(user_data); - RET_IF(connect_info == nullptr); - TCPServerData *parent_info = connect_info->parent; + TCPData *tcp_data = dynamic_cast(user_data); + RET_IF(tcp_data == nullptr); + TCPServerData *parent_info = tcp_data->parent; RET_IF(parent_info == nullptr); Module *impl = parent_info->impl; RET_IF(impl == nullptr); @@ -457,28 +327,16 @@ void Module::ReceiveData(MainLoopHandler::MainLoopResult result, int handle, std::string topic; try { - if (connect_info->parent->is_secure == true) { - topic = impl->ReceiveDecryptedTopic(connect_info); - if (topic.empty()) { - ERR("A topic is empty."); - return impl->HandleClientDisconnect(handle); - } - - if (impl->ReceiveDecryptedPayload(connect_info, szmsg, &msg) == false) { - free(msg); - return impl->HandleClientDisconnect(handle); - } - } else { - topic = impl->ReceiveTopic(connect_info); - if (topic.empty()) { - ERR("A topic is empty."); - return impl->HandleClientDisconnect(handle); - } + topic = impl->GetTopicName(tcp_data); + if (topic.empty()) { + ERR("A topic is empty."); + return; + } - if (impl->ReceivePayload(connect_info, szmsg, &msg) == false) { - free(msg); - return impl->HandleClientDisconnect(handle); - } + int ret = tcp_data->client->RecvSizedData((void **)&msg, szmsg); + if (ret < 0) { + ERR("Got a disconnection message."); + return impl->HandleClientDisconnect(handle); } } catch (std::exception &e) { ERR("An exception(%s) occurs", e.what()); @@ -495,138 +353,51 @@ void Module::ReceiveData(MainLoopHandler::MainLoopResult result, int handle, void Module::HandleClientDisconnect(int handle) { - TCPData *connect_info = dynamic_cast(main_loop.RemoveWatch(handle)); - if (connect_info == nullptr) { + TCPData *tcp_data = dynamic_cast(main_loop.RemoveWatch(handle)); + if (tcp_data == nullptr) { ERR("No watch data"); return; } - connect_info->parent->client_lock.lock(); - auto it = std::find(connect_info->parent->client_list.begin(), - connect_info->parent->client_list.end(), handle); - connect_info->parent->client_list.erase(it); - connect_info->parent->client_lock.unlock(); + tcp_data->parent->client_lock.lock(); + auto it = std::find(tcp_data->parent->client_list.begin(), tcp_data->parent->client_list.end(), + handle); + tcp_data->parent->client_list.erase(it); + tcp_data->parent->client_lock.unlock(); - delete connect_info; + delete tcp_data; } -std::string Module::ReceiveDecryptedTopic(Module::TCPData *connect_info) +std::string Module::GetTopicName(Module::TCPData *tcp_data) { size_t topic_length = 0; - ReceiveDecryptedData(connect_info, static_cast(&topic_length), sizeof(topic_length)); - - if (AITT_TOPIC_NAME_MAX < topic_length) { - ERR("Invalid topic name length(%zu)", topic_length); - return std::string(); - } - - char topic_buffer[topic_length]; - ReceiveDecryptedData(connect_info, topic_buffer, topic_length); - - std::string topic = std::string(topic_buffer, topic_length); - INFO("Complete topic = [%s], topic_len = %zu", topic.c_str(), topic_length); - - return topic; -} - -bool Module::ReceiveDecryptedPayload(Module::TCPData *connect_info, size_t &szmsg, char **msg) -{ - ReceiveDecryptedData(connect_info, static_cast(&szmsg), sizeof(szmsg)); - if (szmsg == 0) { + void *topic_data = nullptr; + int ret = tcp_data->client->RecvSizedData(&topic_data, topic_length); + if (ret < 0) { ERR("Got a disconnection message."); - return false; - } - - if (UINT32_MAX == szmsg) { - // Distinguish between connection problems and zero-size messages - INFO("Got a zero-size message. Skip this payload transmission."); - szmsg = 0; - } else { - *msg = static_cast(malloc(szmsg)); - ReceiveDecryptedData(connect_info, static_cast(*msg), szmsg); - } - - return true; -} - -void Module::ReceiveDecryptedData(Module::TCPData *connect_info, void *data, size_t data_length) -{ - size_t padding_buffer_size = - connect_info->parent->aes_encryptor->GetPaddingBufferSize(data_length); - DBG("data_length = %zu, padding_buffer_size = %zu", data_length, padding_buffer_size); - - unsigned char padding_buffer[padding_buffer_size]; - ReceiveExactSize(connect_info, static_cast(padding_buffer), padding_buffer_size); - - connect_info->parent->aes_encryptor->GetDecryptedData( - padding_buffer, padding_buffer_size, data_length, data); -} - -void Module::ReceiveExactSize(Module::TCPData *connect_info, void *data, size_t data_length) -{ - if (data_length == 0) { - DBG("data_length is zero."); - return; - } - - size_t remaining_size = data_length; - while (0 < remaining_size) { - char *data_index = (char *)data + (data_length - remaining_size); - size_t size_received = remaining_size; - connect_info->client->Recv(data_index, size_received); - if (size_received > 0) { - remaining_size -= size_received; - } else if (size_received == 0) { - DBG("size_received == 0"); - remaining_size = 0; - } + HandleClientDisconnect(tcp_data->client->GetHandle()); + return std::string(); } -} - -std::string Module::ReceiveTopic(Module::TCPData *connect_info) -{ - size_t topic_length = 0; - ReceiveExactSize(connect_info, static_cast(&topic_length), sizeof(topic_length)); - - if (AITT_TOPIC_NAME_MAX < topic_length) { - ERR("Invalid topic name length(%zu)", topic_length); + if (nullptr == topic_data) { + ERR("Unknown topic"); return std::string(); } - char topic_buffer[topic_length]; - ReceiveExactSize(connect_info, topic_buffer, topic_length); - std::string topic = std::string(topic_buffer, topic_length); + std::string topic = std::string(static_cast(topic_data), topic_length); INFO("Complete topic = [%s], topic_len = %zu", topic.c_str(), topic_length); + free(topic_data); return topic; } -bool Module::ReceivePayload(Module::TCPData *connect_info, size_t &szmsg, char **msg) -{ - ReceiveExactSize(connect_info, static_cast(&szmsg), sizeof(szmsg)); - if (szmsg == 0) { - ERR("Got a disconnection message."); - return false; - } - ERR("szmsg = [%zu]", szmsg); - if (UINT32_MAX == szmsg) { - // Distinguish between connection problems and zero-size messages - INFO("Got a zero-size message. Skip this payload transmission."); - szmsg = 0; - } else { - *msg = static_cast(malloc(szmsg)); - ReceiveExactSize(connect_info, static_cast(*msg), szmsg); - } - - return true; -} - void Module::AcceptConnection(MainLoopHandler::MainLoopResult result, int handle, MainLoopHandler::MainLoopData *user_data) { - // TODO: Update the discovery map - std::unique_ptr client; TCPServerData *listen_info = dynamic_cast(user_data); + RET_IF(listen_info == nullptr); Module *impl = listen_info->impl; + RET_IF(impl == nullptr); + + std::unique_ptr client; { std::lock_guard autoLock(impl->subscribeTableLock); @@ -635,7 +406,6 @@ void Module::AcceptConnection(MainLoopHandler::MainLoopResult result, int handle return; client = clientIt->second->AcceptPeer(); - INFO("A TCP connection (client handle=%d) is created.", client->GetHandle()); } if (client == nullptr) { @@ -643,61 +413,52 @@ void Module::AcceptConnection(MainLoopHandler::MainLoopResult result, int handle return; } - int cHandle = client->GetHandle(); - listen_info->client_list.push_back(cHandle); + int client_handle = client->GetHandle(); + listen_info->client_list.push_back(client_handle); TCPData *ecd = new TCPData; ecd->parent = listen_info; ecd->client = std::move(client); - impl->main_loop.AddWatch(cHandle, ReceiveData, ecd); + impl->main_loop.AddWatch(client_handle, ReceiveData, ecd); } void Module::UpdatePublishTable(const std::string &topic, const std::string &clientId, - unsigned short port, const unsigned char *key) + const TCP::ConnectInfo &info) { auto topicIt = publishTable.find(topic); - std::unique_ptr keyInfo(new TCPPublishInfo()); - if (key == nullptr) { - keyInfo = nullptr; - } else { - keyInfo->client_handle = nullptr; - keyInfo->aes_encryptor = new AESEncryptor(key); - } - if (topicIt == publishTable.end()) { PortMap portMap; - portMap.insert(PortMap::value_type(port, std::move(keyInfo))); + portMap.insert(PortMap::value_type(info, nullptr)); HostMap hostMap; hostMap.insert(HostMap::value_type(clientId, std::move(portMap))); publishTable.insert(PublishMap::value_type(topic, std::move(hostMap))); - INFO("A topic(%s) is inserted to the publish table.", topic.c_str()); return; } auto hostIt = topicIt->second.find(clientId); if (hostIt == topicIt->second.end()) { PortMap portMap; - portMap.insert(PortMap::value_type(port, std::move(keyInfo))); + portMap.insert(PortMap::value_type(info, nullptr)); topicIt->second.insert(HostMap::value_type(clientId, std::move(portMap))); - INFO("A HostMap element is added, clientId(%s).", clientId.c_str()); return; } - // NOTE: - // The current implementation only has a single port entry - // Therefore, if the hostIt is not empty, there is the previous connection if (!hostIt->second.empty()) { + ERR("there is the previous connection(The current implementation only has a single port " + "entry)"); auto portIt = hostIt->second.begin(); - INFO("A client handle already exists. port = %d", port); - if (portIt->first == port) - return; // Nothing is changed, keep the current handle - // Otherwise, delete the connection handle - // to make a new connection with the new port + if (portIt->first.port == info.port) { + DBG("nothing changed. keep the current handle"); + return; + } + + DBG("delete the connection handle to make a new connection with the new port"); hostIt->second.clear(); } - INFO("A PortMap element is inserted. clientId(%s), port = %d", clientId.c_str(), port); - hostIt->second.insert(PortMap::value_type(port, std::move(keyInfo))); + hostIt->second.insert(PortMap::value_type(info, nullptr)); } + +} // namespace AittTCPNamespace diff --git a/modules/tcp/Module.h b/modules/tcp/Module.h index 408a5f6..ed748a1 100644 --- a/modules/tcp/Module.h +++ b/modules/tcp/Module.h @@ -31,31 +31,27 @@ using AittTransport = aitt::AittTransport; using MainLoopHandler = aitt::MainLoopHandler; using AittDiscovery = aitt::AittDiscovery; +#define MODULE_NAMESPACE AittTCPNamespace +namespace AittTCPNamespace { + class Module : public AittTransport { public: - explicit Module(AittProtocol protocol, const std::string &ip, AittDiscovery &discovery); + explicit Module(AittProtocol type, AittDiscovery &discovery, const std::string &ip); virtual ~Module(void); void Publish(const std::string &topic, const void *data, const size_t datalen, - AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) override; - - void Publish_(const std::string &topic, const void *data, const size_t datalen, - const std::string &correlation, AittQoS qos, bool retain); - - void Publish(const std::string &topic, const void *data, const size_t datalen, const std::string &correlation, AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) override; + void Publish(const std::string &topic, const void *data, const size_t datalen, + AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) override; + void *Subscribe(const std::string &topic, const SubscribeCallback &cb, void *cbdata = nullptr, AittQoS qos = AITT_QOS_AT_MOST_ONCE) override; - void *Subscribe_(const std::string &topic, const AittTransport::SubscribeCallback &cb, - void *cbdata, AittQoS qos); - void *Subscribe(const std::string &topic, const SubscribeCallback &cb, const void *data, const size_t datalen, void *cbdata = nullptr, AittQoS qos = AITT_QOS_AT_MOST_ONCE) override; - void *Unsubscribe(void *handle) override; private: @@ -66,8 +62,6 @@ class Module : public AittTransport { std::string topic; std::vector client_list; std::mutex client_lock; - bool is_secure; - AESEncryptor *aes_encryptor; }; struct TCPData : public MainLoopHandler::MainLoopData { @@ -75,11 +69,6 @@ class Module : public AittTransport { std::unique_ptr client; }; - struct TCPPublishInfo { - std::unique_ptr client_handle; - AESEncryptor *aes_encryptor; - }; - // SubscribeTable // map { // "/customTopic/mytopic": $serverHandle, @@ -115,7 +104,8 @@ class Module : public AittTransport { // TCP handle should be the unique_ptr, so if we delete the entry from the map, // the handle must be released automatically // in order to make the handle "unique_ptr", it should be a class object not the "void *" - using PortMap = std::map>; + using PortMap = + std::map, TCP::ConnectInfo::Compare>; using HostMap = std::map; using PublishMap = std::map; @@ -124,29 +114,13 @@ class Module : public AittTransport { void DiscoveryMessageCallback(const std::string &clientId, const std::string &status, const void *msg, const int szmsg); void UpdateDiscoveryMsg(); - void ThreadMain(void); - std::string FindHost(HostMap::iterator &host_iterator); - bool SendEncryptedTopic(const std::string &topic, Module::PortMap::iterator &portIt); - void SendEncryptedData( - Module::PortMap::iterator &port_iterator, const void *data, size_t data_length); - void SendExactSize( - Module::PortMap::iterator &port_iterator, const void *data, size_t data_length); - void SendEncryptedPayload( - const size_t &datalen, Module::PortMap::iterator &portIt, const void *data); - bool SendTopic(const std::string &topic, Module::PortMap::iterator &portIt); - void SendPayload(const size_t &datalen, Module::PortMap::iterator &portIt, const void *data); static void ReceiveData(MainLoopHandler::MainLoopResult result, int handle, MainLoopHandler::MainLoopData *watchData); void HandleClientDisconnect(int handle); - std::string ReceiveDecryptedTopic(TCPData *connect_info); - bool ReceiveDecryptedPayload(Module::TCPData *connect_info, size_t &szmsg, char **msg); - static void ReceiveDecryptedData(Module::TCPData *connect_info, void *data, size_t data_length); - static void ReceiveExactSize( - Module::TCPData *connect_info, void *data, size_t data_length); - std::string ReceiveTopic(TCPData *connect_info); - bool ReceivePayload(Module::TCPData *connect_info, size_t &szmsg, char **msg); - void UpdatePublishTable(const std::string &topic, const std::string &host, unsigned short port, - const unsigned char *key); + std::string GetTopicName(TCPData *connect_info); + void ThreadMain(void); + void UpdatePublishTable(const std::string &topic, const std::string &host, + const TCP::ConnectInfo &info); MainLoopHandler main_loop; std::thread aittThread; @@ -158,7 +132,8 @@ class Module : public AittTransport { std::mutex subscribeTableLock; ClientMap clientTable; std::mutex clientTableLock; - - AittProtocol protocol; std::string ip; + bool secure; }; + +} // namespace AittTCPNamespace diff --git a/modules/tcp/TCP.cc b/modules/tcp/TCP.cc index 9984606..4815e20 100644 --- a/modules/tcp/TCP.cc +++ b/modules/tcp/TCP.cc @@ -15,6 +15,7 @@ */ #include "TCP.h" +#include #include #include #include @@ -28,12 +29,15 @@ #include "aitt_internal.h" -TCP::TCP(const std::string &host, unsigned short port) : handle(-1), addrlen(0), addr(nullptr) +namespace AittTCPNamespace { + +TCP::TCP(const std::string &host, const ConnectInfo &connect_info) + : handle(-1), addrlen(0), addr(nullptr), secure(false) { int ret = 0; do { - if (port == 0) { + if (connect_info.port == 0) { ret = EINVAL; break; } @@ -57,16 +61,16 @@ TCP::TCP(const std::string &host, unsigned short port) : handle(-1), addrlen(0), break; } - inet_addr->sin_port = htons(port); + inet_addr->sin_port = htons(connect_info.port); inet_addr->sin_family = AF_INET; ret = connect(handle, addr, addrlen); if (ret < 0) { - ERR("connect() Fail(%s, %d)", host.c_str(), port); + ERR("connect() Fail(%s, %d)", host.c_str(), connect_info.port); break; } - SetupOptions(); + SetupOptions(connect_info); return; } while (0); @@ -79,10 +83,10 @@ TCP::TCP(const std::string &host, unsigned short port) : handle(-1), addrlen(0), throw std::runtime_error(strerror(ret)); } -TCP::TCP(int handle, sockaddr *addr, socklen_t szAddr) - : handle(handle), addrlen(szAddr), addr(addr) +TCP::TCP(int handle, sockaddr *addr, socklen_t szAddr, const ConnectInfo &connect_info) + : handle(handle), addrlen(szAddr), addr(addr), secure(false) { - SetupOptions(); + SetupOptions(connect_info); } TCP::~TCP(void) @@ -95,7 +99,7 @@ TCP::~TCP(void) ERR_CODE(errno, "close"); } -void TCP::SetupOptions(void) +void TCP::SetupOptions(const ConnectInfo &connect_info) { int on = 1; @@ -103,28 +107,72 @@ void TCP::SetupOptions(void) if (ret < 0) { ERR_CODE(errno, "delay option setting failed"); } + + if (connect_info.secure) { + secure = true; + crypto.Init(connect_info.key, connect_info.iv); + } } void TCP::Send(const void *data, size_t &szData) { - int ret = send(handle, data, szData, 0); - if (ret < 0) { - ERR("Fail to send data, handle = %d, size = %zu", handle, szData); - throw std::runtime_error(strerror(errno)); + size_t sent = 0; + while (sent < szData) { + int ret = send(handle, static_cast(data) + sent, szData - sent, 0); + if (ret < 0) { + ERR("Fail to send data, handle = %d, size = %zu", handle, szData); + throw std::runtime_error(strerror(errno)); + } + + sent += ret; } + szData = sent; +} - szData = ret; +void TCP::SendSizedData(const void *data, size_t &szData) +{ + if (secure) + SendSizedDataSecure(data, szData); + else + SendSizedDataNormal(data, szData); } -void TCP::Recv(void *data, size_t &szData) +int TCP::Recv(void *data, size_t &szData) { - int ret = recv(handle, data, szData, 0); - if (ret < 0) { - ERR("Fail to recv data, handle = %d, size = %zu", handle, szData); - throw std::runtime_error(strerror(errno)); + size_t received = 0; + while (received < szData) { + int ret = recv(handle, static_cast(data) + received, szData - received, 0); + if (ret < 0) { + ERR("Fail to recv data, handle = %d, size = %zu", handle, szData); + throw std::runtime_error(strerror(errno)); + } + if (ret == 0) { + ERR("disconnected"); + return -1; + } + + received += ret; } - szData = ret; + szData = received; + return 0; +} + +int TCP::RecvSizedData(void **data, size_t &szData) +{ + if (secure) + return RecvSizedDataSecure(data, szData); + else + return RecvSizedDataNormal(data, szData); +} + +int TCP::HandleZeroMsg(void **data, size_t &data_size) +{ + // distinguish between connection problems and zero-size messages + INFO("Got a zero-size message."); + data_size = 0; + *data = nullptr; + return 0; } int TCP::GetHandle(void) @@ -156,3 +204,103 @@ unsigned short TCP::GetPort(void) return ntohs(addr.sin_port); } + +void TCP::SendSizedDataNormal(const void *data, size_t &data_size) +{ + size_t fixed_data_size = data_size; + if (0 == data_size) { + // distinguish between connection problems and zero-size messages + INFO("Send a zero-size message."); + fixed_data_size = UINT32_MAX; + } + + size_t size_len = sizeof(fixed_data_size); + Send(static_cast(&fixed_data_size), size_len); + Send(data, data_size); +} + +int TCP::RecvSizedDataNormal(void **data, size_t &data_size) +{ + int ret; + + size_t data_len = 0; + size_t size_len = sizeof(data_len); + ret = Recv(static_cast(&data_len), size_len); + if (ret < 0) { + ERR("Recv() Fail(%d)", ret); + return ret; + } + + if (data_len == UINT32_MAX) + return HandleZeroMsg(data, data_size); + + void *data_buf = malloc(data_len); + Recv(data_buf, data_len); + data_size = data_len; + *data = data_buf; + + return 0; +} + +void TCP::SendSizedDataSecure(const void *data, size_t &data_size) +{ + size_t fixed_data_size = data_size; + if (0 == data_size) { + // distinguish between connection problems and zero-size messages + INFO("Send a zero-size message."); + fixed_data_size = UINT32_MAX; + } + + size_t size_len; + if (data_size) { + unsigned char data_buf[crypto.GetCryptogramSize(data_size)]; + size_t data_len = + crypto.Encrypt(static_cast(data), data_size, data_buf); + unsigned char size_buf[crypto.GetCryptogramSize(sizeof(size_t))]; + size_len = crypto.Encrypt((unsigned char *)&data_len, sizeof(data_len), size_buf); + Send(size_buf, size_len); + Send(data_buf, data_len); + } else { + unsigned char size_buf[crypto.GetCryptogramSize(sizeof(size_t))]; + size_len = + crypto.Encrypt((unsigned char *)&fixed_data_size, sizeof(fixed_data_size), size_buf); + Send(size_buf, size_len); + } +} + +int TCP::RecvSizedDataSecure(void **data, size_t &data_size) +{ + int ret; + + size_t cipher_size_len = crypto.GetCryptogramSize(sizeof(size_t)); + unsigned char cipher_size_buf[cipher_size_len]; + ret = Recv(cipher_size_buf, cipher_size_len); + if (ret < 0) { + ERR("Recv() Fail(%d)", ret); + return ret; + } + + unsigned char plain_size_buf[cipher_size_len]; + size_t cipher_data_len = 0; + crypto.Decrypt(cipher_size_buf, cipher_size_len, plain_size_buf); + memcpy(&cipher_data_len, plain_size_buf, sizeof(cipher_data_len)); + if (cipher_data_len == UINT32_MAX) + return HandleZeroMsg(data, data_size); + + if (AITT_MESSAGE_MAX < cipher_data_len) { + ERR("Invalid Size(%zu)", cipher_data_len); + return -1; + } + unsigned char cipher_data_buf[cipher_data_len]; + Recv(cipher_data_buf, cipher_data_len); + unsigned char *data_buf = static_cast(malloc(cipher_data_len)); + data_size = crypto.Decrypt(cipher_data_buf, cipher_data_len, data_buf); + *data = data_buf; + return 0; +} + +TCP::ConnectInfo::ConnectInfo() : port(0), secure(false), key(), iv() +{ +} + +} // namespace AittTCPNamespace diff --git a/modules/tcp/TCP.h b/modules/tcp/TCP.h index 535819c..eac67af 100644 --- a/modules/tcp/TCP.h +++ b/modules/tcp/TCP.h @@ -20,24 +20,53 @@ #include +#include "AESEncryptor.h" + +namespace AittTCPNamespace { + class TCP { public: class Server; + struct ConnectInfo { + struct Compare { + bool operator()(const ConnectInfo &lhs, const ConnectInfo &rhs) const + { + return lhs.port < rhs.port; + } + }; - TCP(const std::string &host, unsigned short port); + ConnectInfo(); + unsigned short port; + bool secure; + unsigned char key[AITT_TCP_ENCRYPTOR_KEY_LEN]; + unsigned char iv[AITT_TCP_ENCRYPTOR_IV_LEN]; + }; + + TCP(const std::string &host, const ConnectInfo &ConnectInfo); virtual ~TCP(void); void Send(const void *data, size_t &szData); - void Recv(void *data, size_t &szData); + void SendSizedData(const void *data, size_t &szData); + int Recv(void *data, size_t &szData); + int RecvSizedData(void **data, size_t &szData); int GetHandle(void); unsigned short GetPort(void); void GetPeerInfo(std::string &host, unsigned short &port); private: - TCP(int handle, sockaddr *addr, socklen_t addrlen); - void SetupOptions(void); + TCP(int handle, sockaddr *addr, socklen_t addrlen, const ConnectInfo &connect_info); + void SetupOptions(const ConnectInfo &connect_info); + int HandleZeroMsg(void **data, size_t &data_size); + void SendSizedDataNormal(const void *data, size_t &data_size); + int RecvSizedDataNormal(void **data, size_t &data_size); + void SendSizedDataSecure(const void *data, size_t &data_size); + int RecvSizedDataSecure(void **data, size_t &data_size); int handle; socklen_t addrlen; sockaddr *addr; + bool secure; + AESEncryptor crypto; }; + +} // namespace AittTCPNamespace diff --git a/modules/tcp/TCPServer.cc b/modules/tcp/TCPServer.cc index 3b912b2..5a215ba 100644 --- a/modules/tcp/TCPServer.cc +++ b/modules/tcp/TCPServer.cc @@ -29,8 +29,10 @@ #define BACKLOG 10 // Accept only 10 simultaneously connections by default -TCP::Server::Server(const std::string &host, unsigned short &port) - : handle(-1), addr(nullptr), addrlen(0) +namespace AittTCPNamespace { + +TCP::Server::Server(const std::string &host, unsigned short &port, bool is_secure) + : handle(-1), addr(nullptr), addrlen(0), secure(is_secure), key(), iv() { int ret = 0; @@ -72,6 +74,9 @@ TCP::Server::Server(const std::string &host, unsigned short &port) if (ret < 0) break; + if (secure) + AESEncryptor::GenerateKey(key, iv); + return; } while (0); @@ -92,7 +97,6 @@ TCP::Server::~Server(void) return; free(addr); - if (close(handle) < 0) ERR_CODE(errno, "close"); } @@ -112,8 +116,13 @@ std::unique_ptr TCP::Server::AcceptPeer(void) free(peerAddr); throw std::runtime_error(strerror(errno)); } - - return std::unique_ptr(new TCP(peerHandle, peerAddr, szAddr)); + ConnectInfo info; + if (secure) { + info.secure = true; + memcpy(info.key, key, sizeof(key)); + memcpy(info.iv, iv, sizeof(iv)); + } + return std::unique_ptr(new TCP(peerHandle, peerAddr, szAddr, info)); } int TCP::Server::GetHandle(void) @@ -132,17 +141,14 @@ unsigned short TCP::Server::GetPort(void) return ntohs(addr.sin_port); } -void TCP::Server::CreateAESEncryptor(void) +const unsigned char *TCP::Server::GetCryptoKey(void) { - aes_encryptor = new AESEncryptor(); + return key; } -AESEncryptor *TCP::Server::GetAESEncryptor(void) +const unsigned char *TCP::Server::GetCryptoIv(void) { - return aes_encryptor; + return iv; } -const unsigned char *TCP::Server::GetKey(void) -{ - return aes_encryptor->GetCipherKey(); -} +} // namespace AittTCPNamespace diff --git a/modules/tcp/TCPServer.h b/modules/tcp/TCPServer.h index fdf3f58..7230c78 100644 --- a/modules/tcp/TCPServer.h +++ b/modules/tcp/TCPServer.h @@ -18,27 +18,29 @@ #include #include -#include "AESEncryptor.h" #include "TCP.h" +namespace AittTCPNamespace { + class TCP::Server { public: - Server(const std::string &host, unsigned short &port); - Server(const Server &) = default; - Server &operator=(const Server &) = default; + Server(const std::string &host, unsigned short &port, bool secure = false); virtual ~Server(void); std::unique_ptr AcceptPeer(void); int GetHandle(void); unsigned short GetPort(void); - void CreateAESEncryptor(void); - AESEncryptor *GetAESEncryptor(void); - const unsigned char *GetKey(void); + const unsigned char *GetCryptoKey(void); + const unsigned char *GetCryptoIv(void); private: int handle; sockaddr *addr; socklen_t addrlen; - AESEncryptor *aes_encryptor; + bool secure; + unsigned char key[AITT_TCP_ENCRYPTOR_KEY_LEN]; + unsigned char iv[AITT_TCP_ENCRYPTOR_IV_LEN]; }; + +} // namespace AittTCPNamespace diff --git a/modules/tcp/samples/CMakeLists.txt b/modules/tcp/samples/CMakeLists.txt index 7f071e6..836512c 100644 --- a/modules/tcp/samples/CMakeLists.txt +++ b/modules/tcp/samples/CMakeLists.txt @@ -1,3 +1,7 @@ -ADD_EXECUTABLE("aitt_tcp_test" tcp_test.cc $) -TARGET_LINK_LIBRARIES("aitt_tcp_test" ${PROJECT_NAME} Threads::Threads ${AITT_NEEDS_LIBRARIES} ${AITT_TCP_NEEDS_LIBRARIES}) +PKG_CHECK_MODULES(SAMPLE_NEEDS REQUIRED glib-2.0 ${TIZEN_LOG_PKG}) +INCLUDE_DIRECTORIES(${SAMPLE_NEEDS_INCLUDE_DIRS}) +LINK_DIRECTORIES(${SAMPLE_NEEDS_LIBRARY_DIRS}) + +ADD_EXECUTABLE("aitt_tcp_test" tcp_test.cc) +TARGET_LINK_LIBRARIES("aitt_tcp_test" TCP_OBJ ${SAMPLE_NEEDS_LIBRARIES} ${AITT_TCP_NEEDS_LIBRARIES}) INSTALL(TARGETS "aitt_tcp_test" DESTINATION ${AITT_TEST_BINDIR}) diff --git a/modules/tcp/samples/tcp_test.cc b/modules/tcp/samples/tcp_test.cc index f550c60..9bdca5c 100644 --- a/modules/tcp/samples/tcp_test.cc +++ b/modules/tcp/samples/tcp_test.cc @@ -33,6 +33,8 @@ __thread __aitt__tls__ __aitt; #define BYE_STRING "bye" #define SEND_INTERVAL 1000 +using namespace AittTCPNamespace; + class AittTcpSample { public: AittTcpSample(const std::string &host, unsigned short &port) @@ -202,7 +204,9 @@ int main(int argc, char *argv[]) SEND_INTERVAL, [](gpointer data) -> gboolean { Main *ctx = static_cast
(data); - std::unique_ptr client(new TCP(ctx->host, ctx->port)); + TCP::ConnectInfo info; + info.port = ctx->port; + std::unique_ptr client(new TCP(ctx->host, info)); INFO("Assigned client port: %u", client->GetPort()); diff --git a/modules/tcp/tests/AESEncryptor_test.cc b/modules/tcp/tests/AESEncryptor_test.cc old mode 100755 new mode 100644 index 4c22ef0..a5a45d1 --- a/modules/tcp/tests/AESEncryptor_test.cc +++ b/modules/tcp/tests/AESEncryptor_test.cc @@ -21,69 +21,44 @@ #include "aitt_internal.h" -static constexpr unsigned char TEST_CIPHER_KEY[] = {0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c}; -static const std::string TEST_MESSAGE("TCP encryptions."); - -class AESEncryptorTest : public testing::Test { - public: - static void PrintKey(const unsigned char *key) - { - for (int i = 0; i < AESEncryptor::AES_KEY_BYTE_SIZE / 8; i++) { - DBG("%u %u %u %u %u %u %u %u", - key[8 * i + 0], key[8 * i + 1], key[8 * i + 2], key[8 * i + 3], key[8 * i + 4], key[8 * i + 5], key[8 * i + 6], key[8 * i + 7]); - } - } -}; - -TEST(AESEncryptor, Positive_Create_Anytime) -{ - std::unique_ptr aes_encryptor(new AESEncryptor()); - ASSERT_NE(aes_encryptor, nullptr); -} +static constexpr unsigned char TEST_CIPHER_KEY[] = {0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, + 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c, 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, + 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c}; +static constexpr unsigned char TEST_CIPHER_IV[] = {0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, + 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c}; -TEST(AESEncryptor, Positive_CreateWithArgument_Anytime) -{ - std::unique_ptr aes_encryptor(new AESEncryptor(TEST_CIPHER_KEY)); - std::string aes_encryptor_key(reinterpret_cast(aes_encryptor->GetCipherKey())); - std::string test_key(reinterpret_cast(TEST_CIPHER_KEY), AESEncryptor::AES_KEY_BYTE_SIZE); - ASSERT_STREQ(aes_encryptor_key.c_str(), test_key.c_str()); -} +static const std::string TEST_MESSAGE("TCP encryptions."); -TEST(AESEncryptor, Positive_GenerateRandomKeys_Anytime) -{ - std::unique_ptr aes_encryptor_first(new AESEncryptor()); - std::unique_ptr aes_encryptor_second(new AESEncryptor()); - std::string first_key(reinterpret_cast(aes_encryptor_first->GetCipherKey()), AESEncryptor::AES_KEY_BYTE_SIZE); - std::string second_key(reinterpret_cast(aes_encryptor_second->GetCipherKey()), AESEncryptor::AES_KEY_BYTE_SIZE); - ASSERT_STRNE(first_key.c_str(), second_key.c_str()); -} +using namespace AittTCPNamespace; -TEST(AESEncryptor, Positive_Encrypt_Anytime) +TEST(AESEncryptor, Encrypt_P_Anytime) { - std::unique_ptr aes_encryptor(new AESEncryptor()); - AESEncryptorTest::PrintKey(aes_encryptor->GetCipherKey()); - try { - unsigned char encryption_buffer[AESEncryptor::AES_KEY_BYTE_SIZE]; - aes_encryptor->Encrypt(reinterpret_cast(TEST_MESSAGE.c_str()), encryption_buffer); + AESEncryptor encryptor; + encryptor.Init(TEST_CIPHER_KEY, TEST_CIPHER_IV); + + unsigned char encryption_buffer[encryptor.GetCryptogramSize(TEST_MESSAGE.size())]; + encryptor.Encrypt(reinterpret_cast(TEST_MESSAGE.c_str()), + TEST_MESSAGE.size(), encryption_buffer); } catch (std::exception &e) { - ASSERT_STREQ(e.what(), strerror(EINVAL)); + FAIL() << "Unexpected exception: " << e.what(); } } -TEST(AESEncryptor, Positive_EncryptDecryped_Anytime) +TEST(AESEncryptor, EncryptDecryped_P_Anytime) { - std::unique_ptr aes_encryptor(new AESEncryptor()); - AESEncryptorTest::PrintKey(aes_encryptor->GetCipherKey()); - try { - unsigned char encryption_buffer[AESEncryptor::AES_KEY_BYTE_SIZE]; - unsigned char decryption_buffer[AESEncryptor::AES_KEY_BYTE_SIZE]; - aes_encryptor->Encrypt(reinterpret_cast(TEST_MESSAGE.c_str()), encryption_buffer); - aes_encryptor->Decrypt(encryption_buffer, decryption_buffer); - std::string decrypted_message(reinterpret_cast(decryption_buffer), AESEncryptor::AES_KEY_BYTE_SIZE); - DBG("TEST_MESSAGE = (%s), decrypted_message = (%s)", TEST_MESSAGE.c_str(), decrypted_message.c_str()); - ASSERT_STREQ(decrypted_message.c_str(), TEST_MESSAGE.c_str()); + AESEncryptor encryptor; + encryptor.Init(TEST_CIPHER_KEY, TEST_CIPHER_IV); + + unsigned char ciphertext[encryptor.GetCryptogramSize(TEST_MESSAGE.size())]; + unsigned char plaintext[encryptor.GetCryptogramSize(TEST_MESSAGE.size())]; + size_t len = + encryptor.Encrypt(reinterpret_cast(TEST_MESSAGE.c_str()), + TEST_MESSAGE.size(), ciphertext); + len = encryptor.Decrypt(ciphertext, len, plaintext); + plaintext[len] = 0; + ASSERT_STREQ(TEST_MESSAGE.c_str(), reinterpret_cast(plaintext)); } catch (std::exception &e) { ASSERT_STREQ(e.what(), strerror(EINVAL)); } diff --git a/modules/tcp/tests/CMakeLists.txt b/modules/tcp/tests/CMakeLists.txt index 72d0177..5974a44 100644 --- a/modules/tcp/tests/CMakeLists.txt +++ b/modules/tcp/tests/CMakeLists.txt @@ -1,4 +1,4 @@ -PKG_CHECK_MODULES(UT_NEEDS REQUIRED gmock_main) +PKG_CHECK_MODULES(UT_NEEDS REQUIRED gmock_main ${TIZEN_LOG_PKG}) INCLUDE_DIRECTORIES(${UT_NEEDS_INCLUDE_DIRS}) LINK_DIRECTORIES(${UT_NEEDS_LIBRARY_DIRS}) @@ -6,8 +6,8 @@ SET(AITT_TCP_UT ${PROJECT_NAME}_tcp_ut) SET(AITT_TCP_UT_SRC TCP_test.cc TCPServer_test.cc AESEncryptor_test.cc) -ADD_EXECUTABLE(${AITT_TCP_UT} ${AITT_TCP_UT_SRC} $) -TARGET_LINK_LIBRARIES(${AITT_TCP_UT} ${UT_NEEDS_LIBRARIES} Threads::Threads ${AITT_NEEDS_LIBRARIES} ${AITT_TCP_NEEDS_LIBRARIES}) +ADD_EXECUTABLE(${AITT_TCP_UT} ${AITT_TCP_UT_SRC}) +TARGET_LINK_LIBRARIES(${AITT_TCP_UT} TCP_OBJ Threads::Threads ${UT_NEEDS_LIBRARIES} ${AITT_TCP_NEEDS_LIBRARIES}) INSTALL(TARGETS ${AITT_TCP_UT} DESTINATION ${AITT_TEST_BINDIR}) ADD_TEST( diff --git a/modules/tcp/tests/TCPServer_test.cc b/modules/tcp/tests/TCPServer_test.cc index e3eb62e..d2206aa 100644 --- a/modules/tcp/tests/TCPServer_test.cc +++ b/modules/tcp/tests/TCPServer_test.cc @@ -28,6 +28,8 @@ #define TEST_SERVER_PORT 8123 #define TEST_SERVER_AVAILABLE_PORT 0 +using namespace AittTCPNamespace; + TEST(TCPServer, Positive_Create_Anytime) { unsigned short port = TEST_SERVER_PORT; @@ -108,7 +110,9 @@ TEST(TCPServer, Positive_AcceptPeer_Anytime) { std::unique_lock lk(m); ready_cv.wait(lk, [&ready] { return ready; }); - std::unique_ptr tcp(new TCP(TEST_SERVER_ADDRESS, serverPort)); + TCP::ConnectInfo info; + info.port = serverPort; + std::unique_ptr tcp(new TCP(TEST_SERVER_ADDRESS, info)); connected_cv.wait(lk, [&connected] { return connected; }); } diff --git a/modules/tcp/tests/TCP_test.cc b/modules/tcp/tests/TCP_test.cc index d65875c..d356702 100644 --- a/modules/tcp/tests/TCP_test.cc +++ b/modules/tcp/tests/TCP_test.cc @@ -22,7 +22,7 @@ #include #include -#include "TCPServer.h" +#include "../TCPServer.h" #define TEST_SERVER_ADDRESS "127.0.0.1" #define TEST_SERVER_INVALID_ADDRESS "287.0.0.1" @@ -32,6 +32,8 @@ #define TEST_BUFFER_HELLO "Hello World" #define TEST_BUFFER_BYE "Good Bye" +using namespace AittTCPNamespace; + class TCPTest : public testing::Test { protected: void SetUp() override @@ -43,7 +45,9 @@ class TCPTest : public testing::Test { clientThread = std::thread([this](void) mutable -> void { std::unique_lock lk(m); ready_cv.wait(lk, [this] { return ready; }); - client = std::unique_ptr(new TCP(TEST_SERVER_ADDRESS, serverPort)); + TCP::ConnectInfo info; + info.port = serverPort; + client = std::unique_ptr(new TCP(TEST_SERVER_ADDRESS, info)); customTest(); }); @@ -78,7 +82,9 @@ class TCPTest : public testing::Test { TEST(TCP, Negative_Create_InvalidPort_Anytime) { try { - std::unique_ptr tcp(new TCP(TEST_SERVER_ADDRESS, TEST_SERVER_AVAILABLE_PORT)); + TCP::ConnectInfo info; + info.port = TEST_SERVER_AVAILABLE_PORT; + std::unique_ptr tcp(new TCP(TEST_SERVER_ADDRESS, info)); ASSERT_EQ(tcp, nullptr); } catch (std::exception &e) { ASSERT_STREQ(e.what(), strerror(EINVAL)); @@ -88,7 +94,9 @@ TEST(TCP, Negative_Create_InvalidPort_Anytime) TEST(TCP, Negative_Create_InvalidAddress_Anytime) { try { - std::unique_ptr tcp(new TCP(TEST_SERVER_INVALID_ADDRESS, TEST_SERVER_PORT)); + TCP::ConnectInfo info; + info.port = TEST_SERVER_PORT; + std::unique_ptr tcp(new TCP(TEST_SERVER_INVALID_ADDRESS, info)); ASSERT_EQ(tcp, nullptr); } catch (std::exception &e) { ASSERT_STREQ(e.what(), strerror(EINVAL)); @@ -127,7 +135,7 @@ TEST_F(TCPTest, Positive_SendRecv_Anytime) char byeBuffer[TEST_BUFFER_SIZE]; customTest = [this, &helloBuffer](void) mutable -> void { - size_t szData = sizeof(helloBuffer); + size_t szData = sizeof(TEST_BUFFER_HELLO); client->Recv(static_cast(helloBuffer), szData); szData = sizeof(TEST_BUFFER_BYE); @@ -139,7 +147,7 @@ TEST_F(TCPTest, Positive_SendRecv_Anytime) size_t szMsg = sizeof(TEST_BUFFER_HELLO); peer->Send(TEST_BUFFER_HELLO, szMsg); - szMsg = sizeof(byeBuffer); + szMsg = sizeof(TEST_BUFFER_BYE); peer->Recv(static_cast(byeBuffer), szMsg); ASSERT_STREQ(helloBuffer, TEST_BUFFER_HELLO); diff --git a/modules/transport_entry.cc b/modules/transport_entry.cc index 00881e8..064697a 100644 --- a/modules/transport_entry.cc +++ b/modules/transport_entry.cc @@ -20,15 +20,15 @@ #include "Module.h" #include "aitt_internal_definitions.h" -extern "C" { +using namespace MODULE_NAMESPACE; -API void *AITT_TRANSPORT_NEW(AittProtocol protocol, const char *ip, AittDiscovery &discovery) +extern "C" { +API void *AITT_TRANSPORT_NEW(AittProtocol type, AittDiscovery &discovery, const std::string &my_ip) { assert(STR_EQ == strcmp(__func__, aitt::AittTransport::MODULE_ENTRY_NAME) && "Entry point name is not matched"); - std::string ip_address(ip); - Module *module = new Module(protocol, ip_address, discovery); + Module *module = new Module(type, discovery, my_ip); // validate that the module creates valid object (which inherits AittTransport) AittTransport *transport_module = dynamic_cast(module); diff --git a/modules/webrtc/CMakeLists.txt b/modules/webrtc/CMakeLists.txt index 1e9ed81..c8c3bed 100644 --- a/modules/webrtc/CMakeLists.txt +++ b/modules/webrtc/CMakeLists.txt @@ -13,7 +13,7 @@ LINK_DIRECTORIES(${AITT_WEBRTC_NEEDS_LIBRARY_DIRS}) FILE(GLOB AITT_WEBRTC_SRC *.cc) list(REMOVE_ITEM AITT_WEBRTC_SRC ${CMAKE_CURRENT_SOURCE_DIR}/Module.cc) ADD_LIBRARY(WEBRTC_OBJ OBJECT ${AITT_WEBRTC_SRC}) -ADD_LIBRARY(${AITT_WEBRTC} SHARED $ ../transport_entry.cc Module.cc) +ADD_LIBRARY(${AITT_WEBRTC} SHARED ../transport_entry.cc Module.cc $) TARGET_LINK_LIBRARIES(${AITT_WEBRTC} ${AITT_WEBRTC_NEEDS_LIBRARIES} ${AITT_COMMON}) TARGET_COMPILE_OPTIONS(${AITT_WEBRTC} PUBLIC ${AITT_WEBRTC_NEEDS_CFLAGS_OTHER}) diff --git a/modules/webrtc/Module.cc b/modules/webrtc/Module.cc index d68b759..5ef73ad 100644 --- a/modules/webrtc/Module.cc +++ b/modules/webrtc/Module.cc @@ -21,8 +21,10 @@ #include "Config.h" #include "aitt_internal.h" -Module::Module(AittProtocol protocol, const std::string &ip, AittDiscovery &discovery) - : AittTransport(discovery) +namespace AittWebRTCNamespace { + +Module::Module(AittProtocol type, AittDiscovery &discovery, const std::string &ip) + : AittTransport(type, discovery) { } @@ -36,8 +38,8 @@ void Module::Publish(const std::string &topic, const void *data, const size_t da // TODO } -void Module::Publish( - const std::string &topic, const void *data, const size_t datalen, AittQoS qos, bool retain) +void Module::Publish(const std::string &topic, const void *data, const size_t datalen, AittQoS qos, + bool retain) { std::lock_guard publish_table_lock(publish_table_lock_); @@ -126,3 +128,5 @@ void *Module::Unsubscribe(void *handlePtr) return ret; } + +} // namespace AittWebRTCNamespace diff --git a/modules/webrtc/Module.h b/modules/webrtc/Module.h index 8d52f0c..6f56f59 100644 --- a/modules/webrtc/Module.h +++ b/modules/webrtc/Module.h @@ -33,9 +33,12 @@ using AittTransport = aitt::AittTransport; using MainLoopHandler = aitt::MainLoopHandler; using AittDiscovery = aitt::AittDiscovery; +#define MODULE_NAMESPACE AittWebRTCNamespace +namespace AittWebRTCNamespace { + class Module : public AittTransport { public: - explicit Module(AittProtocol protocol, const std::string &ip, AittDiscovery &discovery); + explicit Module(AittProtocol type, AittDiscovery &discovery, const std::string &ip); virtual ~Module(void); // TODO: How about regarding topic as service name? @@ -64,3 +67,5 @@ class Module : public AittTransport { std::map> subscribe_table_; std::mutex subscribe_table_lock_; }; + +} // namespace AittWebRTCNamespace diff --git a/modules/webrtc/MqttServer.cc b/modules/webrtc/MqttServer.cc index b8a9c45..12fd8eb 100644 --- a/modules/webrtc/MqttServer.cc +++ b/modules/webrtc/MqttServer.cc @@ -15,14 +15,14 @@ */ #include "MqttServer.h" -#include "MQProxy.h" +#include "MosquittoMQ.h" #include "aitt_internal.h" #define MQTT_HANDLER_MSG_QOS 1 #define MQTT_HANDLER_MGMT_QOS 2 MqttServer::MqttServer(const Config &config) - : mq(new aitt::MQProxy(config.GetLocalId(), AittOption(true, false))), + : mq(new aitt::MosquittoMQ(config.GetLocalId(), true)), connection_state_(ConnectionState::Disconnected) { broker_ip_ = config.GetBrokerIp(); diff --git a/src/AITTImpl.cc b/src/AITTImpl.cc index 61f0fa2..4f6d298 100644 --- a/src/AITTImpl.cc +++ b/src/AITTImpl.cc @@ -23,7 +23,7 @@ #include #include -#include "MQProxy.h" +#include "MosquittoMQ.h" #include "aitt_internal.h" #define WEBRTC_ROOM_ID_PREFIX std::string(AITT_MANAGED_TOPIC_PREFIX "webrtc/room/Room.webrtc") @@ -34,23 +34,18 @@ namespace aitt { AITT::Impl::Impl(AITT &parent, const std::string &id, const std::string &my_ip, const AittOption &option) : public_api(parent), + discovery(id), + modules(my_ip, discovery), id_(id), mqtt_broker_port_(0), - mq(new MQProxy(id, option)), - discovery(id, option), - reply_id(0), - transports{0} + reply_id(0) { - // TODO: Validate my_ip - ModuleLoader loader; - for (ModuleLoader::Type i = ModuleLoader::TYPE_TCP; i < ModuleLoader::TYPE_TRANSPORT_MAX; - i = ModuleLoader::Type(i + 1)) { - module_handles.push_back(loader.OpenModule(i)); - const ModuleLoader::ModuleHandle &handle = module_handles.back(); - if (handle == nullptr) - ERR("OpenModule() Fail"); - - transports[i] = loader.LoadTransport(handle.get(), loader.GetProtocol(i), my_ip, discovery); + if (option.GetUseCustomMqttBroker()) { + mq = modules.NewCustomMQ(id, option); + discovery.SetMQ(modules.NewCustomMQ(id + 'd', option)); + } else { + mq = std::unique_ptr(new MosquittoMQ(id, option.GetClearSession())); + discovery.SetMQ(std::unique_ptr(new MosquittoMQ(id + 'd', option.GetClearSession()))); } aittThread = std::thread(&AITT::Impl::ThreadMain, this); } @@ -128,13 +123,9 @@ void AITT::Impl::UnsubscribeAll() mq->Unsubscribe(subscribe_info->second); break; case AITT_TYPE_TCP: - transports[ModuleLoader::TYPE_TCP]->Unsubscribe(subscribe_info->second); - break; - case AITT_TYPE_SECURE_TCP: - transports[ModuleLoader::TYPE_SECURE_TCP]->Unsubscribe(subscribe_info->second); - break; + case AITT_TYPE_TCP_SECURE: case AITT_TYPE_WEBRTC: - transports[ModuleLoader::TYPE_WEBRTC]->Unsubscribe(subscribe_info->second); + modules.Get(subscribe_info->first).Unsubscribe(subscribe_info->second); break; default: @@ -159,10 +150,10 @@ void AITT::Impl::Publish(const std::string &topic, const void *data, const size_ mq->Publish(topic, data, datalen, qos, retain); if ((protocols & AITT_TYPE_TCP) == AITT_TYPE_TCP) - transports[ModuleLoader::TYPE_TCP]->Publish(topic, data, datalen, qos, retain); + modules.Get(AITT_TYPE_TCP).Publish(topic, data, datalen, qos, retain); - if ((protocols & AITT_TYPE_SECURE_TCP) == AITT_TYPE_SECURE_TCP) - transports[ModuleLoader::TYPE_SECURE_TCP]->Publish(topic, data, datalen, qos, retain); + if ((protocols & AITT_TYPE_TCP_SECURE) == AITT_TYPE_TCP_SECURE) + modules.Get(AITT_TYPE_TCP_SECURE).Publish(topic, data, datalen, qos, retain); if ((protocols & AITT_TYPE_WEBRTC) == AITT_TYPE_WEBRTC) PublishWebRtc(topic, data, datalen, qos, retain); @@ -183,7 +174,7 @@ void AITT::Impl::PublishWebRtc(const std::string &topic, const void *data, const }); fbb.Finish(); auto buf = fbb.GetBuffer(); - transports[ModuleLoader::TYPE_WEBRTC]->Publish(topic, buf.data(), buf.size(), qos, retain); + modules.Get(AITT_TYPE_WEBRTC).Publish(topic, buf.data(), buf.size(), qos, retain); } AittSubscribeID AITT::Impl::Subscribe(const std::string &topic, const AITT::SubscribeCallback &cb, @@ -191,19 +182,16 @@ AittSubscribeID AITT::Impl::Subscribe(const std::string &topic, const AITT::Subs { SubscribeInfo *info = new SubscribeInfo(); info->first = protocol; - void *subscribe_handle; - INFO("[PROTOCOL] %d", static_cast(protocol)); + void *subscribe_handle; switch (protocol) { case AITT_TYPE_MQTT: subscribe_handle = SubscribeMQ(info, &main_loop, topic, cb, user_data, qos); break; case AITT_TYPE_TCP: + case AITT_TYPE_TCP_SECURE: subscribe_handle = SubscribeTCP(info, topic, cb, user_data, qos); break; - case AITT_TYPE_SECURE_TCP: - subscribe_handle = SubscribeSecureTCP(info, topic, cb, user_data, qos); - break; case AITT_TYPE_WEBRTC: subscribe_handle = SubscribeWebRtc(info, topic, cb, user_data, qos); break; @@ -272,14 +260,12 @@ void *AITT::Impl::Unsubscribe(AittSubscribeID subscribe_id) case AITT_TYPE_MQTT: user_data = mq->Unsubscribe(found_info->second); break; - case AITT_TYPE_TCP: { - user_data = transports[ModuleLoader::TYPE_TCP]->Unsubscribe(found_info->second); - break; - } - case AITT_TYPE_WEBRTC: { - user_data = transports[ModuleLoader::TYPE_WEBRTC]->Unsubscribe(found_info->second); + case AITT_TYPE_TCP: + case AITT_TYPE_TCP_SECURE: + case AITT_TYPE_WEBRTC: + user_data = modules.Get(found_info->first).Unsubscribe(found_info->second); break; - } + default: ERR("Unknown AittProtocol(%d)", found_info->first); break; @@ -403,37 +389,20 @@ void AITT::Impl::SendReply(MSG *msg, const void *data, const int datalen, bool e void *AITT::Impl::SubscribeTCP(SubscribeInfo *handle, const std::string &topic, const SubscribeCallback &cb, void *user_data, AittQoS qos) { - return transports[ModuleLoader::TYPE_TCP]->Subscribe( - topic, - [handle, cb](const std::string &topic, const void *data, const size_t datalen, - void *user_data, const std::string &correlation) -> void { - MSG msg; - msg.SetID(handle); - msg.SetTopic(topic); - msg.SetCorrelation(correlation); - msg.SetProtocols(AITT_TYPE_TCP); - - return cb(&msg, data, datalen, user_data); - }, - user_data, qos); -} - -void *AITT::Impl::SubscribeSecureTCP(SubscribeInfo *handle, const std::string &topic, - const SubscribeCallback &cb, void *user_data, AittQoS qos) -{ - return transports[ModuleLoader::TYPE_SECURE_TCP]->Subscribe( - topic, - [handle, cb](const std::string &topic, const void *data, const size_t datalen, - void *user_data, const std::string &correlation) -> void { - MSG msg; - msg.SetID(handle); - msg.SetTopic(topic); - msg.SetCorrelation(correlation); - msg.SetProtocols(AITT_TYPE_SECURE_TCP); - - return cb(&msg, data, datalen, user_data); - }, - user_data, qos); + return modules.Get(handle->first) + .Subscribe( + topic, + [handle, cb](const std::string &topic, const void *data, const size_t datalen, + void *user_data, const std::string &correlation) -> void { + MSG msg; + msg.SetID(handle); + msg.SetTopic(topic); + msg.SetCorrelation(correlation); + msg.SetProtocols(handle->first); + + return cb(&msg, data, datalen, user_data); + }, + user_data, qos); } void *AITT::Impl::SubscribeWebRtc(SubscribeInfo *handle, const std::string &topic, @@ -449,18 +418,19 @@ void *AITT::Impl::SubscribeWebRtc(SubscribeInfo *handle, const std::string &topi fbb.Finish(); auto buf = fbb.GetBuffer(); - return transports[ModuleLoader::TYPE_WEBRTC]->Subscribe( - topic, - [handle, cb](const std::string &topic, const void *data, const size_t datalen, - void *user_data, const std::string &correlation) -> void { - MSG msg; - msg.SetID(handle); - msg.SetTopic(topic); - msg.SetCorrelation(correlation); - msg.SetProtocols(AITT_TYPE_WEBRTC); - - return cb(&msg, data, datalen, user_data); - }, - buf.data(), buf.size(), user_data, qos); + return modules.Get(AITT_TYPE_WEBRTC) + .Subscribe( + topic, + [handle, cb](const std::string &topic, const void *data, const size_t datalen, + void *user_data, const std::string &correlation) -> void { + MSG msg; + msg.SetID(handle); + msg.SetTopic(topic); + msg.SetCorrelation(correlation); + msg.SetProtocols(AITT_TYPE_WEBRTC); + + return cb(&msg, data, datalen, user_data); + }, + buf.data(), buf.size(), user_data, qos); } } // namespace aitt diff --git a/src/AITTImpl.h b/src/AITTImpl.h index 6b9005e..fcd7f08 100644 --- a/src/AITTImpl.h +++ b/src/AITTImpl.h @@ -28,7 +28,7 @@ #include "AittDiscovery.h" #include "MQ.h" #include "MainLoopHandler.h" -#include "ModuleLoader.h" +#include "ModuleManager.h" namespace aitt { @@ -50,18 +50,15 @@ class AITT::Impl { void Publish(const std::string &topic, const void *data, const size_t datalen, AittProtocol protocols, AittQoS qos, bool retain); - int PublishWithReply(const std::string &topic, const void *data, const size_t datalen, AittProtocol protocol, AittQoS qos, bool retain, const AITT::SubscribeCallback &cb, void *cbdata, const std::string &correlation); - int PublishWithReplySync(const std::string &topic, const void *data, const size_t datalen, AittProtocol protocol, AittQoS qos, bool retain, const SubscribeCallback &cb, void *cbdata, const std::string &correlation, int timeout_ms); AittSubscribeID Subscribe(const std::string &topic, const AITT::SubscribeCallback &cb, void *cbdata, AittProtocol protocols, AittQoS qos); - void *Unsubscribe(AittSubscribeID handle); void SendReply(MSG *msg, const void *data, const int datalen, bool end); @@ -78,8 +75,7 @@ class AITT::Impl { MainLoopHandler::MainLoopData *loop_data); void *SubscribeTCP(SubscribeInfo *, const std::string &topic, const SubscribeCallback &cb, void *cbdata, AittQoS qos); - void *SubscribeSecureTCP(SubscribeInfo *handle, const std::string &topic, - const SubscribeCallback &cb, void *user_data, AittQoS qos); + void *SubscribeWebRtc(SubscribeInfo *, const std::string &topic, const SubscribeCallback &cb, void *cbdata, AittQoS qos); void HandleTimeout(int timeout_ms, unsigned int &timeout_id, aitt::MainLoopHandler &sync_loop, @@ -87,21 +83,21 @@ class AITT::Impl { void PublishWebRtc(const std::string &topic, const void *data, const size_t datalen, AittQoS qos, bool retain); void UnsubscribeAll(); + void ThreadMain(void); AITT &public_api; - std::string id_; - std::string mqtt_broker_ip_; - int mqtt_broker_port_; - std::unique_ptr mq; AittDiscovery discovery; - unsigned short reply_id; - std::vector module_handles; - std::unique_ptr transports[ModuleLoader::TYPE_TRANSPORT_MAX]; MainLoopHandler main_loop; - void ThreadMain(void); std::thread aittThread; + ModuleManager modules; + std::unique_ptr mq; std::vector subscribed_list; std::mutex subscribed_list_mutex_; + + std::string id_; + std::string mqtt_broker_ip_; + int mqtt_broker_port_; + unsigned short reply_id; }; } // namespace aitt diff --git a/src/ModuleManager.cc b/src/ModuleManager.cc new file mode 100644 index 0000000..82c21d4 --- /dev/null +++ b/src/ModuleManager.cc @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ModuleManager.h" + +#include + +#include "AittException.h" +#include "NullTransport.h" +#include "aitt_internal.h" + +namespace aitt { + +ModuleManager::ModuleManager(const std::string &my_ip, AittDiscovery &d) + : ip(my_ip), discovery(d), custom_mqtt_handle(nullptr, nullptr), null_transport(discovery, ip) +{ + for (int i = TYPE_TCP; i < TYPE_TRANSPORT_MAX; ++i) { + transport_handles.push_back(ModuleHandle(nullptr, nullptr)); + LoadTransport(static_cast(i)); + } +} + +AittTransport &ModuleManager::Get(AittProtocol protocol) +{ + TransportType type = Convert(protocol); + AittTransport *module = transports[type].get(); + if (nullptr == module) + module = &null_transport; + + return *module; +} + +ModuleManager::TransportType ModuleManager::Convert(AittProtocol type) +{ + switch (type) { + case AITT_TYPE_TCP: + return TYPE_TCP; + case AITT_TYPE_TCP_SECURE: + return TYPE_TCP_SECURE; + case AITT_TYPE_WEBRTC: + return TYPE_WEBRTC; + + case AITT_TYPE_MQTT: + default: + ERR("Unknown Transport Type(%d)", type); + throw AittException(AittException::NO_DATA_ERR); + } + return TYPE_TRANSPORT_MAX; +} + +std::string ModuleManager::GetTransportFileName(TransportType type) +{ + switch (type) { + case TYPE_TCP: + case TYPE_TCP_SECURE: + return "libaitt-transport-tcp.so"; + case TYPE_WEBRTC: + return "libaitt-transport-webrtc.so"; + default: + ERR("Unknown Type(%d)", type); + break; + } + + return std::string("Unknown"); +} + +ModuleManager::ModuleHandle ModuleManager::OpenModule(const char *file) +{ + ModuleHandle handle(dlopen(file, RTLD_LAZY | RTLD_LOCAL), [](const void *handle) -> void { + if (dlclose(const_cast(handle))) + ERR("dlclose: %s", dlerror()); + }); + if (handle == nullptr) + ERR("dlopen(%s): %s", file, dlerror()); + + return handle; +} + +ModuleManager::ModuleHandle ModuleManager::OpenTransport(TransportType type) +{ + if (TYPE_TCP_SECURE == type) + type = TYPE_TCP; + + std::string filename = GetTransportFileName(type); + ModuleHandle handle = OpenModule(filename.c_str()); + + return handle; +} + +void ModuleManager::LoadTransport(TransportType type) +{ + transport_handles[type] = OpenTransport(type); + if (transport_handles[type] == nullptr) { + ERR("OpenTransport(%d) Fail", type); + return; + } + + AittTransport::ModuleEntry get_instance_fn = reinterpret_cast( + dlsym(transport_handles[type].get(), AittTransport::MODULE_ENTRY_NAME)); + if (get_instance_fn == nullptr) { + ERR("dlsym: %s", dlerror()); + return; + } + + AittProtocol protocol = static_cast(0x1 << (type + 1)); + transports[type] = std::unique_ptr( + static_cast(get_instance_fn(protocol, discovery, ip.c_str()))); + if (transports[type] == nullptr) { + ERR("get_instance_fn(%d) Fail", protocol); + } +} + +std::unique_ptr ModuleManager::NewCustomMQ(const std::string &id, const AittOption &option) +{ + ModuleHandle handle = OpenModule("libaitt-st-broker.so"); + + MQ::ModuleEntry get_instance_fn = + reinterpret_cast(dlsym(handle.get(), MQ::MODULE_ENTRY_NAME)); + if (get_instance_fn == nullptr) { + ERR("dlsym: %s", dlerror()); + throw AittException(AittException::SYSTEM_ERR); + } + + std::unique_ptr instance(static_cast(get_instance_fn(id.c_str(), option))); + if (instance == nullptr) { + ERR("get_instance_fn(MQ) Fail"); + throw AittException(AittException::SYSTEM_ERR); + } + + return instance; +} + +} // namespace aitt diff --git a/src/ModuleManager.h b/src/ModuleManager.h new file mode 100644 index 0000000..69c43b7 --- /dev/null +++ b/src/ModuleManager.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include "AittDiscovery.h" +#include "AittTransport.h" +#include "MQ.h" +#include "NullTransport.h" + +namespace aitt { + +class ModuleManager { + public: + explicit ModuleManager(const std::string &my_ip, AittDiscovery &d); + virtual ~ModuleManager() = default; + + AittTransport &Get(AittProtocol type); + std::unique_ptr NewCustomMQ(const std::string &id, const AittOption &option); + + private: + using ModuleHandle = std::unique_ptr; + + // It should be ("the number of shifts" - 1) of AittProtocol + enum TransportType { + TYPE_TCP, //(0x1 << 1) + TYPE_TCP_SECURE, //(0x1 << 2) + TYPE_WEBRTC, //(0x1 << 3) + TYPE_RTSP, + TYPE_TRANSPORT_MAX, + }; + + TransportType Convert(AittProtocol type); + std::string GetTransportFileName(TransportType type); + ModuleHandle OpenModule(const char *file); + ModuleHandle OpenTransport(TransportType type); + void LoadTransport(TransportType type); + + std::string ip; + AittDiscovery &discovery; + std::vector transport_handles; + std::unique_ptr transports[TYPE_TRANSPORT_MAX]; + ModuleHandle custom_mqtt_handle; + NullTransport null_transport; +}; + +} // namespace aitt diff --git a/common/NullTransport.cc b/src/NullTransport.cc similarity index 75% rename from common/NullTransport.cc rename to src/NullTransport.cc index 28a3b1f..e5b46db 100644 --- a/common/NullTransport.cc +++ b/src/NullTransport.cc @@ -17,23 +17,23 @@ #include "aitt_internal.h" -NullTransport::NullTransport(const std::string& ip, AittDiscovery& discovery) - : AittTransport(discovery) +NullTransport::NullTransport(AittDiscovery& discovery, const std::string& ip) + : AittTransport(AITT_TYPE_UNKNOWN, discovery) { } -void NullTransport::Publish( - const std::string& topic, const void* data, const size_t datalen, AittQoS qos, bool retain) +void NullTransport::Publish(const std::string& topic, const void* data, const size_t datalen, + const std::string& correlation, AittQoS qos, bool retain) { } void NullTransport::Publish(const std::string& topic, const void* data, const size_t datalen, - const std::string& correlation, AittQoS qos, bool retain) + AittQoS qos, bool retain) { } -void* NullTransport::Subscribe( - const std::string& topic, const SubscribeCallback& cb, void* cbdata, AittQoS qos) +void* NullTransport::Subscribe(const std::string& topic, const SubscribeCallback& cb, void* cbdata, + AittQoS qos) { return nullptr; } diff --git a/common/NullTransport.h b/src/NullTransport.h similarity index 95% rename from common/NullTransport.h rename to src/NullTransport.h index ab72c78..e283705 100644 --- a/common/NullTransport.h +++ b/src/NullTransport.h @@ -21,16 +21,16 @@ using AittDiscovery = aitt::AittDiscovery; class NullTransport : public AittTransport { public: - explicit NullTransport(const std::string &ip, AittDiscovery &discovery); + explicit NullTransport(AittDiscovery &discovery, const std::string &ip); virtual ~NullTransport(void) = default; void Publish(const std::string &topic, const void *data, const size_t datalen, - AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) override; - - void Publish(const std::string &topic, const void *data, const size_t datalen, const std::string &correlation, AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) override; + void Publish(const std::string &topic, const void *data, const size_t datalen, + AittQoS qos = AITT_QOS_AT_MOST_ONCE, bool retain = false) override; + void *Subscribe(const std::string &topic, const SubscribeCallback &cb, void *cbdata = nullptr, AittQoS qos = AITT_QOS_AT_MOST_ONCE) override; diff --git a/src/aitt_c.cc b/src/aitt_c.cc index 9e175f2..4ff6363 100644 --- a/src/aitt_c.cc +++ b/src/aitt_c.cc @@ -98,7 +98,7 @@ API aitt_option_h aitt_option_new() return handle; } -void aitt_option_destroy(aitt_option_h handle) +API void aitt_option_destroy(aitt_option_h handle) { if (handle == nullptr) { ERR("handle is NULL"); diff --git a/tests/AITT_TCP_test.cc b/tests/AITT_TCP_test.cc index d7a88d0..ed72da4 100644 --- a/tests/AITT_TCP_test.cc +++ b/tests/AITT_TCP_test.cc @@ -119,5 +119,5 @@ TEST_F(AITTTCPTest, TCP_Wildcards2_Anytime) TEST_F(AITTTCPTest, SECURE_TCP_Wildcards_Anytime) { - TCPWildcardsTopicTemplate(AITT_TYPE_SECURE_TCP); + TCPWildcardsTopicTemplate(AITT_TYPE_TCP_SECURE); } diff --git a/tests/AITT_test.cc b/tests/AITT_test.cc index 3918bc2..ec7b902 100644 --- a/tests/AITT_test.cc +++ b/tests/AITT_test.cc @@ -40,7 +40,10 @@ class AITTTest : public testing::Test, public AittTests { [](aitt::MSG *handle, const void *msg, const size_t szmsg, void *cbdata) -> void { AITTTest *test = static_cast(cbdata); test->ToggleReady(); - DBG("Subscribe invoked: %s %zu", static_cast(msg), szmsg); + if (msg) + DBG("Subscribe invoked: %s %zu", static_cast(msg), szmsg); + else + DBG("Subscribe invoked: zero size msg(%zu)", szmsg); }, static_cast(this), protocol); @@ -64,7 +67,7 @@ class AITTTest : public testing::Test, public AittTests { void PublishDisconnectTemplate(AittProtocol protocol) { const char character_set[] = - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; std::mt19937 random_gen{std::random_device{}()}; std::uniform_int_distribution gen(0, 61); @@ -81,7 +84,7 @@ class AITTTest : public testing::Test, public AittTests { int cnt = 0; aitt.Subscribe( - STRESS_TEST_TOPIC, + TEST_STRESS_TOPIC, [&](aitt::MSG *handle, const void *msg, const size_t szmsg, void *cbdata) -> void { AITTTest *test = static_cast(cbdata); @@ -90,9 +93,10 @@ class AITTTest : public testing::Test, public AittTests { FAIL() << "Unexpected value" << cnt; } - DBG("A subscription message is arrived. cnt = %d", cnt); - const char *receivedMsg = static_cast(msg); - ASSERT_TRUE(!strcmp(receivedMsg, dump_msg)); + if (msg) { + ASSERT_TRUE(!strncmp(static_cast(msg), dump_msg, + sizeof(dump_msg))); + } if (cnt == 10) test->ToggleReady(); @@ -110,14 +114,14 @@ class AITTTest : public testing::Test, public AittTests { for (int i = 0; i < 10; i++) { INFO("size = %zu", sizeof(dump_msg)); - aitt1.Publish(STRESS_TEST_TOPIC, dump_msg, sizeof(dump_msg), protocol, - AITT_QOS_AT_MOST_ONCE, true); + aitt1.Publish(TEST_STRESS_TOPIC, dump_msg, sizeof(dump_msg), protocol, + AITT_QOS_AT_MOST_ONCE); } g_timeout_add(10, AittTests::ReadyCheck, static_cast(this)); IterateEventLoop(); } - DBG("Client aitt1 is finished."); + DBG("Client aitt1 is finished"); // Here, an unexpected callback(szmsg = 0) is received // when the publisher is disconnected. @@ -125,8 +129,8 @@ class AITTTest : public testing::Test, public AittTests { ASSERT_TRUE(ready); ready = false; - aitt_retry.Publish(STRESS_TEST_TOPIC, dump_msg, sizeof(dump_msg), protocol, - AITT_QOS_AT_MOST_ONCE, true); + aitt_retry.Publish(TEST_STRESS_TOPIC, dump_msg, sizeof(dump_msg), protocol, + AITT_QOS_AT_MOST_ONCE); g_timeout_add(10, AittTests::ReadyCheck, static_cast(this)); @@ -134,7 +138,7 @@ class AITTTest : public testing::Test, public AittTests { ASSERT_TRUE(ready); - aitt_retry.Publish(STRESS_TEST_TOPIC, nullptr, 0, protocol, AITT_QOS_AT_LEAST_ONCE); + aitt_retry.Publish(TEST_STRESS_TOPIC, nullptr, 0, protocol, AITT_QOS_AT_LEAST_ONCE); // Check auto release of aitt. There should be no segmentation faults. } catch (std::exception &e) { FAIL() << "Unexpected exception: " << e.what(); @@ -202,10 +206,8 @@ class AITTTest : public testing::Test, public AittTests { void *cbdata) -> void { AITTTest *test = static_cast(cbdata); ++cnt; - if (cnt == 1) { - ASSERT_TRUE(msg == nullptr); + if (cnt == 1) test->ToggleReady(); - } DBG("Subscribe callback called: %d", cnt); }, static_cast(this), protocol); @@ -364,7 +366,7 @@ TEST_F(AITTTest, Positive_Publish_SECURE_TCP_Anytime) try { AITT aitt(clientId, LOCAL_IP, AittOption(true, false)); aitt.Connect(); - aitt.Publish(testTopic, TEST_MSG, sizeof(TEST_MSG), AITT_TYPE_SECURE_TCP); + aitt.Publish(testTopic, TEST_MSG, sizeof(TEST_MSG), AITT_TYPE_TCP_SECURE); } catch (std::exception &e) { FAIL() << "Unexpected exception: " << e.what(); } @@ -438,7 +440,7 @@ TEST_F(AITTTest, Positive_Unsubscribe_SECURE_TCP_Anytime) subscribeHandle = aitt.Subscribe( testTopic, [](aitt::MSG *handle, const void *msg, const size_t szmsg, void *cbdata) -> void {}, - nullptr, AITT_TYPE_SECURE_TCP); + nullptr, AITT_TYPE_TCP_SECURE); DBG("Subscribe handle: %p", reinterpret_cast(subscribeHandle)); aitt.Unsubscribe(subscribeHandle); } catch (std::exception &e) { @@ -547,7 +549,7 @@ TEST_F(AITTTest, Positve_PublishSubscribe_TCP_Anytime) TEST_F(AITTTest, Positve_PublishSubscribe_SECURE_TCP_Anytime) { - PubsubTemplate(TEST_MSG, AITT_TYPE_SECURE_TCP); + PubsubTemplate(TEST_MSG, AITT_TYPE_TCP_SECURE); } TEST_F(AITTTest, Positve_Publish_0_TCP_Anytime) @@ -557,7 +559,7 @@ TEST_F(AITTTest, Positve_Publish_0_TCP_Anytime) TEST_F(AITTTest, Positve_Publish_0_SECURE_TCP_Anytime) { - PubsubTemplate("", AITT_TYPE_SECURE_TCP); + PubsubTemplate("", AITT_TYPE_TCP_SECURE); } TEST_F(AITTTest, Positve_PublishSubscribe_Multiple_Protocols_Anytime) @@ -609,27 +611,27 @@ TEST_F(AITTTest, Positve_PublishSubscribe_TCP_twice_Anytime) TEST_F(AITTTest, Positve_PublishSubscribe_SECURE_TCP_twice_Anytime) { - PublishSubscribeTCPTwiceTemplate(AITT_TYPE_SECURE_TCP); + PublishSubscribeTCPTwiceTemplate(AITT_TYPE_TCP_SECURE); } -TEST_F(AITTTest, Positive_Subscribe_Retained_Anytime_TCP) +TEST_F(AITTTest, Positive_Subscribe_Retained_TCP_Anytime) { SubscribeRetainedTCPTemplate(AITT_TYPE_TCP); } -TEST_F(AITTTest, Positive_Subscribe_Retained_Anytime_SECURE_TCP) +TEST_F(AITTTest, Positive_Subscribe_Retained_SECURE_TCP_Anytime) { - SubscribeRetainedTCPTemplate(AITT_TYPE_SECURE_TCP); + SubscribeRetainedTCPTemplate(AITT_TYPE_TCP_SECURE); } -TEST_F(AITTTest, TCP_Publish_Disconnect_Anytime_TCP) +TEST_F(AITTTest, TCP_Publish_Disconnect_TCP_Anytime) { PublishDisconnectTemplate(AITT_TYPE_TCP); } -TEST_F(AITTTest, TCP_Publish_Disconnect_Anytime_SECURE_TCP) +TEST_F(AITTTest, TCP_Publish_Disconnect_SECURE_TCP_Anytime) { - PublishDisconnectTemplate(AITT_TYPE_SECURE_TCP); + PublishDisconnectTemplate(AITT_TYPE_TCP_SECURE); } TEST_F(AITTTest, WillSet_N_Anytime) diff --git a/tests/AittTests.h b/tests/AittTests.h index 4e7b349..10b4cf9 100644 --- a/tests/AittTests.h +++ b/tests/AittTests.h @@ -24,7 +24,7 @@ #define LOCAL_IP "127.0.0.1" #define TEST_C_TOPIC "test/topic_c" #define TEST_C_MSG "test123456789" -#define STRESS_TEST_TOPIC "test/stress1" +#define TEST_STRESS_TOPIC "test/stress1" #define TEST_MSG "This is aitt test message" #define TEST_MSG2 "This message is going to be delivered through a specified AittProtocol" diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e81fb68..43f4d40 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -46,7 +46,7 @@ ADD_TEST( ) ########################################################################### -ADD_EXECUTABLE(${AITT_UT}_module ModuleLoader_test.cc) +ADD_EXECUTABLE(${AITT_UT}_module ModuleLoader_test.cc $) TARGET_LINK_LIBRARIES(${AITT_UT}_module ${UT_NEEDS_LIBRARIES} ${AITT_NEEDS_LIBRARIES} ${CMAKE_DL_LIBS} ${AITT_COMMON}) TARGET_INCLUDE_DIRECTORIES(${AITT_UT}_module PRIVATE ../src) diff --git a/tests/ModuleLoader_test.cc b/tests/ModuleLoader_test.cc index 196b574..702b21c 100644 --- a/tests/ModuleLoader_test.cc +++ b/tests/ModuleLoader_test.cc @@ -13,65 +13,58 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "ModuleLoader.h" - #include #include #include "AittTests.h" #include "AittTransport.h" +#include "ModuleManager.h" #include "aitt_internal.h" -using ModuleLoader = aitt::ModuleLoader; +using ModuleManager = aitt::ModuleManager; class ModuleLoaderTest : public testing::Test { public: - ModuleLoaderTest(void) : discovery("test", AittOption(false, false)) {} + ModuleLoaderTest(void) : discovery("test_id"), modules(LOCAL_IP, discovery) {} protected: void SetUp() override {} void TearDown() override {} aitt::AittDiscovery discovery; - aitt::ModuleLoader loader; + aitt::ModuleManager modules; }; -TEST_F(ModuleLoaderTest, LoadTransport_P_Anytime) +TEST_F(ModuleLoaderTest, Get_P_Anytime) { - ModuleLoader::ModuleHandle handle = loader.OpenModule(ModuleLoader::TYPE_TCP); - ASSERT_NE(handle, nullptr); - - std::shared_ptr module = loader.LoadTransport( - handle.get(), loader.GetProtocol(ModuleLoader::TYPE_TCP), LOCAL_IP, discovery); - ASSERT_NE(module, nullptr); + aitt::AittTransport &tcp = modules.Get(AITT_TYPE_TCP); + EXPECT_TRUE(tcp.GetProtocol() == AITT_TYPE_TCP); + aitt::AittTransport &tcp_secure = modules.Get(AITT_TYPE_TCP_SECURE); + EXPECT_TRUE(tcp_secure.GetProtocol() == AITT_TYPE_TCP_SECURE); } - -TEST_F(ModuleLoaderTest, LoadTransport_N_Anytime) +TEST_F(ModuleLoaderTest, Get_N_Anytime) { - ModuleLoader::ModuleHandle handle = loader.OpenModule(ModuleLoader::TYPE_TRANSPORT_MAX); - ASSERT_EQ(handle.get(), nullptr); - - auto module = loader.LoadTransport( - handle.get(), loader.GetProtocol(ModuleLoader::TYPE_TRANSPORT_MAX), LOCAL_IP, discovery); - ASSERT_NE(module, nullptr); + EXPECT_THROW( + { + aitt::AittTransport &module = modules.Get(AITT_TYPE_MQTT); + FAIL() << "Should not be called" << module.GetProtocol(); + }, + aitt::AittException); } -TEST_F(ModuleLoaderTest, LoadMqttClient_P_Anytime) +TEST_F(ModuleLoaderTest, NewCustomMQ_P) { - ModuleLoader::ModuleHandle handle = loader.OpenModule(ModuleLoader::TYPE_CUSTOM_MQTT); - if (handle) { - EXPECT_NO_THROW({ - auto module = loader.LoadMqttClient(handle.get(), "test", AittOption(false, true)); - ASSERT_NE(module, nullptr); - }); - } + EXPECT_NO_THROW({ + std::unique_ptr mq = modules.NewCustomMQ("test", AittOption(false, true)); + mq->SetConnectionCallback([](int status) {}); + }); } -TEST_F(ModuleLoaderTest, LoadMqttClient_N_Anytime) +TEST_F(ModuleLoaderTest, NewCustomMQ_N_Anytime) { EXPECT_THROW( { - loader.LoadMqttClient(nullptr, "test", AittOption(false, true)); + modules.NewCustomMQ("test", AittOption(false, false)); FAIL() << "Should not be called"; }, aitt::AittException); diff --git a/tests/RequestResponse_test.cc b/tests/RequestResponse_test.cc index fc26049..21bcd0f 100644 --- a/tests/RequestResponse_test.cc +++ b/tests/RequestResponse_test.cc @@ -255,7 +255,7 @@ TEST_F(AITTRRTest, RequestResponse_asymmetry_Anytime) EXPECT_TRUE(sub_ok); EXPECT_TRUE(reply_ok); - } catch (std::exception &e) { + } catch (aitt::AittException &e) { FAIL() << e.what(); } }