From 26ea9aaf27314c7136ed4e3a11c30c0ca110d80b Mon Sep 17 00:00:00 2001 From: Changgyu Choi Date: Tue, 24 Oct 2023 14:58:48 +0900 Subject: [PATCH] [Database] Prevent invalid memory access If the TransactionGuard is destroy later than the Database, there is a possibility of accessing the invalid pointer. Change-Id: I6570e78b777da02672db2bf83e26dba2fb1af990 Signed-off-by: Changgyu Choi --- .../tizen-database_unittests/src/test_database.cc | 16 ++++ tizen-database/database.hpp | 91 +++++++++++----------- 2 files changed, 63 insertions(+), 44 deletions(-) diff --git a/tests/tizen-database_unittests/src/test_database.cc b/tests/tizen-database_unittests/src/test_database.cc index 395780d..669bd87 100644 --- a/tests/tizen-database_unittests/src/test_database.cc +++ b/tests/tizen-database_unittests/src/test_database.cc @@ -41,6 +41,15 @@ constexpr char Q_SELECT[] = "SELECT * FROM TestTable;"; constexpr char Q_UPDATE[] = "UPDATE TestTable SET name=?;"; constexpr char Q_DELETE[] = "DELETE FROM TestTable WHERE name=?;"; +class MoveDatabaseTestClass { + public: + MoveDatabaseTestClass(tizen_base::Database db) : db_(std::move(db)) {} + ~MoveDatabaseTestClass() = default; + + private: + tizen_base::Database db_; +}; + } // namespace TEST(DBBasicTest, create_table) { @@ -603,3 +612,10 @@ TEST_F(DatabaseTest, test_get_string) { EXPECT_EQ(*val, "1234"); } } + +TEST_F(DatabaseTest, test_move_database) { + SetDefault(); + tizen_base::Database db(TEST_DB, SQLITE_OPEN_READWRITE); + auto gaurd = db.CreateTransactionGuard(); + MoveDatabaseTestClass test(std::move(db)); +} diff --git a/tizen-database/database.hpp b/tizen-database/database.hpp index 61c4d5c..b8e9b6c 100644 --- a/tizen-database/database.hpp +++ b/tizen-database/database.hpp @@ -201,45 +201,48 @@ class Database { TransactionGuard& operator = (const TransactionGuard&) = delete; TransactionGuard(TransactionGuard&& t) noexcept { - db_ = t.db_; - t.db_ = nullptr; + db_ = std::move(t.db_); } TransactionGuard& operator = (TransactionGuard&& t) noexcept { if (this != &t) { - if (db_) - sqlite3_exec(db_, "ROLLBACK", nullptr, nullptr, nullptr); - db_ = t.db_; - t.db_ = nullptr; + if (!db_.expired()) + sqlite3_exec(db_.lock().get(), "ROLLBACK", nullptr, nullptr, nullptr); + db_ = std::move(t.db_); } return *this; } - explicit TransactionGuard(sqlite3* db) : db_(db) { - int r = sqlite3_exec(db, "BEGIN DEFERRED", nullptr, nullptr, nullptr); + explicit TransactionGuard(const std::shared_ptr& db) + : db_(db) { + int r = sqlite3_exec(db_.lock().get(), "BEGIN DEFERRED", + nullptr, nullptr, nullptr); if (r != SQLITE_OK) { throw DbException("begin transaction failed", r); } } ~TransactionGuard() { - if (db_) - sqlite3_exec(db_, "ROLLBACK", nullptr, nullptr, nullptr); + if (!db_.expired()) + sqlite3_exec(db_.lock().get(), "ROLLBACK", nullptr, nullptr, nullptr); } int Commit() { - int ret = sqlite3_exec(db_, "COMMIT", nullptr, nullptr, nullptr); + if (db_.expired()) + return SQLITE_OK; + + auto db = db_.lock(); + int ret = sqlite3_exec(db.get(), "COMMIT", nullptr, nullptr, nullptr); if (ret != SQLITE_OK) { - sqlite3_exec(db_, "ROLLBACK", nullptr, nullptr, nullptr); + sqlite3_exec(db.get(), "ROLLBACK", nullptr, nullptr, nullptr); } - db_ = nullptr; return ret; } private: - sqlite3* db_ = nullptr; + std::weak_ptr db_; }; class Sql { @@ -565,15 +568,17 @@ class Database { } explicit operator int() const { - if (db_ == nullptr) + auto db = db_.lock(); + if (db == nullptr) return SQLITE_ERROR; - return sqlite3_errcode(db_); + return sqlite3_errcode(db.get()); } explicit operator const char*() const { - if (db_ == nullptr) + auto db = db_.lock(); + if (db == nullptr) return ""; - return sqlite3_errmsg(db_); + return sqlite3_errmsg(db.get()); } template @@ -647,56 +652,56 @@ class Database { private: friend class Database; - Result(sqlite3_stmt* stmt, sqlite3* db, std::string query, bool is_done) + Result(sqlite3_stmt* stmt, const std::shared_ptr& db, std::string query, bool is_done) : stmt_(stmt), db_(db), query_(std::move(query)), is_done_(is_done) {} sqlite3_stmt* stmt_ = nullptr; - sqlite3* db_ = nullptr; + std::weak_ptr db_; std::string query_; bool is_done_ = false; }; - Database(std::string db, int flags) { - int r = sqlite3_open_v2(db.c_str(), &db_, flags, nullptr); + Database(std::string db_path, int flags) { + sqlite3* raw_db = nullptr; + int r = sqlite3_open_v2(db_path.c_str(), &raw_db, flags, nullptr); if (r != SQLITE_OK) throw DbException("open failed", r); + + db_.reset(raw_db, sqlite3_close_v2); } - Database(std::string db, int flags, std::function busy_handler) { - int r = sqlite3_open_v2(db.c_str(), &db_, flags, nullptr); + Database(std::string db_path, int flags, + std::function busy_handler) { + sqlite3* raw_db = nullptr; + int r = sqlite3_open_v2(db_path.c_str(), &raw_db, flags, nullptr); if (r != SQLITE_OK) throw DbException("sqlite3_open_v2() failed", r); + db_.reset(raw_db, sqlite3_close_v2); busy_handler_ = std::move(busy_handler); - r = sqlite3_busy_handler(db_, [](void* data, int count) { + r = sqlite3_busy_handler(db_.get(), [](void* data, int count) { Database* pDb = static_cast(data); if (pDb->busy_handler_ && pDb->busy_handler_(count)) return 1; return 0; }, this); - if (r != SQLITE_OK) { - sqlite3_close_v2(db_); + if (r != SQLITE_OK) throw DbException("sqlite3_busy_handler() failed", r); - } - } - - ~Database() { - if (db_) - sqlite3_close_v2(db_); } + ~Database() = default; Database() = default; Database(const Database&) = delete; Database& operator = (const Database&) = delete; Database(Database&& db) noexcept { - db_ = db.db_; + db_ = std::move(db.db_); busy_handler_ = std::move(db.busy_handler_); db.db_ = nullptr; db.busy_handler_ = nullptr; - sqlite3_busy_handler(db_, [](void* data, int count) { + sqlite3_busy_handler(db_.get(), [](void* data, int count) { Database* pDb = static_cast(data); if (pDb->busy_handler_ && pDb->busy_handler_(count)) return 1; @@ -712,13 +717,11 @@ class Database { Database& operator = (Database&& db) noexcept { if (this != &db) { - if (db_) - sqlite3_close_v2(db_); - db_ = db.db_; + db_ = std::move(db.db_); busy_handler_ = std::move(db.busy_handler_); db.db_ = nullptr; db.busy_handler_ = nullptr; - sqlite3_busy_handler(db_, [](void* data, int count) { + sqlite3_busy_handler(db_.get(), [](void* data, int count) { Database* pDb = static_cast(data); if (pDb->busy_handler_ && pDb->busy_handler_(count)) return 1; @@ -738,7 +741,7 @@ class Database { throw DbException("Not opened"); sqlite3_stmt* stmt = nullptr; - int r = sqlite3_prepare_v2(db_, sql.GetQuery().c_str(), + int r = sqlite3_prepare_v2(db_.get(), sql.GetQuery().c_str(), -1, &stmt, nullptr); if (r != SQLITE_OK) return { nullptr, nullptr, "", true }; @@ -750,7 +753,7 @@ class Database { throw DbException("Not opened"); sqlite3_stmt* stmt = nullptr; - int r = sqlite3_prepare_v2(db_, sql.GetQuery().c_str(), + int r = sqlite3_prepare_v2(db_.get(), sql.GetQuery().c_str(), -1, &stmt, nullptr); if (r != SQLITE_OK) { return { nullptr, nullptr, "", true }; @@ -821,7 +824,7 @@ class Database { void OneStepExec(const Sql& sql) const { char* errmsg = nullptr; - int ret = sqlite3_exec(db_, sql.GetQuery().c_str(), nullptr, nullptr, + int ret = sqlite3_exec(db_.get(), sql.GetQuery().c_str(), nullptr, nullptr, &errmsg); if (ret != SQLITE_OK) { std::unique_ptr errmsg_auto( @@ -833,7 +836,7 @@ class Database { sqlite3* GetRaw() const { if (!db_) throw DbException("Not opened"); - return db_; + return db_.get(); } private: @@ -869,7 +872,7 @@ class Database { } private: - sqlite3* db_ = nullptr; + std::shared_ptr db_ = nullptr; std::function busy_handler_; }; -- 2.7.4