)
SET(${TARGET_CSR_CLIENT}_SRCS
+ framework/client/async-logic.cpp
framework/client/content-screening.cpp
framework/client/engine-manager.cpp
framework/client/error.cpp
+ framework/client/handle.cpp
+ framework/client/handle-ext.cpp
framework/client/utils.cpp
framework/client/web-protection.cpp
)
common/audit/logger.cpp
common/binary-queue.cpp
common/connection.cpp
- common/cs-types.cpp
+ common/types.cpp
common/dispatcher.cpp
common/mainloop.cpp
common/service.cpp
common/socket.cpp
- common/wp-types.cpp
)
INCLUDE_DIRECTORIES(
--- /dev/null
+/*
+ * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+/*
+ * @file async-logic.cpp
+ * @author Kyungwook Tak (k.tak@samsung.com)
+ * @version 1.0
+ * @brief
+ */
+#include "client/async-logic.h"
+
+#include <utility>
+
+#include "common/audit/logger.h"
+
+namespace Csr {
+namespace Client {
+
+AsyncLogic::AsyncLogic(Context &context, const Callback &cb, void *userdata,
+ const std::function<bool()> &isStopped) :
+ m_origCtx(context),
+ m_ctx(context),
+ m_cb(cb),
+ m_userdata(userdata),
+ m_isStopped(isStopped),
+ m_dispatcher(new Dispatcher("/tmp/." SERVICE_NAME ".socket"))
+{
+}
+
+AsyncLogic::~AsyncLogic()
+{
+ DEBUG("AsyncLogic dtor. Results num in "
+ "mother ctx[" << m_origCtx.size() << "] "
+ "and here[" << m_ctx.size() << "]");
+
+ for (auto &resultPtr : m_ctx.m_results)
+ m_origCtx.add(std::move(resultPtr));
+
+ DEBUG("Integrated mother ctx results num: " << m_origCtx.size());
+}
+
+std::pair<Callback::Id, Task> AsyncLogic::scanDirs(const std::shared_ptr<StrSet>
+ &dirs)
+{
+ // TODO: canonicalize dirs. (e.g. Can omit subdirectory it there is
+ // parent directory in set)
+ std::pair<Callback::Id, Task> t(Callback::Id::OnCompleted, [this] {
+ if (m_cb.onCompleted)
+ m_cb.onCompleted(this->m_userdata);
+ });
+
+ for (const auto &dir : *dirs) {
+ t = scanDir(dir);
+
+ if (t.first != Callback::Id::OnCompleted)
+ return t;
+ }
+
+ return t;
+}
+
+std::pair<Callback::Id, Task> AsyncLogic::scanDir(const std::string &dir)
+{
+ // For in case of there's already detected malware for dir
+ auto retResults =
+ m_dispatcher->methodCall<std::pair<int, std::vector<Result *>>>(
+ CommandId::DIR_GET_RESULTS, m_ctx, dir);
+
+ if (retResults.first != CSR_ERROR_NONE) {
+ ERROR("[Error] ret: " << retResults.first);
+
+ for (auto r : retResults.second)
+ delete r;
+
+ auto ec = retResults.first;
+ return std::make_pair(Callback::Id::OnError, [this, ec] {
+ if (this->m_cb.onError)
+ this->m_cb.onError(this->m_userdata, ec);
+ });
+ }
+
+ // Register already detected malwares to context to be freed with context.
+ for (auto r : retResults.second) {
+ add(r);
+
+ if (m_cb.onDetected)
+ m_cb.onDetected(m_userdata, reinterpret_cast<csr_cs_detected_h>(r));
+ }
+
+ // Already scanned files are excluded according to history
+ auto retFiles = m_dispatcher->methodCall<std::pair<int, StrSet *>>(
+ CommandId::DIR_GET_FILES, m_ctx, dir);
+
+ if (retFiles.first != CSR_ERROR_NONE) {
+ ERROR("[Error] ret: " << retFiles.first);
+ delete retFiles.second;
+ auto ec = retFiles.first;
+ return std::make_pair(Callback::Id::OnError, [this, ec] {
+ if (this->m_cb.onError)
+ this->m_cb.onError(this->m_userdata, ec);
+ });
+ }
+
+ // Let's start scan files!
+ std::shared_ptr<StrSet> strSetPtr(retFiles.second);
+ auto task = scanFiles(strSetPtr);
+ // TODO: register results(in outs) to db and update dir scanning history...
+ return task;
+}
+
+std::pair<Callback::Id, Task> AsyncLogic::scanFiles(const
+ std::shared_ptr<StrSet> &fileSet)
+{
+ for (const auto &file : *fileSet) {
+ if (m_isStopped()) {
+ INFO("async operation cancelled!");
+ return std::make_pair(Callback::Id::OnCancelled, [this] {
+ if (this->m_cb.onCancelled)
+ this->m_cb.onCancelled(this->m_userdata);
+ });
+ }
+
+ auto ret = m_dispatcher->methodCall<std::pair<int, Result *>>(
+ CommandId::SCAN_FILE, m_ctx, file);
+
+ if (ret.first != CSR_ERROR_NONE) {
+ ERROR("[Error] ret: " << ret.first);
+ delete ret.second;
+ auto ec = ret.first;
+ return std::make_pair(Callback::Id::OnError, [this, ec] {
+ if (this->m_cb.onError)
+ this->m_cb.onError(this->m_userdata, ec);
+
+ return;
+ });
+ }
+
+ if (!ret.second->hasValue()) {
+ DEBUG("[Scanned] file[" << file << "]");
+ delete ret.second;
+
+ if (m_cb.onScanned)
+ m_cb.onScanned(m_userdata, file.c_str());
+
+ continue;
+ }
+
+ // malware detected!
+ INFO("[Detected] file[" << file << "]");
+ add(ret.second);
+
+ if (m_cb.onDetected)
+ m_cb.onDetected(m_userdata, reinterpret_cast<csr_cs_detected_h>(ret.second));
+ }
+
+ return std::make_pair(Callback::Id::OnCompleted, [this] {
+ DEBUG("[Completed]");
+
+ if (this->m_cb.onCompleted)
+ this->m_cb.onCompleted(this->m_userdata);
+ });
+}
+
+void AsyncLogic::add(Result *r)
+{
+ m_ctx.add(r);
+}
+
+}
+}
--- /dev/null
+/*
+ * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+/*
+ * @file async-logic.h
+ * @author Kyungwook Tak (k.tak@samsung.com)
+ * @version 1.0
+ * @brief
+ */
+#pragma once
+
+#include <memory>
+#include <atomic>
+
+#include "common/types.h"
+#include "common/dispatcher.h"
+#include "client/callback.h"
+
+namespace Csr {
+namespace Client {
+
+class AsyncLogic {
+public:
+ AsyncLogic(Context &context, const Callback &cb, void *userdata,
+ const std::function<bool()> &isStopped);
+ virtual ~AsyncLogic();
+
+ std::pair<Callback::Id, Task> scanFiles(const std::shared_ptr<StrSet> &files);
+ std::pair<Callback::Id, Task> scanDir(const std::string &dir);
+ std::pair<Callback::Id, Task> scanDirs(const std::shared_ptr<StrSet> &dirs);
+
+ void stop(void);
+
+private:
+ void add(Result *);
+
+ Context &m_origCtx; // for registering results for auto-release
+ Context m_ctx; // TODO: append it to handle context when destroyed
+ Callback m_cb;
+ void *m_userdata;
+ std::function<bool()> m_isStopped;
+
+ std::unique_ptr<Dispatcher> m_dispatcher;
+
+};
+
+}
+}
--- /dev/null
+/*
+ * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+/*
+ * @file callback.h
+ * @author Kyungwook Tak (k.tak@samsung.com)
+ * @version 1.0
+ * @brief CSR callback container for async operations
+ */
+#pragma once
+
+#include <utility>
+#include <stdexcept>
+
+#include "csr/content-screening-types.h"
+
+namespace Csr {
+
+struct Callback {
+ enum class Id : int {
+ OnCompleted,
+ OnCancelled,
+ OnError,
+ OnScanned,
+ OnDetected
+ };
+
+ Callback() {}
+
+ Callback(const Callback &other) :
+ onScanned(other.onScanned),
+ onDetected(other.onDetected),
+ onCompleted(other.onCompleted),
+ onCancelled(other.onCancelled),
+ onError(other.onError)
+ {
+ }
+
+ Callback &operator=(const Callback &other)
+ {
+ onScanned = other.onScanned;
+ onDetected = other.onDetected;
+ onCompleted = other.onCompleted;
+ onCancelled = other.onCancelled,
+ onError = other.onError;
+ return *this;
+ }
+
+ std::function<void(void *, const char *)> onScanned;
+ std::function<void(void *, csr_cs_detected_h)> onDetected;
+ std::function<void(void *)> onCompleted;
+ std::function<void(void *)> onCancelled;
+ std::function<void(void *, int)> onError;
+};
+
+}
#include <new>
#include "client/utils.h"
-#include "common/cs-types.h"
+#include "client/handle-ext.h"
+#include "client/async-logic.h"
#include "common/command-id.h"
#include "common/audit/logger.h"
if (phandle == nullptr)
return CSR_ERROR_INVALID_PARAMETER;
- *phandle = reinterpret_cast<csr_cs_context_h>(new Cs::Context());
+ *phandle = reinterpret_cast<csr_cs_context_h>(new Client::HandleExt());
return CSR_ERROR_NONE;
if (handle == nullptr)
return CSR_ERROR_INVALID_PARAMETER;
- delete reinterpret_cast<Cs::Context *>(handle);
+ delete reinterpret_cast<Client::HandleExt *>(handle);
return CSR_ERROR_NONE;
|| file_path == nullptr || file_path[0] == '\0')
return CSR_ERROR_INVALID_PARAMETER;
- auto context = reinterpret_cast<Cs::Context *>(handle);
- auto ret = context->dispatch<std::pair<int, Cs::Result *>>(
+ auto hExt = reinterpret_cast<Client::HandleExt *>(handle);
+ auto ret = hExt->dispatch<std::pair<int, Result *>>(
CommandId::SCAN_FILE,
- context,
- Client::toStlString(file_path));
+ hExt->getContext(),
+ std::string(file_path));
- if (ret.first != CSR_ERROR_NONE || ret.second == nullptr) {
+ if (ret.first != CSR_ERROR_NONE) {
ERROR("Error! ret: " << ret.first);
return ret.first;
}
+ hExt->add(ret.second);
*pdetected = reinterpret_cast<csr_cs_detected_h>(ret.second);
- context->addResult(ret.second);
return CSR_ERROR_NONE;
}
API
-int csr_cs_set_callback_on_detected(csr_cs_context_h handle, csr_cs_on_detected_cb callback)
+int csr_cs_set_callback_on_file_scanned(csr_cs_context_h handle, csr_cs_on_file_scanned_cb callback)
{
- (void) handle;
- (void) callback;
+ EXCEPTION_SAFE_START
+
+ auto hExt = reinterpret_cast<Client::HandleExt *>(handle);
+
+ if (hExt == nullptr)
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ hExt->m_cb.onScanned = callback;
- DEBUG("start!");
return CSR_ERROR_NONE;
+
+ EXCEPTION_SAFE_END
}
API
-int csr_cs_set_callback_on_completed(csr_cs_context_h handle, csr_cs_on_completed_cb callback)
+int csr_cs_set_callback_on_detected(csr_cs_context_h handle, csr_cs_on_detected_cb callback)
{
- (void) handle;
- (void) callback;
+ EXCEPTION_SAFE_START
+
+ auto hExt = reinterpret_cast<Client::HandleExt *>(handle);
+
+ if (hExt == nullptr)
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ hExt->m_cb.onDetected = callback;
- DEBUG("start!");
return CSR_ERROR_NONE;
+
+ EXCEPTION_SAFE_END
}
API
-int csr_cs_set_callback_on_cancelled(csr_cs_context_h handle, csr_cs_on_cancelled_cb callback)
+int csr_cs_set_callback_on_completed(csr_cs_context_h handle, csr_cs_on_completed_cb callback)
{
- (void) handle;
- (void) callback;
+ EXCEPTION_SAFE_START
+
+ auto hExt = reinterpret_cast<Client::HandleExt *>(handle);
+
+ if (hExt == nullptr)
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ hExt->m_cb.onCompleted = callback;
- DEBUG("start!");
return CSR_ERROR_NONE;
+
+ EXCEPTION_SAFE_END
}
API
-int csr_cs_set_callback_on_error(csr_cs_context_h handle, csr_cs_on_error_cb callback)
+int csr_cs_set_callback_on_cancelled(csr_cs_context_h handle, csr_cs_on_cancelled_cb callback)
{
- (void) handle;
- (void) callback;
+ EXCEPTION_SAFE_START
+
+ auto hExt = reinterpret_cast<Client::HandleExt *>(handle);
+
+ if (hExt == nullptr)
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ hExt->m_cb.onCancelled = callback;
- DEBUG("start!");
return CSR_ERROR_NONE;
+
+ EXCEPTION_SAFE_END
}
API
-int csr_cs_set_callback_on_file_scanned(csr_cs_context_h handle, csr_cs_on_file_scanned_cb callback)
+int csr_cs_set_callback_on_error(csr_cs_context_h handle, csr_cs_on_error_cb callback)
{
- (void) handle;
- (void) callback;
+ EXCEPTION_SAFE_START
+
+ auto hExt = reinterpret_cast<Client::HandleExt *>(handle);
+
+ if (hExt == nullptr)
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ hExt->m_cb.onError = callback;
- DEBUG("start!");
return CSR_ERROR_NONE;
+
+ EXCEPTION_SAFE_END
}
API
-int csr_cs_scan_files_async(csr_cs_context_h handle, const char **file_paths, unsigned int count, void *user_data)
+int csr_cs_scan_files_async(csr_cs_context_h handle, const char **file_paths, unsigned int count, void *user_data)
{
- (void) handle;
- (void) file_paths;
- (void) count;
- (void) user_data;
+ EXCEPTION_SAFE_START
+
+ if (handle == nullptr || file_paths == nullptr || count == 0)
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ auto hExt = reinterpret_cast<Client::HandleExt *>(handle);
+
+ auto fileSet(std::make_shared<StrSet>());
+ for (unsigned int i = 0; i < count; i++) {
+ if (file_paths[i] == nullptr)
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ fileSet->emplace(file_paths[i]);
+ }
+
+ hExt->dispatchAsync([hExt, user_data, fileSet] {
+ Client::AsyncLogic l(hExt->getContext(), hExt->m_cb, user_data,
+ [&hExt] { return hExt->isStopped(); });
+
+ l.scanFiles(fileSet).second();
+ });
- DEBUG("start!");
return CSR_ERROR_NONE;
+
+ EXCEPTION_SAFE_END
}
API
int csr_cs_scan_dir_async(csr_cs_context_h handle, const char *dir_path, void *user_data)
{
- (void) handle;
- (void) dir_path;
- (void) user_data;
+ EXCEPTION_SAFE_START
+
+ if (handle == nullptr || dir_path == nullptr || dir_path[0] == '\0')
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ auto hExt = reinterpret_cast<Client::HandleExt *>(handle);
+
+ hExt->dispatchAsync([hExt, user_data, dir_path] {
+ Client::AsyncLogic l(hExt->getContext(), hExt->m_cb, user_data,
+ [&hExt] { return hExt->isStopped(); });
+
+ l.scanDir(dir_path).second();
+ });
- DEBUG("start!");
return CSR_ERROR_NONE;
+
+ EXCEPTION_SAFE_END
}
API
-int csr_cs_scan_dirs_async(csr_cs_context_h handle, const char **file_paths, unsigned int count, void *user_data)
+int csr_cs_scan_dirs_async(csr_cs_context_h handle, const char **dir_paths, unsigned int count, void *user_data)
{
- (void) handle;
- (void) file_paths;
- (void) count;
- (void) user_data;
+ EXCEPTION_SAFE_START
+
+ if (handle == nullptr || dir_paths == nullptr || count == 0)
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ auto hExt = reinterpret_cast<Client::HandleExt *>(handle);
+
+ auto dirSet(std::make_shared<StrSet>());
+ for (unsigned int i = 0; i < count; i++) {
+ if (dir_paths[i] == nullptr)
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ dirSet->emplace(dir_paths[i]);
+ }
+
+ hExt->dispatchAsync([hExt, user_data, dirSet] {
+ Client::AsyncLogic l(hExt->getContext(), hExt->m_cb, user_data,
+ [&hExt] { return hExt->isStopped(); });
+
+ l.scanDirs(dirSet).second();
+ });
- DEBUG("start!");
return CSR_ERROR_NONE;
+
+ EXCEPTION_SAFE_END
}
API
int csr_cs_scan_cancel(csr_cs_context_h handle)
{
- (void) handle;
+ EXCEPTION_SAFE_START
+
+ if (handle == nullptr)
+ return CSR_ERROR_INVALID_PARAMETER;
+
+ auto hExt = reinterpret_cast<Client::HandleExt *>(handle);
+
+ if (hExt->isStopped())
+ return CSR_ERROR_NONE;
+
+ hExt->stop();
- DEBUG("start!");
return CSR_ERROR_NONE;
+
+ EXCEPTION_SAFE_END
}
API
--- /dev/null
+/*
+ * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+/*
+ * @file handle-ext.cpp
+ * @author Kyungwook Tak (k.tak@samsung.com)
+ * @version 1.0
+ * @brief handle with async request extension
+ */
+#include "client/handle-ext.h"
+
+#include <algorithm>
+
+#include "client/utils.h"
+#include "common/dispatcher.h"
+#include "common/audit/logger.h"
+
+namespace Csr {
+namespace Client {
+
+HandleExt::HandleExt() : m_stop(false)
+{
+}
+
+HandleExt::~HandleExt()
+{
+ DEBUG("Destroying extended handle... join all workers...");
+ eraseJoinableIf();
+}
+
+void HandleExt::stop()
+{
+ DEBUG("Stop & join all workers...");
+ m_stop = true;
+ eraseJoinableIf();
+}
+
+bool HandleExt::isStopped() const
+{
+ return m_stop.load();
+}
+
+void HandleExt::eraseJoinableIf(std::function<bool(const WorkerMapPair &)> pred)
+{
+ std::unique_lock<std::mutex> l(m_mutex);
+ DEBUG("clean joinable workers! current worker map size: " <<
+ m_workerMap.size());
+ auto it = m_workerMap.begin();
+
+ while (it != m_workerMap.end()) {
+ DEBUG("Worker map traversing to erase! current iter tid: " << it->first);
+
+ if (!it->second.t.joinable())
+ throw std::logic_error(FORMAT("All workers should be joinable "
+ "but it isn't. tid: " << it->first));
+
+ if (!pred(*it)) {
+ ++it;
+ continue;
+ }
+
+ DEBUG("Joining worker! tid:" << it->first);
+ l.unlock();
+ it->second.t.join(); // release lock for worker who calls done()
+ l.lock();
+ DEBUG("Joined worker! tid:" << it->first);
+ it = m_workerMap.erase(it);
+ }
+}
+
+void HandleExt::done()
+{
+ std::lock_guard<std::mutex> l(m_mutex);
+ auto it = m_workerMap.find(std::this_thread::get_id());
+
+ if (it == m_workerMap.end())
+ throw std::logic_error(FORMAT("worker done but it's not registered in map. "
+ "tid: " << std::this_thread::get_id()));
+
+ it->second.isDone = true;
+}
+
+void HandleExt::dispatchAsync(const Task &f)
+{
+ eraseJoinableIf([](const WorkerMapPair & pair) {
+ return pair.second.isDone.load();
+ });
+ // TODO: how to handle exceptions in workers
+ std::thread t([this, f] {
+ DEBUG("client async thread dispatched! tid: " << std::this_thread::get_id());
+
+ f();
+ done();
+
+ DEBUG("client async thread done! tid: " << std::this_thread::get_id());
+ });
+ {
+ std::lock_guard<std::mutex> l(m_mutex);
+ m_workerMap.emplace(t.get_id(), std::move(t));
+ }
+}
+
+HandleExt::Worker::Worker() : isDone(false)
+{
+ DEBUG("Worker default constructor called");
+}
+
+HandleExt::Worker::Worker(std::thread &&_t) :
+ isDone(false),
+ t(std::forward<std::thread>(_t))
+{
+}
+
+HandleExt::Worker::Worker(HandleExt::Worker &&other) :
+ isDone(other.isDone.load()),
+ t(std::move(other.t))
+{
+}
+
+HandleExt::Worker &HandleExt::Worker::operator=(HandleExt::Worker &&other)
+{
+ isDone = other.isDone.load();
+ t = std::move(other.t);
+ return *this;
+}
+
+} // namespace Client
+} // namespace Csr
--- /dev/null
+/*
+ * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+/*
+ * @file handle-ext.h
+ * @author Kyungwook Tak (k.tak@samsung.com)
+ * @version 1.0
+ * @brief handle with async request extension
+ */
+#pragma once
+
+#include "client/handle.h"
+#include "client/callback.h"
+#include "common/types.h"
+
+#include <mutex>
+#include <thread>
+#include <atomic>
+#include <string>
+#include <map>
+#include <utility>
+
+#include <set>
+#include <string>
+
+namespace Csr {
+namespace Client {
+
+class HandleExt : public Handle {
+public:
+ HandleExt();
+ virtual ~HandleExt();
+
+ void dispatchAsync(const Task &task);
+ void stop(void);
+ bool isStopped(void) const;
+
+ Callback m_cb; // TODO: to refine..
+
+private:
+ struct Worker {
+ std::atomic<bool> isDone;
+ std::thread t;
+
+ Worker();
+ Worker(const std::thread &_t) = delete; // to prevent thread instance copied
+ Worker(std::thread &&_t);
+ Worker(Worker &&other);
+ Worker &operator=(Worker &&other);
+ };
+
+ using WorkerMapPair = std::pair<const std::thread::id, Worker>;
+
+ void eraseJoinableIf(std::function<bool(const WorkerMapPair &)>
+ = [](const WorkerMapPair &)
+ {
+ return true;
+ });
+ void done(void);
+
+ std::atomic<bool> m_stop;
+ std::mutex m_mutex;
+ std::map<std::thread::id, Worker> m_workerMap;
+};
+
+} // namespace Client
+} // namespace Csr
* limitations under the License
*/
/*
- * @file cs-types.cpp
+ * @file handle.cpp
* @author Kyungwook Tak (k.tak@samsung.com)
* @version 1.0
- * @brief CSR Content Screening internal types
+ * @brief Client request handle with dispatcher in it
*/
-#include "common/cs-types.h"
+#include "client/handle.h"
#include <stdexcept>
namespace Csr {
-namespace Cs {
+namespace Client {
-Context::Context()
+Context &Handle::getContext() noexcept
{
+ return m_ctx;
}
-Context::~Context()
-{
-}
-
-Context::Context(IStream &)
-{
-}
-
-void Context::Serialize(IStream &) const
-{
-}
-
-Context::Context(Context &&other) :
- m_results(std::move(other.m_results))
-{
-}
-
-Context &Context::operator=(Context &&other)
-{
- if (this == &other)
- return *this;
-
- m_results = std::move(other.m_results);
-
- return *this;
-}
-
-void Context::addResult(Result *result)
+void Handle::add(Result *result)
{
if (result == nullptr)
throw std::logic_error("result shouldn't be null");
- m_results.emplace(result);
-}
-
-Result::Result()
-{
-}
-
-Result::~Result()
-{
-}
-
-Result::Result(IStream &)
-{
-}
-
-void Result::Serialize(IStream &) const
-{
-}
-
-Result::Result(Result &&)
-{
-}
-
-Result &Result::operator=(Result &&)
-{
- return *this;
+ m_ctx.add(result);
}
-} // namespace Cs
+} // namespace Client
} // namespace Csr
* limitations under the License
*/
/*
- * @file wp-types.h
+ * @file handle.h
* @author Kyungwook Tak (k.tak@samsung.com)
* @version 1.0
- * @brief CSR Web Protection internal types
+ * @brief Client request handle with dispatcher in it
*/
#pragma once
-#include <set>
-#include <memory>
#include <utility>
+#include "common/types.h"
#include "common/dispatcher.h"
-#include "common/serialization.h"
namespace Csr {
-namespace Wp {
+namespace Client {
-class Result : public ISerializable {
+class Handle {
public:
- Result();
- virtual ~Result();
- Result(IStream &);
- virtual void Serialize(IStream &) const;
-
- Result(Result &&);
- Result &operator=(Result &&);
-};
-
-class Context : public ISerializable {
-public:
- Context();
- virtual ~Context();
- Context(IStream &);
- virtual void Serialize(IStream &) const;
-
- template<typename Type, typename ...Args>
+ template<typename Type, typename ...Args>
Type dispatch(Args &&...);
- Context(Context &&);
- Context &operator=(Context &&);
+ void add(Result *);
- void addResult(Result *);
+ Context &getContext(void) noexcept;
private:
std::unique_ptr<Dispatcher> m_dispatcher;
- std::set<std::unique_ptr<Result>> m_results;
+ Context m_ctx;
};
template<typename Type, typename ...Args>
-Type Context::dispatch(Args &&...args)
+Type Handle::dispatch(Args &&...args)
{
if (m_dispatcher == nullptr)
m_dispatcher.reset(new Dispatcher("/tmp/." SERVICE_NAME ".socket"));
return m_dispatcher->methodCall<Type>(std::forward<Args>(args)...);
}
-} // namespace Wp
+} // namespace Client
} // namespace Csr
*/
#pragma once
-#include <string>
#include <functional>
-#include "common/audit/logger.h"
-
#define API __attribute__((visibility("default")))
#define EXCEPTION_SAFE_START return Csr::Client::exceptionGuard([&]()->int {
namespace Csr {
namespace Client {
-inline std::string toStlString(const char *cstr)
-{
- return (cstr == nullptr) ? std::string() : std::string(cstr);
-}
-
int exceptionGuard(const std::function<int()> &);
} // namespace Client
#include <new>
#include "client/utils.h"
-#include "common/wp-types.h"
+#include "client/handle.h"
+#include "common/types.h"
#include "common/command-id.h"
#include "common/audit/logger.h"
if (phandle == nullptr)
return CSR_ERROR_INVALID_PARAMETER;
- *phandle = reinterpret_cast<csr_wp_context_h>(new Wp::Context());
+ *phandle = reinterpret_cast<csr_wp_context_h>(new Client::Handle());
return CSR_ERROR_NONE;
if (handle == nullptr)
return CSR_ERROR_INVALID_PARAMETER;
- delete reinterpret_cast<Wp::Context *>(handle);
+ delete reinterpret_cast<Client::Handle *>(handle);
return CSR_ERROR_NONE;
|| url == nullptr || url[0] == '\0')
return CSR_ERROR_INVALID_PARAMETER;
- auto context = reinterpret_cast<Wp::Context *>(handle);
- auto ret = context->dispatch<std::pair<int, Wp::Result *>>(
+ auto h = reinterpret_cast<Client::Handle *>(handle);
+ auto ret = h->dispatch<std::pair<int, Result *>>(
CommandId::CHECK_URL,
- context,
- Client::toStlString(url));
+ h->getContext(),
+ std::string(url));
- if (ret.first != CSR_ERROR_NONE || ret.second == nullptr) {
+ if (ret.first != CSR_ERROR_NONE) {
ERROR("Error! ret: " << ret.first);
return ret.first;
}
+ h->add(ret.second);
*presult = reinterpret_cast<csr_wp_check_result_h>(ret.second);
- context->addResult(ret.second);
return CSR_ERROR_NONE;
namespace Csr {
enum class CommandId : int {
- SCAN_FILE = 0x01,
- JUDGE_STATUS = 0x02,
- CHECK_URL = 0x03
+ SCAN_FILE = 0x01,
+ JUDGE_STATUS = 0x02,
+ CHECK_URL = 0x03,
+ DIR_GET_RESULTS = 0x04,
+ DIR_GET_FILES = 0x05
};
}
#include <list>
#include <map>
#include <unordered_map>
+#include <set>
#include <memory>
#include "common/command-id.h"
Serialize(stream, *list);
}
+ template <typename T>
+ static void Serialize(IStream& stream, const std::set<T>& set)
+ {
+ auto len = set.size();
+ stream.write(sizeof(len), &len);
+ for (const auto &item : set)
+ Serialize(stream, item);
+ }
+ template <typename T>
+ static void Serialize(IStream& stream, const std::set<T>* const set)
+ {
+ Serialize(stream, *set);
+ }
+
// RawBuffer
template <typename A>
static void Serialize(IStream& stream, const std::vector<unsigned char, A>& vec)
Deserialize(stream, *list);
}
+ template <typename T>
+ static void Deserialize(IStream& stream, std::set<T>& set)
+ {
+ size_t len;
+ stream.read(sizeof(len), &len);
+ for (size_t i = 0; i < len; ++i) {
+ T obj;
+ Deserialize(stream, obj);
+ set.insert(std::move(obj));
+ }
+ }
+ template <typename T>
+ static void Deserialize(IStream& stream, std::set<T>*& set)
+ {
+ set = new std::set<T>;
+ Deserialize(stream, *set);
+ }
+
// RawBuffer
template <typename A>
static void Deserialize(IStream& stream, std::vector<unsigned char, A>& vec)
* limitations under the License
*/
/*
- * @file wp-types.cpp
+ * @file types.cpp
* @author Kyungwook Tak (k.tak@samsung.com)
* @version 1.0
- * @brief CSR Web Protection internal types
+ * @brief CSR internal serializable types
*/
-#include "common/wp-types.h"
+#include "common/types.h"
#include <stdexcept>
#include <utility>
+#include "common/audit/logger.h"
+
namespace Csr {
-namespace Wp {
Context::Context()
{
{
}
+// don't copy results.. context copy operation only should be used for option copy
+Context::Context(const Context &) :
+ ISerializable(),
+ m_results()
+{
+}
+
+Context &Context::operator=(const Context &)
+{
+ return *this;
+}
+
Context::Context(Context &&other) :
m_results(std::move(other.m_results))
{
return *this;
}
-void Context::addResult(Result *result)
+void Context::add(std::unique_ptr<Result> &&item)
{
- if (result == nullptr)
- throw std::logic_error("result shouldn't be null");
+ std::lock_guard<std::mutex> l(m_mutex);
+ m_results.emplace_back(std::forward<std::unique_ptr<Result>>(item));
+}
+
+void Context::add(Result *item)
+{
+ std::lock_guard<std::mutex> l(m_mutex);
+ m_results.emplace_back(item);
+}
- m_results.emplace(result);
+size_t Context::size() const
+{
+ std::lock_guard<std::mutex> l(m_mutex);
+ return m_results.size();
}
-Result::Result()
+Result::Result() : m_hasVal(false)
{
}
{
}
-Result::Result(IStream &)
+Result::Result(IStream &stream)
{
+ Deserializer<bool>::Deserialize(stream, m_hasVal);
}
-void Result::Serialize(IStream &) const
+void Result::Serialize(IStream &stream) const
{
+ Serializer<bool>::Serialize(stream, m_hasVal);
}
Result::Result(Result &&)
return *this;
}
-} // namespace Wp
+bool Result::hasValue() const
+{
+ return m_hasVal;
+}
+
} // namespace Csr
* limitations under the License
*/
/*
- * @file cs-types.h
+ * @file types.h
* @author Kyungwook Tak (k.tak@samsung.com)
* @version 1.0
- * @brief CSR Content Screening internal types
+ * @brief CSR internal serializable types
*/
#pragma once
-#include <set>
+#include <vector>
#include <memory>
-#include <utility>
+#include <mutex>
#include "common/dispatcher.h"
#include "common/serialization.h"
namespace Csr {
-namespace Cs {
+
+using Task = std::function<void()>;
+using StrSet = std::set<std::string>;
class Result : public ISerializable {
public:
Result(Result &&);
Result &operator=(Result &&);
+
+ bool hasValue(void) const;
+
+private:
+ bool m_hasVal;
};
class Context : public ISerializable {
Context(Context &&);
Context &operator=(Context &&);
- template<typename Type, typename ...Args>
- Type dispatch(Args &&...);
+ // TODO: Handling results vector between contexts should be refined..
+ // copy ctor/assignments for serializing and results vector isn't included here.
+ Context(const Context &);
+ Context &operator=(const Context &);
- void addResult(Result *);
+ void add(std::unique_ptr<Result> &&);
+ void add(Result *);
+ size_t size(void) const;
+ // for destroying with context
+ std::vector<std::unique_ptr<Result>> m_results;
private:
- std::unique_ptr<Dispatcher> m_dispatcher;
- std::set<std::unique_ptr<Result>> m_results;
+ mutable std::mutex m_mutex;
};
-template<typename Type, typename ...Args>
-Type Context::dispatch(Args &&...args)
-{
- if (m_dispatcher == nullptr)
- m_dispatcher.reset(new Dispatcher("/tmp/." SERVICE_NAME ".socket"));
-
- return m_dispatcher->methodCall<Type>(std::forward<Args>(args)...);
-}
-
-} // namespace Cs
} // namespace Csr
switch (info.first) {
case CommandId::SCAN_FILE: {
- Cs::Context context;
+ Context context;
std::string filepath;
info.second.Deserialize(context, filepath);
return scanFile(context, filepath);
/* TODO: should we separate command->logic mapping of CS and WP ? */
case CommandId::CHECK_URL: {
- Wp::Context context;
+ Context context;
std::string url;
info.second.Deserialize(context, url);
return checkUrl(context, url);
}
+ case CommandId::DIR_GET_RESULTS: {
+ Context context;
+ std::string dir;
+ info.second.Deserialize(context, dir);
+ return dirGetResults(context, dir);
+ }
+
+ case CommandId::DIR_GET_FILES: {
+ Context context;
+ std::string dir;
+ info.second.Deserialize(context, dir);
+ return dirGetFiles(context, dir);
+ }
+
default:
throw std::range_error(FORMAT("Command id[" << static_cast<int>(info.first)
<< "] isn't in range."));
return std::make_pair(id, std::move(q));
}
-RawBuffer Logic::scanFile(const Cs::Context &context, const std::string &filepath)
+RawBuffer Logic::scanFile(const Context &context, const std::string &filepath)
{
INFO("Scan file[" << filepath << "] by engine");
(void) context;
- return BinaryQueue::Serialize(CSR_ERROR_NONE, Cs::Result()).pop();
+ return BinaryQueue::Serialize(CSR_ERROR_NONE, Result()).pop();
}
-RawBuffer Logic::checkUrl(const Wp::Context &context, const std::string &url)
+RawBuffer Logic::checkUrl(const Context &context, const std::string &url)
{
INFO("Check url[" << url << "] by engine");
(void) context;
- return BinaryQueue::Serialize(CSR_ERROR_NONE, Wp::Result()).pop();
+ return BinaryQueue::Serialize(CSR_ERROR_NONE, Result()).pop();
+}
+
+RawBuffer Logic::dirGetResults(const Context &context, const std::string &dir)
+{
+ INFO("Dir[" << dir << "] get results");
+ (void) context;
+
+ return BinaryQueue::Serialize(CSR_ERROR_NONE, StrSet()).pop();
+}
+
+RawBuffer Logic::dirGetFiles(const Context &context, const std::string &dir)
+{
+ INFO("Dir[" << dir << "] get files");
+ (void) context;
+
+ return BinaryQueue::Serialize(CSR_ERROR_NONE, std::vector<Result>()).pop();
}
}
#include <string>
#include <utility>
-#include "common/cs-types.h"
-#include "common/wp-types.h"
+#include "common/types.h"
#include "common/command-id.h"
#include "common/raw-buffer.h"
#include "common/binary-queue.h"
private:
std::pair<CommandId, BinaryQueue> getRequestInfo(const RawBuffer &);
- RawBuffer scanFile(const Cs::Context &context, const std::string &filepath);
- RawBuffer checkUrl(const Wp::Context &context, const std::string &url);
+ RawBuffer scanFile(const Context &context, const std::string &filepath);
+ RawBuffer checkUrl(const Context &context, const std::string &url);
+ RawBuffer dirGetResults(const Context &context, const std::string &dir);
+ RawBuffer dirGetFiles(const Context &context, const std::string &dir);
};
}
#include <csr/content-screening.h>
#include <string>
+#include <memory>
+#include <new>
#include <iostream>
+#include <condition_variable>
+#include <thread>
+#include <mutex>
#include <boost/test/unit_test.hpp>
-BOOST_AUTO_TEST_SUITE(API_CONTENT_SCREENING)
+namespace {
-BOOST_AUTO_TEST_CASE(context_create_destroy)
+class ContextPtr {
+public:
+ ContextPtr() : m_context(nullptr) {}
+ ContextPtr(csr_cs_context_h context) : m_context(context) {}
+ virtual ~ContextPtr()
+ {
+ BOOST_REQUIRE(csr_cs_context_destroy(m_context) == CSR_ERROR_NONE);
+ }
+
+ inline csr_cs_context_h get(void)
+ {
+ return m_context;
+ }
+
+private:
+ csr_cs_context_h m_context;
+};
+
+using ScopedContext = std::unique_ptr<ContextPtr>;
+
+inline ScopedContext makeScopedContext(csr_cs_context_h context)
{
- csr_cs_context_h handle;
+ return ScopedContext(new ContextPtr(context));
+}
+
+inline ScopedContext getContextHandle(void)
+{
+ csr_cs_context_h context;
int ret = CSR_ERROR_UNKNOWN;
+ BOOST_REQUIRE_NO_THROW(ret = csr_cs_context_create(&context));
+ BOOST_REQUIRE_MESSAGE(ret == CSR_ERROR_NONE,
+ "Failed to create context handle. ret: " << ret);
+ BOOST_REQUIRE(context != nullptr);
+ return makeScopedContext(context);
+}
- BOOST_REQUIRE_NO_THROW(ret = csr_cs_context_create(&handle));
- BOOST_REQUIRE(ret == CSR_ERROR_NONE);
+}
- BOOST_REQUIRE_NO_THROW(ret = csr_cs_context_destroy(handle));
- BOOST_REQUIRE(ret == CSR_ERROR_NONE);
+BOOST_AUTO_TEST_SUITE(API_CONTENT_SCREENING)
+
+BOOST_AUTO_TEST_CASE(context_create_destroy)
+{
+ auto contextPtr = getContextHandle();
+ (void) contextPtr;
}
BOOST_AUTO_TEST_CASE(scan_file)
{
- csr_cs_context_h handle;
int ret = CSR_ERROR_UNKNOWN;
-
- BOOST_REQUIRE_NO_THROW(ret = csr_cs_context_create(&handle));
- BOOST_REQUIRE(ret == CSR_ERROR_NONE);
-
+ auto contextPtr = getContextHandle();
+ auto context = contextPtr->get();
csr_cs_detected_h detected;
- BOOST_REQUIRE_NO_THROW(ret = csr_cs_scan_file(handle, "dummy_file_path", &detected));
+ BOOST_REQUIRE_NO_THROW(ret = csr_cs_scan_file(context, "dummy_file_path",
+ &detected));
BOOST_REQUIRE(ret == CSR_ERROR_NONE);
+}
+
+struct AsyncTestContext {
+ std::mutex m;
+ std::condition_variable cv;
+ int scannedCnt;
+ int detectedCnt;
+ int completedCnt;
+ int cancelledCnt;
+ int errorCnt;
+
+ AsyncTestContext() :
+ scannedCnt(0),
+ detectedCnt(0),
+ completedCnt(0),
+ cancelledCnt(0),
+ errorCnt(0) {}
+};
- BOOST_REQUIRE_NO_THROW(ret = csr_cs_context_destroy(handle));
+void on_scanned(void *userdata, const char *file)
+{
+ BOOST_MESSAGE("on_scanned called. file[" << file << "] scanned!");
+ auto ctx = reinterpret_cast<AsyncTestContext *>(userdata);
+ ctx->scannedCnt++;
+}
+
+void on_error(void *userdata, int ec)
+{
+ BOOST_MESSAGE("on_error called. async request done with error code[" << ec <<
+ "]");
+ auto ctx = reinterpret_cast<AsyncTestContext *>(userdata);
+ ctx->errorCnt++;
+}
+
+void on_detected(void *userdata, csr_cs_detected_h detected)
+{
+ (void) detected;
+ BOOST_MESSAGE("on_detected called.");
+ auto ctx = reinterpret_cast<AsyncTestContext *>(userdata);
+ ctx->detectedCnt++;
+}
+
+void on_completed(void *userdata)
+{
+ BOOST_MESSAGE("on_completed called. async request completed succesfully.");
+ auto ctx = reinterpret_cast<AsyncTestContext *>(userdata);
+ ctx->completedCnt++;
+ ctx->cv.notify_one();
+}
+
+void on_cancelled(void *userdata)
+{
+ BOOST_MESSAGE("on_cancelled called. async request canceled!");
+ auto ctx = reinterpret_cast<AsyncTestContext *>(userdata);
+ ctx->cancelledCnt++;
+}
+
+BOOST_AUTO_TEST_CASE(scan_files_async)
+{
+ int ret = CSR_ERROR_UNKNOWN;
+ auto contextPtr = getContextHandle();
+ auto context = contextPtr->get();
+ BOOST_REQUIRE_NO_THROW(ret = csr_cs_set_callback_on_completed(context,
+ on_completed));
+ BOOST_REQUIRE(ret == CSR_ERROR_NONE);
+ BOOST_REQUIRE_NO_THROW(ret = csr_cs_set_callback_on_error(context, on_error));
+ BOOST_REQUIRE(ret == CSR_ERROR_NONE);
+ BOOST_REQUIRE_NO_THROW(ret = csr_cs_set_callback_on_cancelled(context,
+ on_cancelled));
+ BOOST_REQUIRE(ret == CSR_ERROR_NONE);
+ BOOST_REQUIRE_NO_THROW(ret = csr_cs_set_callback_on_detected(context,
+ on_detected));
+ BOOST_REQUIRE(ret == CSR_ERROR_NONE);
+ BOOST_REQUIRE_NO_THROW(ret = csr_cs_set_callback_on_file_scanned(context,
+ on_scanned));
+ BOOST_REQUIRE(ret == CSR_ERROR_NONE);
+ const char *files[3] = {
+ TEST_DIR "/test_malware_file",
+ TEST_DIR "/test_normal_file",
+ TEST_DIR "/test_risky_file"
+ };
+ AsyncTestContext testCtx;
+ BOOST_REQUIRE_NO_THROW(ret =
+ csr_cs_scan_files_async(context, files, sizeof(files) / sizeof(const char *),
+ &testCtx));
BOOST_REQUIRE(ret == CSR_ERROR_NONE);
+ std::unique_lock<std::mutex> l(testCtx.m);
+ testCtx.cv.wait(l);
+ l.unlock();
+ BOOST_REQUIRE_MESSAGE(testCtx.completedCnt == 1 && testCtx.scannedCnt == 3 &&
+ testCtx.detectedCnt == 0 && testCtx.cancelledCnt == 0 && testCtx.errorCnt == 0,
+ "Async request result isn't expected.");
}
BOOST_AUTO_TEST_SUITE_END()