- add sources.
[platform/framework/web/crosswalk.git] / src / chrome / test / chromedriver / net / websocket.cc
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chrome/test/chromedriver/net/websocket.h"
6
7 #include <string.h>
8
9 #include "base/base64.h"
10 #include "base/bind.h"
11 #include "base/bind_helpers.h"
12 #include "base/memory/scoped_vector.h"
13 #include "base/rand_util.h"
14 #include "base/sha1.h"
15 #include "base/strings/string_number_conversions.h"
16 #include "base/strings/stringprintf.h"
17 #include "net/base/address_list.h"
18 #include "net/base/io_buffer.h"
19 #include "net/base/ip_endpoint.h"
20 #include "net/base/net_errors.h"
21 #include "net/base/net_util.h"
22 #include "net/base/sys_addrinfo.h"
23 #include "net/http/http_response_headers.h"
24 #include "net/http/http_util.h"
25 #include "net/websockets/websocket_frame.h"
26
27 #if defined(OS_WIN)
28 #include <Winsock2.h>
29 #endif
30
31 namespace {
32
33 bool ResolveHost(const std::string& host, net::IPAddressNumber* address) {
34   struct addrinfo hints;
35   memset(&hints, 0, sizeof(hints));
36   hints.ai_family = AF_UNSPEC;
37   hints.ai_socktype = SOCK_STREAM;
38
39   struct addrinfo* result;
40   if (getaddrinfo(host.c_str(), NULL, &hints, &result))
41     return false;
42
43   for (struct addrinfo* addr = result; addr; addr = addr->ai_next) {
44     if (addr->ai_family == AF_INET || addr->ai_family == AF_INET6) {
45       net::IPEndPoint end_point;
46       if (!end_point.FromSockAddr(addr->ai_addr, addr->ai_addrlen)) {
47         freeaddrinfo(result);
48         return false;
49       }
50       *address = end_point.address();
51     }
52   }
53   freeaddrinfo(result);
54   return true;
55 }
56
57 }  // namespace
58
59 WebSocket::WebSocket(const GURL& url, WebSocketListener* listener)
60     : url_(url),
61       listener_(listener),
62       state_(INITIALIZED),
63       write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)),
64       read_buffer_(new net::IOBufferWithSize(4096)) {}
65
66 WebSocket::~WebSocket() {
67   CHECK(thread_checker_.CalledOnValidThread());
68 }
69
70 void WebSocket::Connect(const net::CompletionCallback& callback) {
71   CHECK(thread_checker_.CalledOnValidThread());
72   CHECK_EQ(INITIALIZED, state_);
73
74   net::IPAddressNumber address;
75   if (!net::ParseIPLiteralToNumber(url_.HostNoBrackets(), &address)) {
76     if (!ResolveHost(url_.HostNoBrackets(), &address)) {
77       callback.Run(net::ERR_ADDRESS_UNREACHABLE);
78       return;
79     }
80   }
81   int port = 80;
82   base::StringToInt(url_.port(), &port);
83   net::AddressList addresses(net::IPEndPoint(address, port));
84   net::NetLog::Source source;
85   socket_.reset(new net::TCPClientSocket(addresses, NULL, source));
86
87   state_ = CONNECTING;
88   connect_callback_ = callback;
89   int code = socket_->Connect(base::Bind(
90       &WebSocket::OnSocketConnect, base::Unretained(this)));
91   if (code != net::ERR_IO_PENDING)
92     OnSocketConnect(code);
93 }
94
95 bool WebSocket::Send(const std::string& message) {
96   CHECK(thread_checker_.CalledOnValidThread());
97   if (state_ != OPEN)
98     return false;
99
100   net::WebSocketFrameHeader header(net::WebSocketFrameHeader::kOpCodeText);
101   header.final = true;
102   header.masked = true;
103   header.payload_length = message.length();
104   int header_size = net::GetWebSocketFrameHeaderSize(header);
105   net::WebSocketMaskingKey masking_key = net::GenerateWebSocketMaskingKey();
106   std::string header_str;
107   header_str.resize(header_size);
108   CHECK_EQ(header_size, net::WriteWebSocketFrameHeader(
109       header, &masking_key, &header_str[0], header_str.length()));
110
111   std::string masked_message = message;
112   net::MaskWebSocketFramePayload(
113       masking_key, 0, &masked_message[0], masked_message.length());
114   Write(header_str + masked_message);
115   return true;
116 }
117
118 void WebSocket::OnSocketConnect(int code) {
119   if (code != net::OK) {
120     Close(code);
121     return;
122   }
123
124   CHECK(base::Base64Encode(base::RandBytesAsString(16), &sec_key_));
125   std::string handshake = base::StringPrintf(
126       "GET %s HTTP/1.1\r\n"
127       "Host: %s\r\n"
128       "Upgrade: websocket\r\n"
129       "Connection: Upgrade\r\n"
130       "Sec-WebSocket-Key: %s\r\n"
131       "Sec-WebSocket-Version: 13\r\n"
132       "Pragma: no-cache\r\n"
133       "Cache-Control: no-cache\r\n"
134       "\r\n",
135       url_.path().c_str(),
136       url_.host().c_str(),
137       sec_key_.c_str());
138   Write(handshake);
139   Read();
140 }
141
142 void WebSocket::Write(const std::string& data) {
143   pending_write_ += data;
144   if (!write_buffer_->BytesRemaining())
145     ContinueWritingIfNecessary();
146 }
147
148 void WebSocket::OnWrite(int code) {
149   if (!socket_->IsConnected()) {
150     // Supposedly if |StreamSocket| is closed, the error code may be undefined.
151     Close(net::ERR_FAILED);
152     return;
153   }
154   if (code < 0) {
155     Close(code);
156     return;
157   }
158
159   write_buffer_->DidConsume(code);
160   ContinueWritingIfNecessary();
161 }
162
163 void WebSocket::ContinueWritingIfNecessary() {
164   if (!write_buffer_->BytesRemaining()) {
165     if (pending_write_.empty())
166       return;
167     write_buffer_ = new net::DrainableIOBuffer(
168         new net::StringIOBuffer(pending_write_),
169         pending_write_.length());
170     pending_write_.clear();
171   }
172   int code =
173       socket_->Write(write_buffer_.get(),
174                      write_buffer_->BytesRemaining(),
175                      base::Bind(&WebSocket::OnWrite, base::Unretained(this)));
176   if (code != net::ERR_IO_PENDING)
177     OnWrite(code);
178 }
179
180 void WebSocket::Read() {
181   int code =
182       socket_->Read(read_buffer_.get(),
183                     read_buffer_->size(),
184                     base::Bind(&WebSocket::OnRead, base::Unretained(this)));
185   if (code != net::ERR_IO_PENDING)
186     OnRead(code);
187 }
188
189 void WebSocket::OnRead(int code) {
190   if (code <= 0) {
191     Close(code ? code : net::ERR_FAILED);
192     return;
193   }
194
195   if (state_ == CONNECTING)
196     OnReadDuringHandshake(read_buffer_->data(), code);
197   else if (state_ == OPEN)
198     OnReadDuringOpen(read_buffer_->data(), code);
199
200   if (state_ != CLOSED)
201     Read();
202 }
203
204 void WebSocket::OnReadDuringHandshake(const char* data, int len) {
205   handshake_response_ += std::string(data, len);
206   int headers_end = net::HttpUtil::LocateEndOfHeaders(
207       handshake_response_.data(), handshake_response_.size(), 0);
208   if (headers_end == -1)
209     return;
210
211   const char kMagicKey[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
212   std::string websocket_accept;
213   CHECK(base::Base64Encode(base::SHA1HashString(sec_key_ + kMagicKey),
214                            &websocket_accept));
215   scoped_refptr<net::HttpResponseHeaders> headers(
216       new net::HttpResponseHeaders(
217           net::HttpUtil::AssembleRawHeaders(
218               handshake_response_.data(), headers_end)));
219   if (headers->response_code() != 101 ||
220       !headers->HasHeaderValue("Upgrade", "WebSocket") ||
221       !headers->HasHeaderValue("Connection", "Upgrade") ||
222       !headers->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept)) {
223     Close(net::ERR_FAILED);
224     return;
225   }
226   std::string leftover_message = handshake_response_.substr(headers_end);
227   handshake_response_.clear();
228   sec_key_.clear();
229   state_ = OPEN;
230   InvokeConnectCallback(net::OK);
231   if (!leftover_message.empty())
232     OnReadDuringOpen(leftover_message.c_str(), leftover_message.length());
233 }
234
235 void WebSocket::OnReadDuringOpen(const char* data, int len) {
236   ScopedVector<net::WebSocketFrameChunk> frame_chunks;
237   CHECK(parser_.Decode(data, len, &frame_chunks));
238   for (size_t i = 0; i < frame_chunks.size(); ++i) {
239     scoped_refptr<net::IOBufferWithSize> buffer = frame_chunks[i]->data;
240     if (buffer.get())
241       next_message_ += std::string(buffer->data(), buffer->size());
242     if (frame_chunks[i]->final_chunk) {
243       listener_->OnMessageReceived(next_message_);
244       next_message_.clear();
245     }
246   }
247 }
248
249 void WebSocket::InvokeConnectCallback(int code) {
250   net::CompletionCallback temp = connect_callback_;
251   connect_callback_.Reset();
252   CHECK(!temp.is_null());
253   temp.Run(code);
254 }
255
256 void WebSocket::Close(int code) {
257   socket_->Disconnect();
258   if (!connect_callback_.is_null())
259     InvokeConnectCallback(code);
260   if (state_ == OPEN)
261     listener_->OnClose();
262
263   state_ = CLOSED;
264 }
265
266