Redo sockets 38/96938/8
authorZofia Abramowska <z.abramowska@samsung.com>
Thu, 10 Nov 2016 09:32:56 +0000 (10:32 +0100)
committerBartlomiej Grzelewski <b.grzelewski@samsung.com>
Wed, 30 Nov 2016 13:42:15 +0000 (14:42 +0100)
Add Socket class, remove SelectRead add Poll class

Change-Id: Ia28b49825808554874af3a033d9b3e727f848659

15 files changed:
src/agent/main/NotificationTalker.cpp
src/agent/main/NotificationTalker.h
src/agent/notification-daemon/AskUserTalker.cpp
src/agent/notification-daemon/AskUserTalker.h
src/common/CMakeLists.txt
src/common/log/alog.h
src/common/socket/Poll.cpp [new file with mode: 0644]
src/common/socket/Poll.h [moved from src/common/socket/SelectRead.h with 58% similarity]
src/common/socket/SelectRead.cpp [deleted file]
src/common/socket/Socket.cpp
src/common/socket/Socket.h
src/common/types/NotificationRequest.h
src/common/types/Protocol.h
test/CMakeLists.txt
test/daemon/notificationTalker.cpp

index d5b2e96..ce7d640 100644 (file)
 #include <exception/ErrnoException.h>
 #include <exception/CynaraException.h>
 #include <log/alog.h>
+#include <socket/Poll.h>
 #include <socket/Socket.h>
 #include <translator/Translator.h>
 #include <config/Path.h>
 #include <types/Protocol.h>
 
+namespace {
+
+std::string getUserFromSocket(const AskUser::Socket &sock) {
+    char *user = nullptr;
+
+    int ret = cynara_creds_socket_get_user(sock.getFd(), USER_METHOD_DEFAULT,&user);
+    if (ret != CYNARA_API_SUCCESS) {
+        throw AskUser::CynaraException("cynara_creds_socket_get_user", ret);
+    }
+
+    std::unique_ptr<char, decltype(free) *> userPtr(user, ::free);
+    return user;
+}
+
+}
+
 namespace AskUser {
 
 namespace Agent {
 
-NotificationTalker::NotificationTalker() : m_failed(true), m_stopflag(false)
+NotificationTalker::NotificationTalker()
+    : m_poll(10),
+      m_serverSocket(Socket::PeerType::SERVER),
+      m_failed(true),
+      m_stopflag(false)
 {
     try {
-        m_sockfd = Socket::listen(Path::getSocketPath());
+        m_serverSocket.bindAndListen(Path::getSocketPath());
         m_thread = std::thread(&NotificationTalker::run, this);
         m_failed = false;
     } catch (const Exception &e) {
@@ -111,8 +132,8 @@ void NotificationTalker::removeRequest(RequestId id)
         }
         if (it == queue.begin()) {
             auto user = std::get<0>(pair);
-            auto it2 = m_userToFd.find(user);
-            if (it2 != m_userToFd.end())
+            auto it2 = m_userToPeer.find(user);
+            if (it2 != m_userToPeer.end())
                 sendDismiss(std::get<1>(*it2));
         }
 
@@ -127,17 +148,11 @@ void NotificationTalker::stop()
 
 void NotificationTalker::clear()
 {
-    for (auto& pair : m_fdStatus) {
-        int fd = std::get<0>(pair);
-        Socket::close(fd);
-    }
-
     m_fdStatus.clear();
     m_fdToUser.clear();
-    m_userToFd.clear();
+    m_userToPeer.clear();
 
-    Socket::close(m_sockfd);
-    m_sockfd = 0;
+    m_serverSocket.close();
 }
 
 NotificationTalker::~NotificationTalker()
@@ -147,30 +162,25 @@ NotificationTalker::~NotificationTalker()
     m_thread.join();
 }
 
-void NotificationTalker::sendRequest(int fd, const NotificationRequest &request)
+void NotificationTalker::sendRequest(Socket &peerSocket, const NotificationRequest &request)
 {
+    int fd = peerSocket.getFd();
     m_fdStatus[fd] = false;
 
     std::string data = Translator::Gui::notificationRequestToData(request.id,
                                                                   request.data.client,
                                                                   request.data.privilege);
-    auto size = data.size();
-
-    if (!Socket::send(fd, &size, sizeof(size))) {
-        remove(fd);
-        return;
-    }
-
-    if (!Socket::send(fd, data.c_str(), size)) {
+    if (!peerSocket.send(data)) {
         remove(fd);
         return;
     }
 }
 
-void NotificationTalker::sendDismiss(int fd)
+void NotificationTalker::sendDismiss(Socket &peerSocket)
 {
+    int fd = peerSocket.getFd();
     if (!m_fdStatus[fd]) {
-        if (!Socket::send(fd, &Protocol::dissmisCode, sizeof(Protocol::dissmisCode))) {
+        if (peerSocket.send(Protocol::dissmisCode)) {
             remove(fd);
             return;
         }
@@ -178,8 +188,9 @@ void NotificationTalker::sendDismiss(int fd)
     }
 }
 
-void NotificationTalker::parseResponse(NotificationResponse response, int fd)
+void NotificationTalker::parseResponse(NotificationResponse response, Socket &peerSocket)
 {
+    int fd = peerSocket.getFd();
     auto &queue = m_requests[m_fdToUser[fd]];
     if (queue.empty()) {
         ALOGD("Request canceled");
@@ -202,7 +213,7 @@ void NotificationTalker::parseResponse(NotificationResponse response, int fd)
 
     m_responseHandler(response);
 
-    if (!Socket::send(fd, &Protocol::ackCode, sizeof(Protocol::ackCode))) {
+    if (peerSocket.send(Protocol::ackCode)) {
         remove(fd);
         return;
     }
@@ -210,53 +221,67 @@ void NotificationTalker::parseResponse(NotificationResponse response, int fd)
     m_fdStatus[fd] = true;
 }
 
+bool NotificationTalker::recvResponse(Socket &peerSocket, NotificationResponse &response)
+{
+    int requestId, responseType;
+    if (!peerSocket.recv(requestId) || !peerSocket.recv(responseType)) {
+        ALOGE("Failed to fetch response");
+        return false;
+    }
+    response.id = static_cast<RequestId>(requestId);
+    response.response = static_cast<NResponseType>(responseType);
+    return true;
+}
+
 void NotificationTalker::recvResponses(int &rv)
 {
-    for (auto pair : m_userToFd) {
+    std::vector<std::string> usersToDelete;
+    for (auto &userPeer : m_userToPeer) {
         if (!rv) break;
-        int fd = std::get<1>(pair);
-
-        if (m_select.isSet(fd)) {
+        auto &peerSocket = userPeer.second;
+        int fd = peerSocket.getFd();
+        if (m_poll.getEvents(fd) | POLLIN) {
             --rv;
 
             NotificationResponse response;
-            if (Socket::recv(fd, &response, sizeof(response))) {
-                parseResponse(response, fd);
+            if (recvResponse(peerSocket, response)) {
+                parseResponse(response, peerSocket);
             } else {
-                remove(fd);
+                m_fdToUser.erase(fd);
+                m_fdStatus.erase(fd);
+                usersToDelete.push_back(userPeer.first);
             }
         }
     }
+
+    for (auto &user : usersToDelete) {
+        m_userToPeer.erase(user);
+    }
 }
 
 void NotificationTalker::newConnection(int &rv)
 {
-    if (m_select.isSet(m_sockfd)) {
+    if (m_poll.getEvents(m_serverSocket.getFd()) | POLLIN) {
         --rv;
-        int fd = Socket::accept(m_sockfd);
-        try {
-            char *user_c = nullptr;
+        Socket peerSocket = m_serverSocket.accept();
+        int peerFd = peerSocket.getFd();
 
-            int ret = cynara_creds_socket_get_user(fd, USER_METHOD_DEFAULT,&user_c);
-
-            std::unique_ptr<char[]> userPtr(user_c);
+        try {
+            std::string user = getUserFromSocket(peerSocket);
 
-            if (ret != CYNARA_API_SUCCESS) {
-                throw CynaraException("cynara_creds_socket_get_user", ret);
+            auto it = m_userToPeer.find(user);
+            // Same user connected second time
+            if (it != m_userToPeer.end()) {
+                remove(it->second.getFd());
             }
-            std::string user = user_c;
-
-            auto it = m_userToFd.find(user);
-            if (it != m_userToFd.end())
-                remove(std::get<1>(*it));
 
-            m_userToFd[user] = fd;
-            m_fdToUser[fd] = user;
-            m_fdStatus[fd] = true;
+            m_userToPeer.emplace(user, std::move(peerSocket));
+            m_fdToUser[peerFd] = user;
+            m_fdStatus[peerFd] = true;
 
             ALOGD("Accepted new conection for user: " << user);
         } catch (...) {
-            Socket::close(fd);
+            peerSocket.close();
             throw;
         }
     }
@@ -264,10 +289,9 @@ void NotificationTalker::newConnection(int &rv)
 
 void NotificationTalker::remove(int fd)
 {
-    Socket::close(fd);
     auto user = m_fdToUser[fd];
     m_fdToUser.erase(fd);
-    m_userToFd.erase(user);
+    m_userToPeer.erase(user);
     m_fdStatus.erase(fd);
 }
 
@@ -278,13 +302,12 @@ void NotificationTalker::run()
         while (!m_stopflag) {
             std::lock_guard<std::mutex> lock(m_bfLock);
 
-            m_select.add(m_sockfd);
+            m_poll.setEvents(m_serverSocket.getFd(), POLLIN);
 
-            for (auto pair : m_userToFd)
-                m_select.add(std::get<1>(pair));
+            for (auto &userPeer : m_userToPeer)
+                m_poll.setEvents(userPeer.second.getFd(), POLLIN);
 
-            m_select.setTimeout(100);
-            int rv = m_select.exec();
+            int rv = m_poll.wait(100);
 
             if (m_stopflag) {
                 clear();
@@ -295,14 +318,13 @@ void NotificationTalker::run()
                 newConnection(rv);
                 recvResponses(rv);
             }
+            for (auto &userPeer : m_userToPeer) {
+                const std::string &user = userPeer.first;
+                Socket &socketPeer = userPeer.second;
 
-            for (auto pair : m_fdStatus ) {
-                int fd = std::get<0>(pair);
-                bool b = std::get<1>(pair);
-                auto &queue = m_requests[m_fdToUser[fd]];
-                if (b && !queue.empty()) {
-                    NotificationRequest request = queue.front();
-                    sendRequest(fd, request);
+                auto &queue = m_requests[user];
+                if (m_fdStatus.at(socketPeer.getFd()) && !queue.empty()) {
+                    sendRequest(socketPeer, queue.front());
                 }
             }
         }
@@ -316,8 +338,9 @@ void NotificationTalker::run()
     }
 
     if (m_failed && m_responseHandler) {
-        for (auto &queuePair : m_requests) {
-            for (auto &request : std::get<1>(queuePair)) {
+        for (auto &userQueue : m_requests) {
+            auto &queue = userQueue.second;
+            for (auto &request : queue) {
                 m_responseHandler({request.id, NResponseType::Error});
             }
         }
index 1bed8b3..ea3fdb9 100644 (file)
@@ -29,7 +29,8 @@
 #include <string>
 #include <thread>
 
-#include <socket/SelectRead.h>
+#include <socket/Socket.h>
+#include <socket/Poll.h>
 #include <types/RequestId.h>
 #include <types/NotificationResponse.h>
 #include <types/NotificationRequest.h>
@@ -41,7 +42,7 @@ namespace AskUser {
 namespace Agent {
 
 typedef std::pair<std::string, int> UserToFdPair;
-typedef std::map<std::string, int> UserToFdMap;
+typedef std::map<std::string, Socket> UserToPeerMap;
 typedef std::map<int, std::string> FdToUserMap;
 typedef std::map<int, bool> FdStatus;
 
@@ -67,7 +68,8 @@ public:
 protected:
     void setErrorMsg(std::string s);
     void run();
-    void parseResponse(NotificationResponse response, int fd);
+    void parseResponse(NotificationResponse response, Socket &peerSocket);
+    bool recvResponse(Socket &peerSocket, NotificationResponse &response);
     void recvResponses(int &rv);
 
     void newConnection(int &rv);
@@ -77,16 +79,17 @@ protected:
 
     virtual void addRequest(NotificationRequest &&request);
     virtual void removeRequest(RequestId id);
-    virtual void sendRequest(int fd, const NotificationRequest &request);
-    virtual void sendDismiss(int fd);
+    virtual void sendRequest(Socket &peerSocket, const NotificationRequest &request);
+    virtual void sendDismiss(Socket &peerSocket);
 
     ResponseHandler m_responseHandler;
 
-    UserToFdMap m_userToFd;
+    UserToPeerMap m_userToPeer;
     FdToUserMap m_fdToUser;
     FdStatus m_fdStatus;
-    Socket::SelectRead m_select;
-    int m_sockfd = 0;
+    Poll m_poll;
+    Socket m_serverSocket;
+    std::vector<Socket> m_peers;
     bool m_failed;
     std::string m_errorMsg;
 
index c36f3db..ef86693 100644 (file)
@@ -28,7 +28,7 @@
 #include "PrivilegeInfo.h"
 
 #include <socket/Socket.h>
-#include <socket/SelectRead.h>
+#include <socket/Poll.h>
 #include <types/NotificationResponse.h>
 #include <types/Protocol.h>
 #include <types/NotificationRequest.h>
@@ -36,7 +36,6 @@
 #include <exception/Exception.h>
 #include <translator/Translator.h>
 #include <config/Path.h>
-#include <config/Limits.h>
 
 #include <security-manager.h>
 
@@ -88,48 +87,49 @@ void setSecurityLevel(const std::string &app, const std::string &perm, const std
 } /* namespace */
 
 
-AskUserTalker::AskUserTalker(GuiRunner *gui) : m_gui(gui) {
+AskUserTalker::AskUserTalker(GuiRunner *gui) : m_clientSocket(Socket::PeerType::CLIENT), m_gui(gui) {
     m_gui->setDropHandler([&](){return this->shouldDismiss();});
 }
 
-AskUserTalker::~AskUserTalker()
-{
-    try {
-        Socket::close(sockfd);
-    } catch (const std::exception &e) {
-        ALOGE(std::string("~AskUserTalker") + e.what());
-    } catch (...) {
-        ALOGE("~AskUserTalker: Unknow error");
+bool AskUserTalker::fetchRequest(NotificationRequest &request) {
+    std::string requestData;
+
+    if (!m_clientSocket.recv(requestData)) {
+        ALOGI("Failed fetching request, closing...");
+        return false;
+    }
+
+    request = Translator::Gui::dataToNotificationRequest(requestData);
+    ALOGD("Recieved data " << request.data.client << " " << request.data.privilege);
+
+    return true;
+}
+
+bool AskUserTalker::sendResponse(const NotificationResponse &response) {
+    if (!m_clientSocket.send(static_cast<int>(response.id))) {
+        ALOGI("Askuserd closed connection, closing...");
+        return false;
+    }
+    if (!m_clientSocket.send(static_cast<int>(response.response))) {
+        ALOGI("Askuserd closed connection, closing...");
+        return false;
     }
+    return true;
 }
 
 void AskUserTalker::run()
 {
-    sockfd = Socket::connect(Path::getSocketPath());
+    m_clientSocket.connect(Path::getSocketPath());
 
     while (!stopFlag) {
-        size_t size;
-        NotificationResponse response;
-
         ALOGD("Waiting for request...");
-
-        if (!Socket::recv(sockfd, &size, sizeof(size))) {
-            ALOGI("Askuserd closed connection, closing...");
-            break;
-        }
-
-        Limits::checkSizeLimit(size);
-
-        std::unique_ptr<char[]> buf(new char[size]);
-
-        if (!Socket::recv(sockfd, buf.get(), size)) {
-            ALOGI("Askuserd closed connection, closing...");
+        NotificationRequest request;
+        if (!fetchRequest(request)) {
+            ALOGE("Couldn't fetch request, closing...");
             break;
         }
 
-        NotificationRequest request = Translator::Gui::dataToNotificationRequest(buf.get());
-        ALOGD("Recieved data " << request.data.client << " " << request.data.privilege);
-
+        NotificationResponse response;
         response.response = m_gui->popupRun(request.data.client,
                 PrivilegeInfo::getPrivilegeDisplayName(request.data.privilege));
         response.id = request.id;
@@ -138,13 +138,13 @@ void AskUserTalker::run()
             continue;
         }
 
-        if (!Socket::send(sockfd, &response, sizeof(response))) {
-            ALOGI("Askuserd closed connection, closing...");
-            break;
-        }
+       if (!sendResponse(response)) {
+           ALOGE("Couldn't send response, closing...");
+           break;
+       }
 
-        uint8_t ack = 0x00;
-        if (!Socket::recv(sockfd, &ack, sizeof(ack))) {
+        int ack = 0;
+        if (!m_clientSocket.recv(ack)) {
             ALOGI("Askuserd closed connection, closing...");
             break;
         }
@@ -159,7 +159,9 @@ void AskUserTalker::run()
         case NResponseType::Never:
             setSecurityLevel(request.data.client, request.data.privilege,
                              Translator::Gui::responseToString(response.response));
+            break;
         default:
+            ALOGD("Unknown response type returned");
             break;
         }
     }
@@ -168,18 +170,22 @@ void AskUserTalker::run()
 void AskUserTalker::stop()
 {
     m_gui->stop();
-    Socket::close(sockfd);
+    m_clientSocket.close();
 }
 
 bool AskUserTalker::shouldDismiss()
 {
-    Socket::SelectRead select;
-    select.add(sockfd);
-    if (select.exec() == 0)
+    Poll poller(1);
+    poller.setEvents(sockfd, POLLIN);
+    int ret = poller.wait(0);
+    if (ret == 0)
         return false;
-
-    uint8_t a = 0x00;
-    Socket::recv(sockfd, &a, sizeof(a));
+    if (ret == -1) {
+        ALOGE("Poll failed");
+        return true;
+    }
+    int a = 0x00;
+    m_clientSocket.recv(a);
 
     if (a != Protocol::dissmisCode)
         throw Exception("Incorrect dismiss flag");
index a7b9e72..1a53589 100644 (file)
 
 #pragma once
 
-#include <functional>
-#include <queue>
-#include <memory>
-#include <mutex>
 
 #include "GuiRunner.h"
 
+#include <socket/Socket.h>
+#include <types/NotificationRequest.h>
+#include <types/NotificationResponse.h>
+
 namespace AskUser {
 
 namespace Notification {
@@ -36,7 +36,6 @@ class AskUserTalker
 {
 public:
       AskUserTalker(GuiRunner *gui);
-      ~AskUserTalker();
 
       void run();
       void stop();
@@ -44,6 +43,9 @@ public:
       bool shouldDismiss();
 
 private:
+      bool fetchRequest(NotificationRequest &request);
+      bool sendResponse(const NotificationResponse &response);
+      Socket m_clientSocket;
       GuiRunner *m_gui;
       int sockfd = 0;
       bool stopFlag = false;
index 57f9ad9..adcae22 100644 (file)
@@ -38,7 +38,7 @@ INCLUDE_DIRECTORIES(
 SET(COMMON_SOURCES
     ${COMMON_PATH}/log/alog.cpp
     ${COMMON_PATH}/socket/Socket.cpp
-    ${COMMON_PATH}/socket/SelectRead.cpp
+    ${COMMON_PATH}/socket/Poll.cpp
     ${COMMON_PATH}/translator/Translator.cpp
     ${COMMON_PATH}/types/AgentErrorMsg.cpp
     ${COMMON_PATH}/util/SafeFunction.cpp
index ff7e545..d4c6379 100644 (file)
@@ -53,6 +53,18 @@ namespace {
         } \
     } while (0)
 
+#define ALOGE_ERRNO(MESSAGE) \
+    do { \
+        if (LOG_ERR <= __alog_level) { \
+            constexpr int bufsize = 1024; \
+            char buf[bufsize]; \
+            char *err_str = strerror_r(errno, buf, bufsize); \
+            std::stringstream __LOG_msg; \
+            __LOG_msg << MESSAGE << " : " << err_str; \
+            __ALOG_FUN(LOG_ERR, __LOG_msg); \
+        } \
+    } while(0)
+
 #define ALOGM(...)  __ALOG(LOG_EMERG, __VA_ARGS__)   /* system is unusable */
 #define ALOGA(...)  __ALOG(LOG_ALERT, __VA_ARGS__)   /* action must be taken immediately */
 #define ALOGC(...)  __ALOG(LOG_CRIT, __VA_ARGS__)    /* critical conditions */
diff --git a/src/common/socket/Poll.cpp b/src/common/socket/Poll.cpp
new file mode 100644 (file)
index 0000000..598016d
--- /dev/null
@@ -0,0 +1,92 @@
+/*
+ *  Copyright (c) 2016 Samsung Electronics Co.
+ *
+ *  Licensed under the Apache License, Version 2.0 (the "License");
+ *  you may not use this file except in compliance with the License.
+ *  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License
+ */
+/**
+ * @file        Poll.cpp
+ * @author      Zofia Abramowska
+ * @brief       Definition of Poll class
+ */
+
+#include "Poll.h"
+
+#include <exception/ErrnoException.h>
+#include <log/alog.h>
+#include <unistd.h>
+
+namespace AskUser {
+
+void Poll::setEvents(int fd, int events) {
+    ALOGD("Setting events : " << events << " for : " << fd);
+    auto it = m_fdToPollFd.find(fd);
+    if (it != m_fdToPollFd.end()) {
+        ALOGD("Updating events for existing fd");
+        m_fds[it->second].events = events;
+        return;
+    }
+
+    if (m_fdsTaken == m_fdsCount) {
+        m_fdsCount *=2;
+        pollfd *newFds = new pollfd[m_fdsCount];
+        memcpy(newFds, m_fds, m_fdsTaken * sizeof(pollfd));
+        delete m_fds;
+        m_fds = newFds;
+    }
+    ALOGD("Adding new entry for fd");
+    m_fdToPollFd[fd] = m_fdsTaken;
+    m_fds[m_fdsTaken].fd = fd;
+    m_fds[m_fdsTaken].events = events;
+    m_fdsTaken++;
+}
+
+void Poll::unset(int fd) {
+    ALOGD("Unsetting fd : " << fd);
+    auto it = m_fdToPollFd.find(fd);
+    if (it == m_fdToPollFd.end()) {
+        ALOGD("Fd not set, ignoring");
+        return; // ignore
+    }
+
+    m_fdsTaken--;
+    if (m_fdsTaken == 0)
+        return;
+
+    // Move last one to place of unset if we have anything
+    auto last = m_fds[m_fdsTaken];
+
+    if (last.fd != fd) {
+        m_fdToPollFd[last.fd] = it->second;
+        m_fds[it->second].fd = last.fd;
+        m_fds[it->second].events = last.events;
+    }
+
+    m_fdToPollFd.erase(it);
+}
+
+int Poll::getEvents(int fd) {
+    ALOGD("Getting events for : " << fd);
+    auto it = m_fdToPollFd.find(fd);
+    if (it == m_fdToPollFd.end()) {
+        ALOGD("Fd not set, ignoring");
+        return 0;
+    }
+    return m_fds[it->second].revents;
+}
+
+int Poll::wait(int msec) {
+    ALOGD("Waiting for : " << (msec == -1 ? "infinity" : std::to_string(msec)));
+    return TEMP_FAILURE_RETRY(poll(m_fds, m_fdsTaken, msec));
+}
+
+} /* namespace AskUser */
similarity index 58%
rename from src/common/socket/SelectRead.h
rename to src/common/socket/Poll.h
index 3646ef5..beac3de 100644 (file)
  *  limitations under the License
  */
 /**
- * @file        SelectRead.h
- * @author      Oskar Świtalski <o.switalski@samsung.com>
- * @brief       Declaration of SelectRead class
+ * @file        Poll.h
+ * @author      Zofia Abramowska <z.abramowska@samsung.com>
+ * @brief       Declaration of Poll class
  */
 
 #pragma once
 
-#include <sys/select.h>
+#include <poll.h>
+#include <map>
 
 namespace AskUser {
 
-namespace Socket {
-
-class SelectRead {
+class Poll {
 public:
-    SelectRead();
+    Poll(int fdsCount) : m_fdsCount(fdsCount), m_fdsTaken(0) { m_fds = new pollfd[m_fdsCount]; }
+    ~Poll() { delete m_fds; }
+    void setEvents(int fd, int events);
+    void unset(int fd);
+    int getEvents(int fd);
+    int wait(int msec);
 
-    void add(int fd);
-    int exec();
-    bool isSet(int fd);
-    void setTimeout(int ms);
 private:
-    bool m_exec;
-
-    fd_set m_set;
-    int m_nfds;
-
-    timeval m_timeout;
+    pollfd *m_fds;
+    int m_fdsCount;
+    int m_fdsTaken;
+    std::map<int, int> m_fdToPollFd;
 };
 
-} /* namespace Socket */
-
 } /* namespace AskUser */
diff --git a/src/common/socket/SelectRead.cpp b/src/common/socket/SelectRead.cpp
deleted file mode 100644 (file)
index b34710e..0000000
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- *  Copyright (c) 2016 Samsung Electronics Co.
- *
- *  Licensed under the Apache License, Version 2.0 (the "License");
- *  you may not use this file except in compliance with the License.
- *  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- *  Unless required by applicable law or agreed to in writing, software
- *  distributed under the License is distributed on an "AS IS" BASIS,
- *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- *  See the License for the specific language governing permissions and
- *  limitations under the License
- */
-/**
- * @file        SelectRead.cpp
- * @author      Oskar Świtalski <o.switalski@samsung.com>
- * @brief       Definition of SelectRead class
- */
-
-#include "SelectRead.h"
-
-#include <exception/ErrnoException.h>
-
-namespace AskUser {
-
-namespace Socket {
-
-SelectRead::SelectRead() : m_exec(true), m_timeout({0, 0}) {}
-
-void SelectRead::add(int fd) {
-    if (m_exec) {
-        FD_ZERO(&m_set);
-        m_nfds = -1;
-        m_exec = false;
-    }
-
-    FD_SET(fd, &m_set);
-    m_nfds = m_nfds > fd ? m_nfds : fd;
-}
-
-int SelectRead::exec() {
-    int result = 0;
-
-    m_exec = true;
-
-    result = select(m_nfds + 1, &m_set, nullptr, nullptr, &m_timeout);
-    if (result == -1)
-        throw ErrnoException("Select failed");
-
-    return result;
-}
-
-bool SelectRead::isSet(int fd) {
-    return FD_ISSET(fd, &m_set);
-}
-
-void SelectRead::setTimeout(int ms) {
-    m_timeout.tv_usec = ms * 1000;
-}
-
-} /* namespace Socket */
-
-} /* namespace AskUser */
index db77992..4d32376 100644 (file)
 
 #include "Socket.h"
 
+#include <config/Limits.h>
 #include <exception/ErrnoException.h>
 #include <log/alog.h>
 
+#include <poll.h>
 #include <stdexcept>
 #include <sys/socket.h>
 #include <sys/un.h>
 #include <unistd.h>
+#include <vector>
 
 namespace AskUser {
+Socket::Socket(Socket &&other) : m_type(std::move(other.m_type)), m_fd(std::move(other.m_fd)) {
+    other.m_fd = -1;
+}
 
-namespace Socket {
-
-int accept(int fd) {
-    int retFd = TEMP_FAILURE_RETRY(::accept(fd, nullptr, nullptr));
-    if (retFd < 0)
-        throw ErrnoException("Accept socket error");
+Socket &Socket::operator=(Socket &&other) {
+    if (this == &other)
+        return *this;
 
-    ALOGD("Accepted socket <" << retFd << ">");
+    m_fd = std::move(other.m_fd);
+    other.m_fd = -1;
 
-    return retFd;
+    m_type = std::move(other.m_type);
+    return *this;
 }
 
-void close(int fd) {
-    int result = TEMP_FAILURE_RETRY(::close(fd));
-    if (result < 0) {
-        ALOGE("Close socket <" << fd << "> failed");
-        return;
+bool Socket::connect(const std::string &path) {
+    if (m_type != PeerType::CLIENT) {
+        ALOGW("Connect is available only for CLIENT type socket");
+        return false;
     }
-
-    ALOGD("Closed socket <" << fd << ">");
-}
-
-int connect(const std::string &path) {
-    int fd = -1;
-    int result = 0;
-    size_t length = 0;
-
     sockaddr_un remote;
     remote.sun_family = AF_UNIX;
 
-    if (path.size() >= sizeof(remote.sun_path))
-        throw std::invalid_argument("Path length is too big");
-    strcpy(remote.sun_path, path.c_str());
+    if (path.size() >= sizeof(remote.sun_path)) {
+        ALOGE("Length of socket path " << path << " is too big");
+        return false;
+    }
 
-    length = strlen(remote.sun_path) + sizeof(remote.sun_family);
+    strcpy(remote.sun_path, path.c_str());
+    ssize_t length = strlen(remote.sun_path) + sizeof(remote.sun_family);
 
-    fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
-    if (fd == -1)
-        throw ErrnoException("Socket creation failed");
+    m_fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
+    if (m_fd < 0) {
+        ALOGE_ERRNO("Socket creation failed");
+        return false;
+    }
 
-    result = TEMP_FAILURE_RETRY(::connect(fd, (struct sockaddr *)&remote, length));
-    if (result == -1)
-        throw ErrnoException("Connecting to <" + path + "> socket failed");
+    int result = TEMP_FAILURE_RETRY(::connect(m_fd, (struct sockaddr *)&remote, length));
+    if (result == -1) {
+        ALOGE_ERRNO("Couldn't connect to " << path);
+        return false;
+    }
 
     ALOGD("Connected to <" << path << "> socket");
-
-    return fd;
+    return true;
 }
 
-int listen(const std::string &path) {
-    int fd = -1;
+bool Socket::bindAndListen(const std::string &path) {
+    if (m_type != PeerType::SERVER) {
+        ALOGW("Connect is available only for SERVER type socket");
+        return false;
+    }
+
     int result = 0;
     size_t length = 0;
 
     sockaddr_un local;
     local.sun_family = AF_UNIX;
 
-    if (path.size() >= sizeof(local.sun_path))
-        throw std::invalid_argument("Socket path too long");
-    strcpy(local.sun_path, path.c_str());
+    if (path.size() >= sizeof(local.sun_path)) {
+        ALOGE("Length of socket path " << path << " is too big");
+        return false;
+    }
 
+    strcpy(local.sun_path, path.c_str());
     length = strlen(local.sun_path) + sizeof(local.sun_family);
 
     result = unlink(path.c_str());
-    if (result == -1 && errno != ENOENT)
-        throw ErrnoException("Unlink " + path + " failed");
+    if (result == -1 && errno != ENOENT) {
+        ALOGE_ERRNO("Unlink " + path + " failed");
+        return false;
+    }
 
-    fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
-    if (fd == -1)
-        throw ErrnoException("Socket creation failed");
+    m_fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
+    if (m_fd == -1) {
+        ALOGE_ERRNO("Socket creation failed");
+        return false;
+    }
 
-    result = ::bind(fd, (struct sockaddr *)&local, length);
-    if (result == -1)
-        throw ErrnoException("Binding to <" + path + "> failed");
+    result = ::bind(m_fd, (struct sockaddr *)&local, length);
+    if (result == -1) {
+        ALOGE_ERRNO("Binding to <" + path + "> failed");
+        return false;
+    }
 
-    result = ::listen(fd, 10);
-    if (result == -1)
-        throw ErrnoException("Listen on socked failed");
+    result = ::listen(m_fd, 5);
+    if (result == -1) {
+        ALOGE_ERRNO("Listen on socked failed");
+        return false;
+    }
 
     ALOGD("Listening on <" << path << "> socket");
+    return true;
+}
 
-    return fd;
+Socket Socket::accept() {
+    int retFd = TEMP_FAILURE_RETRY(::accept(m_fd, nullptr, nullptr));
+    if (retFd < 0) {
+        ALOGE_ERRNO("accept failed");
+        return -1;
+    }
+    ALOGD("Accepted socket <" << retFd << ">");
+    return Socket(retFd);
 }
 
-bool recv(int fd, void *buf, size_t size, int flags) {
+bool Socket::recvData(void *data, size_t len) {
     int result = 0;
     size_t bytesRead = 0;
 
-    while (bytesRead < size) {
-        result = TEMP_FAILURE_RETRY(::recv(fd, static_cast<char*>(buf) + bytesRead,
-                                           size - bytesRead, flags));
+    while (bytesRead < len) {
+        result = TEMP_FAILURE_RETRY(::read(m_fd, static_cast<char*>(data) + bytesRead,
+                                           len - bytesRead));
 
-        if (result < 0 && errno != ECONNRESET)
-            throw ErrnoException("Error receiving data from socket");
-        else if (result <= 0)
+        if (result < 0 && errno != ECONNRESET) {
+            ALOGE_ERRNO("Error receiving data from socket");
+            return false;
+        } else if (result <= 0)
             return false;
 
         bytesRead += result;
-
-        ALOGD("Recieved " << bytesRead << "/" << size << " byte(s)");
+        ALOGD("Recieved " << bytesRead << "/" << len << " byte(s)");
     }
 
     return true;
 }
 
-bool send(int fd, const void *buf, size_t size, int flags) {
+bool Socket::recv(std::string &msg) {
+    int size;
+    if(!recv(size))
+        return false;
+
+    std::vector<char> strData(size);
+    if (!recvData(strData.data(), static_cast<size_t>(size)))
+        return false;
+
+    msg.assign(strData.begin(), strData.end());
+    return true;
+}
+
+bool Socket::recv(int &msg) {
+    int msgData;
+    if (!recvData(&msgData, sizeof(msg)))
+        return false;
+
+    Limits::checkSizeLimit(msgData);
+
+    msg = msgData;
+    return true;
+}
+
+bool Socket::sendData(const void *data, size_t len) {
     int result = 0;
     size_t bytesSend = 0;
 
-    while (bytesSend < size) {
-
-        result = TEMP_FAILURE_RETRY(::send(fd, static_cast<const char*>(buf) + bytesSend,
-                                           size - bytesSend, flags | MSG_NOSIGNAL));
+    while (bytesSend < len) {
+        result = TEMP_FAILURE_RETRY(::send(m_fd, static_cast<const char*>(data) + bytesSend,
+                len - bytesSend, MSG_NOSIGNAL));
 
-        if (result < 0) {
-            if (errno == EPIPE)
-                return false;
-            else
-        throw ErrnoException("Error sending data to socket");
+        if (result < 0 && errno != EPIPE) {
+            ALOGE_ERRNO("Error sending data to socket");
+            return false;
+        } else if (result <= 0) {
+            ALOGE_ERRNO("Other side disconnected");
+            return false;
         }
 
         bytesSend += result;
 
-        ALOGD("Send " << result << "/" << size << " byte(s)");
+        ALOGD("Send " << result << "/" << len << " byte(s)");
     }
 
     return true;
 }
+bool Socket::send(const std::string &msg) {
+    if (!send(msg.size()))
+        return false;
+    return sendData(msg.c_str(), msg.size());
+}
+bool Socket::send(int msg) {
+    return sendData(&msg, sizeof(msg));
+}
 
-} /* namespace Socket */
+void Socket::close() {
+    ::close(m_fd);
+}
 
 } /* namespace AskUser */
index 0f66d0d..db5a68a 100644 (file)
 /**
  * @file        Socket.cpp
  * @author      Oskar Świtalski <o.switalski@samsung.com>
- * @brief       Declaration of Socket methods
+ * @author      Zofia Abramowska <z.abramowska@samsung.com>
+ * @brief       Declaration of client/server Socket wrapper
  */
 
 #pragma once
 
-#include <cstddef>
 #include <string>
+#include <unistd.h>
 
 namespace AskUser {
 
-namespace Socket {
+class Socket {
+public:
+    enum PeerType {
+        CLIENT, // connect
+        SERVER, // accept
+        PEER    // accepted peer
+    };
 
-int accept(int fd);
-void close(int fd);
-int connect(const std::string &path);
-int listen(const std::string &path);
-bool recv(int fd, void *buf, size_t size, int flags = 0);
-bool send(int fd, const void *buf, size_t size, int flags = 0);
 
-} /* namespace Socket */
+    Socket(PeerType type) : m_type(type), m_fd(-1) {}
+    Socket(int fd) : m_type(PEER), m_fd(fd) {}
+
+    Socket(const Socket &other) = delete;
+    Socket &operator=(const Socket &other) = delete;
+
+    Socket(Socket &&other);
+    Socket &operator=(Socket &&other);
+
+    ~Socket() { ::close(m_fd); }
+
+    int getFd() const { return m_fd; }
+
+    bool connect(const std::string &path);
+
+    bool bindAndListen(const std::string &path);
+    Socket accept();
+
+    bool recv(std::string &msg);
+    bool recv(int &msg);
+
+    bool send(const std::string &msg);
+    bool send(int msg);
+
+    void close();
+private:
+    bool recvData(void *data, size_t len);
+    bool sendData(const void *data, size_t len);
+    PeerType m_type;
+    int m_fd;
+};
 
 } /* namespace AskUser */
index 7447c52..35e9e65 100644 (file)
@@ -28,6 +28,7 @@
 namespace AskUser {
 
 struct NotificationRequest {
+    NotificationRequest() : id(-1) {};
     NotificationRequest(RequestId id_) : id(id_) {};
     NotificationRequest(RequestId id_, std::string client, std::string user, std::string privilege)
     : id(id_),
index dafd277..c97a2ed 100644 (file)
@@ -26,8 +26,8 @@
 namespace AskUser {
 namespace Protocol {
 
-constexpr uint8_t dissmisCode = 0xDE;
-constexpr uint8_t ackCode = 0xAC;
+constexpr int dissmisCode = 0xDE;
+constexpr int ackCode = 0xAC;
 
 } // namespace Protocol
 } // namespace AskUser
index 1a56a8b..01abfb8 100644 (file)
@@ -24,10 +24,11 @@ SET(TESTS_SOURCES
     ${TESTS_PATH}/common/translator.cpp
     ${TESTS_PATH}/daemon/notificationTalker.cpp
 
+    ${PROJECT_SOURCE_DIR}/src/common/config/Limits.cpp
     ${PROJECT_SOURCE_DIR}/src/common/config/Path.cpp
     ${PROJECT_SOURCE_DIR}/src/common/log/alog.cpp
     ${PROJECT_SOURCE_DIR}/src/common/socket/Socket.cpp
-    ${PROJECT_SOURCE_DIR}/src/common/socket/SelectRead.cpp
+    ${PROJECT_SOURCE_DIR}/src/common/socket/Poll.cpp
     ${PROJECT_SOURCE_DIR}/src/common/translator/Translator.cpp
     ${PROJECT_SOURCE_DIR}/src/common/types/AgentErrorMsg.cpp
     ${PROJECT_SOURCE_DIR}/src/agent/main/NotificationTalker.cpp
index 68d1956..67ef2eb 100644 (file)
@@ -66,7 +66,7 @@ public:
     }
 
     int getSockFd() {
-        return m_sockfd;
+        return m_serverSocket.getFd();
     }
 
     void clear() {