Add unix socket feature
authorSangwan Kwon <sangwan.kwon@samsung.com>
Thu, 2 Jan 2020 06:19:56 +0000 (15:19 +0900)
committer권상완/Security 2Lab(SR)/Engineer/삼성전자 <sangwan.kwon@samsung.com>
Mon, 6 Jan 2020 10:00:56 +0000 (19:00 +0900)
This would be changed to systemd socket.

Signed-off-by: Sangwan Kwon <sangwan.kwon@samsung.com>
src/vist/rmi/CMakeLists.txt
src/vist/rmi/impl/ondemand/socket.cpp [new file with mode: 0644]
src/vist/rmi/impl/ondemand/socket.hpp [new file with mode: 0644]
src/vist/rmi/impl/ondemand/tests/socket.cpp [new file with mode: 0644]

index 833d53c..dbd06e6 100644 (file)
@@ -16,7 +16,8 @@ SET(TARGET vist-rmi)
 SET(${TARGET}_SRCS gateway.cpp
                                   remote.cpp
                                   message.cpp
-                                  impl/general/protocol.cpp)
+                                  impl/general/protocol.cpp
+                                  impl/ondemand/socket.cpp)
 
 ADD_LIBRARY(${TARGET} SHARED ${${TARGET}_SRCS})
 SET_TARGET_PROPERTIES(${TARGET} PROPERTIES COMPILE_FLAGS "-fvisibility=hidden")
@@ -43,3 +44,6 @@ ADD_VIST_TEST(${RMI_TESTS})
 
 FILE(GLOB RMI_GENERAL_TESTS "impl/general/tests/*.cpp")
 ADD_VIST_TEST(${RMI_GENERAL_TESTS})
+
+FILE(GLOB RMI_ONDEMAND_TESTS "impl/ondemand/tests/*.cpp")
+ADD_VIST_TEST(${RMI_ONDEMAND_TESTS})
diff --git a/src/vist/rmi/impl/ondemand/socket.cpp b/src/vist/rmi/impl/ondemand/socket.cpp
new file mode 100644 (file)
index 0000000..7abd4d3
--- /dev/null
@@ -0,0 +1,152 @@
+/*
+ *  Copyright (c) 2018-present Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ *  Licensed under the Apache License, Version 2.0 (the "License");
+ *  you may not use this file except in compliance with the License.
+ *  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License
+ */
+
+#include "socket.hpp"
+
+#include <fstream>
+#include <iostream>
+#include <fcntl.h>
+
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/un.h>
+
+namespace vist {
+namespace rmi {
+namespace impl {
+namespace ondemand {
+
+namespace {
+
+void set_cloexec(int fd)
+{
+       if (::fcntl(fd, F_SETFD, FD_CLOEXEC) == -1)
+               throw std::runtime_error("Failed to set CLOSEXEC.");
+}
+
+} // anonymous namespace
+
+Socket::Socket(int fd) noexcept : fd(fd)
+{
+}
+
+Socket::Socket(const std::string& path)
+{
+       if (path.size() >= sizeof(::sockaddr_un::sun_path))
+               throw std::invalid_argument("Socket path size is wrong.");
+
+       int fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
+       if (fd == -1)
+               throw std::runtime_error("Failed to create socket.");
+
+       set_cloexec(fd);
+
+       ::sockaddr_un addr;
+       addr.sun_family = AF_UNIX;
+       ::strncpy(addr.sun_path, path.c_str(), sizeof(sockaddr_un::sun_path));
+
+       if (addr.sun_path[0] == '@')
+               addr.sun_path[0] = '\0';
+
+       struct stat buf;
+       if (::stat(path.c_str(), &buf) == 0)
+               if (::unlink(path.c_str()) == -1)
+                       throw std::runtime_error("Failed to remove exist socket.");
+
+       if (::bind(fd, reinterpret_cast<::sockaddr*>(&addr), sizeof(::sockaddr_un)) == -1) {
+               ::close(fd);
+               throw std::runtime_error("Failed to bind.");
+       }
+
+       if (::listen(fd, MAX_BACKLOG_SIZE) == -1) {
+               ::close(fd);
+               throw std::runtime_error("Failed to liten.");
+       }
+
+       this->fd = fd;
+}
+
+Socket::Socket(Socket&& that) : fd(that.fd)
+{
+       that.fd = -1;
+}
+
+Socket& Socket::operator=(Socket&& that)
+{
+       if (this == &that)
+               return *this;
+
+       this->fd = that.fd;
+       that.fd = -1;
+
+       return *this;
+}
+
+Socket::~Socket(void)
+{
+       if (fd != -1)
+               ::close(fd);
+}
+
+Socket Socket::accept(void) const
+{
+       errno = 0;
+       int fd = ::accept(this->fd, nullptr, nullptr);
+       if (fd == -1)
+               THROW(ErrCode::RuntimeError) << "Failed to accept: " << errno;
+
+       set_cloexec(fd);
+
+       return Socket(fd);
+}
+
+Socket Socket::connect(const std::string& path)
+{
+       if (path.size() >= sizeof(::sockaddr_un::sun_path))
+               throw std::invalid_argument("Socket path size is wrong.");
+
+       int fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
+       if (fd == -1)
+               THROW(ErrCode::RuntimeError) << "Failed to create socket.";
+
+       set_cloexec(fd);
+
+       ::sockaddr_un addr;
+       addr.sun_family = AF_UNIX;
+       ::strncpy(addr.sun_path, path.c_str(), sizeof(::sockaddr_un::sun_path));
+
+       if (addr.sun_path[0] == '@')
+               addr.sun_path[0] = '\0';
+
+       errno = 0;
+       if (::connect(fd, reinterpret_cast<::sockaddr*>(&addr), sizeof(sockaddr_un)) == -1) {
+               ::close(fd);
+               THROW(ErrCode::RuntimeError) << "Failed to read connect to: " << path
+                                                                        << ", with: " << errno;
+       }
+
+       return Socket(fd);
+}
+
+int Socket::getFd(void) const noexcept
+{
+       return this->fd;
+}
+
+} // namespace ondemand
+} // namespace impl
+} // namespace rmi
+} // namespace vist
diff --git a/src/vist/rmi/impl/ondemand/socket.hpp b/src/vist/rmi/impl/ondemand/socket.hpp
new file mode 100644 (file)
index 0000000..993fe90
--- /dev/null
@@ -0,0 +1,99 @@
+/*
+ *  Copyright (c) 2018-present Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ *  Licensed under the Apache License, Version 2.0 (the "License");
+ *  you may not use this file except in compliance with the License.
+ *  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License
+ */
+
+#pragma once
+
+#include <vist/exception.hpp>
+
+#include <cstddef>
+#include <string>
+#include <stdexcept>
+
+#include <unistd.h>
+#include <errno.h>
+
+namespace vist {
+namespace rmi {
+namespace impl {
+namespace ondemand {
+
+class Socket {
+public:
+       explicit Socket(int fd) noexcept;
+       explicit Socket(const std::string& path);
+       virtual ~Socket(void);
+
+       Socket(const Socket&) = delete;
+       Socket& operator=(const Socket&) = delete;
+
+       Socket(Socket&&);
+       Socket& operator=(Socket&&);
+
+       Socket accept(void) const;
+       static Socket connect(const std::string& path);
+
+       template<typename T>
+       void send(const T* buffer, const std::size_t size = sizeof(T)) const;
+
+       template<typename T>
+       void recv(T* buffer, const std::size_t size = sizeof(T)) const;
+
+       int getFd(void) const noexcept;
+
+private:
+       const int MAX_BACKLOG_SIZE = 100;
+
+       int fd;
+};
+
+template<typename T>
+void Socket::send(const T *buffer, const std::size_t size) const
+{
+       std::size_t written = 0;
+       while (written < size) {
+               auto rest = reinterpret_cast<const unsigned char*>(buffer) + written;
+               auto bytes = ::write(this->fd, rest, size - written);
+               errno = 0;
+               if (bytes >= 0)
+                       written += bytes;
+               else if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
+                       continue;
+               else
+                       THROW(ErrCode::RuntimeError) << "Failed to write to socket: " << errno;
+       }
+}
+
+template<typename T>
+void Socket::recv(T *buffer, const std::size_t size) const
+{
+       std::size_t readen = 0;
+       while (readen < size) {
+               auto rest = reinterpret_cast<unsigned char*>(buffer) + readen;
+               auto bytes = ::read(this->fd, rest, size - readen);
+               errno = 0;
+               if (bytes >= 0)
+                       readen += bytes;
+               else if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
+                       continue;
+               else
+                       THROW(ErrCode::RuntimeError) << "Failed to read from socket: " << errno;
+       }
+}
+
+} // namespace ondemand
+} // namespace impl
+} // namespace rmi
+} // namespace vist
diff --git a/src/vist/rmi/impl/ondemand/tests/socket.cpp b/src/vist/rmi/impl/ondemand/tests/socket.cpp
new file mode 100644 (file)
index 0000000..86e0eaa
--- /dev/null
@@ -0,0 +1,101 @@
+/*
+ *  Copyright (c) 2020 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ *  Licensed under the Apache License, Version 2.0 (the "License");
+ *  you may not use this file except in compliance with the License.
+ *  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License
+ */
+
+#include <vist/rmi/impl/ondemand/socket.hpp>
+
+#include <string>
+#include <limits>
+#include <thread>
+#include <chrono>
+#include <cstring>
+
+#include <gtest/gtest.h>
+
+using namespace vist::rmi::impl::ondemand;
+
+TEST(SocketTests, socket_read_write)
+{
+       std::string sockPath = "./test.sock";
+       Socket socket(sockPath);
+
+       int input = std::numeric_limits<int>::max();
+       bool input2 = true;
+
+       int output = 0;
+       bool output2 = false;
+
+       auto client = std::thread([&]() {
+               std::this_thread::sleep_for(std::chrono::seconds(1));
+
+               // Send input to server.
+               Socket connected = Socket::connect(sockPath);
+               connected.send(&input);
+
+               // Recv input2 from server.
+               connected.recv(&output2);
+
+               EXPECT_EQ(input2, output2);
+       });
+
+       Socket accepted = socket.accept();
+
+       // Recv input from client.
+       accepted.recv(&output);
+       EXPECT_EQ(input, output);
+
+       // Send input2 to client.
+       accepted.send(&input2);
+
+       if (client.joinable())
+               client.join();
+}
+
+TEST(SocketTests, socket_abstract)
+{
+       std::string sockPath = "@sock";
+       Socket socket(sockPath);
+
+       int input = std::numeric_limits<int>::max();
+       bool input2 = true;
+
+       int output = 0;
+       bool output2 = false;
+
+       auto client = std::thread([&]() {
+               std::this_thread::sleep_for(std::chrono::seconds(1));
+
+               // Send input to server.
+               Socket connected = Socket::connect(sockPath);
+               connected.send(&input);
+
+               // Recv input2 from server.
+               connected.recv(&output2);
+
+               EXPECT_EQ(input2, output2);
+       });
+
+       Socket accepted = socket.accept();
+
+       // Recv input from client.
+       accepted.recv(&output);
+       EXPECT_EQ(input, output);
+
+       // Send input2 to client.
+       accepted.send(&input2);
+
+       if (client.joinable())
+               client.join();
+}