Row inherits CsDetected struct 02/67802/3
authorKyungwook Tak <k.tak@samsung.com>
Thu, 28 Apr 2016 05:53:25 +0000 (14:53 +0900)
committerKyungwook Tak <k.tak@samsung.com>
Fri, 29 Apr 2016 03:03:57 +0000 (12:03 +0900)
remove unused functions in db namespace

remove modified_time field in db schema
modified_time isn't needed because we can determine whether
file is modified by comparing current time and detected_time

Change-Id: I95825bfbde8ada0649522088bc9a67b8874d32f0
Signed-off-by: Kyungwook Tak <k.tak@samsung.com>
data/scripts/create_schema.sql
src/framework/db/manager.cpp
src/framework/db/manager.h
src/framework/db/query.h
src/framework/db/row.h [new file with mode: 0644]
src/framework/db/statement.cpp
src/framework/db/statement.h
test/test-internal-database.cpp

index f867f3e..1a5b70c 100644 (file)
@@ -1,20 +1,19 @@
-CREATE TABLE IF NOT EXISTS SCHEMA_INFO(name TEXT PRIMARY KEY NOT NULL,
+CREATE TABLE IF NOT EXISTS SCHEMA_INFO(name  TEXT PRIMARY KEY NOT NULL,
                                        value TEXT);
 
-CREATE TABLE IF NOT EXISTS ENGINE_STATE(engine_id INTEGER PRIMARY KEY,
-                                        state     INTEGER NOT NULL);
+CREATE TABLE IF NOT EXISTS ENGINE_STATE(id    INTEGER PRIMARY KEY,
+                                        state INTEGER NOT NULL);
 
-CREATE TABLE IF NOT EXISTS SCAN_REQUEST(dir TEXT PRIMARY KEY,
-                                        last_scan INTEGER NOT NULL,
+CREATE TABLE IF NOT EXISTS SCAN_REQUEST(dir          TEXT PRIMARY KEY,
+                                        last_scan    INTEGER NOT NULL,
                                         data_version TEXT NOT NULL);
 
-CREATE TABLE IF NOT EXISTS DETECTED_MALWARE_FILE(path TEXT PRIMARY KEY NOT NULL,
-                                                 data_version TEXT NOT NULL,
-                                                 severity_level INTEGER NOT NULL,
-                                                 threat_type INTEGER NOT NULL,
-                                                 malware_name TEXT NOT NULL,
-                                                 detailed_url TEXT,
+CREATE TABLE IF NOT EXISTS DETECTED_MALWARE_FILE(path          TEXT PRIMARY KEY NOT NULL,
+                                                 data_version  TEXT NOT NULL,
+                                                 severity      INTEGER NOT NULL,
+                                                 threat        INTEGER NOT NULL,
+                                                 malware_name  TEXT NOT NULL,
+                                                 detailed_url  TEXT NOT NULL,
                                                  detected_time INTEGER NOT NULL,
-                                                 modified_time INTEGER NOT NULL,
-                                                 ignored INTEGER NOT NULL);
+                                                 ignored       INTEGER NOT NULL);
 
index d70e59b..548950a 100644 (file)
@@ -44,7 +44,8 @@ const std::string SCRIPT_CREATE_SCHEMA  = "create_schema";
 const std::string SCRIPT_DROP_ALL_ITEMS = "drop_all";
 const std::string SCRIPT_MIGRATE        = "migrate_";
 
-const std::string DB_VERSION_STR  = "DB_VERSION";
+const std::string DB_VERSION_STR    = "DB_VERSION";
+const std::string SCHEMA_INFO_TABLE = "SCHEMA_INFO";
 
 } // namespace anonymous
 
@@ -53,14 +54,6 @@ Manager::Manager(const std::string &dbfile, const std::string &scriptsDir) :
                   Connection::Serialized),
        m_scriptsDir(scriptsDir)
 {
-       initDatabase();
-       m_conn.exec("VACUUM;");
-}
-
-Manager::~Manager() {}
-
-void Manager::initDatabase()
-{
        // run migration if old database is present
        auto sv = getSchemaVersion();
 
@@ -85,6 +78,12 @@ void Manager::initDatabase()
 
                setSchemaVersion(SchemaVersion::LATEST);
        }
+
+       m_conn.exec("VACUUM;");
+}
+
+Manager::~Manager()
+{
 }
 
 void Manager::resetDatabase()
@@ -116,31 +115,39 @@ std::string Manager::getScript(const std::string &scriptName)
        return str;
 }
 
-int Manager::getSchemaVersion()
+bool Manager::isTableExist(const std::string &name)
 {
-       try {
-               Statement stmt(m_conn, Query::SEL_SCHEMA_INFO);
+       Statement stmt(m_conn, Query::CHK_TABLE);
 
-               int idx = 0;
-               stmt.bind(++idx, DB_VERSION_STR);
+       stmt.bind(name);
 
-               if (!stmt.step()) // Tables don't exist yet
-                       return SchemaVersion::NOT_EXIST;
+       return stmt.step();
+}
 
-               return stmt.getInt(0);
-       } catch (const std::runtime_error &e) {
-               INFO("Database doesn't exist.");
-               return SchemaVersion::NOT_EXIST; // table not exist!
+int Manager::getSchemaVersion()
+{
+       if (!isTableExist(SCHEMA_INFO_TABLE)) {
+               WARN("Schema table doesn't exist. This case would be the first time of "
+             "db manager instantiated in target");
+               return SchemaVersion::NOT_EXIST;
        }
+
+       Statement stmt(m_conn, Query::SEL_SCHEMA_INFO);
+
+       stmt.bind(DB_VERSION_STR);
+
+       if (!stmt.step())
+               throw std::logic_error(FORMAT("schema info table should exist!"));
+
+       return stmt.getInt();
 }
 
 void Manager::setSchemaVersion(int sv)
 {
        Statement stmt(m_conn, Query::INS_SCHEMA_INFO);
 
-       int idx = 0;
-       stmt.bind(++idx, DB_VERSION_STR);
-       stmt.bind(++idx, sv);
+       stmt.bind(DB_VERSION_STR);
+       stmt.bind(sv);
 
        if (stmt.exec() == 0)
                throw std::runtime_error(FORMAT("Failed to set schema version!"));
@@ -154,13 +161,9 @@ int Manager::getEngineState(int engineId) noexcept
        try {
                Statement stmt(m_conn, Query::SEL_ENGINE_STATE);
 
-               int idx = 0;
-               stmt.bind(++idx, engineId);
-
-               if (!stmt.step())
-                       return -1;
+               stmt.bind(engineId);
 
-               return stmt.getInt(0);
+               return stmt.step() ? stmt.getInt() : -1;
        } catch (const std::exception &e) {
                ERROR("getEngineState failed. error_msg=" << e.what());
                return -1;
@@ -172,10 +175,10 @@ bool Manager::setEngineState(int engineId, int state) noexcept
        try {
                Statement stmt(m_conn, Query::INS_ENGINE_STATE);
 
-               int idx = 0;
-               stmt.bind(++idx, engineId);
-               stmt.bind(++idx, state);
-               return stmt.exec();
+               stmt.bind(engineId);
+               stmt.bind(state);
+
+               return stmt.exec() != 0;
        } catch (const std::exception &e) {
                ERROR("setEngineState failed. error_msg=" << e.what());
                return false;
@@ -192,14 +195,10 @@ long Manager::getLastScanTime(const std::string &dir,
        try {
                Statement stmt(m_conn, Query::SEL_SCAN_REQUEST);
 
-               int idx = 0;
-               stmt.bind(++idx, dir);
-               stmt.bind(++idx, dataVersion);
+               stmt.bind(dir);
+               stmt.bind(dataVersion);
 
-               if (!stmt.step())
-                       return -1;
-
-               return static_cast<long>(stmt.getInt64(1));
+               return stmt.step() ? static_cast<time_t>(stmt.getInt64()) : -1;
        } catch (const std::exception &e) {
                ERROR("getLastScanTime failed. error_msg=" << e.what());
                return -1;
@@ -212,11 +211,11 @@ bool Manager::insertLastScanTime(const std::string &dir, long scanTime,
        try {
                Statement stmt(m_conn, Query::INS_SCAN_REQUEST);
 
-               int idx = 0;
-               stmt.bind(++idx, dir);
-               stmt.bind(++idx, static_cast<sqlite3_int64>(scanTime));
-               stmt.bind(++idx, dataVersion);
-               return stmt.exec();
+               stmt.bind(dir);
+               stmt.bind(static_cast<sqlite3_int64>(scanTime));
+               stmt.bind(dataVersion);
+
+               return stmt.exec() != 0;
        } catch (const std::exception &e) {
                ERROR("insertLastScanTime failed. error_msg=" << e.what());
                return false;
@@ -228,8 +227,8 @@ bool Manager::deleteLastScanTime(const std::string &dir) noexcept
        try {
                Statement stmt(m_conn, Query::DEL_SCAN_REQUEST_BY_DIR);
 
-               int idx = 0;
-               stmt.bind(++idx, dir);
+               stmt.bind(dir);
+
                stmt.exec();
                return true; // even if no rows are deleted
        } catch (const std::exception &e) {
@@ -242,7 +241,9 @@ bool Manager::cleanLastScanTime() noexcept
 {
        try {
                Statement stmt(m_conn, Query::DEL_SCAN_REQUEST);
+
                stmt.exec();
+
                return true; // even if no rows are deleted
        } catch (const std::exception &e) {
                ERROR("cleanLastScanTime failed. error_msg=" << e.what());
@@ -253,87 +254,79 @@ bool Manager::cleanLastScanTime() noexcept
 //===========================================================================
 // DETECTED_MALWARE_FILE table
 //===========================================================================
-DetectedListShrPtr Manager::getDetectedMalwares(const std::string &dir) noexcept
+RowsShPtr Manager::getDetectedMalwares(const std::string &dir) noexcept
 {
        try {
-               DetectedListShrPtr detectedList =
-                       std::make_shared<std::vector<DetectedShrPtr>>();
-
                Statement stmt(m_conn, Query::SEL_DETECTED_BY_DIR);
+               stmt.bind(dir);
 
-               int idx = 0;
-               stmt.bind(++idx, dir);
+               RowsShPtr rows = std::make_shared<std::vector<RowShPtr>>();
 
                while (stmt.step()) {
-                       DetectedShrPtr detected = std::make_shared<RowDetected>();
-                       idx = -1;
-                       detected->path = stmt.getText(++idx);
-                       detected->dataVersion = stmt.getText(++idx);
-                       detected->severityLevel = stmt.getInt(++idx);
-                       detected->threatType = stmt.getInt(++idx);
-                       detected->name = stmt.getText(++idx);
-                       detected->detailedUrl = stmt.getText(++idx);
-                       detected->detected_time = static_cast<long>(stmt.getInt64(++idx));
-                       detected->modified_time = static_cast<long>(stmt.getInt64(++idx));
-                       detected->ignored = stmt.getInt(++idx);
-
-                       detectedList->push_back(detected);
+                       RowShPtr row = std::make_shared<Row>();
+
+                       row->targetName = stmt.getText();
+                       row->dataVersion = stmt.getText();
+                       row->severity = static_cast<csr_cs_severity_level_e>(stmt.getInt());
+                       row->threat = static_cast<csr_cs_threat_type_e>(stmt.getInt());
+                       row->malwareName = stmt.getText();
+                       row->detailedUrl = stmt.getText();
+                       row->ts = static_cast<time_t>(stmt.getInt64());
+                       row->isIgnored = static_cast<bool>(stmt.getInt());
+
+                       rows->emplace_back(std::move(row));
                }
 
-               return detectedList;
+               return rows;
        } catch (const std::exception &e) {
                ERROR("getDetectedMalwares failed. error_msg=" << e.what());
                return nullptr;
        }
 }
 
-DetectedShrPtr Manager::getDetectedMalware(const std::string &path) noexcept
+RowShPtr Manager::getDetectedMalware(const std::string &path) noexcept
 {
        try {
-               DetectedShrPtr detected = std::make_shared<RowDetected>();
                Statement stmt(m_conn, Query::SEL_DETECTED_BY_PATH);
-
-               int idx = 0;
-               stmt.bind(++idx, path);
+               stmt.bind(path);
 
                if (!stmt.step())
                        return nullptr;
 
-               idx = -1;
-               detected->path = stmt.getText(++idx);
-               detected->dataVersion = stmt.getText(++idx);
-               detected->severityLevel = stmt.getInt(++idx);
-               detected->threatType = stmt.getInt(++idx);
-               detected->name = stmt.getText(++idx);
-               detected->detailedUrl = stmt.getText(++idx);
-               detected->detected_time = static_cast<long>(stmt.getInt64(++idx));
-               detected->modified_time = static_cast<long>(stmt.getInt64(++idx));
-               detected->ignored = stmt.getInt(++idx);
-
-               return detected;
+               RowShPtr row = std::make_shared<Row>();
+               row->targetName = stmt.getText();
+               row->dataVersion = stmt.getText();
+               row->severity = static_cast<csr_cs_severity_level_e>(stmt.getInt());
+               row->threat = static_cast<csr_cs_threat_type_e>(stmt.getInt());
+               row->malwareName = stmt.getText();
+               row->detailedUrl = stmt.getText();
+               row->ts = static_cast<time_t>(stmt.getInt64());
+               row->isIgnored = static_cast<bool>(stmt.getInt());
+
+               return row;
        } catch (const std::exception &e) {
                ERROR("getDetectedMalware failed. error_msg=" << e.what());
                return nullptr;
        }
 }
 
-bool Manager::insertDetectedMalware(const RowDetected &malware) noexcept
+bool Manager::insertDetectedMalware(const CsDetected &d,
+                                                                       const std::string &dataVersion,
+                                                                       bool isIgnored) noexcept
 {
        try {
                Statement stmt(m_conn, Query::INS_DETECTED);
 
-               int idx = 0;
-               stmt.bind(++idx, malware.path);
-               stmt.bind(++idx, malware.dataVersion);
-               stmt.bind(++idx, malware.severityLevel);
-               stmt.bind(++idx, malware.threatType);
-               stmt.bind(++idx, malware.name);
-               stmt.bind(++idx, malware.detailedUrl);
-               stmt.bind(++idx, static_cast<sqlite3_int64>(malware.detected_time));
-               stmt.bind(++idx, static_cast<sqlite3_int64>(malware.modified_time));
-               stmt.bind(++idx, malware.ignored);
-
-               return stmt.exec();
+               stmt.bind(d.targetName);
+               stmt.bind(dataVersion);
+               stmt.bind(static_cast<int>(d.severity));
+               stmt.bind(static_cast<int>(d.threat));
+               stmt.bind(d.malwareName);
+               stmt.bind(d.detailedUrl);
+               stmt.bind(static_cast<sqlite3_int64>(d.ts));
+               stmt.bind(static_cast<int>(isIgnored));
+
+               return stmt.exec() == 1;
        } catch (const std::exception &e) {
                ERROR("insertDetectedMalware failed. error_msg=" << e.what());
                return false;
@@ -341,50 +334,51 @@ bool Manager::insertDetectedMalware(const RowDetected &malware) noexcept
 }
 
 bool Manager::setDetectedMalwareIgnored(const std::string &path,
-                                                                               int ignored) noexcept
+                                                                               bool flag) noexcept
 {
        try {
                Statement stmt(m_conn, Query::UPD_DETECTED_INGNORED);
 
-               int idx = 0;
-               stmt.bind(++idx, ignored);
-               stmt.bind(++idx, path);
+               stmt.bind(flag);
+               stmt.bind(path);
 
-               return stmt.exec();
+               return stmt.exec() == 1;
        } catch (const std::exception &e) {
                ERROR("setDetectedMalwareIgnored failed. error_msg=" << e.what());
                return false;
        }
 }
 
-bool Manager::deleteDetecedMalware(const std::string &path) noexcept
+bool Manager::deleteDetectedMalware(const std::string &path) noexcept
 {
        try {
                Statement stmt(m_conn, Query::DEL_DETECTED_BY_PATH);
 
-               int idx = 0;
-               stmt.bind(++idx, path);
+               stmt.bind(path);
+
                stmt.exec();
+
                return true; // even if no rows are deleted
        } catch (const std::exception &e) {
-               ERROR("deleteDetecedMalware failed.error_msg=" << e.what());
+               ERROR("deleteDetectedMalware failed.error_msg=" << e.what());
                return false;
        }
 }
 
-bool Manager::deleteDeprecatedDetecedMalwares(const std::string &dir,
+bool Manager::deleteDeprecatedDetectedMalwares(const std::string &dir,
                const std::string &dataVersion) noexcept
 {
        try {
                Statement stmt(m_conn, Query::DEL_DETECTED_DEPRECATED);
 
-               int idx = 0;
-               stmt.bind(++idx, dir);
-               stmt.bind(++idx, dataVersion);
+               stmt.bind(dir);
+               stmt.bind(dataVersion);
+
                stmt.exec();
+
                return true; // even if no rows are deleted
        } catch (const std::exception &e) {
-               ERROR("deleteDeprecatedDetecedMalwares failed.error_msg=" << e.what());
+               ERROR("deleteDeprecatedDetectedMalwares failed.error_msg=" << e.what());
                return false;
        }
 }
index f24938d..c784113 100644 (file)
 #include <memory>
 
 #include "db/connection.h"
+#include "db/row.h"
+#include "common/cs-detected.h"
 
 namespace Csr {
 namespace Db {
 
-struct RowDetected {
-       std::string path;
-       std::string dataVersion; // engine's data version
-       std::string name;
-       std::string detailedUrl;
-       int severityLevel;
-       int threatType;
-       int ignored;
-       long detected_time;
-       long modified_time;
-
-       RowDetected() :
-               severityLevel(-1),
-               threatType(-1),
-               ignored(-1),
-               detected_time(-1),
-               modified_time(-1) {}
-
-       virtual ~RowDetected() {}
-};
-
-using DetectedShrPtr = std::shared_ptr<RowDetected>;
-using DetectedListShrPtr = std::shared_ptr<std::vector<DetectedShrPtr>>;
-
 class Manager {
 public:
        Manager(const std::string &dbfile, const std::string &scriptsDir);
        virtual ~Manager();
 
-       // SCHEMA_INFO
+       // SCHEMA_INFO. it's public only for testing for now...
        int getSchemaVersion();
 
        // ENGINE_STATE
@@ -68,25 +46,25 @@ public:
 
        // SCAN_REQUEST
        long getLastScanTime(const std::string &dir,
-                                                const std::string &dataVersion)  noexcept;
+                                                const std::string &dataVersion) noexcept;
        bool insertLastScanTime(const std::string &dir, long scanTime,
-                                                       const std::string &dataVersion)  noexcept;
-       bool deleteLastScanTime(const std::string &dir)  noexcept;
-       bool cleanLastScanTime()  noexcept;
+                                                       const std::string &dataVersion) noexcept;
+       bool deleteLastScanTime(const std::string &dir) noexcept;
+       bool cleanLastScanTime() noexcept;
 
        // DETECTED_MALWARE_FILE & USER_RESPONSE
-       DetectedListShrPtr getDetectedMalwares(const std::string &dir)  noexcept;
-       DetectedShrPtr getDetectedMalware(const std::string &path)  noexcept;
-       bool insertDetectedMalware(const RowDetected &malware)  noexcept;
-       bool setDetectedMalwareIgnored(const std::string &path,
-                                                                  int userResponse)  noexcept;
-       bool deleteDetecedMalware(const std::string &path) noexcept;
-       bool deleteDeprecatedDetecedMalwares(const std::string &dir,
-                                                                                const std::string &dataVersion)  noexcept;
+       RowsShPtr getDetectedMalwares(const std::string &dirpath) noexcept;
+       RowShPtr getDetectedMalware(const std::string &filepath) noexcept;
+       bool insertDetectedMalware(const CsDetected &, const std::string &dataVersion,
+                                                          bool isIgnored) noexcept;
+       bool setDetectedMalwareIgnored(const std::string &path, bool flag) noexcept;
+       bool deleteDetectedMalware(const std::string &path) noexcept;
+       bool deleteDeprecatedDetectedMalwares(const std::string &dir,
+                                                                                 const std::string &dataVersion) noexcept;
 
 private:
-       void initDatabase();
        void resetDatabase();
+       bool isTableExist(const std::string &name);
        std::string getScript(const std::string &scriptName);
        std::string getMigrationScript(int schemaVersion);
 
index 8345815..c85c1ec 100644 (file)
@@ -19,6 +19,9 @@ namespace Csr {
 namespace Db {
 namespace Query {
 
+const std::string CHK_TABLE =
+       "select name from sqlite_master where type = 'table' and name = ?";
+
 const std::string SEL_SCHEMA_INFO =
        "select value from SCHEMA_INFO where name = ?";
 
@@ -26,13 +29,13 @@ const std::string INS_SCHEMA_INFO =
        "insert or replace into SCHEMA_INFO (name, value) values (?, ?)";
 
 const std::string SEL_ENGINE_STATE =
-       "select state from ENGINE_STATE where engine_id = ?";
+       "select state from ENGINE_STATE where id = ?";
 
 const std::string INS_ENGINE_STATE =
-       "insert or replace into ENGINE_STATE (engine_id, state) values (?, ?)";
+       "insert or replace into ENGINE_STATE (id, state) values (?, ?)";
 
 const std::string SEL_SCAN_REQUEST =
-       "select dir, last_scan from SCAN_REQUEST where dir = ? and data_version = ?";
+       "select last_scan from SCAN_REQUEST where dir = ? and data_version = ?";
 
 const std::string INS_SCAN_REQUEST =
        "insert or replace into SCAN_REQUEST (dir, last_scan, data_version) "
@@ -47,21 +50,21 @@ const std::string DEL_SCAN_REQUEST =
 
 const std::string SEL_DETECTED_BY_DIR =
        "SELECT path, data_version, "
-       " severity_level, threat_type, malware_name, "
-       " detailed_url, detected_time, modified_time, ignored "
+       " severity, threat, malware_name, "
+       " detailed_url, detected_time, ignored "
        " FROM detected_malware_file where path like ? || '%' ";
 
 const std::string SEL_DETECTED_BY_PATH =
        "SELECT path, data_version, "
-       " severity_level, threat_type, malware_name, "
-       " detailed_url, detected_time, modified_time, ignored "
+       " severity, threat, malware_name, "
+       " detailed_url, detected_time, ignored "
        " FROM detected_malware_file where path = ? ";
 
 const std::string INS_DETECTED =
        "insert or replace into DETECTED_MALWARE_FILE "
-       " (path, data_version, severity_level, threat_type, malware_name, "
-       " detailed_url, detected_time, modified_time, ignored) "
-       " values (?, ?, ?, ?, ?, ?, ?, ?, ?)";
+       " (path, data_version, severity, threat, malware_name, "
+       " detailed_url, detected_time, ignored) "
+       " values (?, ?, ?, ?, ?, ?, ?, ?)";
 
 const std::string UPD_DETECTED_INGNORED =
        "update DETECTED_MALWARE_FILE set ignored = ? where path = ?";
diff --git a/src/framework/db/row.h b/src/framework/db/row.h
new file mode 100644 (file)
index 0000000..1f7ffad
--- /dev/null
@@ -0,0 +1,44 @@
+/*
+ *  Copyright (c) 2016 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ *  Licensed under the Apache License, Version 2.0 (the "License");
+ *  you may not use this file except in compliance with the License.
+ *  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License
+ */
+/*
+ * @file        row.h
+ * @author      Kyungwook Tak (k.tak@samsung.com)
+ * @version     1.0
+ * @brief       db row
+ */
+#pragma once
+
+#include <memory>
+
+#include "common/cs-detected.h"
+
+namespace Csr {
+namespace Db {
+
+class Row;
+using RowShPtr = std::shared_ptr<Row>;
+using RowsShPtr = std::shared_ptr<std::vector<RowShPtr>>;
+
+struct Row : public Csr::CsDetected {
+       std::string dataVersion; // engine's data version
+       bool isIgnored;
+
+       Row() : isIgnored(false) {}
+       virtual ~Row() {}
+};
+
+} // namespace Db
+} // namespace Csr
index 3746c4a..50e215b 100644 (file)
 #include <stdexcept>
 
 #include "db/connection.h"
+#include "common/audit/logger.h"
 
 namespace Csr {
 namespace Db {
 
+namespace {
+
+const int BindingStartIndex = 0;
+const int ColumnStartIndex = -1;
+
+} // namespace anonymous
+
 Statement::Statement(const Connection &db, const std::string &query) :
-       m_stmt(nullptr),
-       m_columnCount(0),
-       m_validRow(false)
+       m_stmt(nullptr), m_bindingIndex(BindingStartIndex)
 {
-       if (SQLITE_OK != ::sqlite3_prepare_v2(db.get(),
-                                                                                 query.c_str(),
-                                                                                 query.size(),
-                                                                                 &m_stmt,
-                                                                                 nullptr))
+       switch (sqlite3_prepare_v2(db.get(), query.c_str(), query.size(), &m_stmt, nullptr)) {
+       case SQLITE_OK:
+               m_bindParamCount = ::sqlite3_bind_parameter_count(m_stmt);
+               m_columnCount = ::sqlite3_column_count(m_stmt);
+
+               // column index should be initialized after step(), so make it invalid here
+               m_columnIndex = m_columnCount + 1;
+               break;
+
+       default:
                throw std::runtime_error(db.getErrorMessage());
+       }
 
-       m_columnCount = sqlite3_column_count(m_stmt);
 }
 
 Statement::~Statement()
 {
-       if (::sqlite3_finalize(m_stmt) != SQLITE_OK)
+       if (SQLITE_OK != ::sqlite3_finalize(m_stmt))
                throw std::runtime_error(getErrorMessage());
 }
 
 void Statement::reset()
 {
-       if (::sqlite3_clear_bindings(m_stmt) != SQLITE_OK)
-               throw std::runtime_error(getErrorMessage());
+       clearBindings();
 
        if (::sqlite3_reset(m_stmt) != SQLITE_OK)
                throw std::runtime_error(getErrorMessage());
+
+       m_columnIndex = m_columnCount + 1;
 }
 
-void Statement::clearBindings()
+void Statement::clearBindings() const
 {
        if (::sqlite3_clear_bindings(m_stmt) != SQLITE_OK)
                throw std::runtime_error(getErrorMessage());
-}
 
-std::string Statement::getErrorMessage() const
-{
-       return ::sqlite3_errmsg(::sqlite3_db_handle(m_stmt));
+       m_bindingIndex = BindingStartIndex;
 }
 
-std::string Statement::getErrorMessage(int errorCode) const
+std::string Statement::getErrorMessage() const noexcept
 {
-       return ::sqlite3_errstr(errorCode);
+       return ::sqlite3_errmsg(::sqlite3_db_handle(m_stmt));
 }
 
 bool Statement::step()
 {
-       return (m_validRow = (SQLITE_ROW == ::sqlite3_step(m_stmt)));
-}
-
-int Statement::exec()
-{
-       if (SQLITE_DONE == ::sqlite3_step(m_stmt))
-               m_validRow = false;
-
-       return sqlite3_changes(sqlite3_db_handle(m_stmt));
-}
+       switch (::sqlite3_step(m_stmt)) {
+       case SQLITE_ROW:
+               m_columnIndex = ColumnStartIndex;
+               return true;
 
+       case SQLITE_DONE:
+               // column cannot be 'get' after sqlite done, so make index overflow.
+               m_columnIndex = m_columnCount + 1;
+               return false;
 
-void Statement::bind(int index, int value)
-{
-       if (SQLITE_OK != ::sqlite3_bind_int(m_stmt, index, value))
-               throw std::runtime_error(getErrorMessage());
-}
-
-void Statement::bind(int index, sqlite3_int64 value)
-{
-       if (SQLITE_OK != ::sqlite3_bind_int64(m_stmt, index, value))
-               throw std::runtime_error(getErrorMessage());
-}
-
-void Statement::bind(int index, double value)
-{
-       if (SQLITE_OK != ::sqlite3_bind_double(m_stmt, index, value))
+       default:
                throw std::runtime_error(getErrorMessage());
+       }
 }
 
-void Statement::bind(int index, const char *value)
-{
-       if (SQLITE_OK != ::sqlite3_bind_text(m_stmt, index, value, -1,
-                                                                                SQLITE_TRANSIENT))
-               throw std::runtime_error(getErrorMessage());
-}
-
-void Statement::bind(int index, const std::string &value)
-{
-       if (SQLITE_OK != ::sqlite3_bind_text(m_stmt, index, value.c_str(),
-                                                                                static_cast<int>(value.size()), SQLITE_TRANSIENT))
-               throw std::runtime_error(getErrorMessage());
-}
-
-void Statement::bind(int index, const void *value, int size)
-{
-       if (SQLITE_OK != ::sqlite3_bind_blob(m_stmt, index, value, size,
-                                                                                SQLITE_TRANSIENT))
-               throw std::runtime_error(getErrorMessage());
-}
-
-void Statement::bind(int index)
-{
-       if (SQLITE_OK != ::sqlite3_bind_null(m_stmt, index))
-               throw std::runtime_error(getErrorMessage());
-}
-
-void Statement::bind(const std::string &name, int value)
-{
-       int index = sqlite3_bind_parameter_index(m_stmt, name.c_str());
-
-       if (SQLITE_OK != sqlite3_bind_int(m_stmt, index, value))
-               throw std::runtime_error(getErrorMessage());
-}
-
-void Statement::bind(const std::string &name, sqlite3_int64 value)
+int Statement::exec()
 {
-       int index = sqlite3_bind_parameter_index(m_stmt, name.c_str());
-
-       if (SQLITE_OK != ::sqlite3_bind_int64(m_stmt, index, value))
+       if (::sqlite3_step(m_stmt) != SQLITE_DONE)
                throw std::runtime_error(getErrorMessage());
-}
-
-void Statement::bind(const std::string &name, double value)
-{
-       int index = sqlite3_bind_parameter_index(m_stmt, name.c_str());
 
-       if (SQLITE_OK != ::sqlite3_bind_double(m_stmt, index, value))
-               throw std::runtime_error(getErrorMessage());
+       // column cannot be 'get' after sqlite done, so make index overflow.
+       m_columnIndex = m_columnCount + 1;
+       return sqlite3_changes(sqlite3_db_handle(m_stmt));
 }
 
-void Statement::bind(const std::string &name, const std::string &value)
+void Statement::bind(int value) const
 {
-       int index = sqlite3_bind_parameter_index(m_stmt, name.c_str());
+       if (!isBindingIndexValid())
+               throw std::logic_error("index overflowed when binding int to stmt.");
 
-       if (SQLITE_OK != ::sqlite3_bind_text(m_stmt, index, value.c_str(),
-                                                                                static_cast<int>(value.size()), SQLITE_TRANSIENT))
+       if (SQLITE_OK != ::sqlite3_bind_int(m_stmt, ++m_bindingIndex, value))
                throw std::runtime_error(getErrorMessage());
 }
 
-void Statement::bind(const std::string &name, const char *value)
+void Statement::bind(sqlite3_int64 value) const
 {
-       int index = sqlite3_bind_parameter_index(m_stmt, name.c_str());
+       if (!isBindingIndexValid())
+               throw std::logic_error("index overflowed when binding int64 to stmt.");
 
-       if (SQLITE_OK != ::sqlite3_bind_text(m_stmt, index, value, -1,
-                                                                                SQLITE_TRANSIENT))
+       if (SQLITE_OK != ::sqlite3_bind_int64(m_stmt, ++m_bindingIndex, value))
                throw std::runtime_error(getErrorMessage());
 }
 
-void Statement::bind(const std::string &name, const void *value, int size)
+void Statement::bind(const std::string &value) const
 {
-       int index = sqlite3_bind_parameter_index(m_stmt, name.c_str());
+       if (!isBindingIndexValid())
+               throw std::logic_error("index overflowed when binding string to stmt.");
 
-       if (SQLITE_OK != ::sqlite3_bind_blob(m_stmt, index, value, size,
-                                                                                SQLITE_TRANSIENT))
+       if (SQLITE_OK != ::sqlite3_bind_text(m_stmt, ++m_bindingIndex, value.c_str(), -1,
+                                                                                SQLITE_STATIC))
                throw std::runtime_error(getErrorMessage());
 }
 
-void Statement::bind(const std::string &name)
+void Statement::bind(void) const
 {
-       int index = sqlite3_bind_parameter_index(m_stmt, name.c_str());
+       if (!isBindingIndexValid())
+               throw std::logic_error("index overflowed when binding fields from row");
 
-       if (SQLITE_OK != ::sqlite3_bind_null(m_stmt, index))
+       if (SQLITE_OK != ::sqlite3_bind_null(m_stmt, ++m_bindingIndex))
                throw std::runtime_error(getErrorMessage());
 }
 
-
-bool Statement::isColumnValid(int index) const noexcept
-{
-       return m_validRow && index < m_columnCount;
-}
-
-bool Statement::isNullColumn(int index) const
-{
-       if (!isColumnValid(index))
-               throw std::runtime_error(getErrorMessage(SQLITE_RANGE));
-
-       return SQLITE_NULL == sqlite3_column_type(m_stmt, index);
-}
-
-std::string Statement::getColumnName(int index) const
+bool Statement::isNullColumn() const
 {
-       if (!isColumnValid(index))
-               throw std::runtime_error(getErrorMessage(SQLITE_RANGE));
+       if (!isColumnIndexValid())
+               throw std::runtime_error(FORMAT("column isn't valud for index: " <<
+                                                                               m_columnIndex));
 
-       return sqlite3_column_name(m_stmt, index);
+       return SQLITE_NULL == sqlite3_column_type(m_stmt, (m_columnIndex + 1));
 }
 
-int Statement::getInt(int index) const
+int Statement::getInt() const
 {
-       return sqlite3_column_int(m_stmt, index);
-}
+       if (!isColumnIndexValid())
+               throw std::logic_error("index overflowed when getting fields from row");
 
-sqlite3_int64 Statement::getInt64(int index) const
-{
-       return sqlite3_column_int64(m_stmt, index);
+       return sqlite3_column_int(m_stmt, ++m_columnIndex);
 }
 
-double Statement::getDouble(int index) const
+sqlite3_int64 Statement::getInt64() const
 {
-       return sqlite3_column_double(m_stmt, index);
-}
+       if (!isColumnIndexValid())
+               throw std::logic_error("index overflowed when getting fields from row");
 
-const char *Statement::getText(int index) const
-{
-       return reinterpret_cast<const char *>(sqlite3_column_text(m_stmt, index));
+       return sqlite3_column_int64(m_stmt, ++m_columnIndex);
 }
 
-const void *Statement::getBlob(int index) const
+const char *Statement::getText() const
 {
-       return sqlite3_column_blob(m_stmt, index);
-}
+       if (!isColumnIndexValid())
+               throw std::logic_error("index overflowed when getting fields from row");
 
-int Statement::getType(int index) const
-{
-       return sqlite3_column_type(m_stmt, index);
-}
-
-int Statement::getBytes(int index) const
-{
-       return sqlite3_column_bytes(m_stmt, index);
+       return reinterpret_cast<const char *>(sqlite3_column_text(m_stmt, ++m_columnIndex));
 }
 
 } // namespace Db
index b5caf8d..df3e772 100644 (file)
@@ -16,7 +16,6 @@
 #pragma once
 
 #include <string>
-#include <map>
 
 #include <sqlite3.h>
 
@@ -27,52 +26,48 @@ namespace Db {
 
 class Statement {
 public:
-       Statement(const Connection &db, const std::string &query);
+       Statement() = delete;
+       explicit Statement(const Connection &db, const std::string &query);
        virtual ~Statement();
 
        int exec();
        bool step();
 
        void reset();
-       void clearBindings();
+       void clearBindings() const;
 
-       // bind values to query
-       void bind(int index, int value);
-       void bind(int index, sqlite3_int64 value);
-       void bind(int index, double value);
-       void bind(int index, const std::string &value);
-       void bind(int index, const char *value);
-       void bind(int index, const void *value, int size);
-       void bind(int index);
+       // bind values to query. index of column auto-incremented
+       void bind(int value) const;
+       void bind(sqlite3_int64 value) const;
+       void bind(double value) const;
+       void bind(const std::string &value) const;
+       void bind() const;
 
-       void bind(const std::string &name, int value);
-       void bind(const std::string &name, sqlite3_int64 value);
-       void bind(const std::string &name, double value);
-       void bind(const std::string &name, const std::string &value);
-       void bind(const std::string &name, const char *value);
-       void bind(const std::string &name, const void *value, int size);
-       void bind(const std::string &name);
+       // get column values. index of column auto-incremented
+       int getInt() const;
+       sqlite3_int64 getInt64() const;
+       const char *getText() const;
 
-       // get column values
-       std::string getColumnName(int index) const;
-       bool isNullColumn(int index) const;
-
-       int getInt(int index) const;
-       sqlite3_int64 getInt64(int index) const;
-       double getDouble(int index) const;
-       const char *getText(int index) const;
-       const void *getBlob(int index) const;
-       int getType(int index) const;
-       int getBytes(int index) const;
+       bool isNullColumn() const; // it's checking func. not auto incremented.
 
 private:
-       std::string getErrorMessage() const;
-       std::string getErrorMessage(int errorCode) const;
-       bool isColumnValid(int index) const noexcept;
+       inline bool isBindingIndexValid() const noexcept
+       {
+               return m_bindingIndex <= m_bindParamCount;
+       }
+
+       inline bool isColumnIndexValid() const noexcept
+       {
+               return m_columnIndex < m_columnCount;
+       }
+
+       std::string getErrorMessage() const noexcept;
 
        sqlite3_stmt *m_stmt;
+       int m_bindParamCount;
        int m_columnCount;
-       bool m_validRow;
+       mutable int m_bindingIndex;
+       mutable int m_columnIndex;
 };
 
 } // namespace Db
index 1875476..3610bc5 100644 (file)
@@ -19,7 +19,6 @@
  * @version     1.0
  * @brief       CSR Content screening DB internal test
  */
-
 #include "db/manager.h"
 
 #include <iostream>
 #define TEST_DB_FILE     TEST_DIR "/test.db"
 #define TEST_DB_SCRIPTS  RO_DBSPACE
 
+using namespace Csr;
+
 namespace {
 
-void checkSameMalware(Csr::Db::RowDetected &malware1,
-                                         Csr::Db::RowDetected &malware2)
+void checkSameMalware(const CsDetected &d, const Db::Row &r)
 {
-       ASSERT_IF(malware1.path,          malware2.path);
-       ASSERT_IF(malware1.dataVersion,   malware2.dataVersion);
-       ASSERT_IF(malware1.severityLevel, malware2.severityLevel);
-       ASSERT_IF(malware1.threatType,    malware2.threatType);
-       ASSERT_IF(malware1.name,          malware2.name);
-       ASSERT_IF(malware1.detailedUrl,   malware2.detailedUrl);
-       ASSERT_IF(malware1.detected_time, malware2.detected_time);
-       ASSERT_IF(malware1.modified_time, malware2.modified_time);
-       ASSERT_IF(malware1.ignored,       malware2.ignored);
+       ASSERT_IF(d.targetName,  r.targetName);
+       ASSERT_IF(d.severity,    r.severity);
+       ASSERT_IF(d.threat,      r.threat);
+       ASSERT_IF(d.malwareName, r.malwareName);
+       ASSERT_IF(d.detailedUrl, r.detailedUrl);
+       ASSERT_IF(d.ts,          r.ts);
 }
 
 } // namespace anonymous
@@ -57,7 +54,7 @@ BOOST_AUTO_TEST_CASE(schema_info)
 {
        EXCEPTION_GUARD_START
 
-       Csr::Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS);
+       Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS);
 
        ASSERT_IF(db.getSchemaVersion(), 1); // latest version is 1
 
@@ -68,7 +65,7 @@ BOOST_AUTO_TEST_CASE(engine_state)
 {
        EXCEPTION_GUARD_START
 
-       Csr::Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS);
+       Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS);
 
        ASSERT_IF(db.setEngineState(1, 1), true);
        ASSERT_IF(db.setEngineState(2, 2), true);
@@ -86,7 +83,7 @@ BOOST_AUTO_TEST_CASE(scan_time)
 {
        EXCEPTION_GUARD_START
 
-       Csr::Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS);
+       Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS);
 
        std::string dir = "/opt";
        long scantime = 100;
@@ -111,97 +108,88 @@ BOOST_AUTO_TEST_CASE(detected_malware_file)
 {
        EXCEPTION_GUARD_START
 
-       Csr::Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS);
+       Db::Manager db(TEST_DB_FILE, TEST_DB_SCRIPTS);
 
        std::string initDataVersion = "1.0.0";
        std::string changedDataVersion = "2.0.0";
 
        // insert
-       Csr::Db::RowDetected malware1;
-       malware1.path = "/opt/testmalware1";
-       malware1.dataVersion = initDataVersion;
-       malware1.severityLevel = 1;
-       malware1.threatType = 1;
-       malware1.name = "testmalware1";
+       CsDetected malware1;
+       malware1.targetName = "/opt/testmalware1";
+       malware1.severity = CSR_CS_SEVERITY_MEDIUM;
+       malware1.threat = CSR_CS_THREAT_MALWARE;
+       malware1.malwareName = "testmalware1";
        malware1.detailedUrl = "http://detailed.malware.com";
-       malware1.detected_time = 100;
-       malware1.modified_time = 100;
-       malware1.ignored = 1;
-
-       Csr::Db::RowDetected malware2;
-       malware2.path = "/opt/testmalware2";
-       malware2.dataVersion = initDataVersion;
-       malware2.severityLevel = 2;
-       malware2.threatType = 2;
-       malware2.name = "testmalware2";
+       malware1.ts = 100;
+
+       CsDetected malware2;
+       malware2.targetName = "/opt/testmalware2";
+       malware2.severity = CSR_CS_SEVERITY_HIGH;
+       malware2.threat = CSR_CS_THREAT_RISKY;
+       malware2.malwareName = "testmalware2";
        malware2.detailedUrl = "http://detailed2.malware.com";
-       malware2.detected_time = 210;
-       malware2.modified_time = 210;
-       malware2.ignored = 2;
-
-       Csr::Db::RowDetected malware3;
-       malware3.path = "/opt/testmalware3";
-       malware3.dataVersion = changedDataVersion;
-       malware3.severityLevel = 3;
-       malware3.threatType = 3;
-       malware3.name = "testmalware2";
-       malware3.detailedUrl = "http://detailed2.malware.com";
-       malware3.detected_time = 310;
-       malware3.modified_time = 310;
-       malware3.ignored = 3;
+       malware2.ts = 210;
+
+       CsDetected malware3;
+       malware3.targetName = "/opt/testmalware3";
+       malware3.severity = CSR_CS_SEVERITY_LOW;
+       malware3.threat = CSR_CS_THREAT_GENERIC;
+       malware3.malwareName = "testmalware3";
+       malware3.detailedUrl = "http://detailed3.malware.com";
+       malware3.ts = 310;
 
        // select test with vacant data
-       auto detected = db.getDetectedMalware(malware1.path);
+       auto detected = db.getDetectedMalware(malware1.targetName);
        CHECK_IS_NULL(detected);
 
        auto detectedList = db.getDetectedMalwares("/opt");
        ASSERT_IF(detectedList->empty(), true);
 
-       // insertDetectedMalware test
-       ASSERT_IF(db.insertDetectedMalware(malware1), true);
-       ASSERT_IF(db.insertDetectedMalware(malware2), true);
-
-       // getDetectedMalware test
-       detected = db.getDetectedMalware(malware1.path);
+       ASSERT_IF(db.insertDetectedMalware(malware1, initDataVersion, false), true);
+       detected = db.getDetectedMalware(malware1.targetName);
        checkSameMalware(malware1, *detected);
-       detected = db.getDetectedMalware(malware2.path);
+       ASSERT_IF(detected->dataVersion, initDataVersion);
+       ASSERT_IF(detected->isIgnored, false);
+
+       ASSERT_IF(db.insertDetectedMalware(malware2, initDataVersion, true), true);
+       detected = db.getDetectedMalware(malware2.targetName);
        checkSameMalware(malware2, *detected);
+       ASSERT_IF(detected->dataVersion, initDataVersion);
+       ASSERT_IF(detected->isIgnored, true);
 
        // getDetectedMalwares test
        detectedList = db.getDetectedMalwares("/opt");
        ASSERT_IF(detectedList->size(), static_cast<size_t>(2));
 
        for (auto &item : *detectedList) {
-               if (malware1.path == item->path)
+               if (malware1.targetName == item->targetName)
                        checkSameMalware(malware1, *item);
-               else if (malware2.path == item->path)
+               else if (malware2.targetName == item->targetName)
                        checkSameMalware(malware2, *item);
                else
                        BOOST_REQUIRE_MESSAGE(false, "Failed. getDetectedMalwares");
        }
 
        // setDetectedMalwareIgnored test
-       ASSERT_IF(db.setDetectedMalwareIgnored(malware1.path, 1), true);
-
-       malware1.ignored = 1;
-       detected = db.getDetectedMalware(malware1.path);
+       ASSERT_IF(db.setDetectedMalwareIgnored(malware1.targetName, true), true);
+       detected = db.getDetectedMalware(malware1.targetName);
        checkSameMalware(malware1, *detected);
+       ASSERT_IF(detected->isIgnored, true);
 
-       // deleteDeprecatedDetecedMalwares test
-       ASSERT_IF(db.insertDetectedMalware(malware3), true);
-
-       ASSERT_IF(db.deleteDeprecatedDetecedMalwares("/opt", changedDataVersion), true);
-
-       detected = db.getDetectedMalware(malware3.path);
+       // deleteDeprecatedDetectedMalwares test
+       ASSERT_IF(db.insertDetectedMalware(malware3, changedDataVersion, false), true);
+       ASSERT_IF(db.deleteDeprecatedDetectedMalwares("/opt", changedDataVersion), true);
+       detected = db.getDetectedMalware(malware3.targetName);
        checkSameMalware(malware3, *detected);
+       ASSERT_IF(detected->dataVersion, changedDataVersion);
+       ASSERT_IF(detected->isIgnored, false);
 
-       detected = db.getDetectedMalware(malware1.path);
-       CHECK_IS_NULL(detected);
-       detected = db.getDetectedMalware(malware2.path);
-       CHECK_IS_NULL(detected);
+       CHECK_IS_NULL(db.getDetectedMalware(malware1.targetName));
+       CHECK_IS_NULL(db.getDetectedMalware(malware2.targetName));
 
-       // deleteDetecedMalware test
-       ASSERT_IF(db.deleteDetecedMalware(malware3.path), true);
+       // deleteDetectedMalware test
+       ASSERT_IF(db.deleteDetectedMalware(malware3.targetName), true);
+       CHECK_IS_NULL(db.getDetectedMalware(malware3.targetName));
 
        EXCEPTION_GUARD_END
 }