From fdc3f903339f60dd960da6c0a6175aa236fb4ffd Mon Sep 17 00:00:00 2001 From: Kyungwook Tak Date: Thu, 28 Apr 2016 14:53:25 +0900 Subject: [PATCH] Row inherits CsDetected struct remove unused functions in db namespace remove modified_time field in db schema modified_time isn't needed because we can determine whether file is modified by comparing current time and detected_time Change-Id: I95825bfbde8ada0649522088bc9a67b8874d32f0 Signed-off-by: Kyungwook Tak --- data/scripts/create_schema.sql | 25 +++-- src/framework/db/manager.cpp | 216 +++++++++++++++++++------------------- src/framework/db/manager.h | 54 +++------- src/framework/db/query.h | 23 +++-- src/framework/db/row.h | 44 ++++++++ src/framework/db/statement.cpp | 223 ++++++++++++++-------------------------- src/framework/db/statement.h | 61 +++++------ test/test-internal-database.cpp | 132 +++++++++++------------- 8 files changed, 355 insertions(+), 423 deletions(-) create mode 100644 src/framework/db/row.h diff --git a/data/scripts/create_schema.sql b/data/scripts/create_schema.sql index f867f3e..1a5b70c 100644 --- a/data/scripts/create_schema.sql +++ b/data/scripts/create_schema.sql @@ -1,20 +1,19 @@ -CREATE TABLE IF NOT EXISTS SCHEMA_INFO(name TEXT PRIMARY KEY NOT NULL, +CREATE TABLE IF NOT EXISTS SCHEMA_INFO(name TEXT PRIMARY KEY NOT NULL, value TEXT); -CREATE TABLE IF NOT EXISTS ENGINE_STATE(engine_id INTEGER PRIMARY KEY, - state INTEGER NOT NULL); +CREATE TABLE IF NOT EXISTS ENGINE_STATE(id INTEGER PRIMARY KEY, + state INTEGER NOT NULL); -CREATE TABLE IF NOT EXISTS SCAN_REQUEST(dir TEXT PRIMARY KEY, - last_scan INTEGER NOT NULL, +CREATE TABLE IF NOT EXISTS SCAN_REQUEST(dir TEXT PRIMARY KEY, + last_scan INTEGER NOT NULL, data_version TEXT NOT NULL); -CREATE TABLE IF NOT EXISTS DETECTED_MALWARE_FILE(path TEXT PRIMARY KEY NOT NULL, - data_version TEXT NOT NULL, - severity_level INTEGER NOT NULL, - threat_type INTEGER NOT NULL, - malware_name TEXT NOT NULL, - detailed_url TEXT, +CREATE TABLE IF NOT EXISTS DETECTED_MALWARE_FILE(path TEXT PRIMARY KEY NOT NULL, + data_version TEXT NOT NULL, + severity INTEGER NOT NULL, + threat INTEGER NOT NULL, + malware_name TEXT NOT NULL, + detailed_url TEXT NOT NULL, detected_time INTEGER NOT NULL, - modified_time INTEGER NOT NULL, - ignored INTEGER NOT NULL); + ignored INTEGER NOT NULL); diff --git a/src/framework/db/manager.cpp b/src/framework/db/manager.cpp index d70e59b..548950a 100644 --- a/src/framework/db/manager.cpp +++ b/src/framework/db/manager.cpp @@ -44,7 +44,8 @@ const std::string SCRIPT_CREATE_SCHEMA = "create_schema"; const std::string SCRIPT_DROP_ALL_ITEMS = "drop_all"; const std::string SCRIPT_MIGRATE = "migrate_"; -const std::string DB_VERSION_STR = "DB_VERSION"; +const std::string DB_VERSION_STR = "DB_VERSION"; +const std::string SCHEMA_INFO_TABLE = "SCHEMA_INFO"; } // namespace anonymous @@ -53,14 +54,6 @@ Manager::Manager(const std::string &dbfile, const std::string &scriptsDir) : Connection::Serialized), m_scriptsDir(scriptsDir) { - initDatabase(); - m_conn.exec("VACUUM;"); -} - -Manager::~Manager() {} - -void Manager::initDatabase() -{ // run migration if old database is present auto sv = getSchemaVersion(); @@ -85,6 +78,12 @@ void Manager::initDatabase() setSchemaVersion(SchemaVersion::LATEST); } + + m_conn.exec("VACUUM;"); +} + +Manager::~Manager() +{ } void Manager::resetDatabase() @@ -116,31 +115,39 @@ std::string Manager::getScript(const std::string &scriptName) return str; } -int Manager::getSchemaVersion() +bool Manager::isTableExist(const std::string &name) { - try { - Statement stmt(m_conn, Query::SEL_SCHEMA_INFO); + Statement stmt(m_conn, Query::CHK_TABLE); - int idx = 0; - stmt.bind(++idx, DB_VERSION_STR); + stmt.bind(name); - if (!stmt.step()) // Tables don't exist yet - return SchemaVersion::NOT_EXIST; + return stmt.step(); +} - return stmt.getInt(0); - } catch (const std::runtime_error &e) { - INFO("Database doesn't exist."); - return SchemaVersion::NOT_EXIST; // table not exist! +int Manager::getSchemaVersion() +{ + if (!isTableExist(SCHEMA_INFO_TABLE)) { + WARN("Schema table doesn't exist. This case would be the first time of " + "db manager instantiated in target"); + return SchemaVersion::NOT_EXIST; } + + Statement stmt(m_conn, Query::SEL_SCHEMA_INFO); + + stmt.bind(DB_VERSION_STR); + + if (!stmt.step()) + throw std::logic_error(FORMAT("schema info table should exist!")); + + return stmt.getInt(); } void Manager::setSchemaVersion(int sv) { Statement stmt(m_conn, Query::INS_SCHEMA_INFO); - int idx = 0; - stmt.bind(++idx, DB_VERSION_STR); - stmt.bind(++idx, sv); + stmt.bind(DB_VERSION_STR); + stmt.bind(sv); if (stmt.exec() == 0) throw std::runtime_error(FORMAT("Failed to set schema version!")); @@ -154,13 +161,9 @@ int Manager::getEngineState(int engineId) noexcept try { Statement stmt(m_conn, Query::SEL_ENGINE_STATE); - int idx = 0; - stmt.bind(++idx, engineId); - - if (!stmt.step()) - return -1; + stmt.bind(engineId); - return stmt.getInt(0); + return stmt.step() ? stmt.getInt() : -1; } catch (const std::exception &e) { ERROR("getEngineState failed. error_msg=" << e.what()); return -1; @@ -172,10 +175,10 @@ bool Manager::setEngineState(int engineId, int state) noexcept try { Statement stmt(m_conn, Query::INS_ENGINE_STATE); - int idx = 0; - stmt.bind(++idx, engineId); - stmt.bind(++idx, state); - return stmt.exec(); + stmt.bind(engineId); + stmt.bind(state); + + return stmt.exec() != 0; } catch (const std::exception &e) { ERROR("setEngineState failed. error_msg=" << e.what()); return false; @@ -192,14 +195,10 @@ long Manager::getLastScanTime(const std::string &dir, try { Statement stmt(m_conn, Query::SEL_SCAN_REQUEST); - int idx = 0; - stmt.bind(++idx, dir); - stmt.bind(++idx, dataVersion); + stmt.bind(dir); + stmt.bind(dataVersion); - if (!stmt.step()) - return -1; - - return static_cast(stmt.getInt64(1)); + return stmt.step() ? static_cast(stmt.getInt64()) : -1; } catch (const std::exception &e) { ERROR("getLastScanTime failed. error_msg=" << e.what()); return -1; @@ -212,11 +211,11 @@ bool Manager::insertLastScanTime(const std::string &dir, long scanTime, try { Statement stmt(m_conn, Query::INS_SCAN_REQUEST); - int idx = 0; - stmt.bind(++idx, dir); - stmt.bind(++idx, static_cast(scanTime)); - stmt.bind(++idx, dataVersion); - return stmt.exec(); + stmt.bind(dir); + stmt.bind(static_cast(scanTime)); + stmt.bind(dataVersion); + + return stmt.exec() != 0; } catch (const std::exception &e) { ERROR("insertLastScanTime failed. error_msg=" << e.what()); return false; @@ -228,8 +227,8 @@ bool Manager::deleteLastScanTime(const std::string &dir) noexcept try { Statement stmt(m_conn, Query::DEL_SCAN_REQUEST_BY_DIR); - int idx = 0; - stmt.bind(++idx, dir); + stmt.bind(dir); + stmt.exec(); return true; // even if no rows are deleted } catch (const std::exception &e) { @@ -242,7 +241,9 @@ bool Manager::cleanLastScanTime() noexcept { try { Statement stmt(m_conn, Query::DEL_SCAN_REQUEST); + stmt.exec(); + return true; // even if no rows are deleted } catch (const std::exception &e) { ERROR("cleanLastScanTime failed. error_msg=" << e.what()); @@ -253,87 +254,79 @@ bool Manager::cleanLastScanTime() noexcept //=========================================================================== // DETECTED_MALWARE_FILE table //=========================================================================== -DetectedListShrPtr Manager::getDetectedMalwares(const std::string &dir) noexcept +RowsShPtr Manager::getDetectedMalwares(const std::string &dir) noexcept { try { - DetectedListShrPtr detectedList = - std::make_shared>(); - Statement stmt(m_conn, Query::SEL_DETECTED_BY_DIR); + stmt.bind(dir); - int idx = 0; - stmt.bind(++idx, dir); + RowsShPtr rows = std::make_shared>(); while (stmt.step()) { - DetectedShrPtr detected = std::make_shared(); - idx = -1; - detected->path = stmt.getText(++idx); - detected->dataVersion = stmt.getText(++idx); - detected->severityLevel = stmt.getInt(++idx); - detected->threatType = stmt.getInt(++idx); - detected->name = stmt.getText(++idx); - detected->detailedUrl = stmt.getText(++idx); - detected->detected_time = static_cast(stmt.getInt64(++idx)); - detected->modified_time = static_cast(stmt.getInt64(++idx)); - detected->ignored = stmt.getInt(++idx); - - detectedList->push_back(detected); + RowShPtr row = std::make_shared(); + + row->targetName = stmt.getText(); + row->dataVersion = stmt.getText(); + row->severity = static_cast(stmt.getInt()); + row->threat = static_cast(stmt.getInt()); + row->malwareName = stmt.getText(); + row->detailedUrl = stmt.getText(); + row->ts = static_cast(stmt.getInt64()); + row->isIgnored = static_cast(stmt.getInt()); + + rows->emplace_back(std::move(row)); } - return detectedList; + return rows; } catch (const std::exception &e) { ERROR("getDetectedMalwares failed. error_msg=" << e.what()); return nullptr; } } -DetectedShrPtr Manager::getDetectedMalware(const std::string &path) noexcept +RowShPtr Manager::getDetectedMalware(const std::string &path) noexcept { try { - DetectedShrPtr detected = std::make_shared(); Statement stmt(m_conn, Query::SEL_DETECTED_BY_PATH); - - int idx = 0; - stmt.bind(++idx, path); + stmt.bind(path); if (!stmt.step()) return nullptr; - idx = -1; - detected->path = stmt.getText(++idx); - detected->dataVersion = stmt.getText(++idx); - detected->severityLevel = stmt.getInt(++idx); - detected->threatType = stmt.getInt(++idx); - detected->name = stmt.getText(++idx); - detected->detailedUrl = stmt.getText(++idx); - detected->detected_time = static_cast(stmt.getInt64(++idx)); - detected->modified_time = static_cast(stmt.getInt64(++idx)); - detected->ignored = stmt.getInt(++idx); - - return detected; + RowShPtr row = std::make_shared(); + row->targetName = stmt.getText(); + row->dataVersion = stmt.getText(); + row->severity = static_cast(stmt.getInt()); + row->threat = static_cast(stmt.getInt()); + row->malwareName = stmt.getText(); + row->detailedUrl = stmt.getText(); + row->ts = static_cast(stmt.getInt64()); + row->isIgnored = static_cast(stmt.getInt()); + + return row; } catch (const std::exception &e) { ERROR("getDetectedMalware failed. error_msg=" << e.what()); return nullptr; } } -bool Manager::insertDetectedMalware(const RowDetected &malware) noexcept +bool Manager::insertDetectedMalware(const CsDetected &d, + const std::string &dataVersion, + bool isIgnored) noexcept { try { Statement stmt(m_conn, Query::INS_DETECTED); - int idx = 0; - stmt.bind(++idx, malware.path); - stmt.bind(++idx, malware.dataVersion); - stmt.bind(++idx, malware.severityLevel); - stmt.bind(++idx, malware.threatType); - stmt.bind(++idx, malware.name); - stmt.bind(++idx, malware.detailedUrl); - stmt.bind(++idx, static_cast(malware.detected_time)); - stmt.bind(++idx, static_cast(malware.modified_time)); - stmt.bind(++idx, malware.ignored); - - return stmt.exec(); + stmt.bind(d.targetName); + stmt.bind(dataVersion); + stmt.bind(static_cast(d.severity)); + stmt.bind(static_cast(d.threat)); + stmt.bind(d.malwareName); + stmt.bind(d.detailedUrl); + stmt.bind(static_cast(d.ts)); + stmt.bind(static_cast(isIgnored)); + + return stmt.exec() == 1; } catch (const std::exception &e) { ERROR("insertDetectedMalware failed. error_msg=" << e.what()); return false; @@ -341,50 +334,51 @@ bool Manager::insertDetectedMalware(const RowDetected &malware) noexcept } bool Manager::setDetectedMalwareIgnored(const std::string &path, - int ignored) noexcept + bool flag) noexcept { try { Statement stmt(m_conn, Query::UPD_DETECTED_INGNORED); - int idx = 0; - stmt.bind(++idx, ignored); - stmt.bind(++idx, path); + stmt.bind(flag); + stmt.bind(path); - return stmt.exec(); + return stmt.exec() == 1; } catch (const std::exception &e) { ERROR("setDetectedMalwareIgnored failed. error_msg=" << e.what()); return false; } } -bool Manager::deleteDetecedMalware(const std::string &path) noexcept +bool Manager::deleteDetectedMalware(const std::string &path) noexcept { try { Statement stmt(m_conn, Query::DEL_DETECTED_BY_PATH); - int idx = 0; - stmt.bind(++idx, path); + stmt.bind(path); + stmt.exec(); + return true; // even if no rows are deleted } catch (const std::exception &e) { - ERROR("deleteDetecedMalware failed.error_msg=" << e.what()); + ERROR("deleteDetectedMalware failed.error_msg=" << e.what()); return false; } } -bool Manager::deleteDeprecatedDetecedMalwares(const std::string &dir, +bool Manager::deleteDeprecatedDetectedMalwares(const std::string &dir, const std::string &dataVersion) noexcept { try { Statement stmt(m_conn, Query::DEL_DETECTED_DEPRECATED); - int idx = 0; - stmt.bind(++idx, dir); - stmt.bind(++idx, dataVersion); + stmt.bind(dir); + stmt.bind(dataVersion); + stmt.exec(); + return true; // even if no rows are deleted } catch (const std::exception &e) { - ERROR("deleteDeprecatedDetecedMalwares failed.error_msg=" << e.what()); + ERROR("deleteDeprecatedDetectedMalwares failed.error_msg=" << e.what()); return false; } } diff --git a/src/framework/db/manager.h b/src/framework/db/manager.h index f24938d..c784113 100644 --- a/src/framework/db/manager.h +++ b/src/framework/db/manager.h @@ -26,40 +26,18 @@ #include #include "db/connection.h" +#include "db/row.h" +#include "common/cs-detected.h" namespace Csr { namespace Db { -struct RowDetected { - std::string path; - std::string dataVersion; // engine's data version - std::string name; - std::string detailedUrl; - int severityLevel; - int threatType; - int ignored; - long detected_time; - long modified_time; - - RowDetected() : - severityLevel(-1), - threatType(-1), - ignored(-1), - detected_time(-1), - modified_time(-1) {} - - virtual ~RowDetected() {} -}; - -using DetectedShrPtr = std::shared_ptr; -using DetectedListShrPtr = std::shared_ptr>; - class Manager { public: Manager(const std::string &dbfile, const std::string &scriptsDir); virtual ~Manager(); - // SCHEMA_INFO + // SCHEMA_INFO. it's public only for testing for now... int getSchemaVersion(); // ENGINE_STATE @@ -68,25 +46,25 @@ public: // SCAN_REQUEST long getLastScanTime(const std::string &dir, - const std::string &dataVersion) noexcept; + const std::string &dataVersion) noexcept; bool insertLastScanTime(const std::string &dir, long scanTime, - const std::string &dataVersion) noexcept; - bool deleteLastScanTime(const std::string &dir) noexcept; - bool cleanLastScanTime() noexcept; + const std::string &dataVersion) noexcept; + bool deleteLastScanTime(const std::string &dir) noexcept; + bool cleanLastScanTime() noexcept; // DETECTED_MALWARE_FILE & USER_RESPONSE - DetectedListShrPtr getDetectedMalwares(const std::string &dir) noexcept; - DetectedShrPtr getDetectedMalware(const std::string &path) noexcept; - bool insertDetectedMalware(const RowDetected &malware) noexcept; - bool setDetectedMalwareIgnored(const std::string &path, - int userResponse) noexcept; - bool deleteDetecedMalware(const std::string &path) noexcept; - bool deleteDeprecatedDetecedMalwares(const std::string &dir, - const std::string &dataVersion) noexcept; + RowsShPtr getDetectedMalwares(const std::string &dirpath) noexcept; + RowShPtr getDetectedMalware(const std::string &filepath) noexcept; + bool insertDetectedMalware(const CsDetected &, const std::string &dataVersion, + bool isIgnored) noexcept; + bool setDetectedMalwareIgnored(const std::string &path, bool flag) noexcept; + bool deleteDetectedMalware(const std::string &path) noexcept; + bool deleteDeprecatedDetectedMalwares(const std::string &dir, + const std::string &dataVersion) noexcept; private: - void initDatabase(); void resetDatabase(); + bool isTableExist(const std::string &name); std::string getScript(const std::string &scriptName); std::string getMigrationScript(int schemaVersion); diff --git a/src/framework/db/query.h b/src/framework/db/query.h index 8345815..c85c1ec 100644 --- a/src/framework/db/query.h +++ b/src/framework/db/query.h @@ -19,6 +19,9 @@ namespace Csr { namespace Db { namespace Query { +const std::string CHK_TABLE = + "select name from sqlite_master where type = 'table' and name = ?"; + const std::string SEL_SCHEMA_INFO = "select value from SCHEMA_INFO where name = ?"; @@ -26,13 +29,13 @@ const std::string INS_SCHEMA_INFO = "insert or replace into SCHEMA_INFO (name, value) values (?, ?)"; const std::string SEL_ENGINE_STATE = - "select state from ENGINE_STATE where engine_id = ?"; + "select state from ENGINE_STATE where id = ?"; const std::string INS_ENGINE_STATE = - "insert or replace into ENGINE_STATE (engine_id, state) values (?, ?)"; + "insert or replace into ENGINE_STATE (id, state) values (?, ?)"; const std::string SEL_SCAN_REQUEST = - "select dir, last_scan from SCAN_REQUEST where dir = ? and data_version = ?"; + "select last_scan from SCAN_REQUEST where dir = ? and data_version = ?"; const std::string INS_SCAN_REQUEST = "insert or replace into SCAN_REQUEST (dir, last_scan, data_version) " @@ -47,21 +50,21 @@ const std::string DEL_SCAN_REQUEST = const std::string SEL_DETECTED_BY_DIR = "SELECT path, data_version, " - " severity_level, threat_type, malware_name, " - " detailed_url, detected_time, modified_time, ignored " + " severity, threat, malware_name, " + " detailed_url, detected_time, ignored " " FROM detected_malware_file where path like ? || '%' "; const std::string SEL_DETECTED_BY_PATH = "SELECT path, data_version, " - " severity_level, threat_type, malware_name, " - " detailed_url, detected_time, modified_time, ignored " + " severity, threat, malware_name, " + " detailed_url, detected_time, ignored " " FROM detected_malware_file where path = ? "; const std::string INS_DETECTED = "insert or replace into DETECTED_MALWARE_FILE " - " (path, data_version, severity_level, threat_type, malware_name, " - " detailed_url, detected_time, modified_time, ignored) " - " values (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + " (path, data_version, severity, threat, malware_name, " + " detailed_url, detected_time, ignored) " + " values (?, ?, ?, ?, ?, ?, ?, ?)"; const std::string UPD_DETECTED_INGNORED = "update DETECTED_MALWARE_FILE set ignored = ? where path = ?"; diff --git a/src/framework/db/row.h b/src/framework/db/row.h new file mode 100644 index 0000000..1f7ffad --- /dev/null +++ b/src/framework/db/row.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ +/* + * @file row.h + * @author Kyungwook Tak (k.tak@samsung.com) + * @version 1.0 + * @brief db row + */ +#pragma once + +#include + +#include "common/cs-detected.h" + +namespace Csr { +namespace Db { + +class Row; +using RowShPtr = std::shared_ptr; +using RowsShPtr = std::shared_ptr>; + +struct Row : public Csr::CsDetected { + std::string dataVersion; // engine's data version + bool isIgnored; + + Row() : isIgnored(false) {} + virtual ~Row() {} +}; + +} // namespace Db +} // namespace Csr diff --git a/src/framework/db/statement.cpp b/src/framework/db/statement.cpp index 3746c4a..50e215b 100644 --- a/src/framework/db/statement.cpp +++ b/src/framework/db/statement.cpp @@ -19,229 +19,160 @@ #include #include "db/connection.h" +#include "common/audit/logger.h" namespace Csr { namespace Db { +namespace { + +const int BindingStartIndex = 0; +const int ColumnStartIndex = -1; + +} // namespace anonymous + Statement::Statement(const Connection &db, const std::string &query) : - m_stmt(nullptr), - m_columnCount(0), - m_validRow(false) + m_stmt(nullptr), m_bindingIndex(BindingStartIndex) { - if (SQLITE_OK != ::sqlite3_prepare_v2(db.get(), - query.c_str(), - query.size(), - &m_stmt, - nullptr)) + switch (sqlite3_prepare_v2(db.get(), query.c_str(), query.size(), &m_stmt, nullptr)) { + case SQLITE_OK: + m_bindParamCount = ::sqlite3_bind_parameter_count(m_stmt); + m_columnCount = ::sqlite3_column_count(m_stmt); + + // column index should be initialized after step(), so make it invalid here + m_columnIndex = m_columnCount + 1; + break; + + default: throw std::runtime_error(db.getErrorMessage()); + } - m_columnCount = sqlite3_column_count(m_stmt); } Statement::~Statement() { - if (::sqlite3_finalize(m_stmt) != SQLITE_OK) + if (SQLITE_OK != ::sqlite3_finalize(m_stmt)) throw std::runtime_error(getErrorMessage()); } void Statement::reset() { - if (::sqlite3_clear_bindings(m_stmt) != SQLITE_OK) - throw std::runtime_error(getErrorMessage()); + clearBindings(); if (::sqlite3_reset(m_stmt) != SQLITE_OK) throw std::runtime_error(getErrorMessage()); + + m_columnIndex = m_columnCount + 1; } -void Statement::clearBindings() +void Statement::clearBindings() const { if (::sqlite3_clear_bindings(m_stmt) != SQLITE_OK) throw std::runtime_error(getErrorMessage()); -} -std::string Statement::getErrorMessage() const -{ - return ::sqlite3_errmsg(::sqlite3_db_handle(m_stmt)); + m_bindingIndex = BindingStartIndex; } -std::string Statement::getErrorMessage(int errorCode) const +std::string Statement::getErrorMessage() const noexcept { - return ::sqlite3_errstr(errorCode); + return ::sqlite3_errmsg(::sqlite3_db_handle(m_stmt)); } bool Statement::step() { - return (m_validRow = (SQLITE_ROW == ::sqlite3_step(m_stmt))); -} - -int Statement::exec() -{ - if (SQLITE_DONE == ::sqlite3_step(m_stmt)) - m_validRow = false; - - return sqlite3_changes(sqlite3_db_handle(m_stmt)); -} + switch (::sqlite3_step(m_stmt)) { + case SQLITE_ROW: + m_columnIndex = ColumnStartIndex; + return true; + case SQLITE_DONE: + // column cannot be 'get' after sqlite done, so make index overflow. + m_columnIndex = m_columnCount + 1; + return false; -void Statement::bind(int index, int value) -{ - if (SQLITE_OK != ::sqlite3_bind_int(m_stmt, index, value)) - throw std::runtime_error(getErrorMessage()); -} - -void Statement::bind(int index, sqlite3_int64 value) -{ - if (SQLITE_OK != ::sqlite3_bind_int64(m_stmt, index, value)) - throw std::runtime_error(getErrorMessage()); -} - -void Statement::bind(int index, double value) -{ - if (SQLITE_OK != ::sqlite3_bind_double(m_stmt, index, value)) + default: throw std::runtime_error(getErrorMessage()); + } } -void Statement::bind(int index, const char *value) -{ - if (SQLITE_OK != ::sqlite3_bind_text(m_stmt, index, value, -1, - SQLITE_TRANSIENT)) - throw std::runtime_error(getErrorMessage()); -} - -void Statement::bind(int index, const std::string &value) -{ - if (SQLITE_OK != ::sqlite3_bind_text(m_stmt, index, value.c_str(), - static_cast(value.size()), SQLITE_TRANSIENT)) - throw std::runtime_error(getErrorMessage()); -} - -void Statement::bind(int index, const void *value, int size) -{ - if (SQLITE_OK != ::sqlite3_bind_blob(m_stmt, index, value, size, - SQLITE_TRANSIENT)) - throw std::runtime_error(getErrorMessage()); -} - -void Statement::bind(int index) -{ - if (SQLITE_OK != ::sqlite3_bind_null(m_stmt, index)) - throw std::runtime_error(getErrorMessage()); -} - -void Statement::bind(const std::string &name, int value) -{ - int index = sqlite3_bind_parameter_index(m_stmt, name.c_str()); - - if (SQLITE_OK != sqlite3_bind_int(m_stmt, index, value)) - throw std::runtime_error(getErrorMessage()); -} - -void Statement::bind(const std::string &name, sqlite3_int64 value) +int Statement::exec() { - int index = sqlite3_bind_parameter_index(m_stmt, name.c_str()); - - if (SQLITE_OK != ::sqlite3_bind_int64(m_stmt, index, value)) + if (::sqlite3_step(m_stmt) != SQLITE_DONE) throw std::runtime_error(getErrorMessage()); -} - -void Statement::bind(const std::string &name, double value) -{ - int index = sqlite3_bind_parameter_index(m_stmt, name.c_str()); - if (SQLITE_OK != ::sqlite3_bind_double(m_stmt, index, value)) - throw std::runtime_error(getErrorMessage()); + // column cannot be 'get' after sqlite done, so make index overflow. + m_columnIndex = m_columnCount + 1; + return sqlite3_changes(sqlite3_db_handle(m_stmt)); } -void Statement::bind(const std::string &name, const std::string &value) +void Statement::bind(int value) const { - int index = sqlite3_bind_parameter_index(m_stmt, name.c_str()); + if (!isBindingIndexValid()) + throw std::logic_error("index overflowed when binding int to stmt."); - if (SQLITE_OK != ::sqlite3_bind_text(m_stmt, index, value.c_str(), - static_cast(value.size()), SQLITE_TRANSIENT)) + if (SQLITE_OK != ::sqlite3_bind_int(m_stmt, ++m_bindingIndex, value)) throw std::runtime_error(getErrorMessage()); } -void Statement::bind(const std::string &name, const char *value) +void Statement::bind(sqlite3_int64 value) const { - int index = sqlite3_bind_parameter_index(m_stmt, name.c_str()); + if (!isBindingIndexValid()) + throw std::logic_error("index overflowed when binding int64 to stmt."); - if (SQLITE_OK != ::sqlite3_bind_text(m_stmt, index, value, -1, - SQLITE_TRANSIENT)) + if (SQLITE_OK != ::sqlite3_bind_int64(m_stmt, ++m_bindingIndex, value)) throw std::runtime_error(getErrorMessage()); } -void Statement::bind(const std::string &name, const void *value, int size) +void Statement::bind(const std::string &value) const { - int index = sqlite3_bind_parameter_index(m_stmt, name.c_str()); + if (!isBindingIndexValid()) + throw std::logic_error("index overflowed when binding string to stmt."); - if (SQLITE_OK != ::sqlite3_bind_blob(m_stmt, index, value, size, - SQLITE_TRANSIENT)) + if (SQLITE_OK != ::sqlite3_bind_text(m_stmt, ++m_bindingIndex, value.c_str(), -1, + SQLITE_STATIC)) throw std::runtime_error(getErrorMessage()); } -void Statement::bind(const std::string &name) +void Statement::bind(void) const { - int index = sqlite3_bind_parameter_index(m_stmt, name.c_str()); + if (!isBindingIndexValid()) + throw std::logic_error("index overflowed when binding fields from row"); - if (SQLITE_OK != ::sqlite3_bind_null(m_stmt, index)) + if (SQLITE_OK != ::sqlite3_bind_null(m_stmt, ++m_bindingIndex)) throw std::runtime_error(getErrorMessage()); } - -bool Statement::isColumnValid(int index) const noexcept -{ - return m_validRow && index < m_columnCount; -} - -bool Statement::isNullColumn(int index) const -{ - if (!isColumnValid(index)) - throw std::runtime_error(getErrorMessage(SQLITE_RANGE)); - - return SQLITE_NULL == sqlite3_column_type(m_stmt, index); -} - -std::string Statement::getColumnName(int index) const +bool Statement::isNullColumn() const { - if (!isColumnValid(index)) - throw std::runtime_error(getErrorMessage(SQLITE_RANGE)); + if (!isColumnIndexValid()) + throw std::runtime_error(FORMAT("column isn't valud for index: " << + m_columnIndex)); - return sqlite3_column_name(m_stmt, index); + return SQLITE_NULL == sqlite3_column_type(m_stmt, (m_columnIndex + 1)); } -int Statement::getInt(int index) const +int Statement::getInt() const { - return sqlite3_column_int(m_stmt, index); -} + if (!isColumnIndexValid()) + throw std::logic_error("index overflowed when getting fields from row"); -sqlite3_int64 Statement::getInt64(int index) const -{ - return sqlite3_column_int64(m_stmt, index); + return sqlite3_column_int(m_stmt, ++m_columnIndex); } -double Statement::getDouble(int index) const +sqlite3_int64 Statement::getInt64() const { - return sqlite3_column_double(m_stmt, index); -} + if (!isColumnIndexValid()) + throw std::logic_error("index overflowed when getting fields from row"); -const char *Statement::getText(int index) const -{ - return reinterpret_cast(sqlite3_column_text(m_stmt, index)); + return sqlite3_column_int64(m_stmt, ++m_columnIndex); } -const void *Statement::getBlob(int index) const +const char *Statement::getText() const { - return sqlite3_column_blob(m_stmt, index); -} + if (!isColumnIndexValid()) + throw std::logic_error("index overflowed when getting fields from row"); -int Statement::getType(int index) const -{ - return sqlite3_column_type(m_stmt, index); -} - -int Statement::getBytes(int index) const -{ - return sqlite3_column_bytes(m_stmt, index); + return reinterpret_cast(sqlite3_column_text(m_stmt, ++m_columnIndex)); } } // namespace Db diff --git a/src/framework/db/statement.h b/src/framework/db/statement.h index b5caf8d..df3e772 100644 --- a/src/framework/db/statement.h +++ b/src/framework/db/statement.h @@ -16,7 +16,6 @@ #pragma once #include -#include #include @@ -27,52 +26,48 @@ namespace Db { class Statement { public: - Statement(const Connection &db, const std::string &query); + Statement() = delete; + explicit Statement(const Connection &db, const std::string &query); virtual ~Statement(); int exec(); bool step(); void reset(); - void clearBindings(); + void clearBindings() const; - // bind values to query - void bind(int index, int value); - void bind(int index, sqlite3_int64 value); - void bind(int index, double value); - void bind(int index, const std::string &value); - void bind(int index, const char *value); - void bind(int index, const void *value, int size); - void bind(int index); + // bind values to query. index of column auto-incremented + void bind(int value) const; + void bind(sqlite3_int64 value) const; + void bind(double value) const; + void bind(const std::string &value) const; + void bind() const; - void bind(const std::string &name, int value); - void bind(const std::string &name, sqlite3_int64 value); - void bind(const std::string &name, double value); - void bind(const std::string &name, const std::string &value); - void bind(const std::string &name, const char *value); - void bind(const std::string &name, const void *value, int size); - void bind(const std::string &name); + // get column values. index of column auto-incremented + int getInt() const; + sqlite3_int64 getInt64() const; + const char *getText() const; - // get column values - std::string getColumnName(int index) const; - bool isNullColumn(int index) const; - - int getInt(int index) const; - sqlite3_int64 getInt64(int index) const; - double getDouble(int index) const; - const char *getText(int index) const; - const void *getBlob(int index) const; - int getType(int index) const; - int getBytes(int index) const; + bool isNullColumn() const; // it's checking func. not auto incremented. private: - std::string getErrorMessage() const; - std::string getErrorMessage(int errorCode) const; - bool isColumnValid(int index) const noexcept; + inline bool isBindingIndexValid() const noexcept + { + return m_bindingIndex <= m_bindParamCount; + } + + inline bool isColumnIndexValid() const noexcept + { + return m_columnIndex < m_columnCount; + } + + std::string getErrorMessage() const noexcept; sqlite3_stmt *m_stmt; + int m_bindParamCount; int m_columnCount; - bool m_validRow; + mutable int m_bindingIndex; + mutable int m_columnIndex; }; } // namespace Db diff --git a/test/test-internal-database.cpp b/test/test-internal-database.cpp index 1875476..3610bc5 100644 --- a/test/test-internal-database.cpp +++ b/test/test-internal-database.cpp @@ -19,7 +19,6 @@ * @version 1.0 * @brief CSR Content screening DB internal test */ - #include "db/manager.h" #include @@ -33,20 +32,18 @@ #define TEST_DB_FILE TEST_DIR "/test.db" #define TEST_DB_SCRIPTS RO_DBSPACE +using namespace Csr; + namespace { -void checkSameMalware(Csr::Db::RowDetected &malware1, - Csr::Db::RowDetected &malware2) +void checkSameMalware(const CsDetected &d, const Db::Row &r) { - ASSERT_IF(malware1.path, malware2.path); - ASSERT_IF(malware1.dataVersion, malware2.dataVersion); - ASSERT_IF(malware1.severityLevel, malware2.severityLevel); - ASSERT_IF(malware1.threatType, malware2.threatType); - ASSERT_IF(malware1.name, malware2.name); - ASSERT_IF(malware1.detailedUrl, malware2.detailedUrl); - ASSERT_IF(malware1.detected_time, malware2.detected_time); - ASSERT_IF(malware1.modified_time, malware2.modified_time); - ASSERT_IF(malware1.ignored, malware2.ignored); + ASSERT_IF(d.targetName, r.targetName); + ASSERT_IF(d.severity, r.severity); + ASSERT_IF(d.threat, r.threat); + ASSERT_IF(d.malwareName, r.malwareName); + ASSERT_IF(d.detailedUrl, r.detailedUrl); + ASSERT_IF(d.ts, r.ts); } } // namespace anonymous @@ -57,7 +54,7 @@ BOOST_AUTO_TEST_CASE(schema_info) { EXCEPTION_GUARD_START - Csr::Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS); + Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS); ASSERT_IF(db.getSchemaVersion(), 1); // latest version is 1 @@ -68,7 +65,7 @@ BOOST_AUTO_TEST_CASE(engine_state) { EXCEPTION_GUARD_START - Csr::Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS); + Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS); ASSERT_IF(db.setEngineState(1, 1), true); ASSERT_IF(db.setEngineState(2, 2), true); @@ -86,7 +83,7 @@ BOOST_AUTO_TEST_CASE(scan_time) { EXCEPTION_GUARD_START - Csr::Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS); + Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS); std::string dir = "/opt"; long scantime = 100; @@ -111,97 +108,88 @@ BOOST_AUTO_TEST_CASE(detected_malware_file) { EXCEPTION_GUARD_START - Csr::Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS); + Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS); std::string initDataVersion = "1.0.0"; std::string changedDataVersion = "2.0.0"; // insert - Csr::Db::RowDetected malware1; - malware1.path = "/opt/testmalware1"; - malware1.dataVersion = initDataVersion; - malware1.severityLevel = 1; - malware1.threatType = 1; - malware1.name = "testmalware1"; + CsDetected malware1; + malware1.targetName = "/opt/testmalware1"; + malware1.severity = CSR_CS_SEVERITY_MEDIUM; + malware1.threat = CSR_CS_THREAT_MALWARE; + malware1.malwareName = "testmalware1"; malware1.detailedUrl = "http://detailed.malware.com"; - malware1.detected_time = 100; - malware1.modified_time = 100; - malware1.ignored = 1; - - Csr::Db::RowDetected malware2; - malware2.path = "/opt/testmalware2"; - malware2.dataVersion = initDataVersion; - malware2.severityLevel = 2; - malware2.threatType = 2; - malware2.name = "testmalware2"; + malware1.ts = 100; + + CsDetected malware2; + malware2.targetName = "/opt/testmalware2"; + malware2.severity = CSR_CS_SEVERITY_HIGH; + malware2.threat = CSR_CS_THREAT_RISKY; + malware2.malwareName = "testmalware2"; malware2.detailedUrl = "http://detailed2.malware.com"; - malware2.detected_time = 210; - malware2.modified_time = 210; - malware2.ignored = 2; - - Csr::Db::RowDetected malware3; - malware3.path = "/opt/testmalware3"; - malware3.dataVersion = changedDataVersion; - malware3.severityLevel = 3; - malware3.threatType = 3; - malware3.name = "testmalware2"; - malware3.detailedUrl = "http://detailed2.malware.com"; - malware3.detected_time = 310; - malware3.modified_time = 310; - malware3.ignored = 3; + malware2.ts = 210; + + CsDetected malware3; + malware3.targetName = "/opt/testmalware3"; + malware3.severity = CSR_CS_SEVERITY_LOW; + malware3.threat = CSR_CS_THREAT_GENERIC; + malware3.malwareName = "testmalware3"; + malware3.detailedUrl = "http://detailed3.malware.com"; + malware3.ts = 310; // select test with vacant data - auto detected = db.getDetectedMalware(malware1.path); + auto detected = db.getDetectedMalware(malware1.targetName); CHECK_IS_NULL(detected); auto detectedList = db.getDetectedMalwares("/opt"); ASSERT_IF(detectedList->empty(), true); - // insertDetectedMalware test - ASSERT_IF(db.insertDetectedMalware(malware1), true); - ASSERT_IF(db.insertDetectedMalware(malware2), true); - - // getDetectedMalware test - detected = db.getDetectedMalware(malware1.path); + ASSERT_IF(db.insertDetectedMalware(malware1, initDataVersion, false), true); + detected = db.getDetectedMalware(malware1.targetName); checkSameMalware(malware1, *detected); - detected = db.getDetectedMalware(malware2.path); + ASSERT_IF(detected->dataVersion, initDataVersion); + ASSERT_IF(detected->isIgnored, false); + + ASSERT_IF(db.insertDetectedMalware(malware2, initDataVersion, true), true); + detected = db.getDetectedMalware(malware2.targetName); checkSameMalware(malware2, *detected); + ASSERT_IF(detected->dataVersion, initDataVersion); + ASSERT_IF(detected->isIgnored, true); // getDetectedMalwares test detectedList = db.getDetectedMalwares("/opt"); ASSERT_IF(detectedList->size(), static_cast(2)); for (auto &item : *detectedList) { - if (malware1.path == item->path) + if (malware1.targetName == item->targetName) checkSameMalware(malware1, *item); - else if (malware2.path == item->path) + else if (malware2.targetName == item->targetName) checkSameMalware(malware2, *item); else BOOST_REQUIRE_MESSAGE(false, "Failed. getDetectedMalwares"); } // setDetectedMalwareIgnored test - ASSERT_IF(db.setDetectedMalwareIgnored(malware1.path, 1), true); - - malware1.ignored = 1; - detected = db.getDetectedMalware(malware1.path); + ASSERT_IF(db.setDetectedMalwareIgnored(malware1.targetName, true), true); + detected = db.getDetectedMalware(malware1.targetName); checkSameMalware(malware1, *detected); + ASSERT_IF(detected->isIgnored, true); - // deleteDeprecatedDetecedMalwares test - ASSERT_IF(db.insertDetectedMalware(malware3), true); - - ASSERT_IF(db.deleteDeprecatedDetecedMalwares("/opt", changedDataVersion), true); - - detected = db.getDetectedMalware(malware3.path); + // deleteDeprecatedDetectedMalwares test + ASSERT_IF(db.insertDetectedMalware(malware3, changedDataVersion, false), true); + ASSERT_IF(db.deleteDeprecatedDetectedMalwares("/opt", changedDataVersion), true); + detected = db.getDetectedMalware(malware3.targetName); checkSameMalware(malware3, *detected); + ASSERT_IF(detected->dataVersion, changedDataVersion); + ASSERT_IF(detected->isIgnored, false); - detected = db.getDetectedMalware(malware1.path); - CHECK_IS_NULL(detected); - detected = db.getDetectedMalware(malware2.path); - CHECK_IS_NULL(detected); + CHECK_IS_NULL(db.getDetectedMalware(malware1.targetName)); + CHECK_IS_NULL(db.getDetectedMalware(malware2.targetName)); - // deleteDetecedMalware test - ASSERT_IF(db.deleteDetecedMalware(malware3.path), true); + // deleteDetectedMalware test + ASSERT_IF(db.deleteDetectedMalware(malware3.targetName), true); + CHECK_IS_NULL(db.getDetectedMalware(malware3.targetName)); EXCEPTION_GUARD_END } -- 2.7.4