1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/transport_client_socket_pool_test_util.h"
9 #include "base/logging.h"
10 #include "base/memory/weak_ptr.h"
11 #include "base/run_loop.h"
12 #include "net/base/ip_endpoint.h"
13 #include "net/base/load_timing_info.h"
14 #include "net/base/load_timing_info_test_util.h"
15 #include "net/base/net_util.h"
16 #include "net/socket/client_socket_handle.h"
17 #include "net/socket/ssl_client_socket.h"
18 #include "net/udp/datagram_client_socket.h"
19 #include "testing/gtest/include/gtest/gtest.h"
25 IPAddressNumber ParseIP(const std::string& ip) {
26 IPAddressNumber number;
27 CHECK(ParseIPLiteralToNumber(ip, &number));
31 // A StreamSocket which connects synchronously and successfully.
32 class MockConnectClientSocket : public StreamSocket {
34 MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log)
37 net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
38 use_tcp_fastopen_(false) {}
40 // StreamSocket implementation.
41 virtual int Connect(const CompletionCallback& callback) OVERRIDE {
45 virtual void Disconnect() OVERRIDE { connected_ = false; }
46 virtual bool IsConnected() const OVERRIDE { return connected_; }
47 virtual bool IsConnectedAndIdle() const OVERRIDE { return connected_; }
49 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
50 *address = addrlist_.front();
53 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
55 return ERR_SOCKET_NOT_CONNECTED;
56 if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
57 SetIPv4Address(address);
59 SetIPv6Address(address);
62 virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; }
64 virtual void SetSubresourceSpeculation() OVERRIDE {}
65 virtual void SetOmniboxSpeculation() OVERRIDE {}
66 virtual bool WasEverUsed() const OVERRIDE { return false; }
67 virtual void EnableTCPFastOpenIfSupported() OVERRIDE {
68 use_tcp_fastopen_ = true;
70 virtual bool UsingTCPFastOpen() const OVERRIDE { return use_tcp_fastopen_; }
71 virtual bool WasNpnNegotiated() const OVERRIDE { return false; }
72 virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
75 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; }
77 // Socket implementation.
78 virtual int Read(IOBuffer* buf,
80 const CompletionCallback& callback) OVERRIDE {
83 virtual int Write(IOBuffer* buf,
85 const CompletionCallback& callback) OVERRIDE {
88 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; }
89 virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; }
93 const AddressList addrlist_;
95 bool use_tcp_fastopen_;
97 DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket);
100 class MockFailingClientSocket : public StreamSocket {
102 MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log)
103 : addrlist_(addrlist),
104 net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
105 use_tcp_fastopen_(false) {}
107 // StreamSocket implementation.
108 virtual int Connect(const CompletionCallback& callback) OVERRIDE {
109 return ERR_CONNECTION_FAILED;
112 virtual void Disconnect() OVERRIDE {}
114 virtual bool IsConnected() const OVERRIDE { return false; }
115 virtual bool IsConnectedAndIdle() const OVERRIDE { return false; }
116 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
117 return ERR_UNEXPECTED;
119 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
120 return ERR_UNEXPECTED;
122 virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; }
124 virtual void SetSubresourceSpeculation() OVERRIDE {}
125 virtual void SetOmniboxSpeculation() OVERRIDE {}
126 virtual bool WasEverUsed() const OVERRIDE { return false; }
127 virtual void EnableTCPFastOpenIfSupported() OVERRIDE {
128 use_tcp_fastopen_ = true;
130 virtual bool UsingTCPFastOpen() const OVERRIDE { return use_tcp_fastopen_; }
131 virtual bool WasNpnNegotiated() const OVERRIDE { return false; }
132 virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
133 return kProtoUnknown;
135 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; }
137 // Socket implementation.
138 virtual int Read(IOBuffer* buf,
140 const CompletionCallback& callback) OVERRIDE {
144 virtual int Write(IOBuffer* buf,
146 const CompletionCallback& callback) OVERRIDE {
149 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; }
150 virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; }
153 const AddressList addrlist_;
154 BoundNetLog net_log_;
155 bool use_tcp_fastopen_;
157 DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket);
160 class MockTriggerableClientSocket : public StreamSocket {
162 // |should_connect| indicates whether the socket should successfully complete
164 MockTriggerableClientSocket(const AddressList& addrlist,
166 net::NetLog* net_log)
167 : should_connect_(should_connect),
168 is_connected_(false),
170 net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
171 use_tcp_fastopen_(false),
172 weak_factory_(this) {}
174 // Call this method to get a closure which will trigger the connect callback
175 // when called. The closure can be called even after the socket is deleted; it
176 // will safely do nothing.
177 base::Closure GetConnectCallback() {
178 return base::Bind(&MockTriggerableClientSocket::DoCallback,
179 weak_factory_.GetWeakPtr());
182 static scoped_ptr<StreamSocket> MakeMockPendingClientSocket(
183 const AddressList& addrlist,
185 net::NetLog* net_log) {
186 scoped_ptr<MockTriggerableClientSocket> socket(
187 new MockTriggerableClientSocket(addrlist, should_connect, net_log));
188 base::MessageLoop::current()->PostTask(FROM_HERE,
189 socket->GetConnectCallback());
190 return socket.PassAs<StreamSocket>();
193 static scoped_ptr<StreamSocket> MakeMockDelayedClientSocket(
194 const AddressList& addrlist,
196 const base::TimeDelta& delay,
197 net::NetLog* net_log) {
198 scoped_ptr<MockTriggerableClientSocket> socket(
199 new MockTriggerableClientSocket(addrlist, should_connect, net_log));
200 base::MessageLoop::current()->PostDelayedTask(
201 FROM_HERE, socket->GetConnectCallback(), delay);
202 return socket.PassAs<StreamSocket>();
205 static scoped_ptr<StreamSocket> MakeMockStalledClientSocket(
206 const AddressList& addrlist,
207 net::NetLog* net_log) {
208 scoped_ptr<MockTriggerableClientSocket> socket(
209 new MockTriggerableClientSocket(addrlist, true, net_log));
210 return socket.PassAs<StreamSocket>();
213 // StreamSocket implementation.
214 virtual int Connect(const CompletionCallback& callback) OVERRIDE {
215 DCHECK(callback_.is_null());
216 callback_ = callback;
217 return ERR_IO_PENDING;
220 virtual void Disconnect() OVERRIDE {}
222 virtual bool IsConnected() const OVERRIDE { return is_connected_; }
223 virtual bool IsConnectedAndIdle() const OVERRIDE { return is_connected_; }
224 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
225 *address = addrlist_.front();
228 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
230 return ERR_SOCKET_NOT_CONNECTED;
231 if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
232 SetIPv4Address(address);
234 SetIPv6Address(address);
237 virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; }
239 virtual void SetSubresourceSpeculation() OVERRIDE {}
240 virtual void SetOmniboxSpeculation() OVERRIDE {}
241 virtual bool WasEverUsed() const OVERRIDE { return false; }
242 virtual void EnableTCPFastOpenIfSupported() OVERRIDE {
243 use_tcp_fastopen_ = true;
245 virtual bool UsingTCPFastOpen() const OVERRIDE { return use_tcp_fastopen_; }
246 virtual bool WasNpnNegotiated() const OVERRIDE { return false; }
247 virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
248 return kProtoUnknown;
250 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; }
252 // Socket implementation.
253 virtual int Read(IOBuffer* buf,
255 const CompletionCallback& callback) OVERRIDE {
259 virtual int Write(IOBuffer* buf,
261 const CompletionCallback& callback) OVERRIDE {
264 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; }
265 virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; }
269 is_connected_ = should_connect_;
270 callback_.Run(is_connected_ ? OK : ERR_CONNECTION_FAILED);
273 bool should_connect_;
275 const AddressList addrlist_;
276 BoundNetLog net_log_;
277 CompletionCallback callback_;
278 bool use_tcp_fastopen_;
280 base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_;
282 DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket);
287 void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
288 LoadTimingInfo load_timing_info;
289 // Only pass true in as |is_reused|, as in general, HttpStream types should
290 // have stricter concepts of reuse than socket pools.
291 EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));
293 EXPECT_TRUE(load_timing_info.socket_reused);
294 EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
296 ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
297 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
300 void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
301 EXPECT_FALSE(handle.is_reused());
303 LoadTimingInfo load_timing_info;
304 EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
306 EXPECT_FALSE(load_timing_info.socket_reused);
307 EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
309 ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
310 CONNECT_TIMING_HAS_DNS_TIMES);
311 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
313 TestLoadTimingInfoConnectedReused(handle);
316 void SetIPv4Address(IPEndPoint* address) {
317 *address = IPEndPoint(ParseIP("1.1.1.1"), 80);
320 void SetIPv6Address(IPEndPoint* address) {
321 *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
324 MockTransportClientSocketFactory::MockTransportClientSocketFactory(
327 allocation_count_(0),
328 client_socket_type_(MOCK_CLIENT_SOCKET),
329 client_socket_types_(NULL),
330 client_socket_index_(0),
331 client_socket_index_max_(0),
332 delay_(base::TimeDelta::FromMilliseconds(
333 ClientSocketPool::kMaxConnectRetryIntervalMs)) {}
335 MockTransportClientSocketFactory::~MockTransportClientSocketFactory() {}
337 scoped_ptr<DatagramClientSocket>
338 MockTransportClientSocketFactory::CreateDatagramClientSocket(
339 DatagramSocket::BindType bind_type,
340 const RandIntCallback& rand_int_cb,
342 const NetLog::Source& source) {
344 return scoped_ptr<DatagramClientSocket>();
347 scoped_ptr<StreamSocket>
348 MockTransportClientSocketFactory::CreateTransportClientSocket(
349 const AddressList& addresses,
350 NetLog* /* net_log */,
351 const NetLog::Source& /* source */) {
354 ClientSocketType type = client_socket_type_;
355 if (client_socket_types_ && client_socket_index_ < client_socket_index_max_) {
356 type = client_socket_types_[client_socket_index_++];
360 case MOCK_CLIENT_SOCKET:
361 return scoped_ptr<StreamSocket>(
362 new MockConnectClientSocket(addresses, net_log_));
363 case MOCK_FAILING_CLIENT_SOCKET:
364 return scoped_ptr<StreamSocket>(
365 new MockFailingClientSocket(addresses, net_log_));
366 case MOCK_PENDING_CLIENT_SOCKET:
367 return MockTriggerableClientSocket::MakeMockPendingClientSocket(
368 addresses, true, net_log_);
369 case MOCK_PENDING_FAILING_CLIENT_SOCKET:
370 return MockTriggerableClientSocket::MakeMockPendingClientSocket(
371 addresses, false, net_log_);
372 case MOCK_DELAYED_CLIENT_SOCKET:
373 return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
374 addresses, true, delay_, net_log_);
375 case MOCK_DELAYED_FAILING_CLIENT_SOCKET:
376 return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
377 addresses, false, delay_, net_log_);
378 case MOCK_STALLED_CLIENT_SOCKET:
379 return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses,
381 case MOCK_TRIGGERABLE_CLIENT_SOCKET: {
382 scoped_ptr<MockTriggerableClientSocket> rv(
383 new MockTriggerableClientSocket(addresses, true, net_log_));
384 triggerable_sockets_.push(rv->GetConnectCallback());
385 // run_loop_quit_closure_ behaves like a condition variable. It will
386 // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
387 // don't need to worry about atomicity because this code is
389 if (!run_loop_quit_closure_.is_null())
390 run_loop_quit_closure_.Run();
391 return rv.PassAs<StreamSocket>();
395 return scoped_ptr<StreamSocket>(
396 new MockConnectClientSocket(addresses, net_log_));
400 scoped_ptr<SSLClientSocket>
401 MockTransportClientSocketFactory::CreateSSLClientSocket(
402 scoped_ptr<ClientSocketHandle> transport_socket,
403 const HostPortPair& host_and_port,
404 const SSLConfig& ssl_config,
405 const SSLClientSocketContext& context) {
407 return scoped_ptr<SSLClientSocket>();
410 void MockTransportClientSocketFactory::ClearSSLSessionCache() {
414 void MockTransportClientSocketFactory::set_client_socket_types(
415 ClientSocketType* type_list,
417 DCHECK_GT(num_types, 0);
418 client_socket_types_ = type_list;
419 client_socket_index_ = 0;
420 client_socket_index_max_ = num_types;
424 MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() {
425 while (triggerable_sockets_.empty()) {
426 base::RunLoop run_loop;
427 run_loop_quit_closure_ = run_loop.QuitClosure();
429 run_loop_quit_closure_.Reset();
431 base::Closure trigger = triggerable_sockets_.front();
432 triggerable_sockets_.pop();