From 9270226d08ebe4d256f4f93977d67227a83baae2 Mon Sep 17 00:00:00 2001 From: Kyungwook Tak Date: Tue, 5 Apr 2016 17:51:27 +0900 Subject: [PATCH] Async client stub initial commit scan dir / scan dirs needs db integration in server, credential check of client and filesystem related modules. For now it requests file lists in target directory based on history and credentials (client-removable, server-readable) Change-Id: Ia487c916f31e70cc54b1b52b72dd98c096264dd0 Signed-off-by: Kyungwook Tak --- src/CMakeLists.txt | 3 + src/framework/CMakeLists.txt | 3 +- src/framework/client/async-logic.cpp | 182 ++++++++++++++++++++ src/framework/client/async-logic.h | 60 +++++++ src/framework/client/callback.h | 68 ++++++++ src/framework/client/content-screening.cpp | 188 +++++++++++++++------ src/framework/client/handle-ext.cpp | 140 +++++++++++++++ src/framework/client/handle-ext.h | 79 +++++++++ .../{common/cs-types.cpp => client/handle.cpp} | 69 +------- .../{common/wp-types.h => client/handle.h} | 41 ++--- src/framework/client/utils.h | 8 - src/framework/client/web-protection.cpp | 19 ++- src/framework/common/command-id.h | 8 +- src/framework/common/serialization.h | 33 ++++ src/framework/common/{wp-types.cpp => types.cpp} | 53 ++++-- src/framework/common/{cs-types.h => types.h} | 42 ++--- src/framework/service/logic.cpp | 42 ++++- src/framework/service/logic.h | 9 +- test/test-api-content-screening.cpp | 154 +++++++++++++++-- 19 files changed, 986 insertions(+), 215 deletions(-) create mode 100644 src/framework/client/async-logic.cpp create mode 100644 src/framework/client/async-logic.h create mode 100644 src/framework/client/callback.h create mode 100644 src/framework/client/handle-ext.cpp create mode 100644 src/framework/client/handle-ext.h rename src/framework/{common/cs-types.cpp => client/handle.cpp} (53%) rename src/framework/{common/wp-types.h => client/handle.h} (61%) rename src/framework/common/{wp-types.cpp => types.cpp} (56%) rename src/framework/common/{cs-types.h => types.h} (65%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 256d27c..9ccf324 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -71,9 +71,12 @@ PKG_CHECK_MODULES(${TARGET_CSR_CLIENT}_DEP ) SET(${TARGET_CSR_CLIENT}_SRCS + framework/client/async-logic.cpp framework/client/content-screening.cpp framework/client/engine-manager.cpp framework/client/error.cpp + framework/client/handle.cpp + framework/client/handle-ext.cpp framework/client/utils.cpp framework/client/web-protection.cpp ) diff --git a/src/framework/CMakeLists.txt b/src/framework/CMakeLists.txt index 6b1dddf..3c9662d 100644 --- a/src/framework/CMakeLists.txt +++ b/src/framework/CMakeLists.txt @@ -28,12 +28,11 @@ SET(${TARGET_CSR_COMMON}_SRCS common/audit/logger.cpp common/binary-queue.cpp common/connection.cpp - common/cs-types.cpp + common/types.cpp common/dispatcher.cpp common/mainloop.cpp common/service.cpp common/socket.cpp - common/wp-types.cpp ) INCLUDE_DIRECTORIES( diff --git a/src/framework/client/async-logic.cpp b/src/framework/client/async-logic.cpp new file mode 100644 index 0000000..3002106 --- /dev/null +++ b/src/framework/client/async-logic.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ +/* + * @file async-logic.cpp + * @author Kyungwook Tak (k.tak@samsung.com) + * @version 1.0 + * @brief + */ +#include "client/async-logic.h" + +#include + +#include "common/audit/logger.h" + +namespace Csr { +namespace Client { + +AsyncLogic::AsyncLogic(Context &context, const Callback &cb, void *userdata, + const std::function &isStopped) : + m_origCtx(context), + m_ctx(context), + m_cb(cb), + m_userdata(userdata), + m_isStopped(isStopped), + m_dispatcher(new Dispatcher("/tmp/." SERVICE_NAME ".socket")) +{ +} + +AsyncLogic::~AsyncLogic() +{ + DEBUG("AsyncLogic dtor. Results num in " + "mother ctx[" << m_origCtx.size() << "] " + "and here[" << m_ctx.size() << "]"); + + for (auto &resultPtr : m_ctx.m_results) + m_origCtx.add(std::move(resultPtr)); + + DEBUG("Integrated mother ctx results num: " << m_origCtx.size()); +} + +std::pair AsyncLogic::scanDirs(const std::shared_ptr + &dirs) +{ + // TODO: canonicalize dirs. (e.g. Can omit subdirectory it there is + // parent directory in set) + std::pair t(Callback::Id::OnCompleted, [this] { + if (m_cb.onCompleted) + m_cb.onCompleted(this->m_userdata); + }); + + for (const auto &dir : *dirs) { + t = scanDir(dir); + + if (t.first != Callback::Id::OnCompleted) + return t; + } + + return t; +} + +std::pair AsyncLogic::scanDir(const std::string &dir) +{ + // For in case of there's already detected malware for dir + auto retResults = + m_dispatcher->methodCall>>( + CommandId::DIR_GET_RESULTS, m_ctx, dir); + + if (retResults.first != CSR_ERROR_NONE) { + ERROR("[Error] ret: " << retResults.first); + + for (auto r : retResults.second) + delete r; + + auto ec = retResults.first; + return std::make_pair(Callback::Id::OnError, [this, ec] { + if (this->m_cb.onError) + this->m_cb.onError(this->m_userdata, ec); + }); + } + + // Register already detected malwares to context to be freed with context. + for (auto r : retResults.second) { + add(r); + + if (m_cb.onDetected) + m_cb.onDetected(m_userdata, reinterpret_cast(r)); + } + + // Already scanned files are excluded according to history + auto retFiles = m_dispatcher->methodCall>( + CommandId::DIR_GET_FILES, m_ctx, dir); + + if (retFiles.first != CSR_ERROR_NONE) { + ERROR("[Error] ret: " << retFiles.first); + delete retFiles.second; + auto ec = retFiles.first; + return std::make_pair(Callback::Id::OnError, [this, ec] { + if (this->m_cb.onError) + this->m_cb.onError(this->m_userdata, ec); + }); + } + + // Let's start scan files! + std::shared_ptr strSetPtr(retFiles.second); + auto task = scanFiles(strSetPtr); + // TODO: register results(in outs) to db and update dir scanning history... + return task; +} + +std::pair AsyncLogic::scanFiles(const + std::shared_ptr &fileSet) +{ + for (const auto &file : *fileSet) { + if (m_isStopped()) { + INFO("async operation cancelled!"); + return std::make_pair(Callback::Id::OnCancelled, [this] { + if (this->m_cb.onCancelled) + this->m_cb.onCancelled(this->m_userdata); + }); + } + + auto ret = m_dispatcher->methodCall>( + CommandId::SCAN_FILE, m_ctx, file); + + if (ret.first != CSR_ERROR_NONE) { + ERROR("[Error] ret: " << ret.first); + delete ret.second; + auto ec = ret.first; + return std::make_pair(Callback::Id::OnError, [this, ec] { + if (this->m_cb.onError) + this->m_cb.onError(this->m_userdata, ec); + + return; + }); + } + + if (!ret.second->hasValue()) { + DEBUG("[Scanned] file[" << file << "]"); + delete ret.second; + + if (m_cb.onScanned) + m_cb.onScanned(m_userdata, file.c_str()); + + continue; + } + + // malware detected! + INFO("[Detected] file[" << file << "]"); + add(ret.second); + + if (m_cb.onDetected) + m_cb.onDetected(m_userdata, reinterpret_cast(ret.second)); + } + + return std::make_pair(Callback::Id::OnCompleted, [this] { + DEBUG("[Completed]"); + + if (this->m_cb.onCompleted) + this->m_cb.onCompleted(this->m_userdata); + }); +} + +void AsyncLogic::add(Result *r) +{ + m_ctx.add(r); +} + +} +} diff --git a/src/framework/client/async-logic.h b/src/framework/client/async-logic.h new file mode 100644 index 0000000..714c4f5 --- /dev/null +++ b/src/framework/client/async-logic.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ +/* + * @file async-logic.h + * @author Kyungwook Tak (k.tak@samsung.com) + * @version 1.0 + * @brief + */ +#pragma once + +#include +#include + +#include "common/types.h" +#include "common/dispatcher.h" +#include "client/callback.h" + +namespace Csr { +namespace Client { + +class AsyncLogic { +public: + AsyncLogic(Context &context, const Callback &cb, void *userdata, + const std::function &isStopped); + virtual ~AsyncLogic(); + + std::pair scanFiles(const std::shared_ptr &files); + std::pair scanDir(const std::string &dir); + std::pair scanDirs(const std::shared_ptr &dirs); + + void stop(void); + +private: + void add(Result *); + + Context &m_origCtx; // for registering results for auto-release + Context m_ctx; // TODO: append it to handle context when destroyed + Callback m_cb; + void *m_userdata; + std::function m_isStopped; + + std::unique_ptr m_dispatcher; + +}; + +} +} diff --git a/src/framework/client/callback.h b/src/framework/client/callback.h new file mode 100644 index 0000000..c695c8d --- /dev/null +++ b/src/framework/client/callback.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ +/* + * @file callback.h + * @author Kyungwook Tak (k.tak@samsung.com) + * @version 1.0 + * @brief CSR callback container for async operations + */ +#pragma once + +#include +#include + +#include "csr/content-screening-types.h" + +namespace Csr { + +struct Callback { + enum class Id : int { + OnCompleted, + OnCancelled, + OnError, + OnScanned, + OnDetected + }; + + Callback() {} + + Callback(const Callback &other) : + onScanned(other.onScanned), + onDetected(other.onDetected), + onCompleted(other.onCompleted), + onCancelled(other.onCancelled), + onError(other.onError) + { + } + + Callback &operator=(const Callback &other) + { + onScanned = other.onScanned; + onDetected = other.onDetected; + onCompleted = other.onCompleted; + onCancelled = other.onCancelled, + onError = other.onError; + return *this; + } + + std::function onScanned; + std::function onDetected; + std::function onCompleted; + std::function onCancelled; + std::function onError; +}; + +} diff --git a/src/framework/client/content-screening.cpp b/src/framework/client/content-screening.cpp index 10a3a22..802c82b 100644 --- a/src/framework/client/content-screening.cpp +++ b/src/framework/client/content-screening.cpp @@ -24,7 +24,8 @@ #include #include "client/utils.h" -#include "common/cs-types.h" +#include "client/handle-ext.h" +#include "client/async-logic.h" #include "common/command-id.h" #include "common/audit/logger.h" @@ -38,7 +39,7 @@ int csr_cs_context_create(csr_cs_context_h* phandle) if (phandle == nullptr) return CSR_ERROR_INVALID_PARAMETER; - *phandle = reinterpret_cast(new Cs::Context()); + *phandle = reinterpret_cast(new Client::HandleExt()); return CSR_ERROR_NONE; @@ -53,7 +54,7 @@ int csr_cs_context_destroy(csr_cs_context_h handle) if (handle == nullptr) return CSR_ERROR_INVALID_PARAMETER; - delete reinterpret_cast(handle); + delete reinterpret_cast(handle); return CSR_ERROR_NONE; @@ -120,19 +121,19 @@ int csr_cs_scan_file(csr_cs_context_h handle, const char *file_path, csr_cs_dete || file_path == nullptr || file_path[0] == '\0') return CSR_ERROR_INVALID_PARAMETER; - auto context = reinterpret_cast(handle); - auto ret = context->dispatch>( + auto hExt = reinterpret_cast(handle); + auto ret = hExt->dispatch>( CommandId::SCAN_FILE, - context, - Client::toStlString(file_path)); + hExt->getContext(), + std::string(file_path)); - if (ret.first != CSR_ERROR_NONE || ret.second == nullptr) { + if (ret.first != CSR_ERROR_NONE) { ERROR("Error! ret: " << ret.first); return ret.first; } + hExt->add(ret.second); *pdetected = reinterpret_cast(ret.second); - context->addResult(ret.second); return CSR_ERROR_NONE; @@ -140,97 +141,190 @@ int csr_cs_scan_file(csr_cs_context_h handle, const char *file_path, csr_cs_dete } API -int csr_cs_set_callback_on_detected(csr_cs_context_h handle, csr_cs_on_detected_cb callback) +int csr_cs_set_callback_on_file_scanned(csr_cs_context_h handle, csr_cs_on_file_scanned_cb callback) { - (void) handle; - (void) callback; + EXCEPTION_SAFE_START + + auto hExt = reinterpret_cast(handle); + + if (hExt == nullptr) + return CSR_ERROR_INVALID_PARAMETER; + + hExt->m_cb.onScanned = callback; - DEBUG("start!"); return CSR_ERROR_NONE; + + EXCEPTION_SAFE_END } API -int csr_cs_set_callback_on_completed(csr_cs_context_h handle, csr_cs_on_completed_cb callback) +int csr_cs_set_callback_on_detected(csr_cs_context_h handle, csr_cs_on_detected_cb callback) { - (void) handle; - (void) callback; + EXCEPTION_SAFE_START + + auto hExt = reinterpret_cast(handle); + + if (hExt == nullptr) + return CSR_ERROR_INVALID_PARAMETER; + + hExt->m_cb.onDetected = callback; - DEBUG("start!"); return CSR_ERROR_NONE; + + EXCEPTION_SAFE_END } API -int csr_cs_set_callback_on_cancelled(csr_cs_context_h handle, csr_cs_on_cancelled_cb callback) +int csr_cs_set_callback_on_completed(csr_cs_context_h handle, csr_cs_on_completed_cb callback) { - (void) handle; - (void) callback; + EXCEPTION_SAFE_START + + auto hExt = reinterpret_cast(handle); + + if (hExt == nullptr) + return CSR_ERROR_INVALID_PARAMETER; + + hExt->m_cb.onCompleted = callback; - DEBUG("start!"); return CSR_ERROR_NONE; + + EXCEPTION_SAFE_END } API -int csr_cs_set_callback_on_error(csr_cs_context_h handle, csr_cs_on_error_cb callback) +int csr_cs_set_callback_on_cancelled(csr_cs_context_h handle, csr_cs_on_cancelled_cb callback) { - (void) handle; - (void) callback; + EXCEPTION_SAFE_START + + auto hExt = reinterpret_cast(handle); + + if (hExt == nullptr) + return CSR_ERROR_INVALID_PARAMETER; + + hExt->m_cb.onCancelled = callback; - DEBUG("start!"); return CSR_ERROR_NONE; + + EXCEPTION_SAFE_END } API -int csr_cs_set_callback_on_file_scanned(csr_cs_context_h handle, csr_cs_on_file_scanned_cb callback) +int csr_cs_set_callback_on_error(csr_cs_context_h handle, csr_cs_on_error_cb callback) { - (void) handle; - (void) callback; + EXCEPTION_SAFE_START + + auto hExt = reinterpret_cast(handle); + + if (hExt == nullptr) + return CSR_ERROR_INVALID_PARAMETER; + + hExt->m_cb.onError = callback; - DEBUG("start!"); return CSR_ERROR_NONE; + + EXCEPTION_SAFE_END } API -int csr_cs_scan_files_async(csr_cs_context_h handle, const char **file_paths, unsigned int count, void *user_data) +int csr_cs_scan_files_async(csr_cs_context_h handle, const char **file_paths, unsigned int count, void *user_data) { - (void) handle; - (void) file_paths; - (void) count; - (void) user_data; + EXCEPTION_SAFE_START + + if (handle == nullptr || file_paths == nullptr || count == 0) + return CSR_ERROR_INVALID_PARAMETER; + + auto hExt = reinterpret_cast(handle); + + auto fileSet(std::make_shared()); + for (unsigned int i = 0; i < count; i++) { + if (file_paths[i] == nullptr) + return CSR_ERROR_INVALID_PARAMETER; + + fileSet->emplace(file_paths[i]); + } + + hExt->dispatchAsync([hExt, user_data, fileSet] { + Client::AsyncLogic l(hExt->getContext(), hExt->m_cb, user_data, + [&hExt] { return hExt->isStopped(); }); + + l.scanFiles(fileSet).second(); + }); - DEBUG("start!"); return CSR_ERROR_NONE; + + EXCEPTION_SAFE_END } API int csr_cs_scan_dir_async(csr_cs_context_h handle, const char *dir_path, void *user_data) { - (void) handle; - (void) dir_path; - (void) user_data; + EXCEPTION_SAFE_START + + if (handle == nullptr || dir_path == nullptr || dir_path[0] == '\0') + return CSR_ERROR_INVALID_PARAMETER; + + auto hExt = reinterpret_cast(handle); + + hExt->dispatchAsync([hExt, user_data, dir_path] { + Client::AsyncLogic l(hExt->getContext(), hExt->m_cb, user_data, + [&hExt] { return hExt->isStopped(); }); + + l.scanDir(dir_path).second(); + }); - DEBUG("start!"); return CSR_ERROR_NONE; + + EXCEPTION_SAFE_END } API -int csr_cs_scan_dirs_async(csr_cs_context_h handle, const char **file_paths, unsigned int count, void *user_data) +int csr_cs_scan_dirs_async(csr_cs_context_h handle, const char **dir_paths, unsigned int count, void *user_data) { - (void) handle; - (void) file_paths; - (void) count; - (void) user_data; + EXCEPTION_SAFE_START + + if (handle == nullptr || dir_paths == nullptr || count == 0) + return CSR_ERROR_INVALID_PARAMETER; + + auto hExt = reinterpret_cast(handle); + + auto dirSet(std::make_shared()); + for (unsigned int i = 0; i < count; i++) { + if (dir_paths[i] == nullptr) + return CSR_ERROR_INVALID_PARAMETER; + + dirSet->emplace(dir_paths[i]); + } + + hExt->dispatchAsync([hExt, user_data, dirSet] { + Client::AsyncLogic l(hExt->getContext(), hExt->m_cb, user_data, + [&hExt] { return hExt->isStopped(); }); + + l.scanDirs(dirSet).second(); + }); - DEBUG("start!"); return CSR_ERROR_NONE; + + EXCEPTION_SAFE_END } API int csr_cs_scan_cancel(csr_cs_context_h handle) { - (void) handle; + EXCEPTION_SAFE_START + + if (handle == nullptr) + return CSR_ERROR_INVALID_PARAMETER; + + auto hExt = reinterpret_cast(handle); + + if (hExt->isStopped()) + return CSR_ERROR_NONE; + + hExt->stop(); - DEBUG("start!"); return CSR_ERROR_NONE; + + EXCEPTION_SAFE_END } API diff --git a/src/framework/client/handle-ext.cpp b/src/framework/client/handle-ext.cpp new file mode 100644 index 0000000..c32d173 --- /dev/null +++ b/src/framework/client/handle-ext.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ +/* + * @file handle-ext.cpp + * @author Kyungwook Tak (k.tak@samsung.com) + * @version 1.0 + * @brief handle with async request extension + */ +#include "client/handle-ext.h" + +#include + +#include "client/utils.h" +#include "common/dispatcher.h" +#include "common/audit/logger.h" + +namespace Csr { +namespace Client { + +HandleExt::HandleExt() : m_stop(false) +{ +} + +HandleExt::~HandleExt() +{ + DEBUG("Destroying extended handle... join all workers..."); + eraseJoinableIf(); +} + +void HandleExt::stop() +{ + DEBUG("Stop & join all workers..."); + m_stop = true; + eraseJoinableIf(); +} + +bool HandleExt::isStopped() const +{ + return m_stop.load(); +} + +void HandleExt::eraseJoinableIf(std::function pred) +{ + std::unique_lock l(m_mutex); + DEBUG("clean joinable workers! current worker map size: " << + m_workerMap.size()); + auto it = m_workerMap.begin(); + + while (it != m_workerMap.end()) { + DEBUG("Worker map traversing to erase! current iter tid: " << it->first); + + if (!it->second.t.joinable()) + throw std::logic_error(FORMAT("All workers should be joinable " + "but it isn't. tid: " << it->first)); + + if (!pred(*it)) { + ++it; + continue; + } + + DEBUG("Joining worker! tid:" << it->first); + l.unlock(); + it->second.t.join(); // release lock for worker who calls done() + l.lock(); + DEBUG("Joined worker! tid:" << it->first); + it = m_workerMap.erase(it); + } +} + +void HandleExt::done() +{ + std::lock_guard l(m_mutex); + auto it = m_workerMap.find(std::this_thread::get_id()); + + if (it == m_workerMap.end()) + throw std::logic_error(FORMAT("worker done but it's not registered in map. " + "tid: " << std::this_thread::get_id())); + + it->second.isDone = true; +} + +void HandleExt::dispatchAsync(const Task &f) +{ + eraseJoinableIf([](const WorkerMapPair & pair) { + return pair.second.isDone.load(); + }); + // TODO: how to handle exceptions in workers + std::thread t([this, f] { + DEBUG("client async thread dispatched! tid: " << std::this_thread::get_id()); + + f(); + done(); + + DEBUG("client async thread done! tid: " << std::this_thread::get_id()); + }); + { + std::lock_guard l(m_mutex); + m_workerMap.emplace(t.get_id(), std::move(t)); + } +} + +HandleExt::Worker::Worker() : isDone(false) +{ + DEBUG("Worker default constructor called"); +} + +HandleExt::Worker::Worker(std::thread &&_t) : + isDone(false), + t(std::forward(_t)) +{ +} + +HandleExt::Worker::Worker(HandleExt::Worker &&other) : + isDone(other.isDone.load()), + t(std::move(other.t)) +{ +} + +HandleExt::Worker &HandleExt::Worker::operator=(HandleExt::Worker &&other) +{ + isDone = other.isDone.load(); + t = std::move(other.t); + return *this; +} + +} // namespace Client +} // namespace Csr diff --git a/src/framework/client/handle-ext.h b/src/framework/client/handle-ext.h new file mode 100644 index 0000000..fc154c6 --- /dev/null +++ b/src/framework/client/handle-ext.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ +/* + * @file handle-ext.h + * @author Kyungwook Tak (k.tak@samsung.com) + * @version 1.0 + * @brief handle with async request extension + */ +#pragma once + +#include "client/handle.h" +#include "client/callback.h" +#include "common/types.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace Csr { +namespace Client { + +class HandleExt : public Handle { +public: + HandleExt(); + virtual ~HandleExt(); + + void dispatchAsync(const Task &task); + void stop(void); + bool isStopped(void) const; + + Callback m_cb; // TODO: to refine.. + +private: + struct Worker { + std::atomic isDone; + std::thread t; + + Worker(); + Worker(const std::thread &_t) = delete; // to prevent thread instance copied + Worker(std::thread &&_t); + Worker(Worker &&other); + Worker &operator=(Worker &&other); + }; + + using WorkerMapPair = std::pair; + + void eraseJoinableIf(std::function + = [](const WorkerMapPair &) + { + return true; + }); + void done(void); + + std::atomic m_stop; + std::mutex m_mutex; + std::map m_workerMap; +}; + +} // namespace Client +} // namespace Csr diff --git a/src/framework/common/cs-types.cpp b/src/framework/client/handle.cpp similarity index 53% rename from src/framework/common/cs-types.cpp rename to src/framework/client/handle.cpp index 21e255f..a8d3c17 100644 --- a/src/framework/common/cs-types.cpp +++ b/src/framework/client/handle.cpp @@ -14,81 +14,30 @@ * limitations under the License */ /* - * @file cs-types.cpp + * @file handle.cpp * @author Kyungwook Tak (k.tak@samsung.com) * @version 1.0 - * @brief CSR Content Screening internal types + * @brief Client request handle with dispatcher in it */ -#include "common/cs-types.h" +#include "client/handle.h" #include namespace Csr { -namespace Cs { +namespace Client { -Context::Context() +Context &Handle::getContext() noexcept { + return m_ctx; } -Context::~Context() -{ -} - -Context::Context(IStream &) -{ -} - -void Context::Serialize(IStream &) const -{ -} - -Context::Context(Context &&other) : - m_results(std::move(other.m_results)) -{ -} - -Context &Context::operator=(Context &&other) -{ - if (this == &other) - return *this; - - m_results = std::move(other.m_results); - - return *this; -} - -void Context::addResult(Result *result) +void Handle::add(Result *result) { if (result == nullptr) throw std::logic_error("result shouldn't be null"); - m_results.emplace(result); -} - -Result::Result() -{ -} - -Result::~Result() -{ -} - -Result::Result(IStream &) -{ -} - -void Result::Serialize(IStream &) const -{ -} - -Result::Result(Result &&) -{ -} - -Result &Result::operator=(Result &&) -{ - return *this; + m_ctx.add(result); } -} // namespace Cs +} // namespace Client } // namespace Csr diff --git a/src/framework/common/wp-types.h b/src/framework/client/handle.h similarity index 61% rename from src/framework/common/wp-types.h rename to src/framework/client/handle.h index 91ea0b0..a0705ea 100644 --- a/src/framework/common/wp-types.h +++ b/src/framework/client/handle.h @@ -14,56 +14,37 @@ * limitations under the License */ /* - * @file wp-types.h + * @file handle.h * @author Kyungwook Tak (k.tak@samsung.com) * @version 1.0 - * @brief CSR Web Protection internal types + * @brief Client request handle with dispatcher in it */ #pragma once -#include -#include #include +#include "common/types.h" #include "common/dispatcher.h" -#include "common/serialization.h" namespace Csr { -namespace Wp { +namespace Client { -class Result : public ISerializable { +class Handle { public: - Result(); - virtual ~Result(); - Result(IStream &); - virtual void Serialize(IStream &) const; - - Result(Result &&); - Result &operator=(Result &&); -}; - -class Context : public ISerializable { -public: - Context(); - virtual ~Context(); - Context(IStream &); - virtual void Serialize(IStream &) const; - - template + template Type dispatch(Args &&...); - Context(Context &&); - Context &operator=(Context &&); + void add(Result *); - void addResult(Result *); + Context &getContext(void) noexcept; private: std::unique_ptr m_dispatcher; - std::set> m_results; + Context m_ctx; }; template -Type Context::dispatch(Args &&...args) +Type Handle::dispatch(Args &&...args) { if (m_dispatcher == nullptr) m_dispatcher.reset(new Dispatcher("/tmp/." SERVICE_NAME ".socket")); @@ -71,5 +52,5 @@ Type Context::dispatch(Args &&...args) return m_dispatcher->methodCall(std::forward(args)...); } -} // namespace Wp +} // namespace Client } // namespace Csr diff --git a/src/framework/client/utils.h b/src/framework/client/utils.h index 6a4da8a..8d71711 100644 --- a/src/framework/client/utils.h +++ b/src/framework/client/utils.h @@ -21,11 +21,8 @@ */ #pragma once -#include #include -#include "common/audit/logger.h" - #define API __attribute__((visibility("default"))) #define EXCEPTION_SAFE_START return Csr::Client::exceptionGuard([&]()->int { @@ -34,11 +31,6 @@ namespace Csr { namespace Client { -inline std::string toStlString(const char *cstr) -{ - return (cstr == nullptr) ? std::string() : std::string(cstr); -} - int exceptionGuard(const std::function &); } // namespace Client diff --git a/src/framework/client/web-protection.cpp b/src/framework/client/web-protection.cpp index dc0cc71..e533168 100644 --- a/src/framework/client/web-protection.cpp +++ b/src/framework/client/web-protection.cpp @@ -24,7 +24,8 @@ #include #include "client/utils.h" -#include "common/wp-types.h" +#include "client/handle.h" +#include "common/types.h" #include "common/command-id.h" #include "common/audit/logger.h" @@ -38,7 +39,7 @@ int csr_wp_context_create(csr_wp_context_h* phandle) if (phandle == nullptr) return CSR_ERROR_INVALID_PARAMETER; - *phandle = reinterpret_cast(new Wp::Context()); + *phandle = reinterpret_cast(new Client::Handle()); return CSR_ERROR_NONE; @@ -53,7 +54,7 @@ int csr_wp_context_destroy(csr_wp_context_h handle) if (handle == nullptr) return CSR_ERROR_INVALID_PARAMETER; - delete reinterpret_cast(handle); + delete reinterpret_cast(handle); return CSR_ERROR_NONE; @@ -89,19 +90,19 @@ int csr_wp_check_url(csr_wp_context_h handle, const char *url, csr_wp_check_resu || url == nullptr || url[0] == '\0') return CSR_ERROR_INVALID_PARAMETER; - auto context = reinterpret_cast(handle); - auto ret = context->dispatch>( + auto h = reinterpret_cast(handle); + auto ret = h->dispatch>( CommandId::CHECK_URL, - context, - Client::toStlString(url)); + h->getContext(), + std::string(url)); - if (ret.first != CSR_ERROR_NONE || ret.second == nullptr) { + if (ret.first != CSR_ERROR_NONE) { ERROR("Error! ret: " << ret.first); return ret.first; } + h->add(ret.second); *presult = reinterpret_cast(ret.second); - context->addResult(ret.second); return CSR_ERROR_NONE; diff --git a/src/framework/common/command-id.h b/src/framework/common/command-id.h index 2ad184f..ea720df 100644 --- a/src/framework/common/command-id.h +++ b/src/framework/common/command-id.h @@ -24,9 +24,11 @@ namespace Csr { enum class CommandId : int { - SCAN_FILE = 0x01, - JUDGE_STATUS = 0x02, - CHECK_URL = 0x03 + SCAN_FILE = 0x01, + JUDGE_STATUS = 0x02, + CHECK_URL = 0x03, + DIR_GET_RESULTS = 0x04, + DIR_GET_FILES = 0x05 }; } diff --git a/src/framework/common/serialization.h b/src/framework/common/serialization.h index d06d62b..22c66ae 100644 --- a/src/framework/common/serialization.h +++ b/src/framework/common/serialization.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include "common/command-id.h" @@ -181,6 +182,20 @@ struct Serialization { Serialize(stream, *list); } + template + static void Serialize(IStream& stream, const std::set& set) + { + auto len = set.size(); + stream.write(sizeof(len), &len); + for (const auto &item : set) + Serialize(stream, item); + } + template + static void Serialize(IStream& stream, const std::set* const set) + { + Serialize(stream, *set); + } + // RawBuffer template static void Serialize(IStream& stream, const std::vector& vec) @@ -415,6 +430,24 @@ struct Deserialization { Deserialize(stream, *list); } + template + static void Deserialize(IStream& stream, std::set& set) + { + size_t len; + stream.read(sizeof(len), &len); + for (size_t i = 0; i < len; ++i) { + T obj; + Deserialize(stream, obj); + set.insert(std::move(obj)); + } + } + template + static void Deserialize(IStream& stream, std::set*& set) + { + set = new std::set; + Deserialize(stream, *set); + } + // RawBuffer template static void Deserialize(IStream& stream, std::vector& vec) diff --git a/src/framework/common/wp-types.cpp b/src/framework/common/types.cpp similarity index 56% rename from src/framework/common/wp-types.cpp rename to src/framework/common/types.cpp index e0ff970..ee0f4db 100644 --- a/src/framework/common/wp-types.cpp +++ b/src/framework/common/types.cpp @@ -14,18 +14,19 @@ * limitations under the License */ /* - * @file wp-types.cpp + * @file types.cpp * @author Kyungwook Tak (k.tak@samsung.com) * @version 1.0 - * @brief CSR Web Protection internal types + * @brief CSR internal serializable types */ -#include "common/wp-types.h" +#include "common/types.h" #include #include +#include "common/audit/logger.h" + namespace Csr { -namespace Wp { Context::Context() { @@ -43,6 +44,18 @@ void Context::Serialize(IStream &) const { } +// don't copy results.. context copy operation only should be used for option copy +Context::Context(const Context &) : + ISerializable(), + m_results() +{ +} + +Context &Context::operator=(const Context &) +{ + return *this; +} + Context::Context(Context &&other) : m_results(std::move(other.m_results)) { @@ -58,15 +71,25 @@ Context &Context::operator=(Context &&other) return *this; } -void Context::addResult(Result *result) +void Context::add(std::unique_ptr &&item) { - if (result == nullptr) - throw std::logic_error("result shouldn't be null"); + std::lock_guard l(m_mutex); + m_results.emplace_back(std::forward>(item)); +} + +void Context::add(Result *item) +{ + std::lock_guard l(m_mutex); + m_results.emplace_back(item); +} - m_results.emplace(result); +size_t Context::size() const +{ + std::lock_guard l(m_mutex); + return m_results.size(); } -Result::Result() +Result::Result() : m_hasVal(false) { } @@ -74,12 +97,14 @@ Result::~Result() { } -Result::Result(IStream &) +Result::Result(IStream &stream) { + Deserializer::Deserialize(stream, m_hasVal); } -void Result::Serialize(IStream &) const +void Result::Serialize(IStream &stream) const { + Serializer::Serialize(stream, m_hasVal); } Result::Result(Result &&) @@ -91,5 +116,9 @@ Result &Result::operator=(Result &&) return *this; } -} // namespace Wp +bool Result::hasValue() const +{ + return m_hasVal; +} + } // namespace Csr diff --git a/src/framework/common/cs-types.h b/src/framework/common/types.h similarity index 65% rename from src/framework/common/cs-types.h rename to src/framework/common/types.h index 53e50c9..d6a48b1 100644 --- a/src/framework/common/cs-types.h +++ b/src/framework/common/types.h @@ -14,22 +14,24 @@ * limitations under the License */ /* - * @file cs-types.h + * @file types.h * @author Kyungwook Tak (k.tak@samsung.com) * @version 1.0 - * @brief CSR Content Screening internal types + * @brief CSR internal serializable types */ #pragma once -#include +#include #include -#include +#include #include "common/dispatcher.h" #include "common/serialization.h" namespace Csr { -namespace Cs { + +using Task = std::function; +using StrSet = std::set; class Result : public ISerializable { public: @@ -40,6 +42,11 @@ public: Result(Result &&); Result &operator=(Result &&); + + bool hasValue(void) const; + +private: + bool m_hasVal; }; class Context : public ISerializable { @@ -52,24 +59,19 @@ public: Context(Context &&); Context &operator=(Context &&); - template - Type dispatch(Args &&...); + // TODO: Handling results vector between contexts should be refined.. + // copy ctor/assignments for serializing and results vector isn't included here. + Context(const Context &); + Context &operator=(const Context &); - void addResult(Result *); + void add(std::unique_ptr &&); + void add(Result *); + size_t size(void) const; + // for destroying with context + std::vector> m_results; private: - std::unique_ptr m_dispatcher; - std::set> m_results; + mutable std::mutex m_mutex; }; -template -Type Context::dispatch(Args &&...args) -{ - if (m_dispatcher == nullptr) - m_dispatcher.reset(new Dispatcher("/tmp/." SERVICE_NAME ".socket")); - - return m_dispatcher->methodCall(std::forward(args)...); -} - -} // namespace Cs } // namespace Csr diff --git a/src/framework/service/logic.cpp b/src/framework/service/logic.cpp index b4b6a1d..6073f58 100644 --- a/src/framework/service/logic.cpp +++ b/src/framework/service/logic.cpp @@ -46,7 +46,7 @@ RawBuffer Logic::dispatch(const RawBuffer &in) switch (info.first) { case CommandId::SCAN_FILE: { - Cs::Context context; + Context context; std::string filepath; info.second.Deserialize(context, filepath); return scanFile(context, filepath); @@ -54,12 +54,26 @@ RawBuffer Logic::dispatch(const RawBuffer &in) /* TODO: should we separate command->logic mapping of CS and WP ? */ case CommandId::CHECK_URL: { - Wp::Context context; + Context context; std::string url; info.second.Deserialize(context, url); return checkUrl(context, url); } + case CommandId::DIR_GET_RESULTS: { + Context context; + std::string dir; + info.second.Deserialize(context, dir); + return dirGetResults(context, dir); + } + + case CommandId::DIR_GET_FILES: { + Context context; + std::string dir; + info.second.Deserialize(context, dir); + return dirGetFiles(context, dir); + } + default: throw std::range_error(FORMAT("Command id[" << static_cast(info.first) << "] isn't in range.")); @@ -77,20 +91,36 @@ std::pair Logic::getRequestInfo(const RawBuffer &data) return std::make_pair(id, std::move(q)); } -RawBuffer Logic::scanFile(const Cs::Context &context, const std::string &filepath) +RawBuffer Logic::scanFile(const Context &context, const std::string &filepath) { INFO("Scan file[" << filepath << "] by engine"); (void) context; - return BinaryQueue::Serialize(CSR_ERROR_NONE, Cs::Result()).pop(); + return BinaryQueue::Serialize(CSR_ERROR_NONE, Result()).pop(); } -RawBuffer Logic::checkUrl(const Wp::Context &context, const std::string &url) +RawBuffer Logic::checkUrl(const Context &context, const std::string &url) { INFO("Check url[" << url << "] by engine"); (void) context; - return BinaryQueue::Serialize(CSR_ERROR_NONE, Wp::Result()).pop(); + return BinaryQueue::Serialize(CSR_ERROR_NONE, Result()).pop(); +} + +RawBuffer Logic::dirGetResults(const Context &context, const std::string &dir) +{ + INFO("Dir[" << dir << "] get results"); + (void) context; + + return BinaryQueue::Serialize(CSR_ERROR_NONE, StrSet()).pop(); +} + +RawBuffer Logic::dirGetFiles(const Context &context, const std::string &dir) +{ + INFO("Dir[" << dir << "] get files"); + (void) context; + + return BinaryQueue::Serialize(CSR_ERROR_NONE, std::vector()).pop(); } } diff --git a/src/framework/service/logic.h b/src/framework/service/logic.h index 22bff3a..3575b9d 100644 --- a/src/framework/service/logic.h +++ b/src/framework/service/logic.h @@ -24,8 +24,7 @@ #include #include -#include "common/cs-types.h" -#include "common/wp-types.h" +#include "common/types.h" #include "common/command-id.h" #include "common/raw-buffer.h" #include "common/binary-queue.h" @@ -42,8 +41,10 @@ public: private: std::pair getRequestInfo(const RawBuffer &); - RawBuffer scanFile(const Cs::Context &context, const std::string &filepath); - RawBuffer checkUrl(const Wp::Context &context, const std::string &url); + RawBuffer scanFile(const Context &context, const std::string &filepath); + RawBuffer checkUrl(const Context &context, const std::string &url); + RawBuffer dirGetResults(const Context &context, const std::string &dir); + RawBuffer dirGetFiles(const Context &context, const std::string &dir); }; } diff --git a/test/test-api-content-screening.cpp b/test/test-api-content-screening.cpp index 5c81ac5..5c89b09 100644 --- a/test/test-api-content-screening.cpp +++ b/test/test-api-content-screening.cpp @@ -23,37 +23,163 @@ #include #include +#include +#include #include +#include +#include +#include #include -BOOST_AUTO_TEST_SUITE(API_CONTENT_SCREENING) +namespace { -BOOST_AUTO_TEST_CASE(context_create_destroy) +class ContextPtr { +public: + ContextPtr() : m_context(nullptr) {} + ContextPtr(csr_cs_context_h context) : m_context(context) {} + virtual ~ContextPtr() + { + BOOST_REQUIRE(csr_cs_context_destroy(m_context) == CSR_ERROR_NONE); + } + + inline csr_cs_context_h get(void) + { + return m_context; + } + +private: + csr_cs_context_h m_context; +}; + +using ScopedContext = std::unique_ptr; + +inline ScopedContext makeScopedContext(csr_cs_context_h context) { - csr_cs_context_h handle; + return ScopedContext(new ContextPtr(context)); +} + +inline ScopedContext getContextHandle(void) +{ + csr_cs_context_h context; int ret = CSR_ERROR_UNKNOWN; + BOOST_REQUIRE_NO_THROW(ret = csr_cs_context_create(&context)); + BOOST_REQUIRE_MESSAGE(ret == CSR_ERROR_NONE, + "Failed to create context handle. ret: " << ret); + BOOST_REQUIRE(context != nullptr); + return makeScopedContext(context); +} - BOOST_REQUIRE_NO_THROW(ret = csr_cs_context_create(&handle)); - BOOST_REQUIRE(ret == CSR_ERROR_NONE); +} - BOOST_REQUIRE_NO_THROW(ret = csr_cs_context_destroy(handle)); - BOOST_REQUIRE(ret == CSR_ERROR_NONE); +BOOST_AUTO_TEST_SUITE(API_CONTENT_SCREENING) + +BOOST_AUTO_TEST_CASE(context_create_destroy) +{ + auto contextPtr = getContextHandle(); + (void) contextPtr; } BOOST_AUTO_TEST_CASE(scan_file) { - csr_cs_context_h handle; int ret = CSR_ERROR_UNKNOWN; - - BOOST_REQUIRE_NO_THROW(ret = csr_cs_context_create(&handle)); - BOOST_REQUIRE(ret == CSR_ERROR_NONE); - + auto contextPtr = getContextHandle(); + auto context = contextPtr->get(); csr_cs_detected_h detected; - BOOST_REQUIRE_NO_THROW(ret = csr_cs_scan_file(handle, "dummy_file_path", &detected)); + BOOST_REQUIRE_NO_THROW(ret = csr_cs_scan_file(context, "dummy_file_path", + &detected)); BOOST_REQUIRE(ret == CSR_ERROR_NONE); +} + +struct AsyncTestContext { + std::mutex m; + std::condition_variable cv; + int scannedCnt; + int detectedCnt; + int completedCnt; + int cancelledCnt; + int errorCnt; + + AsyncTestContext() : + scannedCnt(0), + detectedCnt(0), + completedCnt(0), + cancelledCnt(0), + errorCnt(0) {} +}; - BOOST_REQUIRE_NO_THROW(ret = csr_cs_context_destroy(handle)); +void on_scanned(void *userdata, const char *file) +{ + BOOST_MESSAGE("on_scanned called. file[" << file << "] scanned!"); + auto ctx = reinterpret_cast(userdata); + ctx->scannedCnt++; +} + +void on_error(void *userdata, int ec) +{ + BOOST_MESSAGE("on_error called. async request done with error code[" << ec << + "]"); + auto ctx = reinterpret_cast(userdata); + ctx->errorCnt++; +} + +void on_detected(void *userdata, csr_cs_detected_h detected) +{ + (void) detected; + BOOST_MESSAGE("on_detected called."); + auto ctx = reinterpret_cast(userdata); + ctx->detectedCnt++; +} + +void on_completed(void *userdata) +{ + BOOST_MESSAGE("on_completed called. async request completed succesfully."); + auto ctx = reinterpret_cast(userdata); + ctx->completedCnt++; + ctx->cv.notify_one(); +} + +void on_cancelled(void *userdata) +{ + BOOST_MESSAGE("on_cancelled called. async request canceled!"); + auto ctx = reinterpret_cast(userdata); + ctx->cancelledCnt++; +} + +BOOST_AUTO_TEST_CASE(scan_files_async) +{ + int ret = CSR_ERROR_UNKNOWN; + auto contextPtr = getContextHandle(); + auto context = contextPtr->get(); + BOOST_REQUIRE_NO_THROW(ret = csr_cs_set_callback_on_completed(context, + on_completed)); + BOOST_REQUIRE(ret == CSR_ERROR_NONE); + BOOST_REQUIRE_NO_THROW(ret = csr_cs_set_callback_on_error(context, on_error)); + BOOST_REQUIRE(ret == CSR_ERROR_NONE); + BOOST_REQUIRE_NO_THROW(ret = csr_cs_set_callback_on_cancelled(context, + on_cancelled)); + BOOST_REQUIRE(ret == CSR_ERROR_NONE); + BOOST_REQUIRE_NO_THROW(ret = csr_cs_set_callback_on_detected(context, + on_detected)); + BOOST_REQUIRE(ret == CSR_ERROR_NONE); + BOOST_REQUIRE_NO_THROW(ret = csr_cs_set_callback_on_file_scanned(context, + on_scanned)); + BOOST_REQUIRE(ret == CSR_ERROR_NONE); + const char *files[3] = { + TEST_DIR "/test_malware_file", + TEST_DIR "/test_normal_file", + TEST_DIR "/test_risky_file" + }; + AsyncTestContext testCtx; + BOOST_REQUIRE_NO_THROW(ret = + csr_cs_scan_files_async(context, files, sizeof(files) / sizeof(const char *), + &testCtx)); BOOST_REQUIRE(ret == CSR_ERROR_NONE); + std::unique_lock l(testCtx.m); + testCtx.cv.wait(l); + l.unlock(); + BOOST_REQUIRE_MESSAGE(testCtx.completedCnt == 1 && testCtx.scannedCnt == 3 && + testCtx.detectedCnt == 0 && testCtx.cancelledCnt == 0 && testCtx.errorCnt == 0, + "Async request result isn't expected."); } BOOST_AUTO_TEST_SUITE_END() -- 2.7.4