Upstream version 10.39.225.0
[platform/framework/web/crosswalk.git] / src / remoting / protocol / channel_multiplexer.cc
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "remoting/protocol/channel_multiplexer.h"
6
7 #include <string.h>
8
9 #include "base/bind.h"
10 #include "base/callback.h"
11 #include "base/location.h"
12 #include "base/single_thread_task_runner.h"
13 #include "base/stl_util.h"
14 #include "base/thread_task_runner_handle.h"
15 #include "net/base/net_errors.h"
16 #include "net/socket/stream_socket.h"
17 #include "remoting/protocol/message_serialization.h"
18
19 namespace remoting {
20 namespace protocol {
21
22 namespace {
23 const int kChannelIdUnknown = -1;
24 const int kMaxPacketSize = 1024;
25
26 class PendingPacket {
27  public:
28   PendingPacket(scoped_ptr<MultiplexPacket> packet,
29                 const base::Closure& done_task)
30       : packet(packet.Pass()),
31         done_task(done_task),
32         pos(0U) {
33   }
34   ~PendingPacket() {
35     done_task.Run();
36   }
37
38   bool is_empty() { return pos >= packet->data().size(); }
39
40   int Read(char* buffer, size_t size) {
41     size = std::min(size, packet->data().size() - pos);
42     memcpy(buffer, packet->data().data() + pos, size);
43     pos += size;
44     return size;
45   }
46
47  private:
48   scoped_ptr<MultiplexPacket> packet;
49   base::Closure done_task;
50   size_t pos;
51
52   DISALLOW_COPY_AND_ASSIGN(PendingPacket);
53 };
54
55 }  // namespace
56
57 const char ChannelMultiplexer::kMuxChannelName[] = "mux";
58
59 struct ChannelMultiplexer::PendingChannel {
60   PendingChannel(const std::string& name,
61                  const ChannelCreatedCallback& callback)
62       : name(name), callback(callback) {
63   }
64   std::string name;
65   ChannelCreatedCallback callback;
66 };
67
68 class ChannelMultiplexer::MuxChannel {
69  public:
70   MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
71              int send_id);
72   ~MuxChannel();
73
74   const std::string& name() { return name_; }
75   int receive_id() { return receive_id_; }
76   void set_receive_id(int id) { receive_id_ = id; }
77
78   // Called by ChannelMultiplexer.
79   scoped_ptr<net::StreamSocket> CreateSocket();
80   void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
81                         const base::Closure& done_task);
82   void OnWriteFailed();
83
84   // Called by MuxSocket.
85   void OnSocketDestroyed();
86   bool DoWrite(scoped_ptr<MultiplexPacket> packet,
87                const base::Closure& done_task);
88   int DoRead(net::IOBuffer* buffer, int buffer_len);
89
90  private:
91   ChannelMultiplexer* multiplexer_;
92   std::string name_;
93   int send_id_;
94   bool id_sent_;
95   int receive_id_;
96   MuxSocket* socket_;
97   std::list<PendingPacket*> pending_packets_;
98
99   DISALLOW_COPY_AND_ASSIGN(MuxChannel);
100 };
101
102 class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
103                                       public base::NonThreadSafe,
104                                       public base::SupportsWeakPtr<MuxSocket> {
105  public:
106   MuxSocket(MuxChannel* channel);
107   virtual ~MuxSocket();
108
109   void OnWriteComplete();
110   void OnWriteFailed();
111   void OnPacketReceived();
112
113   // net::StreamSocket interface.
114   virtual int Read(net::IOBuffer* buffer, int buffer_len,
115                    const net::CompletionCallback& callback) OVERRIDE;
116   virtual int Write(net::IOBuffer* buffer, int buffer_len,
117                     const net::CompletionCallback& callback) OVERRIDE;
118
119   virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
120     NOTIMPLEMENTED();
121     return net::ERR_NOT_IMPLEMENTED;
122   }
123   virtual int SetSendBufferSize(int32 size) OVERRIDE {
124     NOTIMPLEMENTED();
125     return net::ERR_NOT_IMPLEMENTED;
126   }
127
128   virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
129     NOTIMPLEMENTED();
130     return net::ERR_NOT_IMPLEMENTED;
131   }
132   virtual void Disconnect() OVERRIDE {
133     NOTIMPLEMENTED();
134   }
135   virtual bool IsConnected() const OVERRIDE {
136     NOTIMPLEMENTED();
137     return true;
138   }
139   virtual bool IsConnectedAndIdle() const OVERRIDE {
140     NOTIMPLEMENTED();
141     return false;
142   }
143   virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
144     NOTIMPLEMENTED();
145     return net::ERR_NOT_IMPLEMENTED;
146   }
147   virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
148     NOTIMPLEMENTED();
149     return net::ERR_NOT_IMPLEMENTED;
150   }
151   virtual const net::BoundNetLog& NetLog() const OVERRIDE {
152     NOTIMPLEMENTED();
153     return net_log_;
154   }
155   virtual void SetSubresourceSpeculation() OVERRIDE {
156     NOTIMPLEMENTED();
157   }
158   virtual void SetOmniboxSpeculation() OVERRIDE {
159     NOTIMPLEMENTED();
160   }
161   virtual bool WasEverUsed() const OVERRIDE {
162     return true;
163   }
164   virtual bool UsingTCPFastOpen() const OVERRIDE {
165     return false;
166   }
167   virtual bool WasNpnNegotiated() const OVERRIDE {
168     return false;
169   }
170   virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
171     return net::kProtoUnknown;
172   }
173   virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
174     NOTIMPLEMENTED();
175     return false;
176   }
177
178  private:
179   MuxChannel* channel_;
180
181   net::CompletionCallback read_callback_;
182   scoped_refptr<net::IOBuffer> read_buffer_;
183   int read_buffer_size_;
184
185   bool write_pending_;
186   int write_result_;
187   net::CompletionCallback write_callback_;
188
189   net::BoundNetLog net_log_;
190
191   DISALLOW_COPY_AND_ASSIGN(MuxSocket);
192 };
193
194
195 ChannelMultiplexer::MuxChannel::MuxChannel(
196     ChannelMultiplexer* multiplexer,
197     const std::string& name,
198     int send_id)
199     : multiplexer_(multiplexer),
200       name_(name),
201       send_id_(send_id),
202       id_sent_(false),
203       receive_id_(kChannelIdUnknown),
204       socket_(NULL) {
205 }
206
207 ChannelMultiplexer::MuxChannel::~MuxChannel() {
208   // Socket must be destroyed before the channel.
209   DCHECK(!socket_);
210   STLDeleteElements(&pending_packets_);
211 }
212
213 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
214   DCHECK(!socket_);  // Can't create more than one socket per channel.
215   scoped_ptr<MuxSocket> result(new MuxSocket(this));
216   socket_ = result.get();
217   return result.PassAs<net::StreamSocket>();
218 }
219
220 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
221     scoped_ptr<MultiplexPacket> packet,
222     const base::Closure& done_task) {
223   DCHECK_EQ(packet->channel_id(), receive_id_);
224   if (packet->data().size() > 0) {
225     pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
226     if (socket_) {
227       // Notify the socket that we have more data.
228       socket_->OnPacketReceived();
229     }
230   }
231 }
232
233 void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
234   if (socket_)
235     socket_->OnWriteFailed();
236 }
237
238 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
239   DCHECK(socket_);
240   socket_ = NULL;
241 }
242
243 bool ChannelMultiplexer::MuxChannel::DoWrite(
244     scoped_ptr<MultiplexPacket> packet,
245     const base::Closure& done_task) {
246   packet->set_channel_id(send_id_);
247   if (!id_sent_) {
248     packet->set_channel_name(name_);
249     id_sent_ = true;
250   }
251   return multiplexer_->DoWrite(packet.Pass(), done_task);
252 }
253
254 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
255                                            int buffer_len) {
256   int pos = 0;
257   while (buffer_len > 0 && !pending_packets_.empty()) {
258     DCHECK(!pending_packets_.front()->is_empty());
259     int result = pending_packets_.front()->Read(
260         buffer->data() + pos, buffer_len);
261     DCHECK_LE(result, buffer_len);
262     pos += result;
263     buffer_len -= pos;
264     if (pending_packets_.front()->is_empty()) {
265       delete pending_packets_.front();
266       pending_packets_.erase(pending_packets_.begin());
267     }
268   }
269   return pos;
270 }
271
272 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
273     : channel_(channel),
274       read_buffer_size_(0),
275       write_pending_(false),
276       write_result_(0) {
277 }
278
279 ChannelMultiplexer::MuxSocket::~MuxSocket() {
280   channel_->OnSocketDestroyed();
281 }
282
283 int ChannelMultiplexer::MuxSocket::Read(
284     net::IOBuffer* buffer, int buffer_len,
285     const net::CompletionCallback& callback) {
286   DCHECK(CalledOnValidThread());
287   DCHECK(read_callback_.is_null());
288
289   int result = channel_->DoRead(buffer, buffer_len);
290   if (result == 0) {
291     read_buffer_ = buffer;
292     read_buffer_size_ = buffer_len;
293     read_callback_ = callback;
294     return net::ERR_IO_PENDING;
295   }
296   return result;
297 }
298
299 int ChannelMultiplexer::MuxSocket::Write(
300     net::IOBuffer* buffer, int buffer_len,
301     const net::CompletionCallback& callback) {
302   DCHECK(CalledOnValidThread());
303
304   scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
305   size_t size = std::min(kMaxPacketSize, buffer_len);
306   packet->mutable_data()->assign(buffer->data(), size);
307
308   write_pending_ = true;
309   bool result = channel_->DoWrite(packet.Pass(), base::Bind(
310       &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
311
312   if (!result) {
313     // Cannot complete the write, e.g. if the connection has been terminated.
314     return net::ERR_FAILED;
315   }
316
317   // OnWriteComplete() might be called above synchronously.
318   if (write_pending_) {
319     DCHECK(write_callback_.is_null());
320     write_callback_ = callback;
321     write_result_ = size;
322     return net::ERR_IO_PENDING;
323   }
324
325   return size;
326 }
327
328 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
329   write_pending_ = false;
330   if (!write_callback_.is_null()) {
331     net::CompletionCallback cb;
332     std::swap(cb, write_callback_);
333     cb.Run(write_result_);
334   }
335 }
336
337 void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
338   if (!write_callback_.is_null()) {
339     net::CompletionCallback cb;
340     std::swap(cb, write_callback_);
341     cb.Run(net::ERR_FAILED);
342   }
343 }
344
345 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
346   if (!read_callback_.is_null()) {
347     int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
348     read_buffer_ = NULL;
349     DCHECK_GT(result, 0);
350     net::CompletionCallback cb;
351     std::swap(cb, read_callback_);
352     cb.Run(result);
353   }
354 }
355
356 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
357                                        const std::string& base_channel_name)
358     : base_channel_factory_(factory),
359       base_channel_name_(base_channel_name),
360       next_channel_id_(0),
361       weak_factory_(this) {
362 }
363
364 ChannelMultiplexer::~ChannelMultiplexer() {
365   DCHECK(pending_channels_.empty());
366   STLDeleteValues(&channels_);
367
368   // Cancel creation of the base channel if it hasn't finished.
369   if (base_channel_factory_)
370     base_channel_factory_->CancelChannelCreation(base_channel_name_);
371 }
372
373 void ChannelMultiplexer::CreateChannel(const std::string& name,
374                                        const ChannelCreatedCallback& callback) {
375   if (base_channel_.get()) {
376     // Already have |base_channel_|. Create new multiplexed channel
377     // synchronously.
378     callback.Run(GetOrCreateChannel(name)->CreateSocket());
379   } else if (!base_channel_.get() && !base_channel_factory_) {
380     // Fail synchronously if we failed to create |base_channel_|.
381     callback.Run(scoped_ptr<net::StreamSocket>());
382   } else {
383     // Still waiting for the |base_channel_|.
384     pending_channels_.push_back(PendingChannel(name, callback));
385
386     // If this is the first multiplexed channel then create the base channel.
387     if (pending_channels_.size() == 1U) {
388       base_channel_factory_->CreateChannel(
389           base_channel_name_,
390           base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
391                      base::Unretained(this)));
392     }
393   }
394 }
395
396 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
397   for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
398        it != pending_channels_.end(); ++it) {
399     if (it->name == name) {
400       pending_channels_.erase(it);
401       return;
402     }
403   }
404 }
405
406 void ChannelMultiplexer::OnBaseChannelReady(
407     scoped_ptr<net::StreamSocket> socket) {
408   base_channel_factory_ = NULL;
409   base_channel_ = socket.Pass();
410
411   if (base_channel_.get()) {
412     // Initialize reader and writer.
413     reader_.Init(base_channel_.get(),
414                  base::Bind(&ChannelMultiplexer::OnIncomingPacket,
415                             base::Unretained(this)));
416     writer_.Init(base_channel_.get(),
417                  base::Bind(&ChannelMultiplexer::OnWriteFailed,
418                             base::Unretained(this)));
419   }
420
421   DoCreatePendingChannels();
422 }
423
424 void ChannelMultiplexer::DoCreatePendingChannels() {
425   if (pending_channels_.empty())
426     return;
427
428   // Every time this function is called it connects a single channel and posts a
429   // separate task to connect other channels. This is necessary because the
430   // callback may destroy the multiplexer or somehow else modify
431   // |pending_channels_| list (e.g. call CancelChannelCreation()).
432   base::ThreadTaskRunnerHandle::Get()->PostTask(
433       FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
434                             weak_factory_.GetWeakPtr()));
435
436   PendingChannel c = pending_channels_.front();
437   pending_channels_.erase(pending_channels_.begin());
438   scoped_ptr<net::StreamSocket> socket;
439   if (base_channel_.get())
440     socket = GetOrCreateChannel(c.name)->CreateSocket();
441   c.callback.Run(socket.Pass());
442 }
443
444 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
445     const std::string& name) {
446   // Check if we already have a channel with the requested name.
447   std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
448   if (it != channels_.end())
449     return it->second;
450
451   // Create a new channel if we haven't found existing one.
452   MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
453   ++next_channel_id_;
454   channels_[channel->name()] = channel;
455   return channel;
456 }
457
458
459 void ChannelMultiplexer::OnWriteFailed(int error) {
460   for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
461        it != channels_.end(); ++it) {
462     base::ThreadTaskRunnerHandle::Get()->PostTask(
463         FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed,
464                               weak_factory_.GetWeakPtr(), it->second->name()));
465   }
466 }
467
468 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) {
469   std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
470   if (it != channels_.end()) {
471     it->second->OnWriteFailed();
472   }
473 }
474
475 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
476                                           const base::Closure& done_task) {
477   DCHECK(packet->has_channel_id());
478   if (!packet->has_channel_id()) {
479     LOG(ERROR) << "Received packet without channel_id.";
480     done_task.Run();
481     return;
482   }
483
484   int receive_id = packet->channel_id();
485   MuxChannel* channel = NULL;
486   std::map<int, MuxChannel*>::iterator it =
487       channels_by_receive_id_.find(receive_id);
488   if (it != channels_by_receive_id_.end()) {
489     channel = it->second;
490   } else {
491     // This is a new |channel_id| we haven't seen before. Look it up by name.
492     if (!packet->has_channel_name()) {
493       LOG(ERROR) << "Received packet with unknown channel_id and "
494           "without channel_name.";
495       done_task.Run();
496       return;
497     }
498     channel = GetOrCreateChannel(packet->channel_name());
499     channel->set_receive_id(receive_id);
500     channels_by_receive_id_[receive_id] = channel;
501   }
502
503   channel->OnIncomingPacket(packet.Pass(), done_task);
504 }
505
506 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
507                                  const base::Closure& done_task) {
508   return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
509 }
510
511 }  // namespace protocol
512 }  // namespace remoting