1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/browser/extensions/api/cast_channel/cast_socket.h"
10 #include "base/callback_helpers.h"
11 #include "base/lazy_instance.h"
12 #include "base/strings/string_number_conversions.h"
13 #include "base/sys_byteorder.h"
14 #include "chrome/browser/extensions/api/cast_channel/cast_auth_util.h"
15 #include "chrome/browser/extensions/api/cast_channel/cast_channel.pb.h"
16 #include "chrome/browser/extensions/api/cast_channel/cast_message_util.h"
17 #include "net/base/address_list.h"
18 #include "net/base/host_port_pair.h"
19 #include "net/base/net_errors.h"
20 #include "net/base/net_util.h"
21 #include "net/cert/cert_verifier.h"
22 #include "net/cert/x509_certificate.h"
23 #include "net/http/transport_security_state.h"
24 #include "net/socket/client_socket_factory.h"
25 #include "net/socket/client_socket_handle.h"
26 #include "net/socket/ssl_client_socket.h"
27 #include "net/socket/stream_socket.h"
28 #include "net/socket/tcp_client_socket.h"
29 #include "net/ssl/ssl_config_service.h"
30 #include "net/ssl/ssl_info.h"
34 // Allowed schemes for Cast device URLs.
35 const char kCastInsecureScheme[] = "cast";
36 const char kCastSecureScheme[] = "casts";
38 // Size of the message header, in bytes. Don't use sizeof(MessageHeader)
39 // because of alignment; instead, sum the sizeof() for the fields.
40 const uint32 kMessageHeaderSize = sizeof(uint32);
42 // The default keepalive delay. On Linux, keepalives probes will be sent after
43 // the socket is idle for this length of time, and the socket will be closed
44 // after 9 failed probes. So the total idle time before close is 10 *
45 // kTcpKeepAliveDelaySecs.
46 const int kTcpKeepAliveDelaySecs = 10;
50 namespace extensions {
52 static base::LazyInstance<
53 ProfileKeyedAPIFactory<ApiResourceManager<api::cast_channel::CastSocket> > >
54 g_factory = LAZY_INSTANCE_INITIALIZER;
58 ProfileKeyedAPIFactory<ApiResourceManager<api::cast_channel::CastSocket> >*
59 ApiResourceManager<api::cast_channel::CastSocket>::GetFactoryInstance() {
60 return &g_factory.Get();
64 namespace cast_channel {
66 const uint32 kMaxMessageSize = 65536;
68 CastSocket::CastSocket(const std::string& owner_extension_id,
70 CastSocket::Delegate* delegate,
71 net::NetLog* net_log) :
72 ApiResource(owner_extension_id),
76 auth_required_(false),
77 error_state_(CHANNEL_ERROR_NONE),
78 ready_state_(READY_STATE_NONE),
79 write_callback_pending_(false),
80 read_callback_pending_(false),
81 current_message_size_(0),
83 next_state_(CONN_STATE_NONE) {
85 net_log_source_.type = net::NetLog::SOURCE_SOCKET;
86 net_log_source_.id = net_log_->NextID();
88 // We reuse these buffers for each message.
89 header_read_buffer_ = new net::GrowableIOBuffer();
90 header_read_buffer_->SetCapacity(kMessageHeaderSize);
91 body_read_buffer_ = new net::GrowableIOBuffer();
92 body_read_buffer_->SetCapacity(kMaxMessageSize);
93 current_read_buffer_ = header_read_buffer_;
96 CastSocket::~CastSocket() { }
98 const GURL& CastSocket::url() const {
102 scoped_ptr<net::TCPClientSocket> CastSocket::CreateTcpSocket() {
103 net::AddressList addresses(ip_endpoint_);
104 scoped_ptr<net::TCPClientSocket> tcp_socket(
105 new net::TCPClientSocket(addresses, net_log_, net_log_source_));
107 tcp_socket->SetKeepAlive(true, kTcpKeepAliveDelaySecs);
108 return tcp_socket.Pass();
111 scoped_ptr<net::SSLClientSocket> CastSocket::CreateSslSocket() {
112 net::SSLConfig ssl_config;
113 // If a peer cert was extracted in a previous attempt to connect, then
114 // whitelist that cert.
115 if (!peer_cert_.empty()) {
116 net::SSLConfig::CertAndStatus cert_and_status;
117 cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
118 cert_and_status.der_cert = peer_cert_;
119 ssl_config.allowed_bad_certs.push_back(cert_and_status);
122 cert_verifier_.reset(net::CertVerifier::CreateDefault());
123 transport_security_state_.reset(new net::TransportSecurityState);
124 net::SSLClientSocketContext context;
125 // CertVerifier and TransportSecurityState are owned by us, not the
127 context.cert_verifier = cert_verifier_.get();
128 context.transport_security_state = transport_security_state_.get();
130 scoped_ptr<net::ClientSocketHandle> connection(new net::ClientSocketHandle);
131 connection->SetSocket(tcp_socket_.PassAs<net::StreamSocket>());
132 net::HostPortPair host_and_port = net::HostPortPair::FromIPEndPoint(
135 return net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
136 connection.Pass(), host_and_port, ssl_config, context);
139 bool CastSocket::ExtractPeerCert(std::string* cert) {
141 DCHECK(peer_cert_.empty());
142 net::SSLInfo ssl_info;
143 if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get())
145 bool result = net::X509Certificate::GetDEREncoded(
146 ssl_info.cert->os_cert_handle(), cert);
148 DVLOG(1) << "Successfully extracted peer certificate: " << *cert;
152 int CastSocket::SendAuthChallenge() {
153 CastMessage challenge_message;
154 CreateAuthChallengeMessage(&challenge_message);
155 DVLOG(1) << "Sending challenge: " << CastMessageToString(challenge_message);
156 return SendMessageInternal(
158 base::Bind(&CastSocket::OnChallengeEvent, AsWeakPtr()));
161 int CastSocket::ReadAuthChallengeReply() {
165 void CastSocket::OnConnectComplete(int result) {
166 int rv = DoConnectLoop(result);
167 if (rv != net::ERR_IO_PENDING)
168 DoConnectCallback(rv);
171 void CastSocket::OnChallengeEvent(int result) {
172 // result >= 0 means read or write succeeded synchronously.
173 int rv = DoConnectLoop(result >= 0 ? net::OK : result);
174 if (rv != net::ERR_IO_PENDING)
175 DoConnectCallback(rv);
178 void CastSocket::Connect(const net::CompletionCallback& callback) {
179 DCHECK(CalledOnValidThread());
180 int result = net::ERR_CONNECTION_FAILED;
181 DVLOG(1) << "Connect readyState = " << ready_state_;
182 if (ready_state_ != READY_STATE_NONE) {
183 callback.Run(result);
186 if (!ParseChannelUrl(url_)) {
187 CloseWithError(cast_channel::CHANNEL_ERROR_CONNECT_ERROR);
188 callback.Run(result);
191 connect_callback_ = callback;
192 next_state_ = CONN_STATE_TCP_CONNECT;
193 int rv = DoConnectLoop(net::OK);
194 if (rv != net::ERR_IO_PENDING)
195 DoConnectCallback(rv);
198 // This method performs the state machine transitions for connection flow.
199 // There are two entry points to this method:
200 // 1. public Connect method: this starts the flow
201 // 2. OnConnectComplete: callback method called when an async operation
202 // is done. OnConnectComplete calls this method to continue the state
203 // machine transitions.
204 int CastSocket::DoConnectLoop(int result) {
205 // Network operations can either finish sycnronously or asynchronously.
206 // This method executes the state machine transitions in a loop so that
207 // correct state transitions happen even when network operations finish
211 ConnectionState state = next_state_;
212 // All the Do* methods do not set next_state_ in case of an
213 // error. So set next_state_ to NONE to figure out if the Do*
214 // method changed state or not.
215 next_state_ = CONN_STATE_NONE;
217 case CONN_STATE_TCP_CONNECT:
220 case CONN_STATE_TCP_CONNECT_COMPLETE:
221 rv = DoTcpConnectComplete(rv);
223 case CONN_STATE_SSL_CONNECT:
224 DCHECK_EQ(net::OK, rv);
227 case CONN_STATE_SSL_CONNECT_COMPLETE:
228 rv = DoSslConnectComplete(rv);
230 case CONN_STATE_AUTH_CHALLENGE_SEND:
231 rv = DoAuthChallengeSend();
233 case CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE:
234 rv = DoAuthChallengeSendComplete(rv);
236 case CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE:
237 rv = DoAuthChallengeReplyComplete(rv);
241 NOTREACHED() << "BUG in CastSocket state machine code";
244 } while (rv != net::ERR_IO_PENDING && next_state_ != CONN_STATE_NONE);
245 // Get out of the loop either when:
246 // a. A network operation is pending, OR
247 // b. The Do* method called did not change state
252 int CastSocket::DoTcpConnect() {
253 DVLOG(1) << "DoTcpConnect";
254 next_state_ = CONN_STATE_TCP_CONNECT_COMPLETE;
255 tcp_socket_ = CreateTcpSocket();
256 return tcp_socket_->Connect(
257 base::Bind(&CastSocket::OnConnectComplete, AsWeakPtr()));
260 int CastSocket::DoTcpConnectComplete(int result) {
261 DVLOG(1) << "DoTcpConnectComplete: " << result;
262 if (result == net::OK)
263 next_state_ = CONN_STATE_SSL_CONNECT;
267 int CastSocket::DoSslConnect() {
268 DVLOG(1) << "DoSslConnect";
269 next_state_ = CONN_STATE_SSL_CONNECT_COMPLETE;
270 socket_ = CreateSslSocket();
271 return socket_->Connect(
272 base::Bind(&CastSocket::OnConnectComplete, AsWeakPtr()));
275 int CastSocket::DoSslConnectComplete(int result) {
276 DVLOG(1) << "DoSslConnectComplete: " << result;
277 if (result == net::ERR_CERT_AUTHORITY_INVALID &&
278 peer_cert_.empty() &&
279 ExtractPeerCert(&peer_cert_)) {
280 next_state_ = CONN_STATE_TCP_CONNECT;
281 } else if (result == net::OK && auth_required_) {
282 next_state_ = CONN_STATE_AUTH_CHALLENGE_SEND;
287 int CastSocket::DoAuthChallengeSend() {
288 DVLOG(1) << "DoAuthChallengeSend";
289 next_state_ = CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE;
290 return SendAuthChallenge();
293 int CastSocket::DoAuthChallengeSendComplete(int result) {
294 DVLOG(1) << "DoAuthChallengeSendComplete: " << result;
295 if (result != net::OK)
297 next_state_ = CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE;
298 return ReadAuthChallengeReply();
301 int CastSocket::DoAuthChallengeReplyComplete(int result) {
302 DVLOG(1) << "DoAuthChallengeReplyComplete: " << result;
303 if (result != net::OK)
305 if (!VerifyChallengeReply())
306 return net::ERR_FAILED;
307 DVLOG(1) << "Auth challenge verification succeeded";
311 bool CastSocket::VerifyChallengeReply() {
312 return AuthenticateChallengeReply(*challenge_reply_.get(), peer_cert_);
315 void CastSocket::DoConnectCallback(int result) {
316 ready_state_ = (result == net::OK) ? READY_STATE_OPEN : READY_STATE_CLOSED;
317 error_state_ = (result == net::OK) ?
318 CHANNEL_ERROR_NONE : CHANNEL_ERROR_CONNECT_ERROR;
319 base::ResetAndReturn(&connect_callback_).Run(result);
320 // Start the ReadData loop if not already started.
321 // If auth_required_ is true we would've started a ReadData loop already.
322 // TODO(munjal): This is a bit ugly. Refactor read and write code.
323 if (result == net::OK && !auth_required_)
327 void CastSocket::Close(const net::CompletionCallback& callback) {
328 DCHECK(CalledOnValidThread());
329 DVLOG(1) << "Close ReadyState = " << ready_state_;
330 tcp_socket_.reset(NULL);
332 cert_verifier_.reset(NULL);
333 transport_security_state_.reset(NULL);
334 ready_state_ = READY_STATE_CLOSED;
335 callback.Run(net::OK);
338 void CastSocket::SendMessage(const MessageInfo& message,
339 const net::CompletionCallback& callback) {
340 DCHECK(CalledOnValidThread());
341 DVLOG(1) << "Send ReadyState " << ready_state_;
342 int result = net::ERR_FAILED;
343 if (ready_state_ != READY_STATE_OPEN) {
344 callback.Run(result);
347 CastMessage message_proto;
348 if (!MessageInfoToCastMessage(message, &message_proto)) {
349 CloseWithError(cast_channel::CHANNEL_ERROR_INVALID_MESSAGE);
350 // TODO(mfoltz): Do a better job of signaling cast_channel errors to the
352 callback.Run(net::OK);
355 SendMessageInternal(message_proto, callback);
358 int CastSocket::SendMessageInternal(const CastMessage& message_proto,
359 const net::CompletionCallback& callback) {
360 WriteRequest write_request(callback);
361 if (!write_request.SetContent(message_proto))
362 return net::ERR_FAILED;
363 write_queue_.push(write_request);
367 int CastSocket::WriteData() {
368 DCHECK(CalledOnValidThread());
369 DVLOG(1) << "WriteData q = " << write_queue_.size();
370 if (write_queue_.empty() || write_callback_pending_)
371 return net::ERR_FAILED;
373 WriteRequest& request = write_queue_.front();
375 DVLOG(1) << "WriteData byte_count = " << request.io_buffer->size() <<
376 " bytes_written " << request.io_buffer->BytesConsumed();
378 write_callback_pending_ = true;
379 int result = socket_->Write(
380 request.io_buffer.get(),
381 request.io_buffer->BytesRemaining(),
382 base::Bind(&CastSocket::OnWriteData, AsWeakPtr()));
384 if (result != net::ERR_IO_PENDING)
390 void CastSocket::OnWriteData(int result) {
391 DCHECK(CalledOnValidThread());
392 DVLOG(1) << "OnWriteComplete result = " << result;
393 DCHECK(write_callback_pending_);
394 DCHECK(!write_queue_.empty());
395 write_callback_pending_ = false;
396 WriteRequest& request = write_queue_.front();
397 scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer;
400 io_buffer->DidConsume(result);
401 if (io_buffer->BytesRemaining() > 0) {
402 DVLOG(1) << "OnWriteComplete size = " << io_buffer->size() <<
403 " consumed " << io_buffer->BytesConsumed() <<
404 " remaining " << io_buffer->BytesRemaining() <<
405 " # requests " << write_queue_.size();
409 DCHECK_EQ(io_buffer->BytesConsumed(), io_buffer->size());
410 DCHECK_EQ(io_buffer->BytesRemaining(), 0);
411 result = io_buffer->BytesConsumed();
414 request.callback.Run(result);
417 DVLOG(1) << "OnWriteComplete size = " << io_buffer->size() <<
418 " consumed " << io_buffer->BytesConsumed() <<
419 " remaining " << io_buffer->BytesRemaining() <<
420 " # requests " << write_queue_.size();
423 CloseWithError(CHANNEL_ERROR_SOCKET_ERROR);
427 if (!write_queue_.empty())
431 int CastSocket::ReadData() {
432 DCHECK(CalledOnValidThread());
434 return net::ERR_FAILED;
435 DCHECK(!read_callback_pending_);
436 read_callback_pending_ = true;
437 // Figure out if we are reading the header or body, and the remaining bytes.
438 uint32 num_bytes_to_read = 0;
439 if (header_read_buffer_->RemainingCapacity() > 0) {
440 current_read_buffer_ = header_read_buffer_;
441 num_bytes_to_read = header_read_buffer_->RemainingCapacity();
442 DCHECK_LE(num_bytes_to_read, kMessageHeaderSize);
444 DCHECK_GT(current_message_size_, 0U);
445 num_bytes_to_read = current_message_size_ - body_read_buffer_->offset();
446 current_read_buffer_ = body_read_buffer_;
447 DCHECK_LE(num_bytes_to_read, kMaxMessageSize);
449 DCHECK_GT(num_bytes_to_read, 0U);
450 // We read up to num_bytes_to_read into |current_read_buffer_|.
451 int result = socket_->Read(
452 current_read_buffer_.get(),
454 base::Bind(&CastSocket::OnReadData, AsWeakPtr()));
455 DVLOG(1) << "ReadData result = " << result;
458 } else if (result != net::ERR_IO_PENDING) {
459 CloseWithError(CHANNEL_ERROR_SOCKET_ERROR);
464 void CastSocket::OnReadData(int result) {
465 DCHECK(CalledOnValidThread());
466 DVLOG(1) << "OnReadData result = " << result
467 << " header offset = " << header_read_buffer_->offset()
468 << " body offset = " << body_read_buffer_->offset();
469 read_callback_pending_ = false;
471 CloseWithError(CHANNEL_ERROR_SOCKET_ERROR);
474 // We read some data. Move the offset in the current buffer forward.
475 DCHECK_LE(current_read_buffer_->offset() + result,
476 current_read_buffer_->capacity());
477 current_read_buffer_->set_offset(current_read_buffer_->offset() + result);
479 bool should_continue = true;
480 if (current_read_buffer_.get() == header_read_buffer_.get() &&
481 current_read_buffer_->RemainingCapacity() == 0) {
482 // If we have read a full header, process the contents.
483 should_continue = ProcessHeader();
484 } else if (current_read_buffer_.get() == body_read_buffer_.get() &&
485 static_cast<uint32>(current_read_buffer_->offset()) ==
486 current_message_size_) {
487 // If we have read a full body, process the contents.
488 should_continue = ProcessBody();
494 bool CastSocket::ProcessHeader() {
495 DCHECK_EQ(static_cast<uint32>(header_read_buffer_->offset()),
497 MessageHeader header;
498 MessageHeader::ReadFromIOBuffer(header_read_buffer_.get(), &header);
499 if (header.message_size > kMaxMessageSize) {
500 CloseWithError(cast_channel::CHANNEL_ERROR_INVALID_MESSAGE);
503 DVLOG(1) << "Parsed header { message_size: " << header.message_size << " }";
504 current_message_size_ = header.message_size;
508 bool CastSocket::ProcessBody() {
509 DCHECK_EQ(static_cast<uint32>(body_read_buffer_->offset()),
510 current_message_size_);
511 if (!ParseMessageFromBody()) {
512 CloseWithError(cast_channel::CHANNEL_ERROR_INVALID_MESSAGE);
515 current_message_size_ = 0;
516 header_read_buffer_->set_offset(0);
517 body_read_buffer_->set_offset(0);
518 current_read_buffer_ = header_read_buffer_;
522 bool CastSocket::ParseMessageFromBody() {
523 DCHECK(CalledOnValidThread());
524 DCHECK_EQ(static_cast<uint32>(body_read_buffer_->offset()),
525 current_message_size_);
526 CastMessage message_proto;
527 if (!message_proto.ParseFromArray(
528 body_read_buffer_->StartOfBuffer(),
529 current_message_size_))
531 DVLOG(1) << "Parsed message " << CastMessageToString(message_proto);
532 // If the message is an auth message then we handle it internally.
533 if (IsAuthMessage(message_proto)) {
534 challenge_reply_.reset(new CastMessage(message_proto));
535 OnChallengeEvent(net::OK);
536 } else if (delegate_) {
538 if (!CastMessageToMessageInfo(message_proto, &message))
540 delegate_->OnMessage(this, message);
546 bool CastSocket::Serialize(const CastMessage& message_proto,
547 std::string* message_data) {
548 DCHECK(message_data);
549 message_proto.SerializeToString(message_data);
550 size_t message_size = message_data->size();
551 if (message_size > kMaxMessageSize) {
552 message_data->clear();
555 CastSocket::MessageHeader header;
556 header.SetMessageSize(message_size);
557 header.PrependToString(message_data);
561 void CastSocket::CloseWithError(ChannelError error) {
562 DCHECK(CalledOnValidThread());
564 ready_state_ = READY_STATE_CLOSED;
565 error_state_ = error;
567 delegate_->OnError(this, error);
570 bool CastSocket::ParseChannelUrl(const GURL& url) {
571 DVLOG(1) << "url = " + url.spec();
572 if (url.SchemeIs(kCastInsecureScheme)) {
573 auth_required_ = false;
574 } else if (url.SchemeIs(kCastSecureScheme)) {
575 auth_required_ = true;
579 // TODO(mfoltz): Manual parsing, yech. Register cast[s] as standard schemes?
580 // TODO(mfoltz): Test for IPv6 addresses. Brackets or no brackets?
581 // TODO(mfoltz): Maybe enforce restriction to IPv4 private and IPv6 link-local
583 const std::string& path = url.path();
584 // Shortest possible: //A:B
585 if (path.size() < 5) {
588 if (path.find("//") != 0) {
591 size_t colon = path.find_last_of(':');
592 if (colon == std::string::npos || colon < 3 || colon > path.size() - 2) {
595 const std::string& ip_address_str = path.substr(2, colon - 2);
596 const std::string& port_str = path.substr(colon + 1);
597 DVLOG(1) << "addr " << ip_address_str << " port " << port_str;
599 if (!base::StringToInt(port_str, &port))
601 net::IPAddressNumber ip_address;
602 if (!net::ParseIPLiteralToNumber(ip_address_str, &ip_address))
604 ip_endpoint_ = net::IPEndPoint(ip_address, port);
608 void CastSocket::FillChannelInfo(ChannelInfo* channel_info) const {
609 DCHECK(CalledOnValidThread());
610 channel_info->channel_id = channel_id_;
611 channel_info->url = url_.spec();
612 channel_info->ready_state = ready_state_;
613 channel_info->error_state = error_state_;
616 bool CastSocket::CalledOnValidThread() const {
617 return thread_checker_.CalledOnValidThread();
620 CastSocket::MessageHeader::MessageHeader() : message_size(0) { }
622 void CastSocket::MessageHeader::SetMessageSize(size_t size) {
623 DCHECK(size < static_cast<size_t>(kuint32max));
625 message_size = static_cast<size_t>(size);
628 void CastSocket::MessageHeader::PrependToString(std::string* str) {
629 MessageHeader output = *this;
630 output.message_size = base::HostToNet32(message_size);
631 char char_array[kMessageHeaderSize];
632 memcpy(&char_array, &output, arraysize(char_array));
633 str->insert(0, char_array, arraysize(char_array));
636 void CastSocket::MessageHeader::ReadFromIOBuffer(
637 net::GrowableIOBuffer* buffer, MessageHeader* header) {
639 memcpy(&message_size, buffer->StartOfBuffer(), kMessageHeaderSize);
640 header->message_size = base::NetToHost32(message_size);
643 std::string CastSocket::MessageHeader::ToString() {
644 return "{message_size: " + base::UintToString(message_size) + "}";
647 CastSocket::WriteRequest::WriteRequest(const net::CompletionCallback& callback)
648 : callback(callback) { }
650 bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) {
651 DCHECK(!io_buffer.get());
652 std::string message_data;
653 if (!Serialize(message_proto, &message_data))
655 io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(message_data),
656 message_data.size());
660 CastSocket::WriteRequest::~WriteRequest() { }
662 } // namespace cast_channel
664 } // namespace extensions