1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "components/cast_channel/cast_transport.h"
10 #include "base/callback_helpers.h"
11 #include "base/containers/queue.h"
12 #include "base/macros.h"
13 #include "base/memory/ptr_util.h"
14 #include "base/run_loop.h"
15 #include "base/test/scoped_task_environment.h"
16 #include "base/test/simple_test_clock.h"
17 #include "components/cast_channel/cast_framer.h"
18 #include "components/cast_channel/cast_socket.h"
19 #include "components/cast_channel/cast_test_util.h"
20 #include "components/cast_channel/logger.h"
21 #include "components/cast_channel/proto/cast_channel.pb.h"
22 #include "net/base/completion_callback.h"
23 #include "net/base/net_errors.h"
24 #include "net/log/test_net_log.h"
25 #include "net/socket/socket.h"
26 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
27 #include "services/network/network_context.h"
28 #include "testing/gmock/include/gmock/gmock.h"
29 #include "testing/gtest/include/gtest/gtest.h"
33 using testing::InSequence;
34 using testing::Invoke;
35 using testing::NotNull;
36 using testing::Return;
37 using testing::WithArg;
39 namespace cast_channel {
42 const int kChannelId = 0;
44 // Mockable placeholder for write completion events.
45 class CompleteHandler {
48 MOCK_METHOD1(Complete, void(int result));
51 DISALLOW_COPY_AND_ASSIGN(CompleteHandler);
54 // Creates a CastMessage proto with the bare minimum required fields set.
55 CastMessage CreateCastMessage() {
57 output.set_protocol_version(CastMessage::CASTV2_1_0);
58 output.set_namespace_("x");
59 output.set_source_id("source");
60 output.set_destination_id("destination");
61 output.set_payload_type(CastMessage::STRING);
62 output.set_payload_utf8("payload");
66 // FIFO queue of completion callbacks. Outstanding write operations are
67 // Push()ed into the queue. Callback completion is simulated by invoking
68 // Pop() in the same order as Push().
69 class CompletionQueue {
72 ~CompletionQueue() { CHECK_EQ(0u, cb_queue_.size()); }
74 // Enqueues a pending completion callback.
75 void Push(const net::CompletionCallback& cb) { cb_queue_.push(cb); }
76 // Runs the next callback and removes it from the queue.
78 CHECK_GT(cb_queue_.size(), 0u);
79 cb_queue_.front().Run(rv);
84 base::queue<net::CompletionCallback> cb_queue_;
85 DISALLOW_COPY_AND_ASSIGN(CompletionQueue);
88 // GMock action that reads data from an IOBuffer and writes it to a string
91 // buf_idx (template parameter 0): 0-based index of the net::IOBuffer
92 // in the function mock arg list.
93 // size_idx (template parameter 1): 0-based index of the byte count arg.
94 // str: pointer to the string which will receive data from the buffer.
95 ACTION_TEMPLATE(ReadBufferToString,
96 HAS_2_TEMPLATE_PARAMS(int, buf_idx, int, size_idx),
97 AND_1_VALUE_PARAMS(str)) {
98 str->assign(testing::get<buf_idx>(args)->data(),
99 testing::get<size_idx>(args));
102 // GMock action that writes data from a string to an IOBuffer.
104 // buf_idx (template parameter 0): 0-based index of the IOBuffer arg.
105 // str: the string containing data to be written to the IOBuffer.
106 ACTION_TEMPLATE(FillBufferFromString,
107 HAS_1_TEMPLATE_PARAMS(int, buf_idx),
108 AND_1_VALUE_PARAMS(str)) {
109 memcpy(testing::get<buf_idx>(args)->data(), str.data(), str.size());
112 // GMock action that enqueues a write completion callback in a queue.
114 // buf_idx (template parameter 0): 0-based index of the CompletionCallback.
115 // completion_queue: a pointer to the CompletionQueue.
116 ACTION_TEMPLATE(EnqueueCallback,
117 HAS_1_TEMPLATE_PARAMS(int, cb_idx),
118 AND_1_VALUE_PARAMS(completion_queue)) {
119 completion_queue->Push(testing::get<cb_idx>(args));
124 class MockSocket : public cast_channel::CastTransportImpl::Channel {
126 void Read(net::IOBuffer* buffer,
128 net::CompletionOnceCallback callback) override {
129 Read(buffer, bytes, base::AdaptCallbackForRepeating(std::move(callback)));
132 void Write(net::IOBuffer* buffer,
134 net::CompletionOnceCallback callback) override {
135 Write(buffer, bytes, base::AdaptCallbackForRepeating(std::move(callback)));
139 void(net::IOBuffer* buf,
141 const net::CompletionCallback& callback));
144 void(net::IOBuffer* buf,
146 const net::CompletionCallback& callback));
149 class CastTransportTest : public testing::Test {
151 CastTransportTest() : logger_(new Logger()) {
152 delegate_ = new MockCastTransportDelegate;
153 transport_.reset(new CastTransportImpl(&mock_socket_, kChannelId,
154 CreateIPEndPointForTest(), logger_));
155 transport_->SetReadDelegate(base::WrapUnique(delegate_));
157 ~CastTransportTest() override {}
160 // Runs all pending tasks in the message loop.
161 void RunPendingTasks() {
162 base::RunLoop run_loop;
163 run_loop.RunUntilIdle();
166 base::test::ScopedTaskEnvironment task_environment_;
167 MockCastTransportDelegate* delegate_;
168 MockSocket mock_socket_;
170 std::unique_ptr<CastTransport> transport_;
173 // ----------------------------------------------------------------------------
174 // Asynchronous write tests
175 TEST_F(CastTransportTest, TestFullWriteAsync) {
176 CompletionQueue socket_cbs;
177 CompleteHandler write_handler;
180 CastMessage message = CreateCastMessage();
181 std::string serialized_message;
182 EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
184 EXPECT_CALL(mock_socket_, Write(NotNull(), serialized_message.size(), _))
185 .WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
186 EnqueueCallback<2>(&socket_cbs)));
187 EXPECT_CALL(write_handler, Complete(net::OK));
188 transport_->SendMessage(
190 base::Bind(&CompleteHandler::Complete, base::Unretained(&write_handler)));
192 socket_cbs.Pop(serialized_message.size());
194 EXPECT_EQ(serialized_message, output);
197 TEST_F(CastTransportTest, TestPartialWritesAsync) {
199 CompletionQueue socket_cbs;
200 CompleteHandler write_handler;
203 CastMessage message = CreateCastMessage();
204 std::string serialized_message;
205 EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
207 // Only one byte is written.
208 EXPECT_CALL(mock_socket_,
209 Write(NotNull(), static_cast<int>(serialized_message.size()), _))
210 .WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
211 EnqueueCallback<2>(&socket_cbs)));
212 // Remainder of bytes are written.
215 Write(NotNull(), static_cast<int>(serialized_message.size() - 1), _))
216 .WillOnce(DoAll(ReadBufferToString<0, 1>(&output),
217 EnqueueCallback<2>(&socket_cbs)));
219 transport_->SendMessage(
221 base::Bind(&CompleteHandler::Complete, base::Unretained(&write_handler)));
223 EXPECT_EQ(serialized_message, output);
227 EXPECT_CALL(write_handler, Complete(net::OK));
228 socket_cbs.Pop(serialized_message.size() - 1);
230 EXPECT_EQ(serialized_message.substr(1, serialized_message.size() - 1),
234 TEST_F(CastTransportTest, TestWriteFailureAsync) {
235 CompletionQueue socket_cbs;
236 CompleteHandler write_handler;
237 CastMessage message = CreateCastMessage();
238 EXPECT_CALL(mock_socket_, Write(NotNull(), _, _))
239 .WillOnce(EnqueueCallback<2>(&socket_cbs));
240 EXPECT_CALL(write_handler, Complete(net::ERR_FAILED));
241 EXPECT_CALL(*delegate_, OnError(ChannelError::CAST_SOCKET_ERROR));
242 transport_->SendMessage(
244 base::Bind(&CompleteHandler::Complete, base::Unretained(&write_handler)));
246 socket_cbs.Pop(net::ERR_CONNECTION_RESET);
248 EXPECT_EQ(ChannelEvent::SOCKET_WRITE,
249 logger_->GetLastError(kChannelId).channel_event);
250 EXPECT_EQ(net::ERR_CONNECTION_RESET,
251 logger_->GetLastError(kChannelId).net_return_value);
254 // ----------------------------------------------------------------------------
255 // Asynchronous read tests
256 TEST_F(CastTransportTest, TestFullReadAsync) {
258 CompletionQueue socket_cbs;
260 CastMessage message = CreateCastMessage();
261 std::string serialized_message;
262 EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
263 EXPECT_CALL(*delegate_, Start());
265 // Read bytes [0, 3].
266 EXPECT_CALL(mock_socket_,
267 Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
268 .WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
269 EnqueueCallback<2>(&socket_cbs)));
271 // Read bytes [4, n].
272 EXPECT_CALL(mock_socket_,
274 serialized_message.size() -
275 MessageFramer::MessageHeader::header_size(),
277 .WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
278 MessageFramer::MessageHeader::header_size(),
279 serialized_message.size() -
280 MessageFramer::MessageHeader::header_size())),
281 EnqueueCallback<2>(&socket_cbs)))
282 .RetiresOnSaturation();
284 EXPECT_CALL(*delegate_, OnMessage(EqualsProto(message)));
285 EXPECT_CALL(mock_socket_,
286 Read(NotNull(), MessageFramer::MessageHeader::header_size(), _));
289 socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
290 socket_cbs.Pop(serialized_message.size() -
291 MessageFramer::MessageHeader::header_size());
295 TEST_F(CastTransportTest, TestPartialReadAsync) {
297 CompletionQueue socket_cbs;
299 CastMessage message = CreateCastMessage();
300 std::string serialized_message;
301 EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
303 EXPECT_CALL(*delegate_, Start());
305 // Read bytes [0, 3].
306 EXPECT_CALL(mock_socket_,
307 Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
308 .WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
309 EnqueueCallback<2>(&socket_cbs)))
310 .RetiresOnSaturation();
311 // Read bytes [4, n-1].
312 EXPECT_CALL(mock_socket_,
314 serialized_message.size() -
315 MessageFramer::MessageHeader::header_size(),
317 .WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
318 MessageFramer::MessageHeader::header_size(),
319 serialized_message.size() -
320 MessageFramer::MessageHeader::header_size() - 1)),
321 EnqueueCallback<2>(&socket_cbs)))
322 .RetiresOnSaturation();
324 EXPECT_CALL(mock_socket_, Read(NotNull(), 1, _))
325 .WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
326 serialized_message.size() - 1, 1)),
327 EnqueueCallback<2>(&socket_cbs)))
328 .RetiresOnSaturation();
329 EXPECT_CALL(*delegate_, OnMessage(EqualsProto(message)));
331 socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
332 socket_cbs.Pop(serialized_message.size() -
333 MessageFramer::MessageHeader::header_size() - 1);
334 EXPECT_CALL(mock_socket_,
335 Read(NotNull(), MessageFramer::MessageHeader::header_size(), _));
339 TEST_F(CastTransportTest, TestReadErrorInHeaderAsync) {
340 CompletionQueue socket_cbs;
342 CastMessage message = CreateCastMessage();
343 std::string serialized_message;
344 EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
346 EXPECT_CALL(*delegate_, Start());
348 // Read bytes [0, 3].
349 EXPECT_CALL(mock_socket_,
350 Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
351 .WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
352 EnqueueCallback<2>(&socket_cbs)))
353 .RetiresOnSaturation();
355 EXPECT_CALL(*delegate_, OnError(ChannelError::CAST_SOCKET_ERROR));
357 // Header read failure.
358 socket_cbs.Pop(net::ERR_CONNECTION_RESET);
359 EXPECT_EQ(ChannelEvent::SOCKET_READ,
360 logger_->GetLastError(kChannelId).channel_event);
361 EXPECT_EQ(net::ERR_CONNECTION_RESET,
362 logger_->GetLastError(kChannelId).net_return_value);
365 TEST_F(CastTransportTest, TestReadErrorInBodyAsync) {
366 CompletionQueue socket_cbs;
368 CastMessage message = CreateCastMessage();
369 std::string serialized_message;
370 EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
372 EXPECT_CALL(*delegate_, Start());
374 // Read bytes [0, 3].
375 EXPECT_CALL(mock_socket_,
376 Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
377 .WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
378 EnqueueCallback<2>(&socket_cbs)))
379 .RetiresOnSaturation();
380 // Read bytes [4, n-1].
381 EXPECT_CALL(mock_socket_,
383 serialized_message.size() -
384 MessageFramer::MessageHeader::header_size(),
386 .WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
387 MessageFramer::MessageHeader::header_size(),
388 serialized_message.size() -
389 MessageFramer::MessageHeader::header_size() - 1)),
390 EnqueueCallback<2>(&socket_cbs)))
391 .RetiresOnSaturation();
392 EXPECT_CALL(*delegate_, OnError(ChannelError::CAST_SOCKET_ERROR));
394 // Header read is OK.
395 socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
397 socket_cbs.Pop(net::ERR_CONNECTION_RESET);
398 EXPECT_EQ(ChannelEvent::SOCKET_READ,
399 logger_->GetLastError(kChannelId).channel_event);
400 EXPECT_EQ(net::ERR_CONNECTION_RESET,
401 logger_->GetLastError(kChannelId).net_return_value);
404 TEST_F(CastTransportTest, TestReadCorruptedMessageAsync) {
405 CompletionQueue socket_cbs;
407 CastMessage message = CreateCastMessage();
408 std::string serialized_message;
409 EXPECT_TRUE(MessageFramer::Serialize(message, &serialized_message));
411 // Corrupt the serialized message body(set it to X's).
412 for (size_t i = MessageFramer::MessageHeader::header_size();
413 i < serialized_message.size(); ++i) {
414 serialized_message[i] = 'x';
417 EXPECT_CALL(*delegate_, Start());
418 // Read bytes [0, 3].
419 EXPECT_CALL(mock_socket_,
420 Read(NotNull(), MessageFramer::MessageHeader::header_size(), _))
421 .WillOnce(DoAll(FillBufferFromString<0>(serialized_message),
422 EnqueueCallback<2>(&socket_cbs)))
423 .RetiresOnSaturation();
424 // Read bytes [4, n].
425 EXPECT_CALL(mock_socket_,
427 serialized_message.size() -
428 MessageFramer::MessageHeader::header_size(),
430 .WillOnce(DoAll(FillBufferFromString<0>(serialized_message.substr(
431 MessageFramer::MessageHeader::header_size(),
432 serialized_message.size() -
433 MessageFramer::MessageHeader::header_size() - 1)),
434 EnqueueCallback<2>(&socket_cbs)))
435 .RetiresOnSaturation();
437 EXPECT_CALL(*delegate_, OnError(ChannelError::INVALID_MESSAGE));
439 socket_cbs.Pop(MessageFramer::MessageHeader::header_size());
440 socket_cbs.Pop(serialized_message.size() -
441 MessageFramer::MessageHeader::header_size());
444 } // namespace cast_channel