1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "components/cast_channel/cast_socket.h"
12 #include "base/callback_helpers.h"
13 #include "base/files/file_util.h"
14 #include "base/location.h"
15 #include "base/macros.h"
16 #include "base/memory/ptr_util.h"
17 #include "base/memory/weak_ptr.h"
18 #include "base/path_service.h"
19 #include "base/run_loop.h"
20 #include "base/single_thread_task_runner.h"
21 #include "base/strings/string_number_conversions.h"
22 #include "base/sys_byteorder.h"
23 #include "base/test/bind_test_util.h"
24 #include "base/threading/thread_task_runner_handle.h"
25 #include "base/timer/mock_timer.h"
26 #include "build/build_config.h"
27 #include "components/cast_channel/cast_auth_util.h"
28 #include "components/cast_channel/cast_framer.h"
29 #include "components/cast_channel/cast_message_util.h"
30 #include "components/cast_channel/cast_test_util.h"
31 #include "components/cast_channel/cast_transport.h"
32 #include "components/cast_channel/logger.h"
33 #include "components/cast_channel/proto/cast_channel.pb.h"
34 #include "content/public/test/test_browser_thread_bundle.h"
35 #include "crypto/rsa_private_key.h"
36 #include "net/base/address_list.h"
37 #include "net/base/net_errors.h"
38 #include "net/cert/pem_tokenizer.h"
39 #include "net/socket/client_socket_handle.h"
40 #include "net/socket/socket_test_util.h"
41 #include "net/socket/ssl_client_socket.h"
42 #include "net/socket/ssl_server_socket.h"
43 #include "net/socket/tcp_client_socket.h"
44 #include "net/socket/tcp_server_socket.h"
45 #include "net/ssl/ssl_info.h"
46 #include "net/ssl/ssl_server_config.h"
47 #include "net/test/cert_test_util.h"
48 #include "net/test/test_data_directory.h"
49 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
50 #include "net/url_request/url_request_test_util.h"
51 #include "services/network/network_context.h"
52 #include "testing/gmock/include/gmock/gmock.h"
53 #include "testing/gtest/include/gtest/gtest.h"
55 const int64_t kDistantTimeoutMillis = 100000; // 100 seconds (never hit).
58 using ::testing::DoAll;
59 using ::testing::Invoke;
60 using ::testing::InvokeArgument;
61 using ::testing::NotNull;
62 using ::testing::Return;
63 using ::testing::SaveArg;
66 namespace cast_channel {
68 const char kAuthNamespace[] = "urn:x-cast:com.google.cast.tp.deviceauth";
70 // Returns an auth challenge message inline.
71 CastMessage CreateAuthChallenge() {
73 CreateAuthChallengeMessage(&output, AuthContext::Create());
77 // Returns an auth challenge response message inline.
78 CastMessage CreateAuthReply() {
80 output.set_protocol_version(CastMessage::CASTV2_1_0);
81 output.set_source_id("sender-0");
82 output.set_destination_id("receiver-0");
83 output.set_payload_type(CastMessage::BINARY);
84 output.set_payload_binary("abcd");
85 output.set_namespace_(kAuthNamespace);
89 CastMessage CreateTestMessage() {
90 CastMessage test_message;
91 test_message.set_protocol_version(CastMessage::CASTV2_1_0);
92 test_message.set_namespace_("ns");
93 test_message.set_source_id("source");
94 test_message.set_destination_id("dest");
95 test_message.set_payload_type(CastMessage::STRING);
96 test_message.set_payload_utf8("payload");
100 base::FilePath GetTestCertsDirectory() {
102 base::PathService::Get(base::DIR_SOURCE_ROOT, &path);
103 path = path.Append(FILE_PATH_LITERAL("components"));
104 path = path.Append(FILE_PATH_LITERAL("test"));
105 path = path.Append(FILE_PATH_LITERAL("data"));
106 path = path.Append(FILE_PATH_LITERAL("cast_channel"));
110 class MockTCPSocket : public net::MockTCPClientSocket {
112 MockTCPSocket(bool do_nothing, net::SocketDataProvider* socket_provider)
113 : net::MockTCPClientSocket(net::AddressList(), nullptr, socket_provider) {
114 do_nothing_ = do_nothing;
115 set_enable_read_if_ready(true);
118 int Connect(net::CompletionOnceCallback callback) override {
120 // Stall the I/O event loop.
121 return net::ERR_IO_PENDING;
123 return net::MockTCPClientSocket::Connect(std::move(callback));
129 DISALLOW_COPY_AND_ASSIGN(MockTCPSocket);
132 class CompleteHandler {
135 MOCK_METHOD1(OnCloseComplete, void(int result));
136 MOCK_METHOD1(OnConnectComplete, void(CastSocket* socket));
137 MOCK_METHOD1(OnWriteComplete, void(int result));
138 MOCK_METHOD1(OnReadComplete, void(int result));
141 DISALLOW_COPY_AND_ASSIGN(CompleteHandler);
144 class TestCastSocketBase : public CastSocketImpl {
146 TestCastSocketBase(network::mojom::NetworkContext* network_context,
147 const CastSocketOpenParams& open_params,
149 : CastSocketImpl(base::BindRepeating(
150 [](network::mojom::NetworkContext* network_context) {
151 return network_context;
156 AuthContext::Create()),
157 verify_challenge_result_(true),
158 verify_challenge_disallow_(false),
159 mock_timer_(new base::MockOneShotTimer()) {
160 SetPeerCertForTesting(
161 net::ImportCertFromFile(GetTestCertsDirectory(), "self_signed.pem"));
163 ~TestCastSocketBase() override {}
165 void SetVerifyChallengeResult(bool value) {
166 verify_challenge_result_ = value;
169 void TriggerTimeout() { mock_timer_->Fire(); }
171 bool TestVerifyChannelPolicyNone() {
172 AuthResult authResult;
173 return VerifyChannelPolicy(authResult);
176 void DisallowVerifyChallengeResult() { verify_challenge_disallow_ = true; }
179 bool VerifyChallengeReply() override {
180 EXPECT_FALSE(verify_challenge_disallow_);
181 return verify_challenge_result_;
184 base::OneShotTimer* GetTimer() override { return mock_timer_.get(); }
186 // Simulated result of verifying challenge reply.
187 bool verify_challenge_result_;
188 bool verify_challenge_disallow_;
189 std::unique_ptr<base::MockOneShotTimer> mock_timer_;
192 DISALLOW_COPY_AND_ASSIGN(TestCastSocketBase);
195 class MockTestCastSocket : public TestCastSocketBase {
197 static std::unique_ptr<MockTestCastSocket> CreateSecure(
198 network::mojom::NetworkContext* network_context,
199 const CastSocketOpenParams& open_params,
201 return std::unique_ptr<MockTestCastSocket>(
202 new MockTestCastSocket(network_context, open_params, logger));
205 using TestCastSocketBase::TestCastSocketBase;
207 MockTestCastSocket(network::mojom::NetworkContext* network_context,
208 const CastSocketOpenParams& open_params,
210 : TestCastSocketBase(network_context, open_params, logger) {}
212 ~MockTestCastSocket() override {}
214 void SetupMockTransport() {
215 mock_transport_ = new MockCastTransport;
216 SetTransportForTesting(base::WrapUnique(mock_transport_));
219 bool TestVerifyChannelPolicyAudioOnly() {
220 AuthResult authResult;
221 authResult.channel_policies |= AuthResult::POLICY_AUDIO_ONLY;
222 return VerifyChannelPolicy(authResult);
225 MockCastTransport* GetMockTransport() {
226 CHECK(mock_transport_);
227 return mock_transport_;
231 MockCastTransport* mock_transport_ = nullptr;
233 DISALLOW_COPY_AND_ASSIGN(MockTestCastSocket);
236 class TestSocketFactory : public net::ClientSocketFactory {
238 explicit TestSocketFactory(net::IPEndPoint ip) : ip_(ip) {}
239 ~TestSocketFactory() override = default;
241 // Socket connection helpers.
242 void SetupTcpConnect(net::IoMode mode, int result) {
243 tcp_connect_data_.reset(new net::MockConnect(mode, result, ip_));
245 void SetupSslConnect(net::IoMode mode, int result) {
246 ssl_connect_data_.reset(new net::MockConnect(mode, result, ip_));
249 // Socket I/O helpers.
250 void AddWriteResult(const net::MockWrite& write) { writes_.push_back(write); }
251 void AddWriteResult(net::IoMode mode, int result) {
252 AddWriteResult(net::MockWrite(mode, result));
254 void AddWriteResultForData(net::IoMode mode, const std::string& msg) {
255 AddWriteResult(mode, msg.size());
257 void AddReadResult(const net::MockRead& read) { reads_.push_back(read); }
258 void AddReadResult(net::IoMode mode, int result) {
259 AddReadResult(net::MockRead(mode, result));
261 void AddReadResultForData(net::IoMode mode, const std::string& data) {
262 AddReadResult(net::MockRead(mode, data.c_str(), data.size()));
265 // Helpers for modifying other connection-related behaviors.
266 void SetupTcpConnectUnresponsive() { tcp_unresponsive_ = true; }
269 std::unique_ptr<net::TransportClientSocket> tcp_client_socket) {
270 tcp_client_socket_ = std::move(tcp_client_socket);
273 void SetTLSSocketCreatedClosure(base::OnceClosure closure) {
274 tls_socket_created_ = std::move(closure);
278 if (socket_data_provider_)
279 socket_data_provider_->Pause();
281 socket_data_provider_paused_ = true;
284 void Resume() { socket_data_provider_->Resume(); }
287 std::unique_ptr<net::DatagramClientSocket> CreateDatagramClientSocket(
288 net::DatagramSocket::BindType,
290 const net::NetLogSource&) override {
294 std::unique_ptr<net::TransportClientSocket> CreateTransportClientSocket(
295 const net::AddressList&,
296 std::unique_ptr<net::SocketPerformanceWatcher>,
298 const net::NetLogSource&) override {
299 if (tcp_client_socket_)
300 return std::move(tcp_client_socket_);
302 if (tcp_unresponsive_) {
303 socket_data_provider_ = std::make_unique<net::StaticSocketDataProvider>();
304 return std::unique_ptr<net::TransportClientSocket>(
305 new MockTCPSocket(true, socket_data_provider_.get()));
307 socket_data_provider_ =
308 std::make_unique<net::StaticSocketDataProvider>(reads_, writes_);
309 socket_data_provider_->set_connect_data(*tcp_connect_data_);
310 if (socket_data_provider_paused_)
311 socket_data_provider_->Pause();
312 return std::unique_ptr<net::TransportClientSocket>(
313 new MockTCPSocket(false, socket_data_provider_.get()));
316 std::unique_ptr<net::SSLClientSocket> CreateSSLClientSocket(
317 std::unique_ptr<net::ClientSocketHandle> client_handle,
318 const net::HostPortPair& host_and_port,
319 const net::SSLConfig& ssl_config,
320 const net::SSLClientSocketContext& context) override {
321 if (!ssl_connect_data_) {
322 // Test isn't overriding SSL socket creation.
323 return net::ClientSocketFactory::GetDefaultFactory()
324 ->CreateSSLClientSocket(std::move(client_handle), host_and_port,
325 ssl_config, context);
327 ssl_socket_data_provider_ = std::make_unique<net::SSLSocketDataProvider>(
328 ssl_connect_data_->mode, ssl_connect_data_->result);
329 // auto client_handle = std::make_unique<net::ClientSocketHandle>();
331 if (tls_socket_created_)
332 std::move(tls_socket_created_).Run();
334 // client_handle->SetSocket(std::move(tcp_socket));
335 return std::make_unique<net::MockSSLClientSocket>(
336 std::move(client_handle), net::HostPortPair(), net::SSLConfig(),
337 ssl_socket_data_provider_.get());
339 std::unique_ptr<net::ProxyClientSocket> CreateProxyClientSocket(
340 std::unique_ptr<net::ClientSocketHandle> transport_socket,
341 const std::string& user_agent,
342 const net::HostPortPair& endpoint,
343 const net::ProxyServer& proxy_server,
344 net::HttpAuthController* http_auth_controller,
347 net::NextProto negotiated_protocol,
348 net::ProxyDelegate* proxy_delegate,
350 const net::NetworkTrafficAnnotationTag& traffic_annotation) override {
354 void ClearSSLSessionCache() override { NOTIMPLEMENTED(); }
357 // Simulated connect data
358 std::unique_ptr<net::MockConnect> tcp_connect_data_;
359 std::unique_ptr<net::MockConnect> ssl_connect_data_;
360 // Simulated read / write data
361 std::vector<net::MockWrite> writes_;
362 std::vector<net::MockRead> reads_;
363 std::unique_ptr<net::StaticSocketDataProvider> socket_data_provider_;
364 std::unique_ptr<net::SSLSocketDataProvider> ssl_socket_data_provider_;
365 bool socket_data_provider_paused_ = false;
366 // If true, makes TCP connection process stall. For timeout testing.
367 bool tcp_unresponsive_ = false;
368 std::unique_ptr<net::TransportClientSocket> tcp_client_socket_;
369 base::OnceClosure tls_socket_created_;
371 DISALLOW_COPY_AND_ASSIGN(TestSocketFactory);
374 class CastSocketTestBase : public testing::Test {
377 : thread_bundle_(content::TestBrowserThreadBundle::IO_MAINLOOP),
378 url_request_context_(true),
379 logger_(new Logger()),
380 observer_(new MockCastSocketObserver()),
382 CreateIPEndPointForTest(),
383 base::TimeDelta::FromMilliseconds(kDistantTimeoutMillis)),
384 client_socket_factory_(socket_open_params_.ip_endpoint) {}
385 ~CastSocketTestBase() override {}
387 void SetUp() override {
388 EXPECT_CALL(*observer_, OnMessage(_, _)).Times(0);
390 url_request_context_.set_client_socket_factory(&client_socket_factory_);
391 url_request_context_.Init();
392 network_context_ = std::make_unique<network::NetworkContext>(
393 nullptr, mojo::MakeRequest(&network_context_ptr_),
394 &url_request_context_);
397 // Runs all pending tasks in the message loop.
398 void RunPendingTasks() {
399 base::RunLoop run_loop;
400 run_loop.RunUntilIdle();
403 TestSocketFactory* client_socket_factory() { return &client_socket_factory_; }
405 content::TestBrowserThreadBundle thread_bundle_;
406 net::TestURLRequestContext url_request_context_;
407 std::unique_ptr<network::NetworkContext> network_context_;
408 network::mojom::NetworkContextPtr network_context_ptr_;
410 CompleteHandler handler_;
411 std::unique_ptr<MockCastSocketObserver> observer_;
412 CastSocketOpenParams socket_open_params_;
413 TestSocketFactory client_socket_factory_;
416 DISALLOW_COPY_AND_ASSIGN(CastSocketTestBase);
419 class MockCastSocketTest : public CastSocketTestBase {
421 MockCastSocketTest() {}
423 void TearDown() override {
425 EXPECT_CALL(handler_, OnCloseComplete(net::OK));
426 socket_->Close(base::Bind(&CompleteHandler::OnCloseComplete,
427 base::Unretained(&handler_)));
431 void CreateCastSocketSecure() {
432 socket_ = MockTestCastSocket::CreateSecure(network_context_.get(),
433 socket_open_params_, logger_);
436 void HandleAuthHandshake() {
437 socket_->SetupMockTransport();
438 CastMessage challenge_proto = CreateAuthChallenge();
439 EXPECT_CALL(*socket_->GetMockTransport(),
440 SendMessage(EqualsProto(challenge_proto), _))
441 .WillOnce(PostCompletionCallbackTask<1>(net::OK));
442 EXPECT_CALL(*socket_->GetMockTransport(), Start());
443 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
444 socket_->AddObserver(observer_.get());
445 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
446 base::Unretained(&handler_)));
448 socket_->GetMockTransport()->current_delegate()->OnMessage(
453 std::unique_ptr<MockTestCastSocket> socket_;
456 DISALLOW_COPY_AND_ASSIGN(MockCastSocketTest);
459 class SslCastSocketTest : public CastSocketTestBase {
461 SslCastSocketTest() {}
463 void TearDown() override {
465 EXPECT_CALL(handler_, OnCloseComplete(net::OK));
466 socket_->Close(base::Bind(&CompleteHandler::OnCloseComplete,
467 base::Unretained(&handler_)));
471 void CreateSockets() {
472 socket_ = std::make_unique<TestCastSocketBase>(
473 network_context_.get(), socket_open_params_, logger_);
476 net::ImportCertFromFile(GetTestCertsDirectory(), "self_signed.pem");
477 ASSERT_TRUE(server_cert_);
478 server_private_key_ = ReadTestKeyFromPEM("self_signed.pem");
479 ASSERT_TRUE(server_private_key_);
480 server_context_ = CreateSSLServerContext(
481 server_cert_.get(), *server_private_key_, server_ssl_config_);
483 tcp_server_socket_.reset(
484 new net::TCPServerSocket(nullptr, net::NetLogSource()));
486 tcp_server_socket_->ListenWithAddressAndPort("127.0.0.1", 0, 1));
487 net::IPEndPoint server_address;
488 ASSERT_EQ(net::OK, tcp_server_socket_->GetLocalAddress(&server_address));
489 tcp_client_socket_.reset(
490 new net::TCPClientSocket(net::AddressList(server_address), nullptr,
491 nullptr, net::NetLogSource()));
493 std::unique_ptr<net::StreamSocket> accepted_socket;
494 accept_result_ = tcp_server_socket_->Accept(
495 &accepted_socket, base::Bind(&SslCastSocketTest::TcpAcceptCallback,
496 base::Unretained(this)));
497 connect_result_ = tcp_client_socket_->Connect(base::BindOnce(
498 &SslCastSocketTest::TcpConnectCallback, base::Unretained(this)));
499 while (accept_result_ == net::ERR_IO_PENDING ||
500 connect_result_ == net::ERR_IO_PENDING) {
503 ASSERT_EQ(net::OK, accept_result_);
504 ASSERT_EQ(net::OK, connect_result_);
505 ASSERT_TRUE(accepted_socket);
506 ASSERT_TRUE(tcp_client_socket_->IsConnected());
509 server_context_->CreateSSLServerSocket(std::move(accepted_socket));
510 ASSERT_TRUE(server_socket_);
512 client_socket_factory()->SetTcpSocket(std::move(tcp_client_socket_));
515 void ConnectSockets() {
516 socket_->AddObserver(observer_.get());
517 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
518 base::Unretained(&handler_)));
520 net::TestCompletionCallback handshake_callback;
521 int server_ret = handshake_callback.GetResult(
522 server_socket_->Handshake(handshake_callback.callback()));
524 ASSERT_EQ(net::OK, server_ret);
527 void TcpAcceptCallback(int result) { accept_result_ = result; }
529 void TcpConnectCallback(int result) { connect_result_ = result; }
531 std::unique_ptr<crypto::RSAPrivateKey> ReadTestKeyFromPEM(
532 const base::StringPiece& name) {
533 base::FilePath key_path = GetTestCertsDirectory().AppendASCII(name);
534 std::vector<std::string> headers({"PRIVATE KEY"});
535 std::string pem_data;
536 if (!base::ReadFileToString(key_path, &pem_data)) {
539 net::PEMTokenizer pem_tokenizer(pem_data, headers);
540 if (!pem_tokenizer.GetNext()) {
543 std::vector<uint8_t> key_vector(pem_tokenizer.data().begin(),
544 pem_tokenizer.data().end());
545 std::unique_ptr<crypto::RSAPrivateKey> key(
546 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
550 int ReadExactLength(net::IOBuffer* buffer,
552 net::Socket* socket) {
553 scoped_refptr<net::DrainableIOBuffer> draining_buffer =
554 base::MakeRefCounted<net::DrainableIOBuffer>(buffer, buffer_length);
555 while (draining_buffer->BytesRemaining() > 0) {
556 net::TestCompletionCallback read_callback;
557 int read_result = read_callback.GetResult(server_socket_->Read(
558 draining_buffer.get(), draining_buffer->BytesRemaining(),
559 read_callback.callback()));
560 EXPECT_GT(read_result, 0);
561 draining_buffer->DidConsume(read_result);
563 return buffer_length;
566 int WriteExactLength(net::IOBuffer* buffer,
568 net::Socket* socket) {
569 scoped_refptr<net::DrainableIOBuffer> draining_buffer =
570 base::MakeRefCounted<net::DrainableIOBuffer>(buffer, buffer_length);
571 while (draining_buffer->BytesRemaining() > 0) {
572 net::TestCompletionCallback write_callback;
573 int write_result = write_callback.GetResult(server_socket_->Write(
574 draining_buffer.get(), draining_buffer->BytesRemaining(),
575 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS));
576 EXPECT_GT(write_result, 0);
577 draining_buffer->DidConsume(write_result);
579 return buffer_length;
582 // Result values used for TCP socket setup. These should contain values from
587 // Underlying TCP sockets for |socket_| to communicate with |server_socket_|
588 // when testing with the real SSL implementation.
589 std::unique_ptr<net::TransportClientSocket> tcp_client_socket_;
590 std::unique_ptr<net::TCPServerSocket> tcp_server_socket_;
592 std::unique_ptr<TestCastSocketBase> socket_;
594 // |server_socket_| is used for the *RealSSL tests in order to test the
595 // CastSocket over a real SSL socket. The other members below are used to
596 // initialize |server_socket_|.
597 std::unique_ptr<net::SSLServerSocket> server_socket_;
598 std::unique_ptr<net::SSLServerContext> server_context_;
599 std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
600 scoped_refptr<net::X509Certificate> server_cert_;
601 net::SSLServerConfig server_ssl_config_;
604 DISALLOW_COPY_AND_ASSIGN(SslCastSocketTest);
609 // Tests that the following connection flow works:
610 // - TCP connection succeeds (async)
611 // - SSL connection succeeds (async)
612 // - Cert is extracted successfully
613 // - Challenge request is sent (async)
614 // - Challenge response is received (async)
615 // - Credentials are verified successfuly
616 TEST_F(MockCastSocketTest, TestConnectFullSecureFlowAsync) {
617 CreateCastSocketSecure();
618 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
619 client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
621 HandleAuthHandshake();
623 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
624 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
627 // Tests that the following connection flow works:
628 // - TCP connection succeeds (sync)
629 // - SSL connection succeeds (sync)
630 // - Cert is extracted successfully
631 // - Challenge request is sent (sync)
632 // - Challenge response is received (sync)
633 // - Credentials are verified successfuly
634 TEST_F(MockCastSocketTest, TestConnectFullSecureFlowSync) {
635 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
636 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
638 CreateCastSocketSecure();
639 HandleAuthHandshake();
641 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
642 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
645 // Test that an AuthMessage with a mangled namespace triggers cancelation
646 // of the connection event loop.
647 TEST_F(MockCastSocketTest, TestConnectAuthMessageCorrupted) {
648 CreateCastSocketSecure();
649 socket_->SetupMockTransport();
651 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
652 client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
654 CastMessage challenge_proto = CreateAuthChallenge();
655 EXPECT_CALL(*socket_->GetMockTransport(),
656 SendMessage(EqualsProto(challenge_proto), _))
657 .WillOnce(PostCompletionCallbackTask<1>(net::OK));
658 EXPECT_CALL(*socket_->GetMockTransport(), Start());
659 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
660 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
661 base::Unretained(&handler_)));
663 CastMessage mangled_auth_reply = CreateAuthReply();
664 mangled_auth_reply.set_namespace_("BOGUS_NAMESPACE");
666 socket_->GetMockTransport()->current_delegate()->OnMessage(
670 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
671 EXPECT_EQ(ChannelError::TRANSPORT_ERROR, socket_->error_state());
673 // Verifies that the CastSocket's resources were torn down during channel
674 // close. (see http://crbug.com/504078)
675 EXPECT_EQ(nullptr, socket_->transport());
678 // Test connection error - TCP connect fails (async)
679 TEST_F(MockCastSocketTest, TestConnectTcpConnectErrorAsync) {
680 CreateCastSocketSecure();
682 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::ERR_FAILED);
684 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
685 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
686 base::Unretained(&handler_)));
689 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
690 EXPECT_EQ(ChannelError::CONNECT_ERROR, socket_->error_state());
693 // Test connection error - TCP connect fails (sync)
694 TEST_F(MockCastSocketTest, TestConnectTcpConnectErrorSync) {
695 CreateCastSocketSecure();
697 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::ERR_FAILED);
699 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
700 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
701 base::Unretained(&handler_)));
704 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
705 EXPECT_EQ(ChannelError::CONNECT_ERROR, socket_->error_state());
708 // Test connection error - timeout
709 TEST_F(MockCastSocketTest, TestConnectTcpTimeoutError) {
710 CreateCastSocketSecure();
711 client_socket_factory()->SetupTcpConnectUnresponsive();
712 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
713 EXPECT_CALL(*observer_, OnError(_, ChannelError::CONNECT_TIMEOUT));
714 socket_->AddObserver(observer_.get());
715 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
716 base::Unretained(&handler_)));
719 EXPECT_EQ(ReadyState::CONNECTING, socket_->ready_state());
720 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
721 socket_->TriggerTimeout();
724 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
725 EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
728 // Test connection error - TCP socket returns timeout
729 TEST_F(MockCastSocketTest, TestConnectTcpSocketTimeoutError) {
730 CreateCastSocketSecure();
731 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS,
732 net::ERR_CONNECTION_TIMED_OUT);
733 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
734 EXPECT_CALL(*observer_, OnError(_, ChannelError::CONNECT_TIMEOUT));
735 socket_->AddObserver(observer_.get());
736 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
737 base::Unretained(&handler_)));
740 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
741 EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
742 EXPECT_EQ(net::ERR_CONNECTION_TIMED_OUT,
743 logger_->GetLastError(socket_->id()).net_return_value);
746 // Test connection error - SSL connect fails (async)
747 TEST_F(MockCastSocketTest, TestConnectSslConnectErrorAsync) {
748 CreateCastSocketSecure();
750 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
751 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::ERR_FAILED);
753 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
754 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
755 base::Unretained(&handler_)));
758 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
759 EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
762 // Test connection error - SSL connect fails (sync)
763 TEST_F(MockCastSocketTest, TestConnectSslConnectErrorSync) {
764 CreateCastSocketSecure();
766 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
767 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::ERR_FAILED);
769 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
770 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
771 base::Unretained(&handler_)));
774 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
775 EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
776 EXPECT_EQ(net::ERR_FAILED,
777 logger_->GetLastError(socket_->id()).net_return_value);
780 // Test connection error - SSL connect times out (sync)
781 TEST_F(MockCastSocketTest, TestConnectSslConnectTimeoutSync) {
782 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
783 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS,
784 net::ERR_CONNECTION_TIMED_OUT);
786 CreateCastSocketSecure();
788 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
789 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
790 base::Unretained(&handler_)));
793 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
794 EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
795 EXPECT_EQ(net::ERR_CONNECTION_TIMED_OUT,
796 logger_->GetLastError(socket_->id()).net_return_value);
799 // Test connection error - SSL connect times out (async)
800 TEST_F(MockCastSocketTest, TestConnectSslConnectTimeoutAsync) {
801 CreateCastSocketSecure();
803 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
804 client_socket_factory()->SetupSslConnect(net::ASYNC,
805 net::ERR_CONNECTION_TIMED_OUT);
807 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
808 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
809 base::Unretained(&handler_)));
812 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
813 EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
816 // Test connection error - challenge send fails
817 TEST_F(MockCastSocketTest, TestConnectChallengeSendError) {
818 CreateCastSocketSecure();
819 socket_->SetupMockTransport();
821 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
822 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
823 EXPECT_CALL(*socket_->GetMockTransport(),
824 SendMessage(EqualsProto(CreateAuthChallenge()), _))
825 .WillOnce(PostCompletionCallbackTask<1>(net::ERR_CONNECTION_RESET));
827 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
828 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
829 base::Unretained(&handler_)));
832 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
833 EXPECT_EQ(ChannelError::CAST_SOCKET_ERROR, socket_->error_state());
836 // Test connection error - connection is destroyed after the challenge is
837 // sent, with the async result still lurking in the task queue.
838 TEST_F(MockCastSocketTest, TestConnectDestroyedAfterChallengeSent) {
839 CreateCastSocketSecure();
840 socket_->SetupMockTransport();
841 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
842 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
843 EXPECT_CALL(*socket_->GetMockTransport(),
844 SendMessage(EqualsProto(CreateAuthChallenge()), _))
845 .WillOnce(PostCompletionCallbackTask<1>(net::ERR_CONNECTION_RESET));
846 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
847 base::Unretained(&handler_)));
853 // Test connection error - challenge reply receive fails
854 TEST_F(MockCastSocketTest, TestConnectChallengeReplyReceiveError) {
855 CreateCastSocketSecure();
856 socket_->SetupMockTransport();
858 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
859 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
860 EXPECT_CALL(*socket_->GetMockTransport(),
861 SendMessage(EqualsProto(CreateAuthChallenge()), _))
862 .WillOnce(PostCompletionCallbackTask<1>(net::OK));
863 client_socket_factory()->AddReadResult(net::SYNCHRONOUS, net::ERR_FAILED);
864 EXPECT_CALL(*observer_, OnError(_, ChannelError::CAST_SOCKET_ERROR));
865 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
866 EXPECT_CALL(*socket_->GetMockTransport(), Start());
867 socket_->AddObserver(observer_.get());
868 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
869 base::Unretained(&handler_)));
871 socket_->GetMockTransport()->current_delegate()->OnError(
872 ChannelError::CAST_SOCKET_ERROR);
875 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
876 EXPECT_EQ(ChannelError::CAST_SOCKET_ERROR, socket_->error_state());
879 TEST_F(MockCastSocketTest, TestConnectChallengeVerificationFails) {
880 CreateCastSocketSecure();
881 socket_->SetupMockTransport();
882 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
883 client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
884 socket_->SetVerifyChallengeResult(false);
886 EXPECT_CALL(*observer_, OnError(_, ChannelError::AUTHENTICATION_ERROR));
887 CastMessage challenge_proto = CreateAuthChallenge();
888 EXPECT_CALL(*socket_->GetMockTransport(),
889 SendMessage(EqualsProto(challenge_proto), _))
890 .WillOnce(PostCompletionCallbackTask<1>(net::OK));
891 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
892 EXPECT_CALL(*socket_->GetMockTransport(), Start());
893 socket_->AddObserver(observer_.get());
894 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
895 base::Unretained(&handler_)));
897 socket_->GetMockTransport()->current_delegate()->OnMessage(CreateAuthReply());
900 EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
901 EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
904 // Sends message data through an actual non-mocked CastTransport object,
905 // testing the two components in integration.
906 TEST_F(MockCastSocketTest, TestConnectEndToEndWithRealTransportAsync) {
907 CreateCastSocketSecure();
908 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
909 client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
911 // Set low-level auth challenge expectations.
912 CastMessage challenge = CreateAuthChallenge();
913 std::string challenge_str;
914 EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
915 client_socket_factory()->AddWriteResultForData(net::ASYNC, challenge_str);
917 // Set low-level auth reply expectations.
918 CastMessage reply = CreateAuthReply();
919 std::string reply_str;
920 EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
921 client_socket_factory()->AddReadResultForData(net::ASYNC, reply_str);
922 client_socket_factory()->AddReadResult(net::ASYNC, net::ERR_IO_PENDING);
923 // Make sure the data is ready by the TLS socket and not the TCP socket.
924 client_socket_factory()->Pause();
925 client_socket_factory()->SetTLSSocketCreatedClosure(
926 base::BindLambdaForTesting([&] { client_socket_factory()->Resume(); }));
928 CastMessage test_message = CreateTestMessage();
929 std::string test_message_str;
930 EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
931 client_socket_factory()->AddWriteResultForData(net::ASYNC, test_message_str);
933 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
934 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
935 base::Unretained(&handler_)));
937 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
938 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
940 // Send the test message through a real transport object.
941 EXPECT_CALL(handler_, OnWriteComplete(net::OK));
942 socket_->transport()->SendMessage(
943 test_message, base::Bind(&CompleteHandler::OnWriteComplete,
944 base::Unretained(&handler_)));
947 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
948 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
951 // Same as TestConnectEndToEndWithRealTransportAsync, except synchronous.
952 TEST_F(MockCastSocketTest, TestConnectEndToEndWithRealTransportSync) {
953 CreateCastSocketSecure();
954 client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
955 client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
957 // Set low-level auth challenge expectations.
958 CastMessage challenge = CreateAuthChallenge();
959 std::string challenge_str;
960 EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
961 client_socket_factory()->AddWriteResultForData(net::SYNCHRONOUS,
964 // Set low-level auth reply expectations.
965 CastMessage reply = CreateAuthReply();
966 std::string reply_str;
967 EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
968 client_socket_factory()->AddReadResultForData(net::SYNCHRONOUS, reply_str);
969 client_socket_factory()->AddReadResult(net::ASYNC, net::ERR_IO_PENDING);
970 // Make sure the data is ready by the TLS socket and not the TCP socket.
971 client_socket_factory()->Pause();
972 client_socket_factory()->SetTLSSocketCreatedClosure(
973 base::BindLambdaForTesting([&] { client_socket_factory()->Resume(); }));
975 CastMessage test_message = CreateTestMessage();
976 std::string test_message_str;
977 EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
978 client_socket_factory()->AddWriteResultForData(net::SYNCHRONOUS,
981 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
982 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
983 base::Unretained(&handler_)));
985 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
986 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
988 // Send the test message through a real transport object.
989 EXPECT_CALL(handler_, OnWriteComplete(net::OK));
990 socket_->transport()->SendMessage(
991 test_message, base::Bind(&CompleteHandler::OnWriteComplete,
992 base::Unretained(&handler_)));
995 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
996 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
999 TEST_F(MockCastSocketTest, TestObservers) {
1000 CreateCastSocketSecure();
1001 // Test AddObserever
1002 MockCastSocketObserver observer1;
1003 MockCastSocketObserver observer2;
1004 socket_->AddObserver(&observer1);
1005 socket_->AddObserver(&observer1);
1006 socket_->AddObserver(&observer2);
1007 socket_->AddObserver(&observer2);
1009 // Test notify observers
1010 EXPECT_CALL(observer1, OnError(_, cast_channel::ChannelError::CONNECT_ERROR));
1011 EXPECT_CALL(observer2, OnError(_, cast_channel::ChannelError::CONNECT_ERROR));
1012 CastSocketImpl::CastSocketMessageDelegate delegate(socket_.get());
1013 delegate.OnError(cast_channel::ChannelError::CONNECT_ERROR);
1016 TEST_F(MockCastSocketTest, TestOpenChannelConnectingSocket) {
1017 CreateCastSocketSecure();
1018 client_socket_factory()->SetupTcpConnectUnresponsive();
1019 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1020 base::Unretained(&handler_)));
1023 EXPECT_CALL(handler_, OnConnectComplete(socket_.get())).Times(2);
1024 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1025 base::Unretained(&handler_)));
1026 socket_->TriggerTimeout();
1030 TEST_F(MockCastSocketTest, TestOpenChannelConnectedSocket) {
1031 CreateCastSocketSecure();
1032 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
1033 client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
1034 HandleAuthHandshake();
1036 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1037 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1038 base::Unretained(&handler_)));
1041 TEST_F(MockCastSocketTest, TestOpenChannelClosedSocket) {
1042 CreateCastSocketSecure();
1043 client_socket_factory()->SetupTcpConnect(net::ASYNC, net::ERR_FAILED);
1045 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1046 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1047 base::Unretained(&handler_)));
1050 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1051 socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
1052 base::Unretained(&handler_)));
1055 // https://crbug.com/874491, flaky on Win and Mac
1056 #if defined(OS_WIN) || defined(OS_MACOSX)
1057 #define MAYBE_TestConnectEndToEndWithRealSSL \
1058 DISABLED_TestConnectEndToEndWithRealSSL
1060 #define MAYBE_TestConnectEndToEndWithRealSSL TestConnectEndToEndWithRealSSL
1062 // Tests connecting through an actual non-mocked CastTransport object and
1063 // non-mocked SSLClientSocket, testing the components in integration.
1064 TEST_F(SslCastSocketTest, MAYBE_TestConnectEndToEndWithRealSSL) {
1068 // Set low-level auth challenge expectations.
1069 CastMessage challenge = CreateAuthChallenge();
1070 std::string challenge_str;
1071 EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
1073 int challenge_buffer_length = challenge_str.size();
1074 scoped_refptr<net::IOBuffer> challenge_buffer =
1075 base::MakeRefCounted<net::IOBuffer>(challenge_buffer_length);
1076 int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length,
1077 server_socket_.get());
1079 EXPECT_EQ(challenge_buffer_length, read);
1080 EXPECT_EQ(challenge_str,
1081 std::string(challenge_buffer->data(), challenge_buffer_length));
1083 // Set low-level auth reply expectations.
1084 CastMessage reply = CreateAuthReply();
1085 std::string reply_str;
1086 EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
1088 scoped_refptr<net::StringIOBuffer> reply_buffer =
1089 base::MakeRefCounted<net::StringIOBuffer>(reply_str);
1090 int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(),
1091 server_socket_.get());
1093 EXPECT_EQ(reply_buffer->size(), written);
1094 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1097 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1098 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1101 // Sends message data through an actual non-mocked CastTransport object and
1102 // non-mocked SSLClientSocket, testing the components in integration.
1103 TEST_F(SslCastSocketTest, DISABLED_TestMessageEndToEndWithRealSSL) {
1107 // Set low-level auth challenge expectations.
1108 CastMessage challenge = CreateAuthChallenge();
1109 std::string challenge_str;
1110 EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
1112 int challenge_buffer_length = challenge_str.size();
1113 scoped_refptr<net::IOBuffer> challenge_buffer =
1114 base::MakeRefCounted<net::IOBuffer>(challenge_buffer_length);
1116 int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length,
1117 server_socket_.get());
1119 EXPECT_EQ(challenge_buffer_length, read);
1120 EXPECT_EQ(challenge_str,
1121 std::string(challenge_buffer->data(), challenge_buffer_length));
1123 // Set low-level auth reply expectations.
1124 CastMessage reply = CreateAuthReply();
1125 std::string reply_str;
1126 EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
1128 scoped_refptr<net::StringIOBuffer> reply_buffer =
1129 base::MakeRefCounted<net::StringIOBuffer>(reply_str);
1130 int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(),
1131 server_socket_.get());
1133 EXPECT_EQ(reply_buffer->size(), written);
1134 EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
1137 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1138 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1140 // Send a test message through the ssl socket.
1141 CastMessage test_message = CreateTestMessage();
1142 std::string test_message_str;
1143 EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
1145 int test_message_length = test_message_str.size();
1146 scoped_refptr<net::IOBuffer> test_message_buffer =
1147 base::MakeRefCounted<net::IOBuffer>(test_message_length);
1149 EXPECT_CALL(handler_, OnWriteComplete(net::OK));
1150 socket_->transport()->SendMessage(
1151 test_message, base::Bind(&CompleteHandler::OnWriteComplete,
1152 base::Unretained(&handler_)));
1155 read = ReadExactLength(test_message_buffer.get(), test_message_length,
1156 server_socket_.get());
1158 EXPECT_EQ(test_message_length, read);
1159 EXPECT_EQ(test_message_str,
1160 std::string(test_message_buffer->data(), test_message_length));
1162 EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
1163 EXPECT_EQ(ChannelError::NONE, socket_->error_state());
1166 } // namespace cast_channel