3d2bcdfb742421ecfbfd71f592bf59a544dfe664
[platform/framework/web/crosswalk.git] / src / net / websockets / websocket_basic_handshake_stream.cc
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_basic_handshake_stream.h"
6
7 #include <algorithm>
8 #include <iterator>
9 #include <string>
10 #include <vector>
11
12 #include "base/base64.h"
13 #include "base/basictypes.h"
14 #include "base/bind.h"
15 #include "base/containers/hash_tables.h"
16 #include "base/stl_util.h"
17 #include "base/strings/string_util.h"
18 #include "base/strings/stringprintf.h"
19 #include "base/time/time.h"
20 #include "crypto/random.h"
21 #include "net/http/http_request_headers.h"
22 #include "net/http/http_request_info.h"
23 #include "net/http/http_response_body_drainer.h"
24 #include "net/http/http_response_headers.h"
25 #include "net/http/http_status_code.h"
26 #include "net/http/http_stream_parser.h"
27 #include "net/socket/client_socket_handle.h"
28 #include "net/websockets/websocket_basic_stream.h"
29 #include "net/websockets/websocket_extension_parser.h"
30 #include "net/websockets/websocket_handshake_constants.h"
31 #include "net/websockets/websocket_handshake_handler.h"
32 #include "net/websockets/websocket_handshake_request_info.h"
33 #include "net/websockets/websocket_handshake_response_info.h"
34 #include "net/websockets/websocket_stream.h"
35
36 namespace net {
37 namespace {
38
39 enum GetHeaderResult {
40   GET_HEADER_OK,
41   GET_HEADER_MISSING,
42   GET_HEADER_MULTIPLE,
43 };
44
45 std::string MissingHeaderMessage(const std::string& header_name) {
46   return std::string("'") + header_name + "' header is missing";
47 }
48
49 std::string MultipleHeaderValuesMessage(const std::string& header_name) {
50   return
51       std::string("'") +
52       header_name +
53       "' header must not appear more than once in a response";
54 }
55
56 std::string GenerateHandshakeChallenge() {
57   std::string raw_challenge(websockets::kRawChallengeLength, '\0');
58   crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
59   std::string encoded_challenge;
60   base::Base64Encode(raw_challenge, &encoded_challenge);
61   return encoded_challenge;
62 }
63
64 void AddVectorHeaderIfNonEmpty(const char* name,
65                                const std::vector<std::string>& value,
66                                HttpRequestHeaders* headers) {
67   if (value.empty())
68     return;
69   headers->SetHeader(name, JoinString(value, ", "));
70 }
71
72 GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers,
73                                      const base::StringPiece& name,
74                                      std::string* value) {
75   void* state = NULL;
76   size_t num_values = 0;
77   std::string temp_value;
78   while (headers->EnumerateHeader(&state, name, &temp_value)) {
79     if (++num_values > 1)
80       return GET_HEADER_MULTIPLE;
81     *value = temp_value;
82   }
83   return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
84 }
85
86 bool ValidateHeaderHasSingleValue(GetHeaderResult result,
87                                   const std::string& header_name,
88                                   std::string* failure_message) {
89   if (result == GET_HEADER_MISSING) {
90     *failure_message = MissingHeaderMessage(header_name);
91     return false;
92   }
93   if (result == GET_HEADER_MULTIPLE) {
94     *failure_message = MultipleHeaderValuesMessage(header_name);
95     return false;
96   }
97   DCHECK_EQ(result, GET_HEADER_OK);
98   return true;
99 }
100
101 bool ValidateUpgrade(const HttpResponseHeaders* headers,
102                      std::string* failure_message) {
103   std::string value;
104   GetHeaderResult result =
105       GetSingleHeaderValue(headers, websockets::kUpgrade, &value);
106   if (!ValidateHeaderHasSingleValue(result,
107                                     websockets::kUpgrade,
108                                     failure_message)) {
109     return false;
110   }
111
112   if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) {
113     *failure_message =
114         "'Upgrade' header value is not 'WebSocket': " + value;
115     return false;
116   }
117   return true;
118 }
119
120 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers,
121                                 const std::string& expected,
122                                 std::string* failure_message) {
123   std::string actual;
124   GetHeaderResult result =
125       GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual);
126   if (!ValidateHeaderHasSingleValue(result,
127                                     websockets::kSecWebSocketAccept,
128                                     failure_message)) {
129     return false;
130   }
131
132   if (expected != actual) {
133     *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
134     return false;
135   }
136   return true;
137 }
138
139 bool ValidateConnection(const HttpResponseHeaders* headers,
140                         std::string* failure_message) {
141   // Connection header is permitted to contain other tokens.
142   if (!headers->HasHeader(HttpRequestHeaders::kConnection)) {
143     *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection);
144     return false;
145   }
146   if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection,
147                                websockets::kUpgrade)) {
148     *failure_message = "'Connection' header value must contain 'Upgrade'";
149     return false;
150   }
151   return true;
152 }
153
154 bool ValidateSubProtocol(
155     const HttpResponseHeaders* headers,
156     const std::vector<std::string>& requested_sub_protocols,
157     std::string* sub_protocol,
158     std::string* failure_message) {
159   void* state = NULL;
160   std::string value;
161   base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
162                                             requested_sub_protocols.end());
163   int count = 0;
164   bool has_multiple_protocols = false;
165   bool has_invalid_protocol = false;
166
167   while (!has_invalid_protocol || !has_multiple_protocols) {
168     std::string temp_value;
169     if (!headers->EnumerateHeader(
170             &state, websockets::kSecWebSocketProtocol, &temp_value))
171       break;
172     value = temp_value;
173     if (requested_set.count(value) == 0)
174       has_invalid_protocol = true;
175     if (++count > 1)
176       has_multiple_protocols = true;
177   }
178
179   if (has_multiple_protocols) {
180     *failure_message =
181         MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
182     return false;
183   } else if (count > 0 && requested_sub_protocols.size() == 0) {
184     *failure_message =
185         std::string("Response must not include 'Sec-WebSocket-Protocol' "
186                     "header if not present in request: ")
187         + value;
188     return false;
189   } else if (has_invalid_protocol) {
190     *failure_message =
191         "'Sec-WebSocket-Protocol' header value '" +
192         value +
193         "' in response does not match any of sent values";
194     return false;
195   } else if (requested_sub_protocols.size() > 0 && count == 0) {
196     *failure_message =
197         "Sent non-empty 'Sec-WebSocket-Protocol' header "
198         "but no response was received";
199     return false;
200   }
201   *sub_protocol = value;
202   return true;
203 }
204
205 bool ValidateExtensions(const HttpResponseHeaders* headers,
206                         const std::vector<std::string>& requested_extensions,
207                         std::string* extensions,
208                         std::string* failure_message) {
209   void* state = NULL;
210   std::string value;
211   while (headers->EnumerateHeader(
212       &state, websockets::kSecWebSocketExtensions, &value)) {
213     WebSocketExtensionParser parser;
214     parser.Parse(value);
215     if (parser.has_error()) {
216       // TODO(yhirano) Set appropriate failure message.
217       *failure_message =
218           "'Sec-WebSocket-Extensions' header value is "
219           "rejected by the parser: " +
220           value;
221       return false;
222     }
223     // TODO(ricea): Accept permessage-deflate with valid parameters.
224     *failure_message =
225         "Found an unsupported extension '" +
226         parser.extension().name() +
227         "' in 'Sec-WebSocket-Extensions' header";
228     return false;
229   }
230   return true;
231 }
232
233 }  // namespace
234
235 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
236     scoped_ptr<ClientSocketHandle> connection,
237     WebSocketStream::ConnectDelegate* connect_delegate,
238     bool using_proxy,
239     std::vector<std::string> requested_sub_protocols,
240     std::vector<std::string> requested_extensions)
241     : state_(connection.release(), using_proxy),
242       connect_delegate_(connect_delegate),
243       http_response_info_(NULL),
244       requested_sub_protocols_(requested_sub_protocols),
245       requested_extensions_(requested_extensions) {}
246
247 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
248
249 int WebSocketBasicHandshakeStream::InitializeStream(
250     const HttpRequestInfo* request_info,
251     RequestPriority priority,
252     const BoundNetLog& net_log,
253     const CompletionCallback& callback) {
254   url_ = request_info->url;
255   state_.Initialize(request_info, priority, net_log, callback);
256   return OK;
257 }
258
259 int WebSocketBasicHandshakeStream::SendRequest(
260     const HttpRequestHeaders& headers,
261     HttpResponseInfo* response,
262     const CompletionCallback& callback) {
263   DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
264   DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
265   DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
266   DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
267   DCHECK(headers.HasHeader(websockets::kUpgrade));
268   DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
269   DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
270   DCHECK(parser());
271
272   http_response_info_ = response;
273
274   // Create a copy of the headers object, so that we can add the
275   // Sec-WebSockey-Key header.
276   HttpRequestHeaders enriched_headers;
277   enriched_headers.CopyFrom(headers);
278   std::string handshake_challenge;
279   if (handshake_challenge_for_testing_) {
280     handshake_challenge = *handshake_challenge_for_testing_;
281     handshake_challenge_for_testing_.reset();
282   } else {
283     handshake_challenge = GenerateHandshakeChallenge();
284   }
285   enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
286
287   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
288                             requested_sub_protocols_,
289                             &enriched_headers);
290   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
291                             requested_extensions_,
292                             &enriched_headers);
293
294   ComputeSecWebSocketAccept(handshake_challenge,
295                             &handshake_challenge_response_);
296
297   DCHECK(connect_delegate_);
298   scoped_ptr<WebSocketHandshakeRequestInfo> request(
299       new WebSocketHandshakeRequestInfo(url_, base::Time::Now()));
300   request->headers.CopyFrom(enriched_headers);
301   connect_delegate_->OnStartOpeningHandshake(request.Pass());
302
303   return parser()->SendRequest(
304       state_.GenerateRequestLine(), enriched_headers, response, callback);
305 }
306
307 int WebSocketBasicHandshakeStream::ReadResponseHeaders(
308     const CompletionCallback& callback) {
309   // HttpStreamParser uses a weak pointer when reading from the
310   // socket, so it won't be called back after being destroyed. The
311   // HttpStreamParser is owned by HttpBasicState which is owned by this object,
312   // so this use of base::Unretained() is safe.
313   int rv = parser()->ReadResponseHeaders(
314       base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback,
315                  base::Unretained(this),
316                  callback));
317   if (rv == ERR_IO_PENDING)
318     return rv;
319   if (rv == OK)
320     return ValidateResponse();
321   OnFinishOpeningHandshake();
322   return rv;
323 }
324
325 const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const {
326   return parser()->GetResponseInfo();
327 }
328
329 int WebSocketBasicHandshakeStream::ReadResponseBody(
330     IOBuffer* buf,
331     int buf_len,
332     const CompletionCallback& callback) {
333   return parser()->ReadResponseBody(buf, buf_len, callback);
334 }
335
336 void WebSocketBasicHandshakeStream::Close(bool not_reusable) {
337   // This class ignores the value of |not_reusable| and never lets the socket be
338   // re-used.
339   if (parser())
340     parser()->Close(true);
341 }
342
343 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
344   return parser()->IsResponseBodyComplete();
345 }
346
347 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
348   return parser() && parser()->CanFindEndOfResponse();
349 }
350
351 bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
352   return parser()->IsConnectionReused();
353 }
354
355 void WebSocketBasicHandshakeStream::SetConnectionReused() {
356   parser()->SetConnectionReused();
357 }
358
359 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
360   return false;
361 }
362
363 int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
364   return 0;
365 }
366
367 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
368     LoadTimingInfo* load_timing_info) const {
369   return state_.connection()->GetLoadTimingInfo(IsConnectionReused(),
370                                                 load_timing_info);
371 }
372
373 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
374   parser()->GetSSLInfo(ssl_info);
375 }
376
377 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
378     SSLCertRequestInfo* cert_request_info) {
379   parser()->GetSSLCertRequestInfo(cert_request_info);
380 }
381
382 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
383
384 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) {
385   HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this);
386   drainer->Start(session);
387   // |drainer| will delete itself.
388 }
389
390 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
391   // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
392   // gone, then copy whatever has happened there over here.
393 }
394
395 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
396   // TODO(ricea): Add deflate support.
397
398   // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
399   // sure it does not touch it again before it is destroyed.
400   state_.DeleteParser();
401   return scoped_ptr<WebSocketStream>(
402       new WebSocketBasicStream(state_.ReleaseConnection(),
403                                state_.read_buf(),
404                                sub_protocol_,
405                                extensions_));
406 }
407
408 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
409     const std::string& key) {
410   handshake_challenge_for_testing_.reset(new std::string(key));
411 }
412
413 std::string WebSocketBasicHandshakeStream::GetFailureMessage() const {
414   return failure_message_;
415 }
416
417 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
418     const CompletionCallback& callback,
419     int result) {
420   if (result == OK)
421     result = ValidateResponse();
422   else
423     OnFinishOpeningHandshake();
424   callback.Run(result);
425 }
426
427 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
428   DCHECK(connect_delegate_);
429   DCHECK(http_response_info_);
430   scoped_refptr<HttpResponseHeaders> headers = http_response_info_->headers;
431   scoped_ptr<WebSocketHandshakeResponseInfo> response(
432       new WebSocketHandshakeResponseInfo(url_,
433                                          headers->response_code(),
434                                          headers->GetStatusText(),
435                                          headers,
436                                          http_response_info_->response_time));
437   connect_delegate_->OnFinishOpeningHandshake(response.Pass());
438 }
439
440 int WebSocketBasicHandshakeStream::ValidateResponse() {
441   DCHECK(http_response_info_);
442   const scoped_refptr<HttpResponseHeaders>& headers =
443       http_response_info_->headers;
444
445   switch (headers->response_code()) {
446     case HTTP_SWITCHING_PROTOCOLS:
447       OnFinishOpeningHandshake();
448       return ValidateUpgradeResponse(headers);
449
450     // We need to pass these through for authentication to work.
451     case HTTP_UNAUTHORIZED:
452     case HTTP_PROXY_AUTHENTICATION_REQUIRED:
453       return OK;
454
455     // Other status codes are potentially risky (see the warnings in the
456     // WHATWG WebSocket API spec) and so are dropped by default.
457     default:
458       failure_message_ = base::StringPrintf("Unexpected status code: %d",
459                                             headers->response_code());
460       OnFinishOpeningHandshake();
461       return ERR_INVALID_RESPONSE;
462   }
463 }
464
465 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
466     const scoped_refptr<HttpResponseHeaders>& headers) {
467   if (ValidateUpgrade(headers.get(), &failure_message_) &&
468       ValidateSecWebSocketAccept(headers.get(),
469                                  handshake_challenge_response_,
470                                  &failure_message_) &&
471       ValidateConnection(headers.get(), &failure_message_) &&
472       ValidateSubProtocol(headers.get(),
473                           requested_sub_protocols_,
474                           &sub_protocol_,
475                           &failure_message_) &&
476       ValidateExtensions(headers.get(),
477                          requested_extensions_,
478                          &extensions_,
479                          &failure_message_)) {
480     return OK;
481   }
482   failure_message_ = "Error during WebSocket handshake: " + failure_message_;
483   return ERR_INVALID_RESPONSE;
484 }
485
486 }  // namespace net