1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 // TODO(vtl): Factor out the POSIX-specific bits of this test (once we have a
6 // non-POSIX implementation).
8 #include "mojo/system/raw_channel.h"
12 #include <sys/socket.h>
13 #include <sys/types.h>
18 #include "base/basictypes.h"
19 #include "base/bind.h"
20 #include "base/callback.h"
21 #include "base/compiler_specific.h"
22 #include "base/location.h"
23 #include "base/logging.h"
24 #include "base/memory/scoped_ptr.h"
25 #include "base/memory/scoped_vector.h"
26 #include "base/message_loop/message_loop.h"
27 #include "base/posix/eintr_wrapper.h"
28 #include "base/rand_util.h"
29 #include "base/synchronization/lock.h"
30 #include "base/synchronization/waitable_event.h"
31 #include "base/threading/platform_thread.h" // For |Sleep()|.
32 #include "base/threading/simple_thread.h"
33 #include "base/threading/thread.h"
34 #include "base/time/time.h"
35 #include "mojo/system/message_in_transit.h"
36 #include "mojo/system/platform_channel_handle.h"
37 #include "mojo/system/test_utils.h"
38 #include "testing/gtest/include/gtest/gtest.h"
44 MessageInTransit* MakeTestMessage(uint32_t num_bytes) {
45 std::vector<unsigned char> bytes(num_bytes, 0);
46 for (size_t i = 0; i < num_bytes; i++)
47 bytes[i] = static_cast<unsigned char>(i + num_bytes);
48 return MessageInTransit::Create(bytes.data(), num_bytes);
51 bool CheckMessageData(const void* bytes, uint32_t num_bytes) {
52 const unsigned char* b = static_cast<const unsigned char*>(bytes);
53 for (uint32_t i = 0; i < num_bytes; i++) {
54 if (b[i] != static_cast<unsigned char>(i + num_bytes))
60 // -----------------------------------------------------------------------------
62 class RawChannelPosixTest : public testing::Test {
64 RawChannelPosixTest() : io_thread_("io_thread") {
69 virtual ~RawChannelPosixTest() {
72 virtual void SetUp() OVERRIDE {
73 io_thread_.StartWithOptions(
74 base::Thread::Options(base::MessageLoop::TYPE_IO, 0));
77 PCHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, fds_) == 0);
79 // Set the ends to non-blocking.
80 PCHECK(fcntl(fds_[0], F_SETFL, O_NONBLOCK) == 0);
81 PCHECK(fcntl(fds_[1], F_SETFL, O_NONBLOCK) == 0);
84 virtual void TearDown() OVERRIDE {
86 CHECK_EQ(close(fds_[0]), 0);
88 CHECK_EQ(close(fds_[1]), 0);
94 int fd(size_t i) { return fds_[i]; }
95 void clear_fd(size_t i) { fds_[i] = -1; }
97 base::MessageLoop* io_thread_message_loop() {
98 return io_thread_.message_loop();
101 scoped_refptr<base::TaskRunner> io_thread_task_runner() {
102 return io_thread_message_loop()->message_loop_proxy();
106 base::Thread io_thread_;
109 DISALLOW_COPY_AND_ASSIGN(RawChannelPosixTest);
112 // RawChannelPosixTest.WriteMessage --------------------------------------------
114 class WriteOnlyRawChannelDelegate : public RawChannel::Delegate {
116 WriteOnlyRawChannelDelegate() {}
117 virtual ~WriteOnlyRawChannelDelegate() {}
119 // |RawChannel::Delegate| implementation:
120 virtual void OnReadMessage(const MessageInTransit& /*message*/) OVERRIDE {
123 virtual void OnFatalError(FatalError /*fatal_error*/) OVERRIDE {
128 DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate);
131 static const int64_t kMessageReaderSleepMs = 1;
132 static const size_t kMessageReaderMaxPollIterations = 3000;
134 class TestMessageReaderAndChecker {
136 explicit TestMessageReaderAndChecker(int fd) : fd_(fd) {}
137 ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); }
139 bool ReadAndCheckNextMessage(uint32_t expected_size) {
140 unsigned char buffer[4096];
142 for (size_t i = 0; i < kMessageReaderMaxPollIterations;) {
143 ssize_t read_size = HANDLE_EINTR(read(fd_, buffer, sizeof(buffer)));
145 PCHECK(errno == EAGAIN || errno == EWOULDBLOCK);
149 // Append newly-read data to |bytes_|.
150 bytes_.insert(bytes_.end(), buffer, buffer + read_size);
152 // If we have the header....
153 if (bytes_.size() >= sizeof(MessageInTransit)) {
154 const MessageInTransit* message =
155 reinterpret_cast<const MessageInTransit*>(bytes_.data());
156 CHECK_EQ(reinterpret_cast<size_t>(message) %
157 MessageInTransit::kMessageAlignment, 0u);
159 if (message->data_size() != expected_size) {
160 LOG(ERROR) << "Wrong size: " << message->data_size() << " instead of "
161 << expected_size << " bytes.";
165 // If we've read the whole message....
166 if (bytes_.size() >= message->size_with_header_and_padding()) {
167 if (!CheckMessageData(message->data(), message->data_size())) {
168 LOG(ERROR) << "Incorrect message data.";
172 // Erase message data.
173 bytes_.erase(bytes_.begin(),
175 message->size_with_header_and_padding());
180 if (static_cast<size_t>(read_size) < sizeof(buffer)) {
182 base::PlatformThread::Sleep(
183 base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs));
187 LOG(ERROR) << "Too many iterations.";
194 // The start of the received data should always be on a message boundary.
195 std::vector<unsigned char> bytes_;
197 DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker);
200 // Tests writing (and verifies reading using our own custom reader).
201 TEST_F(RawChannelPosixTest, WriteMessage) {
202 WriteOnlyRawChannelDelegate delegate;
203 scoped_ptr<RawChannel> rc(RawChannel::Create(PlatformChannelHandle(fd(0)),
205 io_thread_message_loop()));
206 // |RawChannel::Create()| takes ownership of the FD.
209 TestMessageReaderAndChecker checker(fd(1));
211 test::PostTaskAndWait(io_thread_task_runner(),
213 base::Bind(&RawChannel::Init,
214 base::Unretained(rc.get())));
216 // Write and read, for a variety of sizes.
217 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
218 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
219 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
222 // Write/queue and read afterwards, for a variety of sizes.
223 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
224 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
225 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
226 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
228 test::PostTaskAndWait(io_thread_task_runner(),
230 base::Bind(&RawChannel::Shutdown,
231 base::Unretained(rc.get())));
234 // RawChannelPosixTest.OnReadMessage -------------------------------------------
236 class ReadCheckerRawChannelDelegate : public RawChannel::Delegate {
238 ReadCheckerRawChannelDelegate()
239 : done_event_(false, false),
241 virtual ~ReadCheckerRawChannelDelegate() {}
243 // |RawChannel::Delegate| implementation (called on the I/O thread):
244 virtual void OnReadMessage(const MessageInTransit& message) OVERRIDE {
246 size_t expected_size;
247 bool should_signal = false;
249 base::AutoLock locker(lock_);
250 CHECK_LT(position_, expected_sizes_.size());
251 position = position_;
252 expected_size = expected_sizes_[position];
254 if (position_ >= expected_sizes_.size())
255 should_signal = true;
258 EXPECT_EQ(expected_size, message.data_size()) << position;
259 if (message.data_size() == expected_size) {
260 EXPECT_TRUE(CheckMessageData(message.data(), message.data_size()))
265 done_event_.Signal();
267 virtual void OnFatalError(FatalError /*fatal_error*/) OVERRIDE {
271 // Wait for all the messages (of sizes |expected_sizes_|) to be seen.
276 void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) {
277 base::AutoLock locker(lock_);
278 CHECK_EQ(position_, expected_sizes_.size());
279 expected_sizes_ = expected_sizes;
284 base::WaitableEvent done_event_;
286 base::Lock lock_; // Protects the following members.
287 std::vector<uint32_t> expected_sizes_;
290 DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate);
293 // Tests reading (writing using our own custom writer).
294 TEST_F(RawChannelPosixTest, OnReadMessage) {
295 // We're going to write to |fd(1)|. We'll do so in a blocking manner.
296 PCHECK(fcntl(fd(1), F_SETFL, 0) == 0);
298 ReadCheckerRawChannelDelegate delegate;
299 scoped_ptr<RawChannel> rc(RawChannel::Create(PlatformChannelHandle(fd(0)),
301 io_thread_message_loop()));
302 // |RawChannel::Create()| takes ownership of the FD.
305 test::PostTaskAndWait(io_thread_task_runner(),
307 base::Bind(&RawChannel::Init,
308 base::Unretained(rc.get())));
310 // Write and read, for a variety of sizes.
311 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
312 delegate.SetExpectedSizes(std::vector<uint32_t>(1, size));
313 MessageInTransit* message = MakeTestMessage(size);
314 EXPECT_EQ(static_cast<ssize_t>(message->size_with_header_and_padding()),
315 write(fd(1), message, message->size_with_header_and_padding()));
320 // Set up reader and write as fast as we can.
321 // Write/queue and read afterwards, for a variety of sizes.
322 std::vector<uint32_t> expected_sizes;
323 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
324 expected_sizes.push_back(size);
325 delegate.SetExpectedSizes(expected_sizes);
326 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
327 MessageInTransit* message = MakeTestMessage(size);
328 EXPECT_EQ(static_cast<ssize_t>(message->size_with_header_and_padding()),
329 write(fd(1), message, message->size_with_header_and_padding()));
334 test::PostTaskAndWait(io_thread_task_runner(),
336 base::Bind(&RawChannel::Shutdown,
337 base::Unretained(rc.get())));
340 // RawChannelPosixTest.WriteMessageAndOnReadMessage ----------------------------
342 class RawChannelWriterThread : public base::SimpleThread {
344 RawChannelWriterThread(RawChannel* raw_channel, size_t write_count)
345 : base::SimpleThread("raw_channel_writer_thread"),
346 raw_channel_(raw_channel),
347 left_to_write_(write_count) {
350 virtual ~RawChannelWriterThread() {
355 virtual void Run() OVERRIDE {
356 static const int kMaxRandomMessageSize = 25000;
358 while (left_to_write_-- > 0) {
359 EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage(
360 static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize)))));
364 RawChannel* const raw_channel_;
365 size_t left_to_write_;
367 DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread);
370 class ReadCountdownRawChannelDelegate : public RawChannel::Delegate {
372 explicit ReadCountdownRawChannelDelegate(size_t expected_count)
373 : done_event_(false, false),
374 expected_count_(expected_count),
376 virtual ~ReadCountdownRawChannelDelegate() {}
378 // |RawChannel::Delegate| implementation (called on the I/O thread):
379 virtual void OnReadMessage(const MessageInTransit& message) OVERRIDE {
380 EXPECT_LT(count_, expected_count_);
383 EXPECT_TRUE(CheckMessageData(message.data(), message.data_size()));
385 if (count_ >= expected_count_)
386 done_event_.Signal();
388 virtual void OnFatalError(FatalError /*fatal_error*/) OVERRIDE {
392 // Wait for all the messages to have been seen.
398 base::WaitableEvent done_event_;
399 size_t expected_count_;
402 DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate);
405 TEST_F(RawChannelPosixTest, WriteMessageAndOnReadMessage) {
406 static const size_t kNumWriterThreads = 10;
407 static const size_t kNumWriteMessagesPerThread = 4000;
409 WriteOnlyRawChannelDelegate writer_delegate;
410 scoped_ptr<RawChannel> writer_rc(
411 RawChannel::Create(PlatformChannelHandle(fd(0)),
413 io_thread_message_loop()));
414 // |RawChannel::Create()| takes ownership of the FD.
417 test::PostTaskAndWait(io_thread_task_runner(),
419 base::Bind(&RawChannel::Init,
420 base::Unretained(writer_rc.get())));
422 ReadCountdownRawChannelDelegate reader_delegate(
423 kNumWriterThreads * kNumWriteMessagesPerThread);
424 scoped_ptr<RawChannel> reader_rc(
425 RawChannel::Create(PlatformChannelHandle(fd(1)),
427 io_thread_message_loop()));
428 // |RawChannel::Create()| takes ownership of the FD.
431 test::PostTaskAndWait(io_thread_task_runner(),
433 base::Bind(&RawChannel::Init,
434 base::Unretained(reader_rc.get())));
437 ScopedVector<RawChannelWriterThread> writer_threads;
438 for (size_t i = 0; i < kNumWriterThreads; i++) {
439 writer_threads.push_back(new RawChannelWriterThread(
440 writer_rc.get(), kNumWriteMessagesPerThread));
442 for (size_t i = 0; i < writer_threads.size(); i++)
443 writer_threads[i]->Start();
444 } // Joins all the writer threads.
446 // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be
447 // any, but we want to know about them.)
448 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100));
450 // Wait for reading to finish.
451 reader_delegate.Wait();
453 test::PostTaskAndWait(io_thread_task_runner(),
455 base::Bind(&RawChannel::Shutdown,
456 base::Unretained(reader_rc.get())));
458 test::PostTaskAndWait(io_thread_task_runner(),
460 base::Bind(&RawChannel::Shutdown,
461 base::Unretained(writer_rc.get())));
464 // RawChannelPosixTest.OnFatalError --------------------------------------------
466 class FatalErrorRecordingRawChannelDelegate : public RawChannel::Delegate {
468 FatalErrorRecordingRawChannelDelegate()
469 : got_fatal_error_event_(false, false),
470 on_fatal_error_call_count_(0),
471 last_fatal_error_(FATAL_ERROR_UNKNOWN) {}
472 virtual ~FatalErrorRecordingRawChannelDelegate() {}
474 // |RawChannel::Delegate| implementation:
475 virtual void OnReadMessage(const MessageInTransit& /*message*/) OVERRIDE {
478 virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
479 CHECK_EQ(on_fatal_error_call_count_, 0);
480 on_fatal_error_call_count_++;
481 last_fatal_error_ = fatal_error;
482 got_fatal_error_event_.Signal();
485 FatalError WaitForFatalError() {
486 got_fatal_error_event_.Wait();
487 CHECK_EQ(on_fatal_error_call_count_, 1);
488 return last_fatal_error_;
492 base::WaitableEvent got_fatal_error_event_;
494 int on_fatal_error_call_count_;
495 FatalError last_fatal_error_;
497 DISALLOW_COPY_AND_ASSIGN(FatalErrorRecordingRawChannelDelegate);
500 // Tests fatal errors.
501 // TODO(vtl): Figure out how to make reading fail reliably. (I'm not convinced
503 TEST_F(RawChannelPosixTest, OnFatalError) {
504 FatalErrorRecordingRawChannelDelegate delegate;
505 scoped_ptr<RawChannel> rc(RawChannel::Create(PlatformChannelHandle(fd(0)),
507 io_thread_message_loop()));
508 // |RawChannel::Create()| takes ownership of the FD.
511 test::PostTaskAndWait(io_thread_task_runner(),
513 base::Bind(&RawChannel::Init,
514 base::Unretained(rc.get())));
516 // Close the other end, which should make writing fail.
517 CHECK_EQ(close(fd(1)), 0);
520 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
522 // TODO(vtl): In theory, it's conceivable that closing the other end might
523 // lead to read failing. In practice, it doesn't seem to.
524 EXPECT_EQ(RawChannel::Delegate::FATAL_ERROR_FAILED_WRITE,
525 delegate.WaitForFatalError());
527 test::PostTaskAndWait(io_thread_task_runner(),
529 base::Bind(&RawChannel::Shutdown,
530 base::Unretained(rc.get())));
534 // RawChannelPosixTest.WriteMessageAfterShutdown -------------------------------
536 // Makes sure that calling |WriteMessage()| after |Shutdown()| behaves
538 TEST_F(RawChannelPosixTest, WriteMessageAfterShutdown) {
539 WriteOnlyRawChannelDelegate delegate;
540 scoped_ptr<RawChannel> rc(RawChannel::Create(PlatformChannelHandle(fd(0)),
542 io_thread_message_loop()));
543 // |RawChannel::Create()| takes ownership of the FD.
546 test::PostTaskAndWait(io_thread_task_runner(),
548 base::Bind(&RawChannel::Init,
549 base::Unretained(rc.get())));
550 test::PostTaskAndWait(io_thread_task_runner(),
552 base::Bind(&RawChannel::Shutdown,
553 base::Unretained(rc.get())));
555 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
559 } // namespace system