Upstream version 7.36.149.0
[platform/framework/web/crosswalk.git] / src / mojo / system / raw_channel_unittest.cc
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/system/raw_channel.h"
6
7 #include <stdint.h>
8
9 #include <vector>
10
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/macros.h"
15 #include "base/memory/scoped_ptr.h"
16 #include "base/memory/scoped_vector.h"
17 #include "base/rand_util.h"
18 #include "base/synchronization/lock.h"
19 #include "base/synchronization/waitable_event.h"
20 #include "base/threading/platform_thread.h"  // For |Sleep()|.
21 #include "base/threading/simple_thread.h"
22 #include "base/time/time.h"
23 #include "build/build_config.h"
24 #include "mojo/common/test/test_utils.h"
25 #include "mojo/embedder/platform_channel_pair.h"
26 #include "mojo/embedder/platform_handle.h"
27 #include "mojo/embedder/scoped_platform_handle.h"
28 #include "mojo/system/message_in_transit.h"
29 #include "mojo/system/test_utils.h"
30 #include "testing/gtest/include/gtest/gtest.h"
31
32 namespace mojo {
33 namespace system {
34 namespace {
35
36 scoped_ptr<MessageInTransit> MakeTestMessage(uint32_t num_bytes) {
37   std::vector<unsigned char> bytes(num_bytes, 0);
38   for (size_t i = 0; i < num_bytes; i++)
39     bytes[i] = static_cast<unsigned char>(i + num_bytes);
40   return make_scoped_ptr(
41       new MessageInTransit(MessageInTransit::kTypeMessagePipeEndpoint,
42                            MessageInTransit::kSubtypeMessagePipeEndpointData,
43                            num_bytes, bytes.empty() ? NULL : &bytes[0]));
44 }
45
46 bool CheckMessageData(const void* bytes, uint32_t num_bytes) {
47   const unsigned char* b = static_cast<const unsigned char*>(bytes);
48   for (uint32_t i = 0; i < num_bytes; i++) {
49     if (b[i] != static_cast<unsigned char>(i + num_bytes))
50       return false;
51   }
52   return true;
53 }
54
55 void InitOnIOThread(RawChannel* raw_channel, RawChannel::Delegate* delegate) {
56   CHECK(raw_channel->Init(delegate));
57 }
58
59 bool WriteTestMessageToHandle(const embedder::PlatformHandle& handle,
60                               uint32_t num_bytes) {
61   scoped_ptr<MessageInTransit> message(MakeTestMessage(num_bytes));
62
63   size_t write_size = 0;
64   mojo::test::BlockingWrite(
65       handle, message->main_buffer(), message->main_buffer_size(), &write_size);
66   return write_size == message->main_buffer_size();
67 }
68
69 // -----------------------------------------------------------------------------
70
71 class RawChannelTest : public testing::Test {
72  public:
73   RawChannelTest() : io_thread_(test::TestIOThread::kManualStart) {}
74   virtual ~RawChannelTest() {}
75
76   virtual void SetUp() OVERRIDE {
77     embedder::PlatformChannelPair channel_pair;
78     handles[0] = channel_pair.PassServerHandle();
79     handles[1] = channel_pair.PassClientHandle();
80     io_thread_.Start();
81   }
82
83   virtual void TearDown() OVERRIDE {
84     io_thread_.Stop();
85     handles[0].reset();
86     handles[1].reset();
87   }
88
89  protected:
90   test::TestIOThread* io_thread() { return &io_thread_; }
91
92   embedder::ScopedPlatformHandle handles[2];
93
94  private:
95   test::TestIOThread io_thread_;
96
97   DISALLOW_COPY_AND_ASSIGN(RawChannelTest);
98 };
99
100 // RawChannelTest.WriteMessage -------------------------------------------------
101
102 class WriteOnlyRawChannelDelegate : public RawChannel::Delegate {
103  public:
104   WriteOnlyRawChannelDelegate() {}
105   virtual ~WriteOnlyRawChannelDelegate() {}
106
107   // |RawChannel::Delegate| implementation:
108   virtual void OnReadMessage(
109       const MessageInTransit::View& /*message_view*/) OVERRIDE {
110     NOTREACHED();
111   }
112   virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
113     // We'll get a read error when the connection is closed.
114     CHECK_EQ(fatal_error, FATAL_ERROR_FAILED_READ);
115   }
116
117  private:
118   DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate);
119 };
120
121 static const int64_t kMessageReaderSleepMs = 1;
122 static const size_t kMessageReaderMaxPollIterations = 3000;
123
124 class TestMessageReaderAndChecker {
125  public:
126   explicit TestMessageReaderAndChecker(embedder::PlatformHandle handle)
127       : handle_(handle) {}
128   ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); }
129
130   bool ReadAndCheckNextMessage(uint32_t expected_size) {
131     unsigned char buffer[4096];
132
133     for (size_t i = 0; i < kMessageReaderMaxPollIterations;) {
134       size_t read_size = 0;
135       CHECK(mojo::test::NonBlockingRead(handle_, buffer, sizeof(buffer),
136                                         &read_size));
137
138       // Append newly-read data to |bytes_|.
139       bytes_.insert(bytes_.end(), buffer, buffer + read_size);
140
141       // If we have the header....
142       size_t message_size;
143       if (MessageInTransit::GetNextMessageSize(
144               bytes_.empty() ? NULL : &bytes_[0],
145               bytes_.size(),
146               &message_size)) {
147         // If we've read the whole message....
148         if (bytes_.size() >= message_size) {
149           bool rv = true;
150           MessageInTransit::View message_view(message_size, &bytes_[0]);
151           CHECK_EQ(message_view.main_buffer_size(), message_size);
152
153           if (message_view.num_bytes() != expected_size) {
154             LOG(ERROR) << "Wrong size: " << message_size << " instead of "
155                        << expected_size << " bytes.";
156             rv = false;
157           } else if (!CheckMessageData(message_view.bytes(),
158                                        message_view.num_bytes())) {
159             LOG(ERROR) << "Incorrect message bytes.";
160             rv = false;
161           }
162
163           // Erase message data.
164           bytes_.erase(bytes_.begin(),
165                        bytes_.begin() +
166                            message_view.main_buffer_size());
167           return rv;
168         }
169       }
170
171       if (static_cast<size_t>(read_size) < sizeof(buffer)) {
172         i++;
173         base::PlatformThread::Sleep(
174             base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs));
175       }
176     }
177
178     LOG(ERROR) << "Too many iterations.";
179     return false;
180   }
181
182  private:
183   const embedder::PlatformHandle handle_;
184
185   // The start of the received data should always be on a message boundary.
186   std::vector<unsigned char> bytes_;
187
188   DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker);
189 };
190
191 // Tests writing (and verifies reading using our own custom reader).
192 TEST_F(RawChannelTest, WriteMessage) {
193   WriteOnlyRawChannelDelegate delegate;
194   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
195   TestMessageReaderAndChecker checker(handles[1].get());
196   io_thread()->PostTaskAndWait(FROM_HERE,
197                                base::Bind(&InitOnIOThread, rc.get(),
198                                           base::Unretained(&delegate)));
199
200   // Write and read, for a variety of sizes.
201   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
202     EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
203     EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
204   }
205
206   // Write/queue and read afterwards, for a variety of sizes.
207   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
208     EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
209   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
210     EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
211
212   io_thread()->PostTaskAndWait(FROM_HERE,
213                                base::Bind(&RawChannel::Shutdown,
214                                           base::Unretained(rc.get())));
215 }
216
217 // RawChannelTest.OnReadMessage ------------------------------------------------
218
219 class ReadCheckerRawChannelDelegate : public RawChannel::Delegate {
220  public:
221   ReadCheckerRawChannelDelegate()
222       : done_event_(false, false),
223         position_(0) {}
224   virtual ~ReadCheckerRawChannelDelegate() {}
225
226   // |RawChannel::Delegate| implementation (called on the I/O thread):
227   virtual void OnReadMessage(
228       const MessageInTransit::View& message_view) OVERRIDE {
229     size_t position;
230     size_t expected_size;
231     bool should_signal = false;
232     {
233       base::AutoLock locker(lock_);
234       CHECK_LT(position_, expected_sizes_.size());
235       position = position_;
236       expected_size = expected_sizes_[position];
237       position_++;
238       if (position_ >= expected_sizes_.size())
239         should_signal = true;
240     }
241
242     EXPECT_EQ(expected_size, message_view.num_bytes()) << position;
243     if (message_view.num_bytes() == expected_size) {
244       EXPECT_TRUE(CheckMessageData(message_view.bytes(),
245                   message_view.num_bytes())) << position;
246     }
247
248     if (should_signal)
249       done_event_.Signal();
250   }
251   virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
252     // We'll get a read error when the connection is closed.
253     CHECK_EQ(fatal_error, FATAL_ERROR_FAILED_READ);
254   }
255
256   // Waits for all the messages (of sizes |expected_sizes_|) to be seen.
257   void Wait() {
258     done_event_.Wait();
259   }
260
261   void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) {
262     base::AutoLock locker(lock_);
263     CHECK_EQ(position_, expected_sizes_.size());
264     expected_sizes_ = expected_sizes;
265     position_ = 0;
266   }
267
268  private:
269   base::WaitableEvent done_event_;
270
271   base::Lock lock_;  // Protects the following members.
272   std::vector<uint32_t> expected_sizes_;
273   size_t position_;
274
275   DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate);
276 };
277
278 // Tests reading (writing using our own custom writer).
279 TEST_F(RawChannelTest, OnReadMessage) {
280   ReadCheckerRawChannelDelegate delegate;
281   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
282   io_thread()->PostTaskAndWait(FROM_HERE,
283                                base::Bind(&InitOnIOThread, rc.get(),
284                                           base::Unretained(&delegate)));
285
286   // Write and read, for a variety of sizes.
287   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
288     delegate.SetExpectedSizes(std::vector<uint32_t>(1, size));
289
290     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
291
292     delegate.Wait();
293   }
294
295   // Set up reader and write as fast as we can.
296   // Write/queue and read afterwards, for a variety of sizes.
297   std::vector<uint32_t> expected_sizes;
298   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
299     expected_sizes.push_back(size);
300   delegate.SetExpectedSizes(expected_sizes);
301   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
302     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
303   delegate.Wait();
304
305   io_thread()->PostTaskAndWait(FROM_HERE,
306                                base::Bind(&RawChannel::Shutdown,
307                                           base::Unretained(rc.get())));
308 }
309
310 // RawChannelTest.WriteMessageAndOnReadMessage ---------------------------------
311
312 class RawChannelWriterThread : public base::SimpleThread {
313  public:
314   RawChannelWriterThread(RawChannel* raw_channel, size_t write_count)
315       : base::SimpleThread("raw_channel_writer_thread"),
316         raw_channel_(raw_channel),
317         left_to_write_(write_count) {
318   }
319
320   virtual ~RawChannelWriterThread() {
321     Join();
322   }
323
324  private:
325   virtual void Run() OVERRIDE {
326     static const int kMaxRandomMessageSize = 25000;
327
328     while (left_to_write_-- > 0) {
329       EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage(
330           static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize)))));
331     }
332   }
333
334   RawChannel* const raw_channel_;
335   size_t left_to_write_;
336
337   DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread);
338 };
339
340 class ReadCountdownRawChannelDelegate : public RawChannel::Delegate {
341  public:
342   explicit ReadCountdownRawChannelDelegate(size_t expected_count)
343       : done_event_(false, false),
344         expected_count_(expected_count),
345         count_(0) {}
346   virtual ~ReadCountdownRawChannelDelegate() {}
347
348   // |RawChannel::Delegate| implementation (called on the I/O thread):
349   virtual void OnReadMessage(
350       const MessageInTransit::View& message_view) OVERRIDE {
351     EXPECT_LT(count_, expected_count_);
352     count_++;
353
354     EXPECT_TRUE(CheckMessageData(message_view.bytes(),
355                 message_view.num_bytes()));
356
357     if (count_ >= expected_count_)
358       done_event_.Signal();
359   }
360   virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
361     // We'll get a read error when the connection is closed.
362     CHECK_EQ(fatal_error, FATAL_ERROR_FAILED_READ);
363   }
364
365   // Waits for all the messages to have been seen.
366   void Wait() {
367     done_event_.Wait();
368   }
369
370  private:
371   base::WaitableEvent done_event_;
372   size_t expected_count_;
373   size_t count_;
374
375   DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate);
376 };
377
378 TEST_F(RawChannelTest, WriteMessageAndOnReadMessage) {
379   static const size_t kNumWriterThreads = 10;
380   static const size_t kNumWriteMessagesPerThread = 4000;
381
382   WriteOnlyRawChannelDelegate writer_delegate;
383   scoped_ptr<RawChannel> writer_rc(RawChannel::Create(handles[0].Pass()));
384   io_thread()->PostTaskAndWait(FROM_HERE,
385                                base::Bind(&InitOnIOThread, writer_rc.get(),
386                                           base::Unretained(&writer_delegate)));
387
388   ReadCountdownRawChannelDelegate reader_delegate(
389       kNumWriterThreads * kNumWriteMessagesPerThread);
390   scoped_ptr<RawChannel> reader_rc(RawChannel::Create(handles[1].Pass()));
391   io_thread()->PostTaskAndWait(FROM_HERE,
392                                base::Bind(&InitOnIOThread, reader_rc.get(),
393                                           base::Unretained(&reader_delegate)));
394
395   {
396     ScopedVector<RawChannelWriterThread> writer_threads;
397     for (size_t i = 0; i < kNumWriterThreads; i++) {
398       writer_threads.push_back(new RawChannelWriterThread(
399           writer_rc.get(), kNumWriteMessagesPerThread));
400     }
401     for (size_t i = 0; i < writer_threads.size(); i++)
402       writer_threads[i]->Start();
403   }  // Joins all the writer threads.
404
405   // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be
406   // any, but we want to know about them.)
407   base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100));
408
409   // Wait for reading to finish.
410   reader_delegate.Wait();
411
412   io_thread()->PostTaskAndWait(FROM_HERE,
413                                base::Bind(&RawChannel::Shutdown,
414                                           base::Unretained(reader_rc.get())));
415
416   io_thread()->PostTaskAndWait(FROM_HERE,
417                                base::Bind(&RawChannel::Shutdown,
418                                           base::Unretained(writer_rc.get())));
419 }
420
421 // RawChannelTest.OnFatalError -------------------------------------------------
422
423 class FatalErrorRecordingRawChannelDelegate
424     : public ReadCountdownRawChannelDelegate {
425  public:
426   FatalErrorRecordingRawChannelDelegate(size_t expected_read_count,
427                                         bool expect_read_error,
428                                         bool expect_write_error)
429       : ReadCountdownRawChannelDelegate(expected_read_count),
430         got_read_fatal_error_event_(false, false),
431         got_write_fatal_error_event_(false, false),
432         expecting_read_error_(expect_read_error),
433         expecting_write_error_(expect_write_error) {
434   }
435
436   virtual ~FatalErrorRecordingRawChannelDelegate() {}
437
438   virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
439     if (fatal_error == FATAL_ERROR_FAILED_READ) {
440       ASSERT_TRUE(expecting_read_error_);
441       expecting_read_error_ = false;
442       got_read_fatal_error_event_.Signal();
443     } else if (fatal_error == FATAL_ERROR_FAILED_WRITE) {
444       ASSERT_TRUE(expecting_write_error_);
445       expecting_write_error_ = false;
446       got_write_fatal_error_event_.Signal();
447     } else {
448       ASSERT_TRUE(false);
449     }
450   }
451
452   void WaitForReadFatalError() { got_read_fatal_error_event_.Wait(); }
453   void WaitForWriteFatalError() { got_write_fatal_error_event_.Wait(); }
454
455  private:
456   base::WaitableEvent got_read_fatal_error_event_;
457   base::WaitableEvent got_write_fatal_error_event_;
458
459   bool expecting_read_error_;
460   bool expecting_write_error_;
461
462   DISALLOW_COPY_AND_ASSIGN(FatalErrorRecordingRawChannelDelegate);
463 };
464
465 // Tests fatal errors.
466 TEST_F(RawChannelTest, OnFatalError) {
467   FatalErrorRecordingRawChannelDelegate delegate(0, true, true);
468   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
469   io_thread()->PostTaskAndWait(FROM_HERE,
470                                base::Bind(&InitOnIOThread, rc.get(),
471                                           base::Unretained(&delegate)));
472
473   // Close the handle of the other end, which should make writing fail.
474   handles[1].reset();
475
476   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
477
478   // We should get a write fatal error.
479   delegate.WaitForWriteFatalError();
480
481   // We should also get a read fatal error.
482   delegate.WaitForReadFatalError();
483
484   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(2)));
485
486   // Sleep a bit, to make sure we don't get another |OnFatalError()|
487   // notification. (If we actually get another one, |OnFatalError()| crashes.)
488   base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(20));
489
490   io_thread()->PostTaskAndWait(FROM_HERE,
491                                base::Bind(&RawChannel::Shutdown,
492                                           base::Unretained(rc.get())));
493 }
494
495 // RawChannelTest.ReadUnaffectedByWriteFatalError ------------------------------
496
497 TEST_F(RawChannelTest, ReadUnaffectedByWriteFatalError) {
498   const size_t kMessageCount = 5;
499
500   // Write a few messages into the other end.
501   uint32_t message_size = 1;
502   for (size_t i = 0; i < kMessageCount;
503        i++, message_size += message_size / 2 + 1)
504     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size));
505
506   // Close the other end, which should make writing fail.
507   handles[1].reset();
508
509   // Only start up reading here. The system buffer should still contain the
510   // messages that were written.
511   FatalErrorRecordingRawChannelDelegate delegate(kMessageCount, true, true);
512   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
513   io_thread()->PostTaskAndWait(FROM_HERE,
514                                base::Bind(&InitOnIOThread, rc.get(),
515                                           base::Unretained(&delegate)));
516
517   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
518
519   // We should definitely get a write fatal error.
520   delegate.WaitForWriteFatalError();
521
522   // Wait for reading to finish. A writing failure shouldn't affect reading.
523   delegate.Wait();
524
525   // And then we should get a read fatal error.
526   delegate.WaitForReadFatalError();
527
528   io_thread()->PostTaskAndWait(FROM_HERE,
529                                base::Bind(&RawChannel::Shutdown,
530                                           base::Unretained(rc.get())));
531 }
532
533 // RawChannelTest.WriteMessageAfterShutdown ------------------------------------
534
535 // Makes sure that calling |WriteMessage()| after |Shutdown()| behaves
536 // correctly.
537 TEST_F(RawChannelTest, WriteMessageAfterShutdown) {
538   WriteOnlyRawChannelDelegate delegate;
539   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
540   io_thread()->PostTaskAndWait(FROM_HERE,
541                                base::Bind(&InitOnIOThread, rc.get(),
542                                           base::Unretained(&delegate)));
543   io_thread()->PostTaskAndWait(FROM_HERE,
544                                base::Bind(&RawChannel::Shutdown,
545                                           base::Unretained(rc.get())));
546
547   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
548 }
549
550 // RawChannelTest.ShutdownOnReadMessage ----------------------------------------
551
552 class ShutdownOnReadMessageRawChannelDelegate : public RawChannel::Delegate {
553  public:
554   explicit ShutdownOnReadMessageRawChannelDelegate(RawChannel* raw_channel)
555       : raw_channel_(raw_channel),
556         done_event_(false, false),
557         did_shutdown_(false) {}
558   virtual ~ShutdownOnReadMessageRawChannelDelegate() {}
559
560   // |RawChannel::Delegate| implementation (called on the I/O thread):
561   virtual void OnReadMessage(
562       const MessageInTransit::View& message_view) OVERRIDE {
563     EXPECT_FALSE(did_shutdown_);
564     EXPECT_TRUE(CheckMessageData(message_view.bytes(),
565                 message_view.num_bytes()));
566     raw_channel_->Shutdown();
567     did_shutdown_ = true;
568     done_event_.Signal();
569   }
570   virtual void OnFatalError(FatalError /*fatal_error*/) OVERRIDE {
571     CHECK(false);  // Should not get called.
572   }
573
574   // Waits for shutdown.
575   void Wait() {
576     done_event_.Wait();
577     EXPECT_TRUE(did_shutdown_);
578   }
579
580  private:
581   RawChannel* const raw_channel_;
582   base::WaitableEvent done_event_;
583   bool did_shutdown_;
584
585   DISALLOW_COPY_AND_ASSIGN(ShutdownOnReadMessageRawChannelDelegate);
586 };
587
588 TEST_F(RawChannelTest, ShutdownOnReadMessage) {
589   // Write a few messages into the other end.
590   for (size_t count = 0; count < 5; count++)
591     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), 10));
592
593   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
594   ShutdownOnReadMessageRawChannelDelegate delegate(rc.get());
595   io_thread()->PostTaskAndWait(FROM_HERE,
596                                base::Bind(&InitOnIOThread, rc.get(),
597                                           base::Unretained(&delegate)));
598
599   // Wait for the delegate, which will shut the |RawChannel| down.
600   delegate.Wait();
601 }
602
603 // RawChannelTest.ShutdownOnFatalError{Read, Write} ----------------------------
604
605 class ShutdownOnFatalErrorRawChannelDelegate : public RawChannel::Delegate {
606  public:
607   ShutdownOnFatalErrorRawChannelDelegate(RawChannel* raw_channel,
608                                          FatalError shutdown_on_error_type)
609       : raw_channel_(raw_channel),
610         shutdown_on_error_type_(shutdown_on_error_type),
611         done_event_(false, false),
612         did_shutdown_(false) {}
613   virtual ~ShutdownOnFatalErrorRawChannelDelegate() {}
614
615   // |RawChannel::Delegate| implementation (called on the I/O thread):
616   virtual void OnReadMessage(
617       const MessageInTransit::View& /*message_view*/) OVERRIDE {
618     CHECK(false);  // Should not get called.
619   }
620   virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
621     EXPECT_FALSE(did_shutdown_);
622     if (fatal_error != shutdown_on_error_type_)
623       return;
624     raw_channel_->Shutdown();
625     did_shutdown_ = true;
626     done_event_.Signal();
627   }
628
629   // Waits for shutdown.
630   void Wait() {
631     done_event_.Wait();
632     EXPECT_TRUE(did_shutdown_);
633   }
634
635  private:
636   RawChannel* const raw_channel_;
637   const FatalError shutdown_on_error_type_;
638   base::WaitableEvent done_event_;
639   bool did_shutdown_;
640
641   DISALLOW_COPY_AND_ASSIGN(ShutdownOnFatalErrorRawChannelDelegate);
642 };
643
644 TEST_F(RawChannelTest, ShutdownOnFatalErrorRead) {
645   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
646   ShutdownOnFatalErrorRawChannelDelegate delegate(
647       rc.get(), RawChannel::Delegate::FATAL_ERROR_FAILED_READ);
648   io_thread()->PostTaskAndWait(FROM_HERE,
649                                base::Bind(&InitOnIOThread, rc.get(),
650                                           base::Unretained(&delegate)));
651
652   // Close the handle of the other end, which should stuff fail.
653   handles[1].reset();
654
655   // Wait for the delegate, which will shut the |RawChannel| down.
656   delegate.Wait();
657 }
658
659 TEST_F(RawChannelTest, ShutdownOnFatalErrorWrite) {
660   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
661   ShutdownOnFatalErrorRawChannelDelegate delegate(
662       rc.get(), RawChannel::Delegate::FATAL_ERROR_FAILED_WRITE);
663   io_thread()->PostTaskAndWait(FROM_HERE,
664                                base::Bind(&InitOnIOThread, rc.get(),
665                                           base::Unretained(&delegate)));
666
667   // Close the handle of the other end, which should stuff fail.
668   handles[1].reset();
669
670   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
671
672   // Wait for the delegate, which will shut the |RawChannel| down.
673   delegate.Wait();
674 }
675
676 }  // namespace
677 }  // namespace system
678 }  // namespace mojo