Fix resetting prepared statement
[platform/core/security/security-manager.git] / src / common / privilege_db.cpp
index 6c8d1f3..0498f21 100644 (file)
@@ -60,6 +60,7 @@ PrivilegeDb::PrivilegeDb(const std::string &path)
         mSqlConnection = new DB::SqlConnection(path,
                 DB::SqlConnection::Flag::None,
                 DB::SqlConnection::Flag::RW);
+        initDataCommands();
     } catch (DB::SqlConnection::Exception::Base &e) {
         LogError("Database initialization error: " << e.DumpToString());
         ThrowMsg(PrivilegeDb::Exception::IOError,
@@ -68,11 +69,43 @@ PrivilegeDb::PrivilegeDb(const std::string &path)
     };
 }
 
+void PrivilegeDb::initDataCommands()
+{
+    for (auto &it : Queries) {
+        m_commands.push_back(mSqlConnection->PrepareDataCommand(it.second));
+    }
+}
+
+PrivilegeDb::StatementWrapper::StatementWrapper(DB::SqlConnection::DataCommandAutoPtr &ref)
+    : m_ref(ref) {}
+
+PrivilegeDb::StatementWrapper::~StatementWrapper()
+{
+    m_ref->Reset();
+}
+
+DB::SqlConnection::DataCommand* PrivilegeDb::StatementWrapper::operator->()
+{
+    return m_ref.get();
+}
+
+PrivilegeDb::StatementWrapper PrivilegeDb::getStatement(StmtType queryType)
+{
+    return StatementWrapper(m_commands.at(static_cast<size_t>(queryType)));
+}
+
 PrivilegeDb::~PrivilegeDb()
 {
+    m_commands.clear();
     delete mSqlConnection;
 }
 
+PrivilegeDb &PrivilegeDb::getInstance()
+{
+    static PrivilegeDb privilegeDb;
+    return privilegeDb;
+}
+
 void PrivilegeDb::BeginTransaction(void)
 {
     try_catch<void>([&] {
@@ -97,27 +130,17 @@ void PrivilegeDb::RollbackTransaction(void)
 bool PrivilegeDb::PkgIdExists(const std::string &pkgId)
 {
     return try_catch<bool>([&] {
-        DB::SqlConnection::DataCommandAutoPtr command =
-                mSqlConnection->PrepareDataCommand(
-                        Queries.at(QueryType::EPkgIdExists));
-        command->BindString(1, pkgId.c_str());
-        if (command->Step()) {
-            // pkgId found in the database
-            command->Reset();
-            return true;
-        };
-
-        // pkgId not found in the database
-        return false;
+        auto command = getStatement(StmtType::EPkgIdExists);
+        command->BindString(1, pkgId);
+        return command->Step();
     });
 }
 
 bool PrivilegeDb::GetAppPkgId(const std::string &appId, std::string &pkgId)
 {
     return try_catch<bool>([&] {
-        DB::SqlConnection::DataCommandAutoPtr command =
-            mSqlConnection->PrepareDataCommand(Queries.at(QueryType::EGetPkgId));
-        command->BindString(1, appId.c_str());
+        auto command = getStatement(StmtType::EGetPkgId);
+        command->BindString(1, appId);
 
         if (!command->Step()) {
             // No application with such appId
@@ -132,25 +155,19 @@ bool PrivilegeDb::GetAppPkgId(const std::string &appId, std::string &pkgId)
 }
 
 void PrivilegeDb::AddApplication(const std::string &appId,
-        const std::string &pkgId, uid_t uid, bool &pkgIdIsNew)
+        const std::string &pkgId, uid_t uid)
 {
-    pkgIdIsNew = !(this->PkgIdExists(pkgId));
-
     try_catch<void>([&] {
-        DB::SqlConnection::DataCommandAutoPtr command =
-                mSqlConnection->PrepareDataCommand(
-                        Queries.at(QueryType::EAddApplication));
-
-        command->BindString(1, appId.c_str());
-        command->BindString(2, pkgId.c_str());
+        auto command = getStatement(StmtType::EAddApplication);
+        command->BindString(1, appId);
+        command->BindString(2, pkgId);
         command->BindInteger(3, static_cast<unsigned int>(uid));
 
         if (command->Step()) {
             LogDebug("Unexpected SQLITE_ROW answer to query: " <<
-                    Queries.at(QueryType::EAddApplication));
+                    Queries.at(StmtType::EAddApplication));
         };
 
-        command->Reset();
         LogDebug("Added appId: " << appId << ", pkgId: " << pkgId);
     });
 }
@@ -165,19 +182,15 @@ void PrivilegeDb::RemoveApplication(const std::string &appId, uid_t uid,
             return;
         }
 
-        DB::SqlConnection::DataCommandAutoPtr command =
-                mSqlConnection->PrepareDataCommand(
-                        Queries.at(QueryType::ERemoveApplication));
-
-        command->BindString(1, appId.c_str());
+        auto command = getStatement(StmtType::ERemoveApplication);
+        command->BindString(1, appId);
         command->BindInteger(2, static_cast<unsigned int>(uid));
 
         if (command->Step()) {
             LogDebug("Unexpected SQLITE_ROW answer to query: " <<
-                    Queries.at(QueryType::ERemoveApplication));
+                    Queries.at(StmtType::ERemoveApplication));
         };
 
-        command->Reset();
         LogDebug("Removed appId: " << appId);
 
         pkgIdIsNoMore = !(this->PkgIdExists(pkgId));
@@ -188,10 +201,8 @@ void PrivilegeDb::GetPkgPrivileges(const std::string &pkgId, uid_t uid,
         std::vector<std::string> &currentPrivileges)
 {
     try_catch<void>([&] {
-        DB::SqlConnection::DataCommandAutoPtr command =
-                mSqlConnection->PrepareDataCommand(
-                        Queries.at(QueryType::EGetPkgPrivileges));
-        command->BindString(1, pkgId.c_str());
+        auto command = getStatement(StmtType::EGetPkgPrivileges);
+        command->BindString(1, pkgId);
         command->BindInteger(2, static_cast<unsigned int>(uid));
 
         while (command->Step()) {
@@ -202,17 +213,33 @@ void PrivilegeDb::GetPkgPrivileges(const std::string &pkgId, uid_t uid,
     });
 }
 
-void PrivilegeDb::RemoveAppPrivileges(const std::string &appId, uid_t uid)
+void PrivilegeDb::GetAppPrivileges(const std::string &appId, uid_t uid,
+        std::vector<std::string> &currentPrivileges)
 {
     try_catch<void>([&] {
-        DB::SqlConnection::DataCommandAutoPtr command =
-            mSqlConnection->PrepareDataCommand(Queries.at(QueryType::ERemoveAppPrivileges));
+        auto command = getStatement(StmtType::EGetAppPrivileges);
+
+        command->BindString(1, appId);
+        command->BindInteger(2, static_cast<unsigned int>(uid));
+        currentPrivileges.clear();
+
+        while (command->Step()) {
+            std::string privilege = command->GetColumnString(0);
+            LogDebug("Got privilege: " << privilege);
+            currentPrivileges.push_back(privilege);
+        };
+    });
+}
 
-        command->BindString(1, appId.c_str());
+void PrivilegeDb::RemoveAppPrivileges(const std::string &appId, uid_t uid)
+{
+    try_catch<void>([&] {
+        auto command = getStatement(StmtType::ERemoveAppPrivileges);
+        command->BindString(1, appId);
         command->BindInteger(2, static_cast<unsigned int>(uid));
         if (command->Step()) {
             LogDebug("Unexpected SQLITE_ROW answer to query: " <<
-                    Queries.at(QueryType::ERemoveAppPrivileges));
+                    Queries.at(StmtType::ERemoveAppPrivileges));
         }
 
         LogDebug("Removed all privileges for appId: " << appId);
@@ -223,15 +250,14 @@ void PrivilegeDb::UpdateAppPrivileges(const std::string &appId, uid_t uid,
         const std::vector<std::string> &privileges)
 {
     try_catch<void>([&] {
-        DB::SqlConnection::DataCommandAutoPtr command =
-            mSqlConnection->PrepareDataCommand(Queries.at(QueryType::EAddAppPrivileges));
-        command->BindString(1, appId.c_str());
+        auto command = getStatement(StmtType::EAddAppPrivileges);
+        command->BindString(1, appId);
         command->BindInteger(2, static_cast<unsigned int>(uid));
 
         RemoveAppPrivileges(appId, uid);
 
         for (const auto &privilege : privileges) {
-            command->BindString(3, privilege.c_str());
+            command->BindString(3, privilege);
             command->Step();
             command->Reset();
             LogDebug("Added privilege: " << privilege << " to appId: " << appId);
@@ -243,10 +269,8 @@ void PrivilegeDb::GetPrivilegeGroups(const std::string &privilege,
         std::vector<std::string> &groups)
 {
    try_catch<void>([&] {
-        DB::SqlConnection::DataCommandAutoPtr command =
-                mSqlConnection->PrepareDataCommand(
-                        Queries.at(QueryType::EGetPrivilegeGroups));
-        command->BindString(1, privilege.c_str());
+        auto command = getStatement(StmtType::EGetPrivilegeGroups);
+        command->BindString(1, privilege);
 
         while (command->Step()) {
             std::string groupName = command->GetColumnString(0);
@@ -256,5 +280,110 @@ void PrivilegeDb::GetPrivilegeGroups(const std::string &privilege,
     });
 }
 
+void PrivilegeDb::GetUserApps(uid_t uid, std::vector<std::string> &apps)
+{
+   try_catch<void>([&] {
+        auto command = getStatement(StmtType::EGetUserApps);
+        command->BindInteger(1, static_cast<unsigned int>(uid));
+        apps.clear();
+        while (command->Step()) {
+            std::string app = command->GetColumnString(0);
+            LogDebug("User " << uid << " has app " << app << " installed");
+            apps.push_back(app);
+        };
+    });
+}
+
+void PrivilegeDb::GetAppIdsForPkgId(const std::string &pkgId,
+        std::vector<std::string> &appIds)
+{
+    try_catch<void>([&] {
+        auto command = getStatement(StmtType::EGetAppsInPkg);
+
+        command->BindString(1, pkgId);
+        appIds.clear();
+
+        while (command->Step()) {
+            std::string appId = command->GetColumnString (0);
+            LogDebug ("Got appid: " << appId << " for pkgId " << pkgId);
+            appIds.push_back(appId);
+        };
+    });
+}
+
+void PrivilegeDb::GetDefaultMapping(const std::string &version_from,
+                                    const std::string &version_to,
+                                    std::vector<std::string> &mappings)
+{
+    try_catch<void>([&] {
+        auto command = getStatement(StmtType::EGetDefaultMappings);
+        command->BindString(1, version_from);
+        command->BindString(2, version_to);
+
+        mappings.clear();
+        while (command->Step()) {
+            std::string mapping = command->GetColumnString(0);
+            LogDebug("Default Privilege from version " << version_from
+                    <<" to version " << version_to << " is " << mapping);
+            mappings.push_back(mapping);
+        }
+    });
+}
+
+void PrivilegeDb::GetPrivilegeMappings(const std::string &version_from,
+                                       const std::string &version_to,
+                                       const std::string &privilege,
+                                       std::vector<std::string> &mappings)
+{
+    try_catch<void>([&] {
+        auto command = getStatement(StmtType::EGetPrivilegeMappings);
+        command->BindString(1, version_from);
+        command->BindString(2, version_to);
+        command->BindString(3, privilege);
+
+        mappings.clear();
+        while (command->Step()) {
+            std::string mapping = command->GetColumnString(0);
+            LogDebug("Privilege " << privilege << " in version " << version_from
+                    <<" has mapping " << mapping << " in version " << version_to);
+            mappings.push_back(mapping);
+        }
+    });
+}
+
+void PrivilegeDb::GetPrivilegesMappings(const std::string &version_from,
+                                        const std::string &version_to,
+                                        const std::vector<std::string> &privileges,
+                                        std::vector<std::string> &mappings)
+{
+    try_catch<void>([&] {
+        auto deleteCmd = getStatement(StmtType::EDeletePrivilegesToMap);
+        deleteCmd->Step();
+
+        auto insertCmd = getStatement(StmtType::EInsertPrivilegeToMap);
+        for (auto &privilege : privileges) {
+            if (privilege.empty())
+                continue;
+            insertCmd->BindString(1, privilege);
+            insertCmd->Step();
+            insertCmd->Reset();
+        }
+
+        insertCmd->BindNull(1);
+        insertCmd->Step();
+
+        auto queryCmd = getStatement(StmtType::EGetPrivilegesMappings);
+        queryCmd->BindString(1, version_from);
+        queryCmd->BindString(2, version_to);
+
+        mappings.clear();
+        while (queryCmd->Step()) {
+            std::string mapping = queryCmd->GetColumnString(0);
+            LogDebug("Privilege set  in version " << version_from
+                     <<" has mapping " << mapping << " in version " << version_to);
+             mappings.push_back(mapping);
+        }
+    });
+}
 
 } //namespace SecurityManager