Upstream version 11.39.250.0
[platform/framework/web/crosswalk.git] / src / google_apis / gcm / engine / connection_handler_impl.cc
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
6
7 #include "base/message_loop/message_loop.h"
8 #include "google/protobuf/io/coded_stream.h"
9 #include "google_apis/gcm/base/mcs_util.h"
10 #include "google_apis/gcm/base/socket_stream.h"
11 #include "google_apis/gcm/protocol/mcs.pb.h"
12 #include "net/base/net_errors.h"
13 #include "net/socket/stream_socket.h"
14
15 using namespace google::protobuf::io;
16
17 namespace gcm {
18
19 namespace {
20
21 // # of bytes a MCS version packet consumes.
22 const int kVersionPacketLen = 1;
23 // # of bytes a tag packet consumes.
24 const int kTagPacketLen = 1;
25 // Max # of bytes a length packet consumes. A Varint32 can consume up to 5 bytes
26 // (the MSB in each byte is reserved for denoting whether more bytes follow).
27 // But, the protocol only allows for 4KiB payloads, and the socket stream buffer
28 // is only of size 8KiB. As such we should never need more than 2 bytes (max
29 // value of 16KiB). Anything higher than that will result in an error, either
30 // because the socket stream buffer overflowed or too many bytes were required
31 // in the size packet.
32 const int kSizePacketLenMin = 1;
33 const int kSizePacketLenMax = 2;
34
35 // The current MCS protocol version.
36 const int kMCSVersion = 41;
37
38 }  // namespace
39
40 ConnectionHandlerImpl::ConnectionHandlerImpl(
41     base::TimeDelta read_timeout,
42     const ProtoReceivedCallback& read_callback,
43     const ProtoSentCallback& write_callback,
44     const ConnectionChangedCallback& connection_callback)
45     : read_timeout_(read_timeout),
46       socket_(NULL),
47       handshake_complete_(false),
48       message_tag_(0),
49       message_size_(0),
50       read_callback_(read_callback),
51       write_callback_(write_callback),
52       connection_callback_(connection_callback),
53       weak_ptr_factory_(this) {
54 }
55
56 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
57 }
58
59 void ConnectionHandlerImpl::Init(
60     const mcs_proto::LoginRequest& login_request,
61     net::StreamSocket* socket) {
62   DCHECK(!read_callback_.is_null());
63   DCHECK(!write_callback_.is_null());
64   DCHECK(!connection_callback_.is_null());
65
66   // Invalidate any previously outstanding reads.
67   weak_ptr_factory_.InvalidateWeakPtrs();
68
69   handshake_complete_ = false;
70   message_tag_ = 0;
71   message_size_ = 0;
72   socket_ = socket;
73   input_stream_.reset(new SocketInputStream(socket_));
74   output_stream_.reset(new SocketOutputStream(socket_));
75
76   Login(login_request);
77 }
78
79 void ConnectionHandlerImpl::Reset() {
80   CloseConnection();
81 }
82
83 bool ConnectionHandlerImpl::CanSendMessage() const {
84   return handshake_complete_ && output_stream_.get() &&
85       output_stream_->GetState() == SocketOutputStream::EMPTY;
86 }
87
88 void ConnectionHandlerImpl::SendMessage(
89     const google::protobuf::MessageLite& message) {
90   DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
91   DCHECK(handshake_complete_);
92
93   {
94     CodedOutputStream coded_output_stream(output_stream_.get());
95     DVLOG(1) << "Writing proto of size " << message.ByteSize();
96     int tag = GetMCSProtoTag(message);
97     DCHECK_NE(tag, -1);
98     coded_output_stream.WriteRaw(&tag, 1);
99     coded_output_stream.WriteVarint32(message.ByteSize());
100     message.SerializeToCodedStream(&coded_output_stream);
101   }
102
103   if (output_stream_->Flush(
104           base::Bind(&ConnectionHandlerImpl::OnMessageSent,
105                      weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
106     OnMessageSent();
107   }
108 }
109
110 void ConnectionHandlerImpl::Login(
111     const google::protobuf::MessageLite& login_request) {
112   DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
113
114   const char version_byte[1] = {kMCSVersion};
115   const char login_request_tag[1] = {kLoginRequestTag};
116   {
117     CodedOutputStream coded_output_stream(output_stream_.get());
118     coded_output_stream.WriteRaw(version_byte, 1);
119     coded_output_stream.WriteRaw(login_request_tag, 1);
120     coded_output_stream.WriteVarint32(login_request.ByteSize());
121     login_request.SerializeToCodedStream(&coded_output_stream);
122   }
123
124   if (output_stream_->Flush(
125           base::Bind(&ConnectionHandlerImpl::OnMessageSent,
126                      weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
127     base::MessageLoop::current()->PostTask(
128         FROM_HERE,
129         base::Bind(&ConnectionHandlerImpl::OnMessageSent,
130                    weak_ptr_factory_.GetWeakPtr()));
131   }
132
133   read_timeout_timer_.Start(FROM_HERE,
134                             read_timeout_,
135                             base::Bind(&ConnectionHandlerImpl::OnTimeout,
136                                        weak_ptr_factory_.GetWeakPtr()));
137   WaitForData(MCS_VERSION_TAG_AND_SIZE);
138 }
139
140 void ConnectionHandlerImpl::OnMessageSent() {
141   if (!output_stream_.get()) {
142     // The connection has already been closed. Just return.
143     DCHECK(!input_stream_.get());
144     DCHECK(!read_timeout_timer_.IsRunning());
145     return;
146   }
147
148   if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
149     int last_error = output_stream_->last_error();
150     CloseConnection();
151     // If the socket stream had an error, plumb it up, else plumb up FAILED.
152     if (last_error == net::OK)
153       last_error = net::ERR_FAILED;
154     connection_callback_.Run(last_error);
155     return;
156   }
157
158   write_callback_.Run();
159 }
160
161 void ConnectionHandlerImpl::GetNextMessage() {
162   DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
163          SocketInputStream::READY == input_stream_->GetState());
164   message_tag_ = 0;
165   message_size_ = 0;
166
167   WaitForData(MCS_TAG_AND_SIZE);
168 }
169
170 void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
171   DVLOG(1) << "Waiting for MCS data: state == " << state;
172
173   if (!input_stream_) {
174     // The connection has already been closed. Just return.
175     DCHECK(!output_stream_.get());
176     DCHECK(!read_timeout_timer_.IsRunning());
177     return;
178   }
179
180   if (input_stream_->GetState() != SocketInputStream::EMPTY &&
181       input_stream_->GetState() != SocketInputStream::READY) {
182     // An error occurred.
183     int last_error = output_stream_->last_error();
184     CloseConnection();
185     // If the socket stream had an error, plumb it up, else plumb up FAILED.
186     if (last_error == net::OK)
187       last_error = net::ERR_FAILED;
188     connection_callback_.Run(last_error);
189     return;
190   }
191
192   // Used to determine whether a Socket::Read is necessary.
193   int min_bytes_needed = 0;
194   // Used to limit the size of the Socket::Read.
195   int max_bytes_needed = 0;
196
197   switch(state) {
198     case MCS_VERSION_TAG_AND_SIZE:
199       min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
200       max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
201       break;
202     case MCS_TAG_AND_SIZE:
203       min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
204       max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
205       break;
206     case MCS_FULL_SIZE:
207       // If in this state, the minimum size packet length must already have been
208       // insufficient, so set both to the max length.
209       min_bytes_needed = kSizePacketLenMax;
210       max_bytes_needed = kSizePacketLenMax;
211       break;
212     case MCS_PROTO_BYTES:
213       read_timeout_timer_.Reset();
214       // No variability in the message size, set both to the same.
215       min_bytes_needed = message_size_;
216       max_bytes_needed = message_size_;
217       break;
218     default:
219       NOTREACHED();
220   }
221   DCHECK_GE(max_bytes_needed, min_bytes_needed);
222
223   int unread_byte_count = input_stream_->UnreadByteCount();
224   if (min_bytes_needed > unread_byte_count &&
225       input_stream_->Refresh(
226           base::Bind(&ConnectionHandlerImpl::WaitForData,
227                      weak_ptr_factory_.GetWeakPtr(),
228                      state),
229           max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) {
230     return;
231   }
232
233   // Check for refresh errors.
234   if (input_stream_->GetState() != SocketInputStream::READY) {
235     // An error occurred.
236     int last_error = input_stream_->last_error();
237     CloseConnection();
238     // If the socket stream had an error, plumb it up, else plumb up FAILED.
239     if (last_error == net::OK)
240       last_error = net::ERR_FAILED;
241     connection_callback_.Run(last_error);
242     return;
243   }
244
245   // Check whether read is complete, or needs to be continued (
246   // SocketInputStream::Refresh can finish without reading all the data).
247   if (input_stream_->UnreadByteCount() < min_bytes_needed) {
248     DVLOG(1) << "Socket read finished prematurely. Waiting for "
249              << min_bytes_needed - input_stream_->UnreadByteCount()
250              << " more bytes.";
251     base::MessageLoop::current()->PostTask(
252         FROM_HERE,
253         base::Bind(&ConnectionHandlerImpl::WaitForData,
254                    weak_ptr_factory_.GetWeakPtr(),
255                    MCS_PROTO_BYTES));
256     return;
257   }
258
259   // Received enough bytes, process them.
260   DVLOG(1) << "Processing MCS data: state == " << state;
261   switch(state) {
262     case MCS_VERSION_TAG_AND_SIZE:
263       OnGotVersion();
264       break;
265     case MCS_TAG_AND_SIZE:
266       OnGotMessageTag();
267       break;
268     case MCS_FULL_SIZE:
269       OnGotMessageSize();
270       break;
271     case MCS_PROTO_BYTES:
272       OnGotMessageBytes();
273       break;
274     default:
275       NOTREACHED();
276   }
277 }
278
279 void ConnectionHandlerImpl::OnGotVersion() {
280   uint8 version = 0;
281   {
282     CodedInputStream coded_input_stream(input_stream_.get());
283     coded_input_stream.ReadRaw(&version, 1);
284   }
285   // TODO(zea): remove this when the server is ready.
286   if (version < kMCSVersion && version != 38) {
287     LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
288     connection_callback_.Run(net::ERR_FAILED);
289     return;
290   }
291
292   input_stream_->RebuildBuffer();
293
294   // Process the LoginResponse message tag.
295   OnGotMessageTag();
296 }
297
298 void ConnectionHandlerImpl::OnGotMessageTag() {
299   if (input_stream_->GetState() != SocketInputStream::READY) {
300     LOG(ERROR) << "Failed to receive protobuf tag.";
301     read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
302     return;
303   }
304
305   {
306     CodedInputStream coded_input_stream(input_stream_.get());
307     coded_input_stream.ReadRaw(&message_tag_, 1);
308   }
309
310   DVLOG(1) << "Received proto of type "
311            << static_cast<unsigned int>(message_tag_);
312
313   if (!read_timeout_timer_.IsRunning()) {
314     read_timeout_timer_.Start(FROM_HERE,
315                               read_timeout_,
316                               base::Bind(&ConnectionHandlerImpl::OnTimeout,
317                                          weak_ptr_factory_.GetWeakPtr()));
318   }
319   OnGotMessageSize();
320 }
321
322 void ConnectionHandlerImpl::OnGotMessageSize() {
323   if (input_stream_->GetState() != SocketInputStream::READY) {
324     LOG(ERROR) << "Failed to receive message size.";
325     read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
326     return;
327   }
328
329   bool need_another_byte = false;
330   int prev_byte_count = input_stream_->UnreadByteCount();
331   {
332     CodedInputStream coded_input_stream(input_stream_.get());
333     if (!coded_input_stream.ReadVarint32(&message_size_))
334       need_another_byte = true;
335   }
336
337   if (need_another_byte) {
338     DVLOG(1) << "Expecting another message size byte.";
339     if (prev_byte_count >= kSizePacketLenMax) {
340       // Already had enough bytes, something else went wrong.
341       LOG(ERROR) << "Failed to process message size, too many bytes needed.";
342       connection_callback_.Run(net::ERR_FILE_TOO_BIG);
343       return;
344     }
345     // Back up by the amount read (should always be 1 byte).
346     int bytes_read = prev_byte_count - input_stream_->UnreadByteCount();
347     DCHECK_EQ(bytes_read, 1);
348     input_stream_->BackUp(bytes_read);
349     WaitForData(MCS_FULL_SIZE);
350     return;
351   }
352
353   DVLOG(1) << "Proto size: " << message_size_;
354
355   if (message_size_ > 0)
356     WaitForData(MCS_PROTO_BYTES);
357   else
358     OnGotMessageBytes();
359 }
360
361 void ConnectionHandlerImpl::OnGotMessageBytes() {
362   read_timeout_timer_.Stop();
363   scoped_ptr<google::protobuf::MessageLite> protobuf(
364       BuildProtobufFromTag(message_tag_));
365   // Messages with no content are valid; just use the default protobuf for
366   // that tag.
367   if (protobuf.get() && message_size_ == 0) {
368     base::MessageLoop::current()->PostTask(
369         FROM_HERE,
370         base::Bind(&ConnectionHandlerImpl::GetNextMessage,
371                    weak_ptr_factory_.GetWeakPtr()));
372     read_callback_.Run(protobuf.Pass());
373     return;
374   }
375
376   if (input_stream_->GetState() != SocketInputStream::READY) {
377     LOG(ERROR) << "Failed to extract protobuf bytes of type "
378                << static_cast<unsigned int>(message_tag_);
379     // Reset the connection.
380     connection_callback_.Run(net::ERR_FAILED);
381     return;
382   }
383
384   if (!protobuf.get()) {
385      LOG(ERROR) << "Received message of invalid type "
386                 << static_cast<unsigned int>(message_tag_);
387      connection_callback_.Run(net::ERR_INVALID_ARGUMENT);
388      return;
389    }
390
391   {
392     CodedInputStream coded_input_stream(input_stream_.get());
393     if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
394       LOG(ERROR) << "Unable to parse GCM message of type "
395                  << static_cast<unsigned int>(message_tag_);
396       // Reset the connection.
397       connection_callback_.Run(net::ERR_FAILED);
398       return;
399     }
400   }
401
402   input_stream_->RebuildBuffer();
403   base::MessageLoop::current()->PostTask(
404       FROM_HERE,
405       base::Bind(&ConnectionHandlerImpl::GetNextMessage,
406                  weak_ptr_factory_.GetWeakPtr()));
407   if (message_tag_ == kLoginResponseTag) {
408     if (handshake_complete_) {
409       LOG(ERROR) << "Unexpected login response.";
410     } else {
411       handshake_complete_ = true;
412       DVLOG(1) << "GCM Handshake complete.";
413       connection_callback_.Run(net::OK);
414     }
415   }
416   read_callback_.Run(protobuf.Pass());
417 }
418
419 void ConnectionHandlerImpl::OnTimeout() {
420   LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
421   CloseConnection();
422   connection_callback_.Run(net::ERR_TIMED_OUT);
423 }
424
425 void ConnectionHandlerImpl::CloseConnection() {
426   DVLOG(1) << "Closing connection.";
427   read_timeout_timer_.Stop();
428   if (socket_)
429     socket_->Disconnect();
430   socket_ = NULL;
431   handshake_complete_ = false;
432   message_tag_ = 0;
433   message_size_ = 0;
434   input_stream_.reset();
435   output_stream_.reset();
436   weak_ptr_factory_.InvalidateWeakPtrs();
437 }
438
439 }  // namespace gcm