1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "tools/android/forwarder2/socket.h"
10 #include <netinet/in.h>
13 #include <sys/socket.h>
14 #include <sys/types.h>
17 #include "base/logging.h"
18 #include "base/posix/eintr_wrapper.h"
19 #include "base/safe_strerror_posix.h"
20 #include "tools/android/common/net.h"
21 #include "tools/android/forwarder2/common.h"
24 const int kNoTimeout = -1;
25 const int kConnectTimeOut = 10; // Seconds.
27 bool FamilyIsTCP(int family) {
28 return family == AF_INET || family == AF_INET6;
32 namespace forwarder2 {
34 bool Socket::BindUnix(const std::string& path) {
36 if (!InitUnixSocket(path) || !BindAndListen()) {
43 bool Socket::BindTcp(const std::string& host, int port) {
45 if (!InitTcpSocket(host, port) || !BindAndListen()) {
52 bool Socket::ConnectUnix(const std::string& path) {
54 if (!InitUnixSocket(path) || !Connect()) {
61 bool Socket::ConnectTcp(const std::string& host, int port) {
63 if (!InitTcpSocket(host, port) || !Connect()) {
75 addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
76 addr_len_(sizeof(sockaddr)) {
77 memset(&addr_, 0, sizeof(addr_));
84 void Socket::Shutdown() {
86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
90 void Socket::Close() {
97 bool Socket::InitSocketInternal() {
98 socket_ = socket(family_, SOCK_STREAM, 0);
100 PLOG(ERROR) << "socket";
103 tools::DisableNagle(socket_);
105 setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
107 if (!SetNonBlocking())
112 bool Socket::SetNonBlocking() {
113 const int flags = fcntl(socket_, F_GETFL);
115 PLOG(ERROR) << "fcntl";
118 if (flags & O_NONBLOCK)
120 if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) {
121 PLOG(ERROR) << "fcntl";
127 bool Socket::InitUnixSocket(const std::string& path) {
128 static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
129 // For abstract sockets we need one extra byte for the leading zero.
130 if (path.size() + 2 /* '\0' */ > kPathMax) {
131 LOG(ERROR) << "The provided path is too big to create a unix "
132 << "domain socket: " << path;
136 addr_.addr_un.sun_family = family_;
137 // Copied from net/socket/unix_domain_socket_posix.cc
138 // Convert the path given into abstract socket name. It must start with
139 // the '\0' character, so we are adding it. |addr_len| must specify the
140 // length of the structure exactly, as potentially the socket name may
141 // have '\0' characters embedded (although we don't support this).
142 // Note that addr_.addr_un.sun_path is already zero initialized.
143 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
144 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
145 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
146 return InitSocketInternal();
149 bool Socket::InitTcpSocket(const std::string& host, int port) {
152 // Use localhost: INADDR_LOOPBACK
154 addr_.addr4.sin_family = family_;
155 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
156 } else if (!Resolve(host)) {
159 CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
160 if (family_ == AF_INET) {
161 addr_.addr4.sin_port = htons(port_);
162 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
163 addr_len_ = sizeof(addr_.addr4);
164 } else if (family_ == AF_INET6) {
165 addr_.addr6.sin6_port = htons(port_);
166 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
167 addr_len_ = sizeof(addr_.addr6);
169 return InitSocketInternal();
172 bool Socket::BindAndListen() {
174 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
175 HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
176 PLOG(ERROR) << "bind/listen";
180 if (port_ == 0 && FamilyIsTCP(family_)) {
182 memset(&addr, 0, sizeof(addr));
183 socklen_t addrlen = 0;
184 sockaddr* addr_ptr = NULL;
185 uint16* port_ptr = NULL;
186 if (family_ == AF_INET) {
187 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
188 port_ptr = &addr.addr4.sin_port;
189 addrlen = sizeof(addr.addr4);
190 } else if (family_ == AF_INET6) {
191 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
192 port_ptr = &addr.addr6.sin6_port;
193 addrlen = sizeof(addr.addr6);
196 if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
197 PLOG(ERROR) << "getsockname";
201 port_ = ntohs(*port_ptr);
206 bool Socket::Accept(Socket* new_socket) {
207 DCHECK(new_socket != NULL);
208 if (!WaitForEvent(READ, kNoTimeout)) {
213 int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
214 if (new_socket_fd < 0) {
218 tools::DisableNagle(new_socket_fd);
219 new_socket->socket_ = new_socket_fd;
220 if (!new_socket->SetNonBlocking())
225 bool Socket::Connect() {
226 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
228 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
229 errno != EINPROGRESS) {
233 // Wait for connection to complete, or receive a notification.
234 if (!WaitForEvent(WRITE, kConnectTimeOut)) {
239 socklen_t opt_len = sizeof(socket_errno);
240 if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
241 PLOG(ERROR) << "getsockopt()";
245 if (socket_errno != 0) {
246 LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
253 bool Socket::Resolve(const std::string& host) {
254 struct addrinfo hints;
255 struct addrinfo* res;
256 memset(&hints, 0, sizeof(hints));
257 hints.ai_family = AF_UNSPEC;
258 hints.ai_socktype = SOCK_STREAM;
259 hints.ai_flags |= AI_CANONNAME;
261 int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
268 family_ = res->ai_family;
269 switch (res->ai_family) {
272 reinterpret_cast<sockaddr_in*>(res->ai_addr),
273 sizeof(sockaddr_in));
277 reinterpret_cast<sockaddr_in6*>(res->ai_addr),
278 sizeof(sockaddr_in6));
285 int Socket::GetPort() {
286 if (!FamilyIsTCP(family_)) {
287 LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
293 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
294 size_t bytes_read = 0;
296 while (bytes_read < num_bytes && ret > 0) {
297 ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
304 void Socket::SetSocketError() {
305 socket_error_ = true;
306 DCHECK_NE(EAGAIN, errno);
307 DCHECK_NE(EWOULDBLOCK, errno);
311 int Socket::Read(void* buffer, size_t buffer_size) {
312 if (!WaitForEvent(READ, kNoTimeout)) {
316 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
318 PLOG(ERROR) << "read";
324 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) {
325 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
326 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
328 PLOG(ERROR) << "read";
334 int Socket::Write(const void* buffer, size_t count) {
335 if (!WaitForEvent(WRITE, kNoTimeout)) {
339 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
341 PLOG(ERROR) << "send";
347 int Socket::NonBlockingWrite(const void* buffer, size_t count) {
348 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
349 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
351 PLOG(ERROR) << "send";
357 int Socket::WriteString(const std::string& buffer) {
358 return WriteNumBytes(buffer.c_str(), buffer.size());
361 void Socket::AddEventFd(int event_fd) {
364 event.was_fired = false;
365 events_.push_back(event);
368 bool Socket::DidReceiveEventOnFd(int fd) const {
369 for (size_t i = 0; i < events_.size(); ++i)
370 if (events_[i].fd == fd)
371 return events_[i].was_fired;
375 bool Socket::DidReceiveEvent() const {
376 for (size_t i = 0; i < events_.size(); ++i)
377 if (events_[i].was_fired)
382 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
383 size_t bytes_written = 0;
385 while (bytes_written < num_bytes && ret > 0) {
386 ret = Write(static_cast<const char*>(buffer) + bytes_written,
387 num_bytes - bytes_written);
389 bytes_written += ret;
391 return bytes_written;
394 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
397 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
403 FD_SET(socket_, &read_fds);
405 FD_SET(socket_, &write_fds);
406 for (size_t i = 0; i < events_.size(); ++i)
407 FD_SET(events_[i].fd, &read_fds);
409 timeval* tv_ptr = NULL;
410 if (timeout_secs > 0) {
411 tv.tv_sec = timeout_secs;
415 int max_fd = socket_;
416 for (size_t i = 0; i < events_.size(); ++i)
417 if (events_[i].fd > max_fd)
418 max_fd = events_[i].fd;
420 select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
421 PLOG(ERROR) << "select";
424 bool event_was_fired = false;
425 for (size_t i = 0; i < events_.size(); ++i) {
426 if (FD_ISSET(events_[i].fd, &read_fds)) {
427 events_[i].was_fired = true;
428 event_was_fired = true;
431 return !event_was_fired;
435 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
437 if (!socket.ConnectUnix(path))
440 socklen_t len = sizeof(ucred);
441 if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
442 CHECK_NE(ENOPROTOOPT, errno);
448 } // namespace forwarder2