1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "mojo/system/raw_channel.h"
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/macros.h"
15 #include "base/memory/scoped_ptr.h"
16 #include "base/memory/scoped_vector.h"
17 #include "base/rand_util.h"
18 #include "base/synchronization/lock.h"
19 #include "base/synchronization/waitable_event.h"
20 #include "base/threading/platform_thread.h" // For |Sleep()|.
21 #include "base/threading/simple_thread.h"
22 #include "base/time/time.h"
23 #include "build/build_config.h"
24 #include "mojo/common/test/test_utils.h"
25 #include "mojo/embedder/platform_channel_pair.h"
26 #include "mojo/embedder/platform_handle.h"
27 #include "mojo/embedder/scoped_platform_handle.h"
28 #include "mojo/system/message_in_transit.h"
29 #include "mojo/system/test_utils.h"
30 #include "testing/gtest/include/gtest/gtest.h"
36 scoped_ptr<MessageInTransit> MakeTestMessage(uint32_t num_bytes) {
37 std::vector<unsigned char> bytes(num_bytes, 0);
38 for (size_t i = 0; i < num_bytes; i++)
39 bytes[i] = static_cast<unsigned char>(i + num_bytes);
40 return make_scoped_ptr(
41 new MessageInTransit(MessageInTransit::kTypeMessagePipeEndpoint,
42 MessageInTransit::kSubtypeMessagePipeEndpointData,
43 num_bytes, bytes.empty() ? NULL : &bytes[0]));
46 bool CheckMessageData(const void* bytes, uint32_t num_bytes) {
47 const unsigned char* b = static_cast<const unsigned char*>(bytes);
48 for (uint32_t i = 0; i < num_bytes; i++) {
49 if (b[i] != static_cast<unsigned char>(i + num_bytes))
55 void InitOnIOThread(RawChannel* raw_channel, RawChannel::Delegate* delegate) {
56 CHECK(raw_channel->Init(delegate));
59 bool WriteTestMessageToHandle(const embedder::PlatformHandle& handle,
61 scoped_ptr<MessageInTransit> message(MakeTestMessage(num_bytes));
63 size_t write_size = 0;
64 mojo::test::BlockingWrite(
65 handle, message->main_buffer(), message->main_buffer_size(), &write_size);
66 return write_size == message->main_buffer_size();
69 // -----------------------------------------------------------------------------
71 class RawChannelTest : public testing::Test {
73 RawChannelTest() : io_thread_(test::TestIOThread::kManualStart) {}
74 virtual ~RawChannelTest() {}
76 virtual void SetUp() OVERRIDE {
77 embedder::PlatformChannelPair channel_pair;
78 handles[0] = channel_pair.PassServerHandle();
79 handles[1] = channel_pair.PassClientHandle();
83 virtual void TearDown() OVERRIDE {
90 test::TestIOThread* io_thread() { return &io_thread_; }
92 embedder::ScopedPlatformHandle handles[2];
95 test::TestIOThread io_thread_;
97 DISALLOW_COPY_AND_ASSIGN(RawChannelTest);
100 // RawChannelTest.WriteMessage -------------------------------------------------
102 class WriteOnlyRawChannelDelegate : public RawChannel::Delegate {
104 WriteOnlyRawChannelDelegate() {}
105 virtual ~WriteOnlyRawChannelDelegate() {}
107 // |RawChannel::Delegate| implementation:
108 virtual void OnReadMessage(
109 const MessageInTransit::View& /*message_view*/) OVERRIDE {
112 virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
113 // We'll get a read error when the connection is closed.
114 CHECK_EQ(fatal_error, FATAL_ERROR_FAILED_READ);
118 DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate);
121 static const int64_t kMessageReaderSleepMs = 1;
122 static const size_t kMessageReaderMaxPollIterations = 3000;
124 class TestMessageReaderAndChecker {
126 explicit TestMessageReaderAndChecker(embedder::PlatformHandle handle)
128 ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); }
130 bool ReadAndCheckNextMessage(uint32_t expected_size) {
131 unsigned char buffer[4096];
133 for (size_t i = 0; i < kMessageReaderMaxPollIterations;) {
134 size_t read_size = 0;
135 CHECK(mojo::test::NonBlockingRead(handle_, buffer, sizeof(buffer),
138 // Append newly-read data to |bytes_|.
139 bytes_.insert(bytes_.end(), buffer, buffer + read_size);
141 // If we have the header....
143 if (MessageInTransit::GetNextMessageSize(
144 bytes_.empty() ? NULL : &bytes_[0],
147 // If we've read the whole message....
148 if (bytes_.size() >= message_size) {
150 MessageInTransit::View message_view(message_size, &bytes_[0]);
151 CHECK_EQ(message_view.main_buffer_size(), message_size);
153 if (message_view.num_bytes() != expected_size) {
154 LOG(ERROR) << "Wrong size: " << message_size << " instead of "
155 << expected_size << " bytes.";
157 } else if (!CheckMessageData(message_view.bytes(),
158 message_view.num_bytes())) {
159 LOG(ERROR) << "Incorrect message bytes.";
163 // Erase message data.
164 bytes_.erase(bytes_.begin(),
166 message_view.main_buffer_size());
171 if (static_cast<size_t>(read_size) < sizeof(buffer)) {
173 base::PlatformThread::Sleep(
174 base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs));
178 LOG(ERROR) << "Too many iterations.";
183 const embedder::PlatformHandle handle_;
185 // The start of the received data should always be on a message boundary.
186 std::vector<unsigned char> bytes_;
188 DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker);
191 // Tests writing (and verifies reading using our own custom reader).
192 TEST_F(RawChannelTest, WriteMessage) {
193 WriteOnlyRawChannelDelegate delegate;
194 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
195 TestMessageReaderAndChecker checker(handles[1].get());
196 io_thread()->PostTaskAndWait(FROM_HERE,
197 base::Bind(&InitOnIOThread, rc.get(),
198 base::Unretained(&delegate)));
200 // Write and read, for a variety of sizes.
201 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
202 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
203 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
206 // Write/queue and read afterwards, for a variety of sizes.
207 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
208 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
209 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
210 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
212 io_thread()->PostTaskAndWait(FROM_HERE,
213 base::Bind(&RawChannel::Shutdown,
214 base::Unretained(rc.get())));
217 // RawChannelTest.OnReadMessage ------------------------------------------------
219 class ReadCheckerRawChannelDelegate : public RawChannel::Delegate {
221 ReadCheckerRawChannelDelegate()
222 : done_event_(false, false),
224 virtual ~ReadCheckerRawChannelDelegate() {}
226 // |RawChannel::Delegate| implementation (called on the I/O thread):
227 virtual void OnReadMessage(
228 const MessageInTransit::View& message_view) OVERRIDE {
230 size_t expected_size;
231 bool should_signal = false;
233 base::AutoLock locker(lock_);
234 CHECK_LT(position_, expected_sizes_.size());
235 position = position_;
236 expected_size = expected_sizes_[position];
238 if (position_ >= expected_sizes_.size())
239 should_signal = true;
242 EXPECT_EQ(expected_size, message_view.num_bytes()) << position;
243 if (message_view.num_bytes() == expected_size) {
244 EXPECT_TRUE(CheckMessageData(message_view.bytes(),
245 message_view.num_bytes())) << position;
249 done_event_.Signal();
251 virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
252 // We'll get a read error when the connection is closed.
253 CHECK_EQ(fatal_error, FATAL_ERROR_FAILED_READ);
256 // Waits for all the messages (of sizes |expected_sizes_|) to be seen.
261 void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) {
262 base::AutoLock locker(lock_);
263 CHECK_EQ(position_, expected_sizes_.size());
264 expected_sizes_ = expected_sizes;
269 base::WaitableEvent done_event_;
271 base::Lock lock_; // Protects the following members.
272 std::vector<uint32_t> expected_sizes_;
275 DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate);
278 // Tests reading (writing using our own custom writer).
279 TEST_F(RawChannelTest, OnReadMessage) {
280 ReadCheckerRawChannelDelegate delegate;
281 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
282 io_thread()->PostTaskAndWait(FROM_HERE,
283 base::Bind(&InitOnIOThread, rc.get(),
284 base::Unretained(&delegate)));
286 // Write and read, for a variety of sizes.
287 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
288 delegate.SetExpectedSizes(std::vector<uint32_t>(1, size));
290 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
295 // Set up reader and write as fast as we can.
296 // Write/queue and read afterwards, for a variety of sizes.
297 std::vector<uint32_t> expected_sizes;
298 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
299 expected_sizes.push_back(size);
300 delegate.SetExpectedSizes(expected_sizes);
301 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
302 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
305 io_thread()->PostTaskAndWait(FROM_HERE,
306 base::Bind(&RawChannel::Shutdown,
307 base::Unretained(rc.get())));
310 // RawChannelTest.WriteMessageAndOnReadMessage ---------------------------------
312 class RawChannelWriterThread : public base::SimpleThread {
314 RawChannelWriterThread(RawChannel* raw_channel, size_t write_count)
315 : base::SimpleThread("raw_channel_writer_thread"),
316 raw_channel_(raw_channel),
317 left_to_write_(write_count) {
320 virtual ~RawChannelWriterThread() {
325 virtual void Run() OVERRIDE {
326 static const int kMaxRandomMessageSize = 25000;
328 while (left_to_write_-- > 0) {
329 EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage(
330 static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize)))));
334 RawChannel* const raw_channel_;
335 size_t left_to_write_;
337 DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread);
340 class ReadCountdownRawChannelDelegate : public RawChannel::Delegate {
342 explicit ReadCountdownRawChannelDelegate(size_t expected_count)
343 : done_event_(false, false),
344 expected_count_(expected_count),
346 virtual ~ReadCountdownRawChannelDelegate() {}
348 // |RawChannel::Delegate| implementation (called on the I/O thread):
349 virtual void OnReadMessage(
350 const MessageInTransit::View& message_view) OVERRIDE {
351 EXPECT_LT(count_, expected_count_);
354 EXPECT_TRUE(CheckMessageData(message_view.bytes(),
355 message_view.num_bytes()));
357 if (count_ >= expected_count_)
358 done_event_.Signal();
360 virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
361 // We'll get a read error when the connection is closed.
362 CHECK_EQ(fatal_error, FATAL_ERROR_FAILED_READ);
365 // Waits for all the messages to have been seen.
371 base::WaitableEvent done_event_;
372 size_t expected_count_;
375 DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate);
378 TEST_F(RawChannelTest, WriteMessageAndOnReadMessage) {
379 static const size_t kNumWriterThreads = 10;
380 static const size_t kNumWriteMessagesPerThread = 4000;
382 WriteOnlyRawChannelDelegate writer_delegate;
383 scoped_ptr<RawChannel> writer_rc(RawChannel::Create(handles[0].Pass()));
384 io_thread()->PostTaskAndWait(FROM_HERE,
385 base::Bind(&InitOnIOThread, writer_rc.get(),
386 base::Unretained(&writer_delegate)));
388 ReadCountdownRawChannelDelegate reader_delegate(
389 kNumWriterThreads * kNumWriteMessagesPerThread);
390 scoped_ptr<RawChannel> reader_rc(RawChannel::Create(handles[1].Pass()));
391 io_thread()->PostTaskAndWait(FROM_HERE,
392 base::Bind(&InitOnIOThread, reader_rc.get(),
393 base::Unretained(&reader_delegate)));
396 ScopedVector<RawChannelWriterThread> writer_threads;
397 for (size_t i = 0; i < kNumWriterThreads; i++) {
398 writer_threads.push_back(new RawChannelWriterThread(
399 writer_rc.get(), kNumWriteMessagesPerThread));
401 for (size_t i = 0; i < writer_threads.size(); i++)
402 writer_threads[i]->Start();
403 } // Joins all the writer threads.
405 // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be
406 // any, but we want to know about them.)
407 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100));
409 // Wait for reading to finish.
410 reader_delegate.Wait();
412 io_thread()->PostTaskAndWait(FROM_HERE,
413 base::Bind(&RawChannel::Shutdown,
414 base::Unretained(reader_rc.get())));
416 io_thread()->PostTaskAndWait(FROM_HERE,
417 base::Bind(&RawChannel::Shutdown,
418 base::Unretained(writer_rc.get())));
421 // RawChannelTest.OnFatalError -------------------------------------------------
423 class FatalErrorRecordingRawChannelDelegate
424 : public ReadCountdownRawChannelDelegate {
426 FatalErrorRecordingRawChannelDelegate(size_t expected_read_count,
427 bool expect_read_error,
428 bool expect_write_error)
429 : ReadCountdownRawChannelDelegate(expected_read_count),
430 got_read_fatal_error_event_(false, false),
431 got_write_fatal_error_event_(false, false),
432 expecting_read_error_(expect_read_error),
433 expecting_write_error_(expect_write_error) {
436 virtual ~FatalErrorRecordingRawChannelDelegate() {}
438 virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
439 if (fatal_error == FATAL_ERROR_FAILED_READ) {
440 ASSERT_TRUE(expecting_read_error_);
441 expecting_read_error_ = false;
442 got_read_fatal_error_event_.Signal();
443 } else if (fatal_error == FATAL_ERROR_FAILED_WRITE) {
444 ASSERT_TRUE(expecting_write_error_);
445 expecting_write_error_ = false;
446 got_write_fatal_error_event_.Signal();
452 void WaitForReadFatalError() { got_read_fatal_error_event_.Wait(); }
453 void WaitForWriteFatalError() { got_write_fatal_error_event_.Wait(); }
456 base::WaitableEvent got_read_fatal_error_event_;
457 base::WaitableEvent got_write_fatal_error_event_;
459 bool expecting_read_error_;
460 bool expecting_write_error_;
462 DISALLOW_COPY_AND_ASSIGN(FatalErrorRecordingRawChannelDelegate);
465 // Tests fatal errors.
466 TEST_F(RawChannelTest, OnFatalError) {
467 FatalErrorRecordingRawChannelDelegate delegate(0, true, true);
468 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
469 io_thread()->PostTaskAndWait(FROM_HERE,
470 base::Bind(&InitOnIOThread, rc.get(),
471 base::Unretained(&delegate)));
473 // Close the handle of the other end, which should make writing fail.
476 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
478 // We should get a write fatal error.
479 delegate.WaitForWriteFatalError();
481 // We should also get a read fatal error.
482 delegate.WaitForReadFatalError();
484 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(2)));
486 // Sleep a bit, to make sure we don't get another |OnFatalError()|
487 // notification. (If we actually get another one, |OnFatalError()| crashes.)
488 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(20));
490 io_thread()->PostTaskAndWait(FROM_HERE,
491 base::Bind(&RawChannel::Shutdown,
492 base::Unretained(rc.get())));
495 // RawChannelTest.ReadUnaffectedByWriteFatalError ------------------------------
497 TEST_F(RawChannelTest, ReadUnaffectedByWriteFatalError) {
498 const size_t kMessageCount = 5;
500 // Write a few messages into the other end.
501 uint32_t message_size = 1;
502 for (size_t i = 0; i < kMessageCount;
503 i++, message_size += message_size / 2 + 1)
504 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size));
506 // Close the other end, which should make writing fail.
509 // Only start up reading here. The system buffer should still contain the
510 // messages that were written.
511 FatalErrorRecordingRawChannelDelegate delegate(kMessageCount, true, true);
512 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
513 io_thread()->PostTaskAndWait(FROM_HERE,
514 base::Bind(&InitOnIOThread, rc.get(),
515 base::Unretained(&delegate)));
517 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
519 // We should definitely get a write fatal error.
520 delegate.WaitForWriteFatalError();
522 // Wait for reading to finish. A writing failure shouldn't affect reading.
525 // And then we should get a read fatal error.
526 delegate.WaitForReadFatalError();
528 io_thread()->PostTaskAndWait(FROM_HERE,
529 base::Bind(&RawChannel::Shutdown,
530 base::Unretained(rc.get())));
533 // RawChannelTest.WriteMessageAfterShutdown ------------------------------------
535 // Makes sure that calling |WriteMessage()| after |Shutdown()| behaves
537 TEST_F(RawChannelTest, WriteMessageAfterShutdown) {
538 WriteOnlyRawChannelDelegate delegate;
539 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
540 io_thread()->PostTaskAndWait(FROM_HERE,
541 base::Bind(&InitOnIOThread, rc.get(),
542 base::Unretained(&delegate)));
543 io_thread()->PostTaskAndWait(FROM_HERE,
544 base::Bind(&RawChannel::Shutdown,
545 base::Unretained(rc.get())));
547 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
550 // RawChannelTest.ShutdownOnReadMessage ----------------------------------------
552 class ShutdownOnReadMessageRawChannelDelegate : public RawChannel::Delegate {
554 explicit ShutdownOnReadMessageRawChannelDelegate(RawChannel* raw_channel)
555 : raw_channel_(raw_channel),
556 done_event_(false, false),
557 did_shutdown_(false) {}
558 virtual ~ShutdownOnReadMessageRawChannelDelegate() {}
560 // |RawChannel::Delegate| implementation (called on the I/O thread):
561 virtual void OnReadMessage(
562 const MessageInTransit::View& message_view) OVERRIDE {
563 EXPECT_FALSE(did_shutdown_);
564 EXPECT_TRUE(CheckMessageData(message_view.bytes(),
565 message_view.num_bytes()));
566 raw_channel_->Shutdown();
567 did_shutdown_ = true;
568 done_event_.Signal();
570 virtual void OnFatalError(FatalError /*fatal_error*/) OVERRIDE {
571 CHECK(false); // Should not get called.
574 // Waits for shutdown.
577 EXPECT_TRUE(did_shutdown_);
581 RawChannel* const raw_channel_;
582 base::WaitableEvent done_event_;
585 DISALLOW_COPY_AND_ASSIGN(ShutdownOnReadMessageRawChannelDelegate);
588 TEST_F(RawChannelTest, ShutdownOnReadMessage) {
589 // Write a few messages into the other end.
590 for (size_t count = 0; count < 5; count++)
591 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), 10));
593 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
594 ShutdownOnReadMessageRawChannelDelegate delegate(rc.get());
595 io_thread()->PostTaskAndWait(FROM_HERE,
596 base::Bind(&InitOnIOThread, rc.get(),
597 base::Unretained(&delegate)));
599 // Wait for the delegate, which will shut the |RawChannel| down.
603 // RawChannelTest.ShutdownOnFatalError{Read, Write} ----------------------------
605 class ShutdownOnFatalErrorRawChannelDelegate : public RawChannel::Delegate {
607 ShutdownOnFatalErrorRawChannelDelegate(RawChannel* raw_channel,
608 FatalError shutdown_on_error_type)
609 : raw_channel_(raw_channel),
610 shutdown_on_error_type_(shutdown_on_error_type),
611 done_event_(false, false),
612 did_shutdown_(false) {}
613 virtual ~ShutdownOnFatalErrorRawChannelDelegate() {}
615 // |RawChannel::Delegate| implementation (called on the I/O thread):
616 virtual void OnReadMessage(
617 const MessageInTransit::View& /*message_view*/) OVERRIDE {
618 CHECK(false); // Should not get called.
620 virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
621 EXPECT_FALSE(did_shutdown_);
622 if (fatal_error != shutdown_on_error_type_)
624 raw_channel_->Shutdown();
625 did_shutdown_ = true;
626 done_event_.Signal();
629 // Waits for shutdown.
632 EXPECT_TRUE(did_shutdown_);
636 RawChannel* const raw_channel_;
637 const FatalError shutdown_on_error_type_;
638 base::WaitableEvent done_event_;
641 DISALLOW_COPY_AND_ASSIGN(ShutdownOnFatalErrorRawChannelDelegate);
644 TEST_F(RawChannelTest, ShutdownOnFatalErrorRead) {
645 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
646 ShutdownOnFatalErrorRawChannelDelegate delegate(
647 rc.get(), RawChannel::Delegate::FATAL_ERROR_FAILED_READ);
648 io_thread()->PostTaskAndWait(FROM_HERE,
649 base::Bind(&InitOnIOThread, rc.get(),
650 base::Unretained(&delegate)));
652 // Close the handle of the other end, which should stuff fail.
655 // Wait for the delegate, which will shut the |RawChannel| down.
659 TEST_F(RawChannelTest, ShutdownOnFatalErrorWrite) {
660 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
661 ShutdownOnFatalErrorRawChannelDelegate delegate(
662 rc.get(), RawChannel::Delegate::FATAL_ERROR_FAILED_WRITE);
663 io_thread()->PostTaskAndWait(FROM_HERE,
664 base::Bind(&InitOnIOThread, rc.get(),
665 base::Unretained(&delegate)));
667 // Close the handle of the other end, which should stuff fail.
670 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
672 // Wait for the delegate, which will shut the |RawChannel| down.
677 } // namespace system