Upstream version 9.38.198.0
[platform/framework/web/crosswalk.git] / src / net / socket / socket_test_util.cc
index f993801..3498c13 100644 (file)
@@ -10,6 +10,7 @@
 #include "base/basictypes.h"
 #include "base/bind.h"
 #include "base/bind_helpers.h"
+#include "base/callback_helpers.h"
 #include "base/compiler_specific.h"
 #include "base/message_loop/message_loop.h"
 #include "base/run_loop.h"
@@ -277,7 +278,9 @@ SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
       client_cert_sent(false),
       cert_request_info(NULL),
       channel_id_sent(false),
-      connection_status(0) {
+      connection_status(0),
+      should_pause_on_connect(false),
+      is_in_session_cache(false) {
   SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_2,
                                 &connection_status);
   // Set to TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305
@@ -699,10 +702,13 @@ scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
     const HostPortPair& host_and_port,
     const SSLConfig& ssl_config,
     const SSLClientSocketContext& context) {
-  return scoped_ptr<SSLClientSocket>(
+  scoped_ptr<MockSSLClientSocket> socket(
       new MockSSLClientSocket(transport_socket.Pass(),
-                              host_and_port, ssl_config,
+                              host_and_port,
+                              ssl_config,
                               mock_ssl_data_.GetNext()));
+  ssl_client_sockets_.push_back(socket.get());
+  return socket.PassAs<SSLClientSocket>();
 }
 
 void MockClientSocketFactory::ClearSSLSessionCache() {
@@ -758,6 +764,15 @@ const BoundNetLog& MockClientSocket::NetLog() const {
   return net_log_;
 }
 
+bool MockClientSocket::InSessionCache() const {
+  NOTIMPLEMENTED();
+  return false;
+}
+
+void MockClientSocket::SetHandshakeCompletionCallback(const base::Closure& cb) {
+  NOTIMPLEMENTED();
+}
+
 void MockClientSocket::GetSSLCertRequestInfo(
   SSLCertRequestInfo* cert_request_info) {
 }
@@ -776,15 +791,14 @@ int MockClientSocket::GetTLSUniqueChannelBinding(std::string* out) {
   return OK;
 }
 
-ServerBoundCertService* MockClientSocket::GetServerBoundCertService() const {
+ChannelIDService* MockClientSocket::GetChannelIDService() const {
   NOTREACHED();
   return NULL;
 }
 
 SSLClientSocket::NextProtoStatus
-MockClientSocket::GetNextProto(std::string* proto, std::string* server_protos) {
+MockClientSocket::GetNextProto(std::string* proto) {
   proto->clear();
-  server_protos->clear();
   return SSLClientSocket::kNextProtoUnsupported;
 }
 
@@ -1298,31 +1312,24 @@ void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {}
 void DeterministicMockTCPClientSocket::OnConnectComplete(
     const MockConnect& data) {}
 
-// static
-void MockSSLClientSocket::ConnectCallback(
-    MockSSLClientSocket* ssl_client_socket,
-    const CompletionCallback& callback,
-    int rv) {
-  if (rv == OK)
-    ssl_client_socket->connected_ = true;
-  callback.Run(rv);
-}
-
 MockSSLClientSocket::MockSSLClientSocket(
     scoped_ptr<ClientSocketHandle> transport_socket,
     const HostPortPair& host_port_pair,
     const SSLConfig& ssl_config,
     SSLSocketDataProvider* data)
     : MockClientSocket(
-         // Have to use the right BoundNetLog for LoadTimingInfo regression
-         // tests.
-         transport_socket->socket()->NetLog()),
+          // Have to use the right BoundNetLog for LoadTimingInfo regression
+          // tests.
+          transport_socket->socket()->NetLog()),
       transport_(transport_socket.Pass()),
       data_(data),
       is_npn_state_set_(false),
       new_npn_value_(false),
       is_protocol_negotiated_set_(false),
-      protocol_negotiated_(kProtoUnknown) {
+      protocol_negotiated_(kProtoUnknown),
+      next_connect_state_(STATE_NONE),
+      reached_connect_(false),
+      weak_factory_(this) {
   DCHECK(data_);
   peer_addr_ = data->connect.peer_addr;
 }
@@ -1342,28 +1349,23 @@ int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len,
 }
 
 int MockSSLClientSocket::Connect(const CompletionCallback& callback) {
-  int rv = transport_->socket()->Connect(
-      base::Bind(&ConnectCallback, base::Unretained(this), callback));
-  if (rv == OK) {
-    if (data_->connect.result == OK)
-      connected_ = true;
-    if (data_->connect.mode == ASYNC) {
-      RunCallbackAsync(callback, data_->connect.result);
-      return ERR_IO_PENDING;
-    }
-    return data_->connect.result;
-  }
+  next_connect_state_ = STATE_SSL_CONNECT;
+  reached_connect_ = true;
+  int rv = DoConnectLoop(OK);
+  if (rv == ERR_IO_PENDING)
+    connect_callback_ = callback;
   return rv;
 }
 
 void MockSSLClientSocket::Disconnect() {
+  weak_factory_.InvalidateWeakPtrs();
   MockClientSocket::Disconnect();
   if (transport_->socket() != NULL)
     transport_->socket()->Disconnect();
 }
 
 bool MockSSLClientSocket::IsConnected() const {
-  return transport_->socket()->IsConnected();
+  return transport_->socket()->IsConnected() && connected_;
 }
 
 bool MockSSLClientSocket::WasEverUsed() const {
@@ -1387,6 +1389,15 @@ bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
   return true;
 }
 
+bool MockSSLClientSocket::InSessionCache() const {
+  return data_->is_in_session_cache;
+}
+
+void MockSSLClientSocket::SetHandshakeCompletionCallback(
+    const base::Closure& cb) {
+  handshake_completion_callback_ = cb;
+}
+
 void MockSSLClientSocket::GetSSLCertRequestInfo(
     SSLCertRequestInfo* cert_request_info) {
   DCHECK(cert_request_info);
@@ -1400,9 +1411,8 @@ void MockSSLClientSocket::GetSSLCertRequestInfo(
 }
 
 SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto(
-    std::string* proto, std::string* server_protos) {
+    std::string* proto) {
   *proto = data_->next_proto;
-  *server_protos = data_->server_protos;
   return data_->next_proto_status;
 }
 
@@ -1437,8 +1447,8 @@ void MockSSLClientSocket::set_channel_id_sent(bool channel_id_sent) {
   data_->channel_id_sent = channel_id_sent;
 }
 
-ServerBoundCertService* MockSSLClientSocket::GetServerBoundCertService() const {
-  return data_->server_bound_cert_service;
+ChannelIDService* MockSSLClientSocket::GetChannelIDService() const {
+  return data_->channel_id_service;
 }
 
 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
@@ -1449,6 +1459,69 @@ void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
   NOTIMPLEMENTED();
 }
 
+void MockSSLClientSocket::RestartPausedConnect() {
+  DCHECK(data_->should_pause_on_connect);
+  DCHECK_EQ(next_connect_state_, STATE_SSL_CONNECT_COMPLETE);
+  OnIOComplete(data_->connect.result);
+}
+
+void MockSSLClientSocket::OnIOComplete(int result) {
+  int rv = DoConnectLoop(result);
+  if (rv != ERR_IO_PENDING)
+    base::ResetAndReturn(&connect_callback_).Run(rv);
+}
+
+int MockSSLClientSocket::DoConnectLoop(int result) {
+  DCHECK_NE(next_connect_state_, STATE_NONE);
+
+  int rv = result;
+  do {
+    ConnectState state = next_connect_state_;
+    next_connect_state_ = STATE_NONE;
+    switch (state) {
+      case STATE_SSL_CONNECT:
+        rv = DoSSLConnect();
+        break;
+      case STATE_SSL_CONNECT_COMPLETE:
+        rv = DoSSLConnectComplete(rv);
+        break;
+      default:
+        NOTREACHED() << "bad state";
+        rv = ERR_UNEXPECTED;
+        break;
+    }
+  } while (rv != ERR_IO_PENDING && next_connect_state_ != STATE_NONE);
+
+  return rv;
+}
+
+int MockSSLClientSocket::DoSSLConnect() {
+  next_connect_state_ = STATE_SSL_CONNECT_COMPLETE;
+
+  if (data_->should_pause_on_connect)
+    return ERR_IO_PENDING;
+
+  if (data_->connect.mode == ASYNC) {
+    base::MessageLoop::current()->PostTask(
+        FROM_HERE,
+        base::Bind(&MockSSLClientSocket::OnIOComplete,
+                   weak_factory_.GetWeakPtr(),
+                   data_->connect.result));
+    return ERR_IO_PENDING;
+  }
+
+  return data_->connect.result;
+}
+
+int MockSSLClientSocket::DoSSLConnectComplete(int result) {
+  if (result == OK)
+    connected_ = true;
+
+  if (!handshake_completion_callback_.is_null())
+    base::ResetAndReturn(&handshake_completion_callback_).Run();
+  return result;
+}
+
 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
                                          net::NetLog* net_log)
     : connected_(false),