1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "extensions/browser/api/cast_channel/cast_socket.h"
10 #include "base/bind.h"
11 #include "base/callback_helpers.h"
12 #include "base/format_macros.h"
13 #include "base/lazy_instance.h"
14 #include "base/numerics/safe_conversions.h"
15 #include "base/strings/string_number_conversions.h"
16 #include "base/strings/stringprintf.h"
17 #include "base/sys_byteorder.h"
18 #include "extensions/browser/api/cast_channel/cast_auth_util.h"
19 #include "extensions/browser/api/cast_channel/cast_channel.pb.h"
20 #include "extensions/browser/api/cast_channel/cast_message_util.h"
21 #include "extensions/browser/api/cast_channel/logger.h"
22 #include "extensions/browser/api/cast_channel/logger_util.h"
23 #include "net/base/address_list.h"
24 #include "net/base/host_port_pair.h"
25 #include "net/base/net_errors.h"
26 #include "net/base/net_util.h"
27 #include "net/cert/cert_verifier.h"
28 #include "net/cert/x509_certificate.h"
29 #include "net/http/transport_security_state.h"
30 #include "net/socket/client_socket_factory.h"
31 #include "net/socket/client_socket_handle.h"
32 #include "net/socket/ssl_client_socket.h"
33 #include "net/socket/stream_socket.h"
34 #include "net/socket/tcp_client_socket.h"
35 #include "net/ssl/ssl_config_service.h"
36 #include "net/ssl/ssl_info.h"
38 // Assumes |ip_endpoint_| of type net::IPEndPoint and |channel_auth_| of enum
39 // type ChannelAuthType are available in the current scope.
40 #define VLOG_WITH_CONNECTION(level) VLOG(level) << "[" << \
41 ip_endpoint_.ToString() << ", auth=" << channel_auth_ << "] "
45 // The default keepalive delay. On Linux, keepalives probes will be sent after
46 // the socket is idle for this length of time, and the socket will be closed
47 // after 9 failed probes. So the total idle time before close is 10 *
48 // kTcpKeepAliveDelaySecs.
49 const int kTcpKeepAliveDelaySecs = 10;
53 namespace extensions {
55 static base::LazyInstance<BrowserContextKeyedAPIFactory<
56 ApiResourceManager<core_api::cast_channel::CastSocket> > > g_factory =
57 LAZY_INSTANCE_INITIALIZER;
61 BrowserContextKeyedAPIFactory<
62 ApiResourceManager<core_api::cast_channel::CastSocket> >*
63 ApiResourceManager<core_api::cast_channel::CastSocket>::GetFactoryInstance() {
64 return g_factory.Pointer();
68 namespace cast_channel {
72 proto::ReadyState ReadyStateToProto(ReadyState state) {
74 case READY_STATE_NONE:
75 return proto::READY_STATE_NONE;
76 case READY_STATE_CONNECTING:
77 return proto::READY_STATE_CONNECTING;
78 case READY_STATE_OPEN:
79 return proto::READY_STATE_OPEN;
80 case READY_STATE_CLOSING:
81 return proto::READY_STATE_CLOSING;
82 case READY_STATE_CLOSED:
83 return proto::READY_STATE_CLOSED;
86 return proto::READY_STATE_NONE;
90 proto::ConnectionState ConnectStateToProto(CastSocket::ConnectionState state) {
92 case CastSocket::CONN_STATE_NONE:
93 return proto::CONN_STATE_NONE;
94 case CastSocket::CONN_STATE_TCP_CONNECT:
95 return proto::CONN_STATE_TCP_CONNECT;
96 case CastSocket::CONN_STATE_TCP_CONNECT_COMPLETE:
97 return proto::CONN_STATE_TCP_CONNECT_COMPLETE;
98 case CastSocket::CONN_STATE_SSL_CONNECT:
99 return proto::CONN_STATE_SSL_CONNECT;
100 case CastSocket::CONN_STATE_SSL_CONNECT_COMPLETE:
101 return proto::CONN_STATE_SSL_CONNECT_COMPLETE;
102 case CastSocket::CONN_STATE_AUTH_CHALLENGE_SEND:
103 return proto::CONN_STATE_AUTH_CHALLENGE_SEND;
104 case CastSocket::CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE:
105 return proto::CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE;
106 case CastSocket::CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE:
107 return proto::CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE;
110 return proto::CONN_STATE_NONE;
114 proto::ReadState ReadStateToProto(CastSocket::ReadState state) {
116 case CastSocket::READ_STATE_NONE:
117 return proto::READ_STATE_NONE;
118 case CastSocket::READ_STATE_READ:
119 return proto::READ_STATE_READ;
120 case CastSocket::READ_STATE_READ_COMPLETE:
121 return proto::READ_STATE_READ_COMPLETE;
122 case CastSocket::READ_STATE_DO_CALLBACK:
123 return proto::READ_STATE_DO_CALLBACK;
124 case CastSocket::READ_STATE_ERROR:
125 return proto::READ_STATE_ERROR;
128 return proto::READ_STATE_NONE;
132 proto::WriteState WriteStateToProto(CastSocket::WriteState state) {
134 case CastSocket::WRITE_STATE_NONE:
135 return proto::WRITE_STATE_NONE;
136 case CastSocket::WRITE_STATE_WRITE:
137 return proto::WRITE_STATE_WRITE;
138 case CastSocket::WRITE_STATE_WRITE_COMPLETE:
139 return proto::WRITE_STATE_WRITE_COMPLETE;
140 case CastSocket::WRITE_STATE_DO_CALLBACK:
141 return proto::WRITE_STATE_DO_CALLBACK;
142 case CastSocket::WRITE_STATE_ERROR:
143 return proto::WRITE_STATE_ERROR;
146 return proto::WRITE_STATE_NONE;
150 proto::ErrorState ErrorStateToProto(ChannelError state) {
152 case CHANNEL_ERROR_NONE:
153 return proto::CHANNEL_ERROR_NONE;
154 case CHANNEL_ERROR_CHANNEL_NOT_OPEN:
155 return proto::CHANNEL_ERROR_CHANNEL_NOT_OPEN;
156 case CHANNEL_ERROR_AUTHENTICATION_ERROR:
157 return proto::CHANNEL_ERROR_AUTHENTICATION_ERROR;
158 case CHANNEL_ERROR_CONNECT_ERROR:
159 return proto::CHANNEL_ERROR_CONNECT_ERROR;
160 case CHANNEL_ERROR_SOCKET_ERROR:
161 return proto::CHANNEL_ERROR_SOCKET_ERROR;
162 case CHANNEL_ERROR_TRANSPORT_ERROR:
163 return proto::CHANNEL_ERROR_TRANSPORT_ERROR;
164 case CHANNEL_ERROR_INVALID_MESSAGE:
165 return proto::CHANNEL_ERROR_INVALID_MESSAGE;
166 case CHANNEL_ERROR_INVALID_CHANNEL_ID:
167 return proto::CHANNEL_ERROR_INVALID_CHANNEL_ID;
168 case CHANNEL_ERROR_CONNECT_TIMEOUT:
169 return proto::CHANNEL_ERROR_CONNECT_TIMEOUT;
170 case CHANNEL_ERROR_UNKNOWN:
171 return proto::CHANNEL_ERROR_UNKNOWN;
174 return proto::CHANNEL_ERROR_NONE;
180 CastSocket::CastSocket(const std::string& owner_extension_id,
181 const net::IPEndPoint& ip_endpoint,
182 ChannelAuthType channel_auth,
183 CastSocket::Delegate* delegate,
184 net::NetLog* net_log,
185 const base::TimeDelta& timeout,
186 const scoped_refptr<Logger>& logger)
187 : ApiResource(owner_extension_id),
189 ip_endpoint_(ip_endpoint),
190 channel_auth_(channel_auth),
192 current_message_size_(0),
193 current_message_(new CastMessage()),
196 connect_timeout_(timeout),
197 connect_timeout_timer_(new base::OneShotTimer<CastSocket>),
199 connect_state_(CONN_STATE_NONE),
200 write_state_(WRITE_STATE_NONE),
201 read_state_(READ_STATE_NONE),
202 error_state_(CHANNEL_ERROR_NONE),
203 ready_state_(READY_STATE_NONE) {
205 DCHECK(channel_auth_ == CHANNEL_AUTH_TYPE_SSL ||
206 channel_auth_ == CHANNEL_AUTH_TYPE_SSL_VERIFIED);
207 net_log_source_.type = net::NetLog::SOURCE_SOCKET;
208 net_log_source_.id = net_log_->NextID();
210 // Reuse these buffers for each message.
211 header_read_buffer_ = new net::GrowableIOBuffer();
212 header_read_buffer_->SetCapacity(MessageHeader::header_size());
213 body_read_buffer_ = new net::GrowableIOBuffer();
214 body_read_buffer_->SetCapacity(MessageHeader::max_message_size());
215 current_read_buffer_ = header_read_buffer_;
218 CastSocket::~CastSocket() {
219 // Ensure that resources are freed but do not run pending callbacks to avoid
224 ReadyState CastSocket::ready_state() const {
228 ChannelError CastSocket::error_state() const {
232 scoped_ptr<net::TCPClientSocket> CastSocket::CreateTcpSocket() {
233 net::AddressList addresses(ip_endpoint_);
234 return scoped_ptr<net::TCPClientSocket>(
235 new net::TCPClientSocket(addresses, net_log_, net_log_source_));
236 // Options cannot be set on the TCPClientSocket yet, because the
237 // underlying platform socket will not be created until Bind()
238 // or Connect() is called.
241 scoped_ptr<net::SSLClientSocket> CastSocket::CreateSslSocket(
242 scoped_ptr<net::StreamSocket> socket) {
243 net::SSLConfig ssl_config;
244 // If a peer cert was extracted in a previous attempt to connect, then
245 // whitelist that cert.
246 if (!peer_cert_.empty()) {
247 net::SSLConfig::CertAndStatus cert_and_status;
248 cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
249 cert_and_status.der_cert = peer_cert_;
250 ssl_config.allowed_bad_certs.push_back(cert_and_status);
251 logger_->LogSocketEvent(channel_id_, proto::SSL_CERT_WHITELISTED);
254 cert_verifier_.reset(net::CertVerifier::CreateDefault());
255 transport_security_state_.reset(new net::TransportSecurityState);
256 net::SSLClientSocketContext context;
257 // CertVerifier and TransportSecurityState are owned by us, not the
259 context.cert_verifier = cert_verifier_.get();
260 context.transport_security_state = transport_security_state_.get();
262 scoped_ptr<net::ClientSocketHandle> connection(new net::ClientSocketHandle);
263 connection->SetSocket(socket.Pass());
264 net::HostPortPair host_and_port = net::HostPortPair::FromIPEndPoint(
267 return net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
268 connection.Pass(), host_and_port, ssl_config, context);
271 bool CastSocket::ExtractPeerCert(std::string* cert) {
273 DCHECK(peer_cert_.empty());
274 net::SSLInfo ssl_info;
275 if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get())
278 logger_->LogSocketEvent(channel_id_, proto::SSL_INFO_OBTAINED);
280 bool result = net::X509Certificate::GetDEREncoded(
281 ssl_info.cert->os_cert_handle(), cert);
283 VLOG_WITH_CONNECTION(1) << "Successfully extracted peer certificate: "
286 logger_->LogSocketEventWithRv(
287 channel_id_, proto::DER_ENCODED_CERT_OBTAIN, result ? 1 : 0);
291 bool CastSocket::VerifyChallengeReply() {
292 AuthResult result = AuthenticateChallengeReply(*challenge_reply_, peer_cert_);
293 logger_->LogSocketChallengeReplyEvent(channel_id_, result);
294 return result.success();
297 void CastSocket::Connect(const net::CompletionCallback& callback) {
298 DCHECK(CalledOnValidThread());
299 VLOG_WITH_CONNECTION(1) << "Connect readyState = " << ready_state_;
300 if (ready_state_ != READY_STATE_NONE) {
301 logger_->LogSocketEventWithDetails(
302 channel_id_, proto::CONNECT_FAILED, "ReadyState not NONE");
303 callback.Run(net::ERR_CONNECTION_FAILED);
307 connect_callback_ = callback;
308 SetReadyState(READY_STATE_CONNECTING);
309 SetConnectState(CONN_STATE_TCP_CONNECT);
311 if (connect_timeout_.InMicroseconds() > 0) {
312 DCHECK(connect_timeout_callback_.IsCancelled());
313 connect_timeout_callback_.Reset(
314 base::Bind(&CastSocket::OnConnectTimeout, base::Unretained(this)));
315 GetTimer()->Start(FROM_HERE,
317 connect_timeout_callback_.callback());
319 DoConnectLoop(net::OK);
322 void CastSocket::PostTaskToStartConnectLoop(int result) {
323 DCHECK(CalledOnValidThread());
324 DCHECK(connect_loop_callback_.IsCancelled());
325 connect_loop_callback_.Reset(base::Bind(&CastSocket::DoConnectLoop,
326 base::Unretained(this),
328 base::MessageLoop::current()->PostTask(FROM_HERE,
329 connect_loop_callback_.callback());
332 void CastSocket::OnConnectTimeout() {
333 DCHECK(CalledOnValidThread());
334 // Stop all pending connection setup tasks and report back to the client.
336 logger_->LogSocketEvent(channel_id_, proto::CONNECT_TIMED_OUT);
337 VLOG_WITH_CONNECTION(1) << "Timeout while establishing a connection.";
338 DoConnectCallback(net::ERR_TIMED_OUT);
341 // This method performs the state machine transitions for connection flow.
342 // There are two entry points to this method:
343 // 1. Connect method: this starts the flow
344 // 2. Callback from network operations that finish asynchronously
345 void CastSocket::DoConnectLoop(int result) {
346 connect_loop_callback_.Cancel();
348 LOG(ERROR) << "CANCELLED - Aborting DoConnectLoop.";
351 // Network operations can either finish synchronously or asynchronously.
352 // This method executes the state machine transitions in a loop so that
353 // correct state transitions happen even when network operations finish
357 ConnectionState state = connect_state_;
358 // Default to CONN_STATE_NONE, which breaks the processing loop if any
359 // handler fails to transition to another state to continue processing.
360 connect_state_ = CONN_STATE_NONE;
362 case CONN_STATE_TCP_CONNECT:
365 case CONN_STATE_TCP_CONNECT_COMPLETE:
366 rv = DoTcpConnectComplete(rv);
368 case CONN_STATE_SSL_CONNECT:
369 DCHECK_EQ(net::OK, rv);
372 case CONN_STATE_SSL_CONNECT_COMPLETE:
373 rv = DoSslConnectComplete(rv);
375 case CONN_STATE_AUTH_CHALLENGE_SEND:
376 rv = DoAuthChallengeSend();
378 case CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE:
379 rv = DoAuthChallengeSendComplete(rv);
381 case CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE:
382 rv = DoAuthChallengeReplyComplete(rv);
385 NOTREACHED() << "BUG in connect flow. Unknown state: " << state;
388 } while (rv != net::ERR_IO_PENDING && connect_state_ != CONN_STATE_NONE);
389 // Get out of the loop either when:
390 // a. A network operation is pending, OR
391 // b. The Do* method called did not change state
393 // No state change occurred in do-while loop above. This means state has
394 // transitioned to NONE.
395 if (connect_state_ == CONN_STATE_NONE) {
396 logger_->LogSocketConnectState(channel_id_,
397 ConnectStateToProto(connect_state_));
400 // Connect loop is finished: if there is no pending IO invoke the callback.
401 if (rv != net::ERR_IO_PENDING) {
403 DoConnectCallback(rv);
407 int CastSocket::DoTcpConnect() {
408 DCHECK(connect_loop_callback_.IsCancelled());
409 VLOG_WITH_CONNECTION(1) << "DoTcpConnect";
410 SetConnectState(CONN_STATE_TCP_CONNECT_COMPLETE);
411 tcp_socket_ = CreateTcpSocket();
413 int rv = tcp_socket_->Connect(
414 base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this)));
415 logger_->LogSocketEventWithRv(channel_id_, proto::TCP_SOCKET_CONNECT, rv);
419 int CastSocket::DoTcpConnectComplete(int result) {
420 VLOG_WITH_CONNECTION(1) << "DoTcpConnectComplete: " << result;
421 if (result == net::OK) {
422 // Enable TCP protocol-level keep-alive.
423 bool result = tcp_socket_->SetKeepAlive(true, kTcpKeepAliveDelaySecs);
424 LOG_IF(WARNING, !result) << "Failed to SetKeepAlive.";
425 logger_->LogSocketEventWithRv(
426 channel_id_, proto::TCP_SOCKET_SET_KEEP_ALIVE, result ? 1 : 0);
427 SetConnectState(CONN_STATE_SSL_CONNECT);
432 int CastSocket::DoSslConnect() {
433 DCHECK(connect_loop_callback_.IsCancelled());
434 VLOG_WITH_CONNECTION(1) << "DoSslConnect";
435 SetConnectState(CONN_STATE_SSL_CONNECT_COMPLETE);
436 socket_ = CreateSslSocket(tcp_socket_.PassAs<net::StreamSocket>());
438 int rv = socket_->Connect(
439 base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this)));
440 logger_->LogSocketEventWithRv(channel_id_, proto::SSL_SOCKET_CONNECT, rv);
444 int CastSocket::DoSslConnectComplete(int result) {
445 VLOG_WITH_CONNECTION(1) << "DoSslConnectComplete: " << result;
446 if (result == net::ERR_CERT_AUTHORITY_INVALID &&
447 peer_cert_.empty() && ExtractPeerCert(&peer_cert_)) {
448 SetConnectState(CONN_STATE_TCP_CONNECT);
449 } else if (result == net::OK &&
450 channel_auth_ == CHANNEL_AUTH_TYPE_SSL_VERIFIED) {
451 SetConnectState(CONN_STATE_AUTH_CHALLENGE_SEND);
456 int CastSocket::DoAuthChallengeSend() {
457 VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSend";
458 SetConnectState(CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE);
460 CastMessage challenge_message;
461 CreateAuthChallengeMessage(&challenge_message);
462 VLOG_WITH_CONNECTION(1) << "Sending challenge: "
463 << CastMessageToString(challenge_message);
464 // Post a task to send auth challenge so that DoWriteLoop is not nested inside
465 // DoConnectLoop. This is not strictly necessary but keeps the write loop
466 // code decoupled from connect loop code.
467 DCHECK(send_auth_challenge_callback_.IsCancelled());
468 send_auth_challenge_callback_.Reset(
469 base::Bind(&CastSocket::SendCastMessageInternal,
470 base::Unretained(this),
472 base::Bind(&CastSocket::DoAuthChallengeSendWriteComplete,
473 base::Unretained(this))));
474 base::MessageLoop::current()->PostTask(
476 send_auth_challenge_callback_.callback());
477 // Always return IO_PENDING since the result is always asynchronous.
478 return net::ERR_IO_PENDING;
481 void CastSocket::DoAuthChallengeSendWriteComplete(int result) {
482 send_auth_challenge_callback_.Cancel();
483 VLOG_WITH_CONNECTION(2) << "DoAuthChallengeSendWriteComplete: " << result;
484 DCHECK_GT(result, 0);
485 DCHECK_EQ(write_queue_.size(), 1UL);
486 PostTaskToStartConnectLoop(result);
489 int CastSocket::DoAuthChallengeSendComplete(int result) {
490 VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSendComplete: " << result;
493 SetConnectState(CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE);
495 // Post a task to start read loop so that DoReadLoop is not nested inside
496 // DoConnectLoop. This is not strictly necessary but keeps the read loop
497 // code decoupled from connect loop code.
498 PostTaskToStartReadLoop();
499 // Always return IO_PENDING since the result is always asynchronous.
500 return net::ERR_IO_PENDING;
503 int CastSocket::DoAuthChallengeReplyComplete(int result) {
504 VLOG_WITH_CONNECTION(1) << "DoAuthChallengeReplyComplete: " << result;
507 if (!VerifyChallengeReply())
508 return net::ERR_FAILED;
509 VLOG_WITH_CONNECTION(1) << "Auth challenge verification succeeded";
513 void CastSocket::DoConnectCallback(int result) {
514 SetReadyState((result == net::OK) ? READY_STATE_OPEN : READY_STATE_CLOSED);
515 if (result == net::OK) {
516 SetErrorState(CHANNEL_ERROR_NONE);
517 PostTaskToStartReadLoop();
518 VLOG_WITH_CONNECTION(1) << "Calling Connect_Callback";
519 base::ResetAndReturn(&connect_callback_).Run(result);
521 } else if (result == net::ERR_TIMED_OUT) {
522 SetErrorState(CHANNEL_ERROR_CONNECT_TIMEOUT);
524 SetErrorState(CHANNEL_ERROR_CONNECT_ERROR);
526 // Calls the connect callback.
530 void CastSocket::Close(const net::CompletionCallback& callback) {
532 RunPendingCallbacksOnClose();
533 // Run this callback last. It may delete the socket.
534 callback.Run(net::OK);
537 void CastSocket::CloseInternal() {
538 // TODO(mfoltz): Enforce this when CastChannelAPITest is rewritten to create
539 // and free sockets on the same thread. crbug.com/398242
540 // DCHECK(CalledOnValidThread());
541 if (ready_state_ == READY_STATE_CLOSED) {
545 VLOG_WITH_CONNECTION(1) << "Close ReadyState = " << ready_state_;
548 cert_verifier_.reset();
549 transport_security_state_.reset();
552 // Cancel callbacks that we queued ourselves to re-enter the connect or read
554 connect_loop_callback_.Cancel();
555 send_auth_challenge_callback_.Cancel();
556 read_loop_callback_.Cancel();
557 connect_timeout_callback_.Cancel();
558 SetReadyState(READY_STATE_CLOSED);
559 logger_->LogSocketEvent(channel_id_, proto::SOCKET_CLOSED);
562 void CastSocket::RunPendingCallbacksOnClose() {
563 DCHECK_EQ(ready_state_, READY_STATE_CLOSED);
564 if (!connect_callback_.is_null()) {
565 connect_callback_.Run(net::ERR_CONNECTION_FAILED);
566 connect_callback_.Reset();
568 for (; !write_queue_.empty(); write_queue_.pop()) {
569 net::CompletionCallback& callback = write_queue_.front().callback;
570 callback.Run(net::ERR_FAILED);
575 void CastSocket::SendMessage(const MessageInfo& message,
576 const net::CompletionCallback& callback) {
577 DCHECK(CalledOnValidThread());
578 if (ready_state_ != READY_STATE_OPEN) {
579 logger_->LogSocketEventForMessage(channel_id_,
580 proto::SEND_MESSAGE_FAILED,
582 "Ready state not OPEN");
583 callback.Run(net::ERR_FAILED);
586 CastMessage message_proto;
587 if (!MessageInfoToCastMessage(message, &message_proto)) {
588 logger_->LogSocketEventForMessage(channel_id_,
589 proto::SEND_MESSAGE_FAILED,
591 "Failed to convert to CastMessage");
592 callback.Run(net::ERR_FAILED);
595 SendCastMessageInternal(message_proto, callback);
598 void CastSocket::SendCastMessageInternal(
599 const CastMessage& message,
600 const net::CompletionCallback& callback) {
601 WriteRequest write_request(callback);
602 if (!write_request.SetContent(message)) {
603 logger_->LogSocketEventForMessage(channel_id_,
604 proto::SEND_MESSAGE_FAILED,
605 message.namespace_(),
606 "SetContent failed");
607 callback.Run(net::ERR_FAILED);
611 write_queue_.push(write_request);
612 logger_->LogSocketEventForMessage(
614 proto::MESSAGE_ENQUEUED,
615 message.namespace_(),
616 base::StringPrintf("Queue size: %" PRIuS, write_queue_.size()));
617 if (write_state_ == WRITE_STATE_NONE) {
618 SetWriteState(WRITE_STATE_WRITE);
619 DoWriteLoop(net::OK);
623 void CastSocket::DoWriteLoop(int result) {
624 DCHECK(CalledOnValidThread());
625 VLOG_WITH_CONNECTION(1) << "DoWriteLoop queue size: " << write_queue_.size();
627 if (write_queue_.empty()) {
628 SetWriteState(WRITE_STATE_NONE);
632 // Network operations can either finish synchronously or asynchronously.
633 // This method executes the state machine transitions in a loop so that
634 // write state transitions happen even when network operations finish
638 WriteState state = write_state_;
639 write_state_ = WRITE_STATE_NONE;
641 case WRITE_STATE_WRITE:
644 case WRITE_STATE_WRITE_COMPLETE:
645 rv = DoWriteComplete(rv);
647 case WRITE_STATE_DO_CALLBACK:
648 rv = DoWriteCallback();
650 case WRITE_STATE_ERROR:
651 rv = DoWriteError(rv);
654 NOTREACHED() << "BUG in write flow. Unknown state: " << state;
657 } while (!write_queue_.empty() &&
658 rv != net::ERR_IO_PENDING &&
659 write_state_ != WRITE_STATE_NONE);
661 // No state change occurred in do-while loop above. This means state has
662 // transitioned to NONE.
663 if (write_state_ == WRITE_STATE_NONE) {
664 logger_->LogSocketWriteState(channel_id_, WriteStateToProto(write_state_));
667 // If write loop is done because the queue is empty then set write
669 if (write_queue_.empty())
670 SetWriteState(WRITE_STATE_NONE);
672 // Write loop is done - if the result is ERR_FAILED then close with error.
673 if (rv == net::ERR_FAILED)
677 int CastSocket::DoWrite() {
678 DCHECK(!write_queue_.empty());
679 WriteRequest& request = write_queue_.front();
681 VLOG_WITH_CONNECTION(2) << "WriteData byte_count = "
682 << request.io_buffer->size() << " bytes_written "
683 << request.io_buffer->BytesConsumed();
685 SetWriteState(WRITE_STATE_WRITE_COMPLETE);
687 int rv = socket_->Write(
688 request.io_buffer.get(),
689 request.io_buffer->BytesRemaining(),
690 base::Bind(&CastSocket::DoWriteLoop, base::Unretained(this)));
691 logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_WRITE, rv);
696 int CastSocket::DoWriteComplete(int result) {
697 DCHECK(!write_queue_.empty());
698 if (result <= 0) { // NOTE that 0 also indicates an error
699 SetErrorState(CHANNEL_ERROR_SOCKET_ERROR);
700 SetWriteState(WRITE_STATE_ERROR);
701 return result == 0 ? net::ERR_FAILED : result;
704 // Some bytes were successfully written
705 WriteRequest& request = write_queue_.front();
706 scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer;
707 io_buffer->DidConsume(result);
708 if (io_buffer->BytesRemaining() == 0) // Message fully sent
709 SetWriteState(WRITE_STATE_DO_CALLBACK);
711 SetWriteState(WRITE_STATE_WRITE);
716 int CastSocket::DoWriteCallback() {
717 DCHECK(!write_queue_.empty());
719 SetWriteState(WRITE_STATE_WRITE);
721 WriteRequest& request = write_queue_.front();
722 int bytes_consumed = request.io_buffer->BytesConsumed();
723 logger_->LogSocketEventForMessage(
725 proto::MESSAGE_WRITTEN,
726 request.message_namespace,
727 base::StringPrintf("Bytes: %d", bytes_consumed));
728 request.callback.Run(bytes_consumed);
733 int CastSocket::DoWriteError(int result) {
734 DCHECK(!write_queue_.empty());
735 DCHECK_LT(result, 0);
737 // If inside connection flow, then there should be exactly one item in
739 if (ready_state_ == READY_STATE_CONNECTING) {
741 DCHECK(write_queue_.empty());
742 PostTaskToStartConnectLoop(result);
743 // Connect loop will handle the error. Return net::OK so that write flow
744 // does not try to report error also.
748 while (!write_queue_.empty()) {
749 WriteRequest& request = write_queue_.front();
750 request.callback.Run(result);
753 return net::ERR_FAILED;
756 void CastSocket::PostTaskToStartReadLoop() {
757 DCHECK(CalledOnValidThread());
758 DCHECK(read_loop_callback_.IsCancelled());
759 read_loop_callback_.Reset(base::Bind(&CastSocket::StartReadLoop,
760 base::Unretained(this)));
761 base::MessageLoop::current()->PostTask(FROM_HERE,
762 read_loop_callback_.callback());
765 void CastSocket::StartReadLoop() {
766 read_loop_callback_.Cancel();
767 // Read loop would have already been started if read state is not NONE
768 if (read_state_ == READ_STATE_NONE) {
769 SetReadState(READ_STATE_READ);
774 void CastSocket::DoReadLoop(int result) {
775 DCHECK(CalledOnValidThread());
776 // Network operations can either finish synchronously or asynchronously.
777 // This method executes the state machine transitions in a loop so that
778 // write state transitions happen even when network operations finish
782 ReadState state = read_state_;
783 read_state_ = READ_STATE_NONE;
786 case READ_STATE_READ:
789 case READ_STATE_READ_COMPLETE:
790 rv = DoReadComplete(rv);
792 case READ_STATE_DO_CALLBACK:
793 rv = DoReadCallback();
795 case READ_STATE_ERROR:
796 rv = DoReadError(rv);
797 DCHECK_EQ(read_state_, READ_STATE_NONE);
800 NOTREACHED() << "BUG in read flow. Unknown state: " << state;
803 } while (rv != net::ERR_IO_PENDING && read_state_ != READ_STATE_NONE);
805 // No state change occurred in do-while loop above. This means state has
806 // transitioned to NONE.
807 if (read_state_ == READ_STATE_NONE) {
808 logger_->LogSocketReadState(channel_id_, ReadStateToProto(read_state_));
811 if (rv == net::ERR_FAILED) {
812 if (ready_state_ == READY_STATE_CONNECTING) {
813 // Read errors during the handshake should notify the caller via the
814 // connect callback. This will also send error status via the OnError
816 PostTaskToStartConnectLoop(net::ERR_FAILED);
818 // Connection is already established. Close and send error status via the
825 int CastSocket::DoRead() {
826 SetReadState(READ_STATE_READ_COMPLETE);
827 // Figure out whether to read header or body, and the remaining bytes.
828 uint32 num_bytes_to_read = 0;
829 if (header_read_buffer_->RemainingCapacity() > 0) {
830 current_read_buffer_ = header_read_buffer_;
831 num_bytes_to_read = header_read_buffer_->RemainingCapacity();
832 CHECK_LE(num_bytes_to_read, MessageHeader::header_size());
834 DCHECK_GT(current_message_size_, 0U);
835 num_bytes_to_read = current_message_size_ - body_read_buffer_->offset();
836 current_read_buffer_ = body_read_buffer_;
837 CHECK_LE(num_bytes_to_read, MessageHeader::max_message_size());
839 CHECK_GT(num_bytes_to_read, 0U);
841 // Read up to num_bytes_to_read into |current_read_buffer_|.
842 int rv = socket_->Read(
843 current_read_buffer_.get(),
845 base::Bind(&CastSocket::DoReadLoop, base::Unretained(this)));
846 logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_READ, rv);
851 int CastSocket::DoReadComplete(int result) {
852 VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result
853 << " header offset = "
854 << header_read_buffer_->offset()
855 << " body offset = " << body_read_buffer_->offset();
856 if (result <= 0) { // 0 means EOF: the peer closed the socket
857 VLOG_WITH_CONNECTION(1) << "Read error, peer closed the socket";
858 SetErrorState(CHANNEL_ERROR_SOCKET_ERROR);
859 SetReadState(READ_STATE_ERROR);
860 return result == 0 ? net::ERR_FAILED : result;
863 // Some data was read. Move the offset in the current buffer forward.
864 CHECK_LE(current_read_buffer_->offset() + result,
865 current_read_buffer_->capacity());
866 current_read_buffer_->set_offset(current_read_buffer_->offset() + result);
868 if (current_read_buffer_.get() == header_read_buffer_.get() &&
869 current_read_buffer_->RemainingCapacity() == 0) {
870 // A full header is read, process the contents.
871 if (!ProcessHeader()) {
872 SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
873 SetReadState(READ_STATE_ERROR);
875 // Processed header, now read the body.
876 SetReadState(READ_STATE_READ);
878 } else if (current_read_buffer_.get() == body_read_buffer_.get() &&
879 static_cast<uint32>(current_read_buffer_->offset()) ==
880 current_message_size_) {
881 // Store a copy of current_message_size_ since it will be reset by
883 uint32 message_size = current_message_size_;
884 // Full body is read, process the contents.
886 logger_->LogSocketEventForMessage(
889 current_message_->namespace_(),
890 base::StringPrintf("Message size: %u", message_size));
891 SetReadState(READ_STATE_DO_CALLBACK);
893 SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
894 SetReadState(READ_STATE_ERROR);
897 // Have not received full header or full body yet; keep reading.
898 SetReadState(READ_STATE_READ);
904 int CastSocket::DoReadCallback() {
905 SetReadState(READ_STATE_READ);
906 const CastMessage& message = *current_message_;
907 if (ready_state_ == READY_STATE_CONNECTING) {
908 if (IsAuthMessage(message)) {
909 challenge_reply_.reset(new CastMessage(message));
910 logger_->LogSocketEvent(channel_id_, proto::RECEIVED_CHALLENGE_REPLY);
911 PostTaskToStartConnectLoop(net::OK);
914 SetReadState(READ_STATE_ERROR);
915 SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
916 return net::ERR_INVALID_RESPONSE;
920 MessageInfo message_info;
921 if (!CastMessageToMessageInfo(message, &message_info)) {
922 current_message_->Clear();
923 SetReadState(READ_STATE_ERROR);
924 SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
925 return net::ERR_INVALID_RESPONSE;
928 logger_->LogSocketEventForMessage(channel_id_,
929 proto::NOTIFY_ON_MESSAGE,
930 message.namespace_(),
932 delegate_->OnMessage(this, message_info);
933 current_message_->Clear();
938 int CastSocket::DoReadError(int result) {
939 DCHECK_LE(result, 0);
940 return net::ERR_FAILED;
943 bool CastSocket::ProcessHeader() {
944 CHECK_EQ(static_cast<uint32>(header_read_buffer_->offset()),
945 MessageHeader::header_size());
946 MessageHeader header;
947 MessageHeader::ReadFromIOBuffer(header_read_buffer_.get(), &header);
948 if (header.message_size > MessageHeader::max_message_size())
951 VLOG_WITH_CONNECTION(2) << "Parsed header { message_size: "
952 << header.message_size << " }";
953 current_message_size_ = header.message_size;
957 bool CastSocket::ProcessBody() {
958 CHECK_EQ(static_cast<uint32>(body_read_buffer_->offset()),
959 current_message_size_);
960 if (!current_message_->ParseFromArray(
961 body_read_buffer_->StartOfBuffer(), current_message_size_)) {
964 current_message_size_ = 0;
965 header_read_buffer_->set_offset(0);
966 body_read_buffer_->set_offset(0);
967 current_read_buffer_ = header_read_buffer_;
972 bool CastSocket::Serialize(const CastMessage& message_proto,
973 std::string* message_data) {
974 DCHECK(message_data);
975 message_proto.SerializeToString(message_data);
976 size_t message_size = message_data->size();
977 if (message_size > MessageHeader::max_message_size()) {
978 message_data->clear();
981 CastSocket::MessageHeader header;
982 header.SetMessageSize(message_size);
983 header.PrependToString(message_data);
987 void CastSocket::CloseWithError() {
988 DCHECK(CalledOnValidThread());
990 RunPendingCallbacksOnClose();
992 logger_->LogSocketEvent(channel_id_, proto::NOTIFY_ON_ERROR);
993 delegate_->OnError(this, error_state_, logger_->GetLastErrors(channel_id_));
997 std::string CastSocket::CastUrl() const {
998 return ((channel_auth_ == CHANNEL_AUTH_TYPE_SSL_VERIFIED) ?
999 "casts://" : "cast://") + ip_endpoint_.ToString();
1002 bool CastSocket::CalledOnValidThread() const {
1003 return thread_checker_.CalledOnValidThread();
1006 base::Timer* CastSocket::GetTimer() {
1007 return connect_timeout_timer_.get();
1010 void CastSocket::SetConnectState(ConnectionState connect_state) {
1011 if (connect_state_ != connect_state) {
1012 connect_state_ = connect_state;
1013 logger_->LogSocketConnectState(channel_id_,
1014 ConnectStateToProto(connect_state_));
1018 void CastSocket::SetReadyState(ReadyState ready_state) {
1019 if (ready_state_ != ready_state) {
1020 ready_state_ = ready_state;
1021 logger_->LogSocketReadyState(channel_id_, ReadyStateToProto(ready_state_));
1025 void CastSocket::SetErrorState(ChannelError error_state) {
1026 if (error_state_ != error_state) {
1027 error_state_ = error_state;
1028 logger_->LogSocketErrorState(channel_id_, ErrorStateToProto(error_state_));
1032 void CastSocket::SetReadState(ReadState read_state) {
1033 if (read_state_ != read_state) {
1034 read_state_ = read_state;
1035 logger_->LogSocketReadState(channel_id_, ReadStateToProto(read_state_));
1039 void CastSocket::SetWriteState(WriteState write_state) {
1040 if (write_state_ != write_state) {
1041 write_state_ = write_state;
1042 logger_->LogSocketWriteState(channel_id_, WriteStateToProto(write_state_));
1046 CastSocket::MessageHeader::MessageHeader() : message_size(0) { }
1048 void CastSocket::MessageHeader::SetMessageSize(size_t size) {
1049 DCHECK_LT(size, static_cast<size_t>(kuint32max));
1050 DCHECK_GT(size, 0U);
1051 message_size = size;
1054 // TODO(mfoltz): Investigate replacing header serialization with base::Pickle,
1055 // if bit-for-bit compatible.
1056 void CastSocket::MessageHeader::PrependToString(std::string* str) {
1057 MessageHeader output = *this;
1058 output.message_size = base::HostToNet32(message_size);
1059 size_t header_size = base::checked_cast<size_t, uint32>(
1060 MessageHeader::header_size());
1061 scoped_ptr<char, base::FreeDeleter> char_array(
1062 static_cast<char*>(malloc(header_size)));
1063 memcpy(char_array.get(), &output, header_size);
1064 str->insert(0, char_array.get(), header_size);
1067 // TODO(mfoltz): Investigate replacing header deserialization with base::Pickle,
1068 // if bit-for-bit compatible.
1069 void CastSocket::MessageHeader::ReadFromIOBuffer(
1070 net::GrowableIOBuffer* buffer, MessageHeader* header) {
1071 uint32 message_size;
1072 size_t header_size = base::checked_cast<size_t, uint32>(
1073 MessageHeader::header_size());
1074 memcpy(&message_size, buffer->StartOfBuffer(), header_size);
1075 header->message_size = base::NetToHost32(message_size);
1078 std::string CastSocket::MessageHeader::ToString() {
1079 return "{message_size: " + base::UintToString(message_size) + "}";
1082 CastSocket::WriteRequest::WriteRequest(const net::CompletionCallback& callback)
1083 : callback(callback) { }
1085 bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) {
1086 DCHECK(!io_buffer.get());
1087 std::string message_data;
1088 if (!Serialize(message_proto, &message_data))
1090 message_namespace = message_proto.namespace_();
1091 io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(message_data),
1092 message_data.size());
1096 CastSocket::WriteRequest::~WriteRequest() { }
1098 } // namespace cast_channel
1099 } // namespace core_api
1100 } // namespace extensions
1102 #undef VLOG_WITH_CONNECTION