- add sources.
[platform/framework/web/crosswalk.git] / src / ipc / unix_domain_socket_util_unittest.cc
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <sys/socket.h>
6
7 #include "base/bind.h"
8 #include "base/files/file_path.h"
9 #include "base/path_service.h"
10 #include "base/posix/eintr_wrapper.h"
11 #include "base/synchronization/waitable_event.h"
12 #include "base/threading/thread.h"
13 #include "base/threading/thread_restrictions.h"
14 #include "ipc/unix_domain_socket_util.h"
15 #include "testing/gtest/include/gtest/gtest.h"
16
17 namespace {
18
19 class SocketAcceptor : public base::MessageLoopForIO::Watcher {
20  public:
21   SocketAcceptor(int fd, base::MessageLoopProxy* target_thread)
22       : server_fd_(-1),
23         target_thread_(target_thread),
24         started_watching_event_(false, false),
25         accepted_event_(false, false) {
26     target_thread->PostTask(FROM_HERE,
27         base::Bind(&SocketAcceptor::StartWatching, base::Unretained(this), fd));
28   }
29
30   virtual ~SocketAcceptor() {
31     Close();
32   }
33
34   int server_fd() const { return server_fd_; }
35
36   void WaitUntilReady() {
37     started_watching_event_.Wait();
38   }
39
40   void WaitForAccept() {
41     accepted_event_.Wait();
42   }
43
44   void Close() {
45     if (watcher_.get()) {
46       target_thread_->PostTask(FROM_HERE,
47           base::Bind(&SocketAcceptor::StopWatching, base::Unretained(this),
48               watcher_.release()));
49     }
50   }
51
52  private:
53   void StartWatching(int fd) {
54     watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher);
55     base::MessageLoopForIO::current()->WatchFileDescriptor(
56         fd, true, base::MessageLoopForIO::WATCH_READ, watcher_.get(), this);
57     started_watching_event_.Signal();
58   }
59   void StopWatching(base::MessageLoopForIO::FileDescriptorWatcher* watcher) {
60     watcher->StopWatchingFileDescriptor();
61     delete watcher;
62   }
63   virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE {
64     ASSERT_EQ(-1, server_fd_);
65     IPC::ServerAcceptConnection(fd, &server_fd_);
66     watcher_->StopWatchingFileDescriptor();
67     accepted_event_.Signal();
68   }
69   virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE {}
70
71   int server_fd_;
72   base::MessageLoopProxy* target_thread_;
73   scoped_ptr<base::MessageLoopForIO::FileDescriptorWatcher> watcher_;
74   base::WaitableEvent started_watching_event_;
75   base::WaitableEvent accepted_event_;
76
77   DISALLOW_COPY_AND_ASSIGN(SocketAcceptor);
78 };
79
80 const base::FilePath GetChannelDir() {
81 #if defined(OS_ANDROID)
82   base::FilePath tmp_dir;
83   PathService::Get(base::DIR_CACHE, &tmp_dir);
84   return tmp_dir;
85 #else
86   return base::FilePath("/var/tmp");
87 #endif
88 }
89
90 class TestUnixSocketConnection {
91  public:
92   TestUnixSocketConnection()
93       : worker_("WorkerThread"),
94         server_listen_fd_(-1),
95         server_fd_(-1),
96         client_fd_(-1) {
97     socket_name_ = GetChannelDir().Append("TestSocket");
98     base::Thread::Options options;
99     options.message_loop_type = base::MessageLoop::TYPE_IO;
100     worker_.StartWithOptions(options);
101   }
102
103   bool CreateServerSocket() {
104     IPC::CreateServerUnixDomainSocket(socket_name_, &server_listen_fd_);
105     if (server_listen_fd_ < 0)
106       return false;
107     struct stat socket_stat;
108     stat(socket_name_.value().c_str(), &socket_stat);
109     EXPECT_TRUE(S_ISSOCK(socket_stat.st_mode));
110     acceptor_.reset(new SocketAcceptor(server_listen_fd_,
111                                        worker_.message_loop_proxy().get()));
112     acceptor_->WaitUntilReady();
113     return true;
114   }
115
116   bool CreateClientSocket() {
117     DCHECK(server_listen_fd_ >= 0);
118     IPC::CreateClientUnixDomainSocket(socket_name_, &client_fd_);
119     if (client_fd_ < 0)
120       return false;
121     acceptor_->WaitForAccept();
122     server_fd_ = acceptor_->server_fd();
123     return server_fd_ >= 0;
124   }
125
126   virtual ~TestUnixSocketConnection() {
127     if (client_fd_ >= 0)
128       close(client_fd_);
129     if (server_fd_ >= 0)
130       close(server_fd_);
131     if (server_listen_fd_ >= 0) {
132       close(server_listen_fd_);
133       unlink(socket_name_.value().c_str());
134     }
135   }
136
137   int client_fd() const { return client_fd_; }
138   int server_fd() const { return server_fd_; }
139
140  private:
141   base::Thread worker_;
142   base::FilePath socket_name_;
143   int server_listen_fd_;
144   int server_fd_;
145   int client_fd_;
146   scoped_ptr<SocketAcceptor> acceptor_;
147 };
148
149 // Ensure that IPC::CreateServerUnixDomainSocket creates a socket that
150 // IPC::CreateClientUnixDomainSocket can successfully connect to.
151 TEST(UnixDomainSocketUtil, Connect) {
152   TestUnixSocketConnection connection;
153   ASSERT_TRUE(connection.CreateServerSocket());
154   ASSERT_TRUE(connection.CreateClientSocket());
155 }
156
157 // Ensure that messages can be sent across the resulting socket.
158 TEST(UnixDomainSocketUtil, SendReceive) {
159   TestUnixSocketConnection connection;
160   ASSERT_TRUE(connection.CreateServerSocket());
161   ASSERT_TRUE(connection.CreateClientSocket());
162
163   const char buffer[] = "Hello, server!";
164   size_t buf_len = sizeof(buffer);
165   size_t sent_bytes =
166       HANDLE_EINTR(send(connection.client_fd(), buffer, buf_len, 0));
167   ASSERT_EQ(buf_len, sent_bytes);
168   char recv_buf[sizeof(buffer)];
169   size_t received_bytes =
170       HANDLE_EINTR(recv(connection.server_fd(), recv_buf, buf_len, 0));
171   ASSERT_EQ(buf_len, received_bytes);
172   ASSERT_EQ(0, memcmp(recv_buf, buffer, buf_len));
173 }
174
175 }  // namespace