[Database] Prevent invalid memory access 08/300408/3
authorChanggyu Choi <changyu.choi@samsung.com>
Tue, 24 Oct 2023 05:58:48 +0000 (14:58 +0900)
committerChanggyu Choi <changyu.choi@samsung.com>
Mon, 30 Oct 2023 00:32:56 +0000 (09:32 +0900)
If the TransactionGuard is destroy later than the Database,
there is a possibility of accessing the invalid pointer.

Change-Id: I6570e78b777da02672db2bf83e26dba2fb1af990
Signed-off-by: Changgyu Choi <changyu.choi@samsung.com>
tests/tizen-database_unittests/src/test_database.cc
tizen-database/database.hpp

index 395780d..669bd87 100644 (file)
@@ -41,6 +41,15 @@ constexpr char Q_SELECT[] = "SELECT * FROM TestTable;";
 constexpr char Q_UPDATE[] = "UPDATE TestTable SET name=?;";
 constexpr char Q_DELETE[] = "DELETE FROM TestTable WHERE name=?;";
 
+class MoveDatabaseTestClass {
+ public:
+  MoveDatabaseTestClass(tizen_base::Database db) : db_(std::move(db)) {}
+  ~MoveDatabaseTestClass() = default;
+
+ private:
+  tizen_base::Database db_;
+};
+
 }  // namespace
 
 TEST(DBBasicTest, create_table) {
@@ -603,3 +612,10 @@ TEST_F(DatabaseTest, test_get_string) {
     EXPECT_EQ(*val, "1234");
   }
 }
+
+TEST_F(DatabaseTest, test_move_database) {
+  SetDefault();
+  tizen_base::Database db(TEST_DB, SQLITE_OPEN_READWRITE);
+  auto gaurd = db.CreateTransactionGuard();
+  MoveDatabaseTestClass test(std::move(db));
+}
index 61c4d5c..b8e9b6c 100644 (file)
@@ -201,45 +201,48 @@ class Database {
     TransactionGuard& operator = (const TransactionGuard&) = delete;
 
     TransactionGuard(TransactionGuard&& t) noexcept {
-      db_ = t.db_;
-      t.db_ = nullptr;
+      db_ = std::move(t.db_);
     }
 
     TransactionGuard& operator = (TransactionGuard&& t) noexcept {
       if (this != &t) {
-        if (db_)
-          sqlite3_exec(db_, "ROLLBACK", nullptr, nullptr, nullptr);
-        db_ = t.db_;
-        t.db_ = nullptr;
+        if (!db_.expired())
+          sqlite3_exec(db_.lock().get(), "ROLLBACK", nullptr, nullptr, nullptr);
+        db_ = std::move(t.db_);
       }
 
       return *this;
     }
 
-    explicit TransactionGuard(sqlite3* db) : db_(db) {
-      int r = sqlite3_exec(db, "BEGIN DEFERRED", nullptr, nullptr, nullptr);
+    explicit TransactionGuard(const std::shared_ptr<sqlite3>& db)
+        : db_(db) {
+      int r = sqlite3_exec(db_.lock().get(), "BEGIN DEFERRED",
+          nullptr, nullptr, nullptr);
       if (r != SQLITE_OK) {
         throw DbException("begin transaction failed", r);
       }
     }
 
     ~TransactionGuard() {
-      if (db_)
-        sqlite3_exec(db_, "ROLLBACK", nullptr, nullptr, nullptr);
+      if (!db_.expired())
+        sqlite3_exec(db_.lock().get(), "ROLLBACK", nullptr, nullptr, nullptr);
     }
 
     int Commit() {
-      int ret = sqlite3_exec(db_, "COMMIT", nullptr, nullptr, nullptr);
+      if (db_.expired())
+        return SQLITE_OK;
+
+      auto db = db_.lock();
+      int ret = sqlite3_exec(db.get(), "COMMIT", nullptr, nullptr, nullptr);
       if (ret != SQLITE_OK) {
-        sqlite3_exec(db_, "ROLLBACK", nullptr, nullptr, nullptr);
+        sqlite3_exec(db.get(), "ROLLBACK", nullptr, nullptr, nullptr);
       }
 
-      db_ = nullptr;
       return ret;
     }
 
    private:
-    sqlite3* db_ = nullptr;
+    std::weak_ptr<sqlite3> db_;
   };
 
   class Sql {
@@ -565,15 +568,17 @@ class Database {
     }
 
     explicit operator int() const {
-      if (db_ == nullptr)
+      auto db = db_.lock();
+      if (db == nullptr)
         return SQLITE_ERROR;
-      return sqlite3_errcode(db_);
+      return sqlite3_errcode(db.get());
     }
 
     explicit operator const char*() const {
-      if (db_ == nullptr)
+      auto db = db_.lock();
+      if (db == nullptr)
         return "";
-      return sqlite3_errmsg(db_);
+      return sqlite3_errmsg(db.get());
     }
 
     template <class T>
@@ -647,56 +652,56 @@ class Database {
 
    private:
     friend class Database;
-    Result(sqlite3_stmt* stmt, sqlite3* db, std::string query, bool is_done)
+    Result(sqlite3_stmt* stmt, const std::shared_ptr<sqlite3>& db, std::string query, bool is_done)
         : stmt_(stmt), db_(db), query_(std::move(query)), is_done_(is_done) {}
 
     sqlite3_stmt* stmt_ = nullptr;
-    sqlite3* db_ = nullptr;
+    std::weak_ptr<sqlite3> db_;
     std::string query_;
     bool is_done_ = false;
   };
 
-  Database(std::string db, int flags) {
-    int r = sqlite3_open_v2(db.c_str(), &db_, flags, nullptr);
+  Database(std::string db_path, int flags) {
+    sqlite3* raw_db = nullptr;
+    int r = sqlite3_open_v2(db_path.c_str(), &raw_db, flags, nullptr);
     if (r != SQLITE_OK)
       throw DbException("open failed", r);
+
+    db_.reset(raw_db, sqlite3_close_v2);
   }
 
-  Database(std::string db, int flags, std::function<bool(int)> busy_handler) {
-    int r = sqlite3_open_v2(db.c_str(), &db_, flags, nullptr);
+  Database(std::string db_path, int flags,
+      std::function<bool(int)> busy_handler) {
+    sqlite3* raw_db = nullptr;
+    int r = sqlite3_open_v2(db_path.c_str(), &raw_db, flags, nullptr);
     if (r != SQLITE_OK)
       throw DbException("sqlite3_open_v2() failed", r);
 
+    db_.reset(raw_db, sqlite3_close_v2);
     busy_handler_ = std::move(busy_handler);
-    r = sqlite3_busy_handler(db_, [](void* data, int count) {
+    r = sqlite3_busy_handler(db_.get(), [](void* data, int count) {
       Database* pDb = static_cast<Database*>(data);
       if (pDb->busy_handler_ && pDb->busy_handler_(count))
         return 1;
       return 0;
     }, this);
 
-    if (r != SQLITE_OK) {
-      sqlite3_close_v2(db_);
+    if (r != SQLITE_OK)
       throw DbException("sqlite3_busy_handler() failed", r);
-    }
-  }
-
-  ~Database() {
-    if (db_)
-      sqlite3_close_v2(db_);
   }
 
+  ~Database() = default;
   Database() = default;
   Database(const Database&) = delete;
   Database& operator = (const Database&) = delete;
 
   Database(Database&& db) noexcept {
-    db_ = db.db_;
+    db_ = std::move(db.db_);
     busy_handler_ = std::move(db.busy_handler_);
     db.db_ = nullptr;
     db.busy_handler_ = nullptr;
 
-    sqlite3_busy_handler(db_, [](void* data, int count) {
+    sqlite3_busy_handler(db_.get(), [](void* data, int count) {
       Database* pDb = static_cast<Database*>(data);
       if (pDb->busy_handler_ && pDb->busy_handler_(count))
         return 1;
@@ -712,13 +717,11 @@ class Database {
 
   Database& operator = (Database&& db) noexcept {
     if (this != &db) {
-      if (db_)
-        sqlite3_close_v2(db_);
-      db_ = db.db_;
+      db_ = std::move(db.db_);
       busy_handler_ = std::move(db.busy_handler_);
       db.db_ = nullptr;
       db.busy_handler_ = nullptr;
-      sqlite3_busy_handler(db_, [](void* data, int count) {
+      sqlite3_busy_handler(db_.get(), [](void* data, int count) {
         Database* pDb = static_cast<Database*>(data);
         if (pDb->busy_handler_ && pDb->busy_handler_(count))
           return 1;
@@ -738,7 +741,7 @@ class Database {
       throw DbException("Not opened");
 
     sqlite3_stmt* stmt = nullptr;
-    int r = sqlite3_prepare_v2(db_, sql.GetQuery().c_str(),
+    int r = sqlite3_prepare_v2(db_.get(), sql.GetQuery().c_str(),
         -1, &stmt, nullptr);
     if (r != SQLITE_OK)
       return { nullptr, nullptr, "", true };
@@ -750,7 +753,7 @@ class Database {
       throw DbException("Not opened");
 
     sqlite3_stmt* stmt = nullptr;
-    int r = sqlite3_prepare_v2(db_, sql.GetQuery().c_str(),
+    int r = sqlite3_prepare_v2(db_.get(), sql.GetQuery().c_str(),
         -1, &stmt, nullptr);
     if (r != SQLITE_OK) {
       return { nullptr, nullptr, "", true };
@@ -821,7 +824,7 @@ class Database {
 
   void OneStepExec(const Sql& sql) const {
     char* errmsg = nullptr;
-    int ret = sqlite3_exec(db_, sql.GetQuery().c_str(), nullptr, nullptr,
+    int ret = sqlite3_exec(db_.get(), sql.GetQuery().c_str(), nullptr, nullptr,
         &errmsg);
     if (ret != SQLITE_OK) {
       std::unique_ptr<char, decltype(sqlite3_free)*> errmsg_auto(
@@ -833,7 +836,7 @@ class Database {
   sqlite3* GetRaw() const {
     if (!db_)
       throw DbException("Not opened");
-    return db_;
+    return db_.get();
   }
 
  private:
@@ -869,7 +872,7 @@ class Database {
   }
 
  private:
-  sqlite3* db_ = nullptr;
+  std::shared_ptr<sqlite3> db_ = nullptr;
   std::function<bool(int)> busy_handler_;
 };