1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
7 #include "base/message_loop/message_loop.h"
8 #include "google/protobuf/io/coded_stream.h"
9 #include "google_apis/gcm/base/mcs_util.h"
10 #include "google_apis/gcm/base/socket_stream.h"
11 #include "google_apis/gcm/protocol/mcs.pb.h"
12 #include "net/base/net_errors.h"
13 #include "net/socket/stream_socket.h"
15 using namespace google::protobuf::io;
21 // # of bytes a MCS version packet consumes.
22 const int kVersionPacketLen = 1;
23 // # of bytes a tag packet consumes.
24 const int kTagPacketLen = 1;
25 // Max # of bytes a length packet consumes. A Varint32 can consume up to 5 bytes
26 // (the MSB in each byte is reserved for denoting whether more bytes follow).
27 // But, the protocol only allows for 4KiB payloads, and the socket stream buffer
28 // is only of size 8KiB. As such we should never need more than 2 bytes (max
29 // value of 16KiB). Anything higher than that will result in an error, either
30 // because the socket stream buffer overflowed or too many bytes were required
31 // in the size packet.
32 const int kSizePacketLenMin = 1;
33 const int kSizePacketLenMax = 2;
35 // The current MCS protocol version.
36 const int kMCSVersion = 41;
40 ConnectionHandlerImpl::ConnectionHandlerImpl(
41 base::TimeDelta read_timeout,
42 const ProtoReceivedCallback& read_callback,
43 const ProtoSentCallback& write_callback,
44 const ConnectionChangedCallback& connection_callback)
45 : read_timeout_(read_timeout),
47 handshake_complete_(false),
50 read_callback_(read_callback),
51 write_callback_(write_callback),
52 connection_callback_(connection_callback),
53 weak_ptr_factory_(this) {
56 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
59 void ConnectionHandlerImpl::Init(
60 const mcs_proto::LoginRequest& login_request,
61 net::StreamSocket* socket) {
62 DCHECK(!read_callback_.is_null());
63 DCHECK(!write_callback_.is_null());
64 DCHECK(!connection_callback_.is_null());
66 // Invalidate any previously outstanding reads.
67 weak_ptr_factory_.InvalidateWeakPtrs();
69 handshake_complete_ = false;
73 input_stream_.reset(new SocketInputStream(socket_));
74 output_stream_.reset(new SocketOutputStream(socket_));
79 void ConnectionHandlerImpl::Reset() {
83 bool ConnectionHandlerImpl::CanSendMessage() const {
84 return handshake_complete_ && output_stream_.get() &&
85 output_stream_->GetState() == SocketOutputStream::EMPTY;
88 void ConnectionHandlerImpl::SendMessage(
89 const google::protobuf::MessageLite& message) {
90 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
91 DCHECK(handshake_complete_);
94 CodedOutputStream coded_output_stream(output_stream_.get());
95 DVLOG(1) << "Writing proto of size " << message.ByteSize();
96 int tag = GetMCSProtoTag(message);
98 coded_output_stream.WriteRaw(&tag, 1);
99 coded_output_stream.WriteVarint32(message.ByteSize());
100 message.SerializeToCodedStream(&coded_output_stream);
103 if (output_stream_->Flush(
104 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
105 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
110 void ConnectionHandlerImpl::Login(
111 const google::protobuf::MessageLite& login_request) {
112 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
114 const char version_byte[1] = {kMCSVersion};
115 const char login_request_tag[1] = {kLoginRequestTag};
117 CodedOutputStream coded_output_stream(output_stream_.get());
118 coded_output_stream.WriteRaw(version_byte, 1);
119 coded_output_stream.WriteRaw(login_request_tag, 1);
120 coded_output_stream.WriteVarint32(login_request.ByteSize());
121 login_request.SerializeToCodedStream(&coded_output_stream);
124 if (output_stream_->Flush(
125 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
126 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
127 base::MessageLoop::current()->PostTask(
129 base::Bind(&ConnectionHandlerImpl::OnMessageSent,
130 weak_ptr_factory_.GetWeakPtr()));
133 read_timeout_timer_.Start(FROM_HERE,
135 base::Bind(&ConnectionHandlerImpl::OnTimeout,
136 weak_ptr_factory_.GetWeakPtr()));
137 WaitForData(MCS_VERSION_TAG_AND_SIZE);
140 void ConnectionHandlerImpl::OnMessageSent() {
141 if (!output_stream_.get()) {
142 // The connection has already been closed. Just return.
143 DCHECK(!input_stream_.get());
144 DCHECK(!read_timeout_timer_.IsRunning());
148 if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
149 int last_error = output_stream_->last_error();
151 // If the socket stream had an error, plumb it up, else plumb up FAILED.
152 if (last_error == net::OK)
153 last_error = net::ERR_FAILED;
154 connection_callback_.Run(last_error);
158 write_callback_.Run();
161 void ConnectionHandlerImpl::GetNextMessage() {
162 DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
163 SocketInputStream::READY == input_stream_->GetState());
167 WaitForData(MCS_TAG_AND_SIZE);
170 void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
171 DVLOG(1) << "Waiting for MCS data: state == " << state;
173 if (!input_stream_) {
174 // The connection has already been closed. Just return.
175 DCHECK(!output_stream_.get());
176 DCHECK(!read_timeout_timer_.IsRunning());
180 if (input_stream_->GetState() != SocketInputStream::EMPTY &&
181 input_stream_->GetState() != SocketInputStream::READY) {
182 // An error occurred.
183 int last_error = output_stream_->last_error();
185 // If the socket stream had an error, plumb it up, else plumb up FAILED.
186 if (last_error == net::OK)
187 last_error = net::ERR_FAILED;
188 connection_callback_.Run(last_error);
192 // Used to determine whether a Socket::Read is necessary.
193 int min_bytes_needed = 0;
194 // Used to limit the size of the Socket::Read.
195 int max_bytes_needed = 0;
198 case MCS_VERSION_TAG_AND_SIZE:
199 min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
200 max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
202 case MCS_TAG_AND_SIZE:
203 min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
204 max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
207 // If in this state, the minimum size packet length must already have been
208 // insufficient, so set both to the max length.
209 min_bytes_needed = kSizePacketLenMax;
210 max_bytes_needed = kSizePacketLenMax;
212 case MCS_PROTO_BYTES:
213 read_timeout_timer_.Reset();
214 // No variability in the message size, set both to the same.
215 min_bytes_needed = message_size_;
216 max_bytes_needed = message_size_;
221 DCHECK_GE(max_bytes_needed, min_bytes_needed);
223 int unread_byte_count = input_stream_->UnreadByteCount();
224 if (min_bytes_needed > unread_byte_count &&
225 input_stream_->Refresh(
226 base::Bind(&ConnectionHandlerImpl::WaitForData,
227 weak_ptr_factory_.GetWeakPtr(),
229 max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) {
233 // Check for refresh errors.
234 if (input_stream_->GetState() != SocketInputStream::READY) {
235 // An error occurred.
236 int last_error = input_stream_->last_error();
238 // If the socket stream had an error, plumb it up, else plumb up FAILED.
239 if (last_error == net::OK)
240 last_error = net::ERR_FAILED;
241 connection_callback_.Run(last_error);
245 // Check whether read is complete, or needs to be continued (
246 // SocketInputStream::Refresh can finish without reading all the data).
247 if (input_stream_->UnreadByteCount() < min_bytes_needed) {
248 DVLOG(1) << "Socket read finished prematurely. Waiting for "
249 << min_bytes_needed - input_stream_->UnreadByteCount()
251 base::MessageLoop::current()->PostTask(
253 base::Bind(&ConnectionHandlerImpl::WaitForData,
254 weak_ptr_factory_.GetWeakPtr(),
259 // Received enough bytes, process them.
260 DVLOG(1) << "Processing MCS data: state == " << state;
262 case MCS_VERSION_TAG_AND_SIZE:
265 case MCS_TAG_AND_SIZE:
271 case MCS_PROTO_BYTES:
279 void ConnectionHandlerImpl::OnGotVersion() {
282 CodedInputStream coded_input_stream(input_stream_.get());
283 coded_input_stream.ReadRaw(&version, 1);
285 // TODO(zea): remove this when the server is ready.
286 if (version < kMCSVersion && version != 38) {
287 LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
288 connection_callback_.Run(net::ERR_FAILED);
292 input_stream_->RebuildBuffer();
294 // Process the LoginResponse message tag.
298 void ConnectionHandlerImpl::OnGotMessageTag() {
299 if (input_stream_->GetState() != SocketInputStream::READY) {
300 LOG(ERROR) << "Failed to receive protobuf tag.";
301 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
306 CodedInputStream coded_input_stream(input_stream_.get());
307 coded_input_stream.ReadRaw(&message_tag_, 1);
310 DVLOG(1) << "Received proto of type "
311 << static_cast<unsigned int>(message_tag_);
313 if (!read_timeout_timer_.IsRunning()) {
314 read_timeout_timer_.Start(FROM_HERE,
316 base::Bind(&ConnectionHandlerImpl::OnTimeout,
317 weak_ptr_factory_.GetWeakPtr()));
322 void ConnectionHandlerImpl::OnGotMessageSize() {
323 if (input_stream_->GetState() != SocketInputStream::READY) {
324 LOG(ERROR) << "Failed to receive message size.";
325 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
329 bool need_another_byte = false;
330 int prev_byte_count = input_stream_->UnreadByteCount();
332 CodedInputStream coded_input_stream(input_stream_.get());
333 if (!coded_input_stream.ReadVarint32(&message_size_))
334 need_another_byte = true;
337 if (need_another_byte) {
338 DVLOG(1) << "Expecting another message size byte.";
339 if (prev_byte_count >= kSizePacketLenMax) {
340 // Already had enough bytes, something else went wrong.
341 LOG(ERROR) << "Failed to process message size, too many bytes needed.";
342 connection_callback_.Run(net::ERR_FILE_TOO_BIG);
345 // Back up by the amount read (should always be 1 byte).
346 int bytes_read = prev_byte_count - input_stream_->UnreadByteCount();
347 DCHECK_EQ(bytes_read, 1);
348 input_stream_->BackUp(bytes_read);
349 WaitForData(MCS_FULL_SIZE);
353 DVLOG(1) << "Proto size: " << message_size_;
355 if (message_size_ > 0)
356 WaitForData(MCS_PROTO_BYTES);
361 void ConnectionHandlerImpl::OnGotMessageBytes() {
362 read_timeout_timer_.Stop();
363 scoped_ptr<google::protobuf::MessageLite> protobuf(
364 BuildProtobufFromTag(message_tag_));
365 // Messages with no content are valid; just use the default protobuf for
367 if (protobuf.get() && message_size_ == 0) {
368 base::MessageLoop::current()->PostTask(
370 base::Bind(&ConnectionHandlerImpl::GetNextMessage,
371 weak_ptr_factory_.GetWeakPtr()));
372 read_callback_.Run(protobuf.Pass());
376 if (input_stream_->GetState() != SocketInputStream::READY) {
377 LOG(ERROR) << "Failed to extract protobuf bytes of type "
378 << static_cast<unsigned int>(message_tag_);
379 // Reset the connection.
380 connection_callback_.Run(net::ERR_FAILED);
384 if (!protobuf.get()) {
385 LOG(ERROR) << "Received message of invalid type "
386 << static_cast<unsigned int>(message_tag_);
387 connection_callback_.Run(net::ERR_INVALID_ARGUMENT);
392 CodedInputStream coded_input_stream(input_stream_.get());
393 if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
394 LOG(ERROR) << "Unable to parse GCM message of type "
395 << static_cast<unsigned int>(message_tag_);
396 // Reset the connection.
397 connection_callback_.Run(net::ERR_FAILED);
402 input_stream_->RebuildBuffer();
403 base::MessageLoop::current()->PostTask(
405 base::Bind(&ConnectionHandlerImpl::GetNextMessage,
406 weak_ptr_factory_.GetWeakPtr()));
407 if (message_tag_ == kLoginResponseTag) {
408 if (handshake_complete_) {
409 LOG(ERROR) << "Unexpected login response.";
411 handshake_complete_ = true;
412 DVLOG(1) << "GCM Handshake complete.";
413 connection_callback_.Run(net::OK);
416 read_callback_.Run(protobuf.Pass());
419 void ConnectionHandlerImpl::OnTimeout() {
420 LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
422 connection_callback_.Run(net::ERR_TIMED_OUT);
425 void ConnectionHandlerImpl::CloseConnection() {
426 DVLOG(1) << "Closing connection.";
427 read_timeout_timer_.Stop();
429 socket_->Disconnect();
431 handshake_complete_ = false;
434 input_stream_.reset();
435 output_stream_.reset();
436 weak_ptr_factory_.InvalidateWeakPtrs();