Update To 11.40.268.0
[platform/framework/web/crosswalk.git] / src / net / socket / ssl_server_socket_unittest.cc
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 // This test suite uses SSLClientSocket to test the implementation of
6 // SSLServerSocket. In order to establish connections between the sockets
7 // we need two additional classes:
8 // 1. FakeSocket
9 //    Connects SSL socket to FakeDataChannel. This class is just a stub.
10 //
11 // 2. FakeDataChannel
12 //    Implements the actual exchange of data between two FakeSockets.
13 //
14 // Implementations of these two classes are included in this file.
15
16 #include "net/socket/ssl_server_socket.h"
17
18 #include <stdlib.h>
19
20 #include <queue>
21
22 #include "base/compiler_specific.h"
23 #include "base/files/file_path.h"
24 #include "base/files/file_util.h"
25 #include "base/message_loop/message_loop.h"
26 #include "crypto/nss_util.h"
27 #include "crypto/rsa_private_key.h"
28 #include "net/base/address_list.h"
29 #include "net/base/completion_callback.h"
30 #include "net/base/host_port_pair.h"
31 #include "net/base/io_buffer.h"
32 #include "net/base/ip_endpoint.h"
33 #include "net/base/net_errors.h"
34 #include "net/base/net_log.h"
35 #include "net/base/test_data_directory.h"
36 #include "net/cert/cert_status_flags.h"
37 #include "net/cert/mock_cert_verifier.h"
38 #include "net/cert/x509_certificate.h"
39 #include "net/http/transport_security_state.h"
40 #include "net/socket/client_socket_factory.h"
41 #include "net/socket/socket_test_util.h"
42 #include "net/socket/ssl_client_socket.h"
43 #include "net/socket/stream_socket.h"
44 #include "net/ssl/ssl_config_service.h"
45 #include "net/ssl/ssl_info.h"
46 #include "net/test/cert_test_util.h"
47 #include "testing/gtest/include/gtest/gtest.h"
48 #include "testing/platform_test.h"
49
50 namespace net {
51
52 namespace {
53
54 class FakeDataChannel {
55  public:
56   FakeDataChannel()
57       : read_buf_len_(0),
58         closed_(false),
59         write_called_after_close_(false),
60         weak_factory_(this) {
61   }
62
63   int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
64     DCHECK(read_callback_.is_null());
65     DCHECK(!read_buf_.get());
66     if (closed_)
67       return 0;
68     if (data_.empty()) {
69       read_callback_ = callback;
70       read_buf_ = buf;
71       read_buf_len_ = buf_len;
72       return ERR_IO_PENDING;
73     }
74     return PropogateData(buf, buf_len);
75   }
76
77   int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
78     DCHECK(write_callback_.is_null());
79     if (closed_) {
80       if (write_called_after_close_)
81         return ERR_CONNECTION_RESET;
82       write_called_after_close_ = true;
83       write_callback_ = callback;
84       base::MessageLoop::current()->PostTask(
85           FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback,
86                                 weak_factory_.GetWeakPtr()));
87       return ERR_IO_PENDING;
88     }
89     // This function returns synchronously, so make a copy of the buffer.
90     data_.push(new DrainableIOBuffer(
91         new StringIOBuffer(std::string(buf->data(), buf_len)),
92         buf_len));
93     base::MessageLoop::current()->PostTask(
94         FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback,
95                               weak_factory_.GetWeakPtr()));
96     return buf_len;
97   }
98
99   // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
100   // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
101   // after the FakeDataChannel is closed, the first Write() call completes
102   // asynchronously, which is necessary to reproduce bug 127822.
103   void Close() {
104     closed_ = true;
105   }
106
107  private:
108   void DoReadCallback() {
109     if (read_callback_.is_null() || data_.empty())
110       return;
111
112     int copied = PropogateData(read_buf_, read_buf_len_);
113     CompletionCallback callback = read_callback_;
114     read_callback_.Reset();
115     read_buf_ = NULL;
116     read_buf_len_ = 0;
117     callback.Run(copied);
118   }
119
120   void DoWriteCallback() {
121     if (write_callback_.is_null())
122       return;
123
124     CompletionCallback callback = write_callback_;
125     write_callback_.Reset();
126     callback.Run(ERR_CONNECTION_RESET);
127   }
128
129   int PropogateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) {
130     scoped_refptr<DrainableIOBuffer> buf = data_.front();
131     int copied = std::min(buf->BytesRemaining(), read_buf_len);
132     memcpy(read_buf->data(), buf->data(), copied);
133     buf->DidConsume(copied);
134
135     if (!buf->BytesRemaining())
136       data_.pop();
137     return copied;
138   }
139
140   CompletionCallback read_callback_;
141   scoped_refptr<IOBuffer> read_buf_;
142   int read_buf_len_;
143
144   CompletionCallback write_callback_;
145
146   std::queue<scoped_refptr<DrainableIOBuffer> > data_;
147
148   // True if Close() has been called.
149   bool closed_;
150
151   // Controls the completion of Write() after the FakeDataChannel is closed.
152   // After the FakeDataChannel is closed, the first Write() call completes
153   // asynchronously.
154   bool write_called_after_close_;
155
156   base::WeakPtrFactory<FakeDataChannel> weak_factory_;
157
158   DISALLOW_COPY_AND_ASSIGN(FakeDataChannel);
159 };
160
161 class FakeSocket : public StreamSocket {
162  public:
163   FakeSocket(FakeDataChannel* incoming_channel,
164              FakeDataChannel* outgoing_channel)
165       : incoming_(incoming_channel),
166         outgoing_(outgoing_channel) {
167   }
168
169   ~FakeSocket() override {}
170
171   int Read(IOBuffer* buf,
172            int buf_len,
173            const CompletionCallback& callback) override {
174     // Read random number of bytes.
175     buf_len = rand() % buf_len + 1;
176     return incoming_->Read(buf, buf_len, callback);
177   }
178
179   int Write(IOBuffer* buf,
180             int buf_len,
181             const CompletionCallback& callback) override {
182     // Write random number of bytes.
183     buf_len = rand() % buf_len + 1;
184     return outgoing_->Write(buf, buf_len, callback);
185   }
186
187   int SetReceiveBufferSize(int32 size) override { return OK; }
188
189   int SetSendBufferSize(int32 size) override { return OK; }
190
191   int Connect(const CompletionCallback& callback) override { return OK; }
192
193   void Disconnect() override {
194     incoming_->Close();
195     outgoing_->Close();
196   }
197
198   bool IsConnected() const override { return true; }
199
200   bool IsConnectedAndIdle() const override { return true; }
201
202   int GetPeerAddress(IPEndPoint* address) const override {
203     IPAddressNumber ip_address(kIPv4AddressSize);
204     *address = IPEndPoint(ip_address, 0 /*port*/);
205     return OK;
206   }
207
208   int GetLocalAddress(IPEndPoint* address) const override {
209     IPAddressNumber ip_address(4);
210     *address = IPEndPoint(ip_address, 0);
211     return OK;
212   }
213
214   const BoundNetLog& NetLog() const override { return net_log_; }
215
216   void SetSubresourceSpeculation() override {}
217   void SetOmniboxSpeculation() override {}
218
219   bool WasEverUsed() const override { return true; }
220
221   bool UsingTCPFastOpen() const override { return false; }
222
223   bool WasNpnNegotiated() const override { return false; }
224
225   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
226
227   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
228
229  private:
230   BoundNetLog net_log_;
231   FakeDataChannel* incoming_;
232   FakeDataChannel* outgoing_;
233
234   DISALLOW_COPY_AND_ASSIGN(FakeSocket);
235 };
236
237 }  // namespace
238
239 // Verify the correctness of the test helper classes first.
240 TEST(FakeSocketTest, DataTransfer) {
241   // Establish channels between two sockets.
242   FakeDataChannel channel_1;
243   FakeDataChannel channel_2;
244   FakeSocket client(&channel_1, &channel_2);
245   FakeSocket server(&channel_2, &channel_1);
246
247   const char kTestData[] = "testing123";
248   const int kTestDataSize = strlen(kTestData);
249   const int kReadBufSize = 1024;
250   scoped_refptr<IOBuffer> write_buf = new StringIOBuffer(kTestData);
251   scoped_refptr<IOBuffer> read_buf = new IOBuffer(kReadBufSize);
252
253   // Write then read.
254   int written =
255       server.Write(write_buf.get(), kTestDataSize, CompletionCallback());
256   EXPECT_GT(written, 0);
257   EXPECT_LE(written, kTestDataSize);
258
259   int read = client.Read(read_buf.get(), kReadBufSize, CompletionCallback());
260   EXPECT_GT(read, 0);
261   EXPECT_LE(read, written);
262   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
263
264   // Read then write.
265   TestCompletionCallback callback;
266   EXPECT_EQ(ERR_IO_PENDING,
267             server.Read(read_buf.get(), kReadBufSize, callback.callback()));
268
269   written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback());
270   EXPECT_GT(written, 0);
271   EXPECT_LE(written, kTestDataSize);
272
273   read = callback.WaitForResult();
274   EXPECT_GT(read, 0);
275   EXPECT_LE(read, written);
276   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
277 }
278
279 class SSLServerSocketTest : public PlatformTest {
280  public:
281   SSLServerSocketTest()
282       : socket_factory_(ClientSocketFactory::GetDefaultFactory()),
283         cert_verifier_(new MockCertVerifier()),
284         transport_security_state_(new TransportSecurityState) {
285     cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID);
286   }
287
288  protected:
289   void Initialize() {
290     scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle);
291     client_connection->SetSocket(
292         scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_)));
293     scoped_ptr<StreamSocket> server_socket(
294         new FakeSocket(&channel_2_, &channel_1_));
295
296     base::FilePath certs_dir(GetTestCertsDirectory());
297
298     base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
299     std::string cert_der;
300     ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der));
301
302     scoped_refptr<X509Certificate> cert =
303         X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
304
305     base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
306     std::string key_string;
307     ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
308     std::vector<uint8> key_vector(
309         reinterpret_cast<const uint8*>(key_string.data()),
310         reinterpret_cast<const uint8*>(key_string.data() +
311                                        key_string.length()));
312
313     scoped_ptr<crypto::RSAPrivateKey> private_key(
314         crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
315
316     SSLConfig ssl_config;
317     ssl_config.false_start_enabled = false;
318     ssl_config.channel_id_enabled = false;
319
320     // Certificate provided by the host doesn't need authority.
321     SSLConfig::CertAndStatus cert_and_status;
322     cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID;
323     cert_and_status.der_cert = cert_der;
324     ssl_config.allowed_bad_certs.push_back(cert_and_status);
325
326     HostPortPair host_and_pair("unittest", 0);
327     SSLClientSocketContext context;
328     context.cert_verifier = cert_verifier_.get();
329     context.transport_security_state = transport_security_state_.get();
330     client_socket_ =
331         socket_factory_->CreateSSLClientSocket(
332             client_connection.Pass(), host_and_pair, ssl_config, context);
333     server_socket_ = CreateSSLServerSocket(
334         server_socket.Pass(),
335         cert.get(), private_key.get(), SSLConfig());
336   }
337
338   FakeDataChannel channel_1_;
339   FakeDataChannel channel_2_;
340   scoped_ptr<SSLClientSocket> client_socket_;
341   scoped_ptr<SSLServerSocket> server_socket_;
342   ClientSocketFactory* socket_factory_;
343   scoped_ptr<MockCertVerifier> cert_verifier_;
344   scoped_ptr<TransportSecurityState> transport_security_state_;
345 };
346
347 // This test only executes creation of client and server sockets. This is to
348 // test that creation of sockets doesn't crash and have minimal code to run
349 // under valgrind in order to help debugging memory problems.
350 TEST_F(SSLServerSocketTest, Initialize) {
351   Initialize();
352 }
353
354 // This test executes Connect() on SSLClientSocket and Handshake() on
355 // SSLServerSocket to make sure handshaking between the two sockets is
356 // completed successfully.
357 TEST_F(SSLServerSocketTest, Handshake) {
358   Initialize();
359
360   TestCompletionCallback connect_callback;
361   TestCompletionCallback handshake_callback;
362
363   int server_ret = server_socket_->Handshake(handshake_callback.callback());
364   EXPECT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
365
366   int client_ret = client_socket_->Connect(connect_callback.callback());
367   EXPECT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
368
369   if (client_ret == ERR_IO_PENDING) {
370     EXPECT_EQ(OK, connect_callback.WaitForResult());
371   }
372   if (server_ret == ERR_IO_PENDING) {
373     EXPECT_EQ(OK, handshake_callback.WaitForResult());
374   }
375
376   // Make sure the cert status is expected.
377   SSLInfo ssl_info;
378   client_socket_->GetSSLInfo(&ssl_info);
379   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
380 }
381
382 TEST_F(SSLServerSocketTest, DataTransfer) {
383   Initialize();
384
385   TestCompletionCallback connect_callback;
386   TestCompletionCallback handshake_callback;
387
388   // Establish connection.
389   int client_ret = client_socket_->Connect(connect_callback.callback());
390   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
391
392   int server_ret = server_socket_->Handshake(handshake_callback.callback());
393   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
394
395   client_ret = connect_callback.GetResult(client_ret);
396   ASSERT_EQ(OK, client_ret);
397   server_ret = handshake_callback.GetResult(server_ret);
398   ASSERT_EQ(OK, server_ret);
399
400   const int kReadBufSize = 1024;
401   scoped_refptr<StringIOBuffer> write_buf =
402       new StringIOBuffer("testing123");
403   scoped_refptr<DrainableIOBuffer> read_buf =
404       new DrainableIOBuffer(new IOBuffer(kReadBufSize), kReadBufSize);
405
406   // Write then read.
407   TestCompletionCallback write_callback;
408   TestCompletionCallback read_callback;
409   server_ret = server_socket_->Write(
410       write_buf.get(), write_buf->size(), write_callback.callback());
411   EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
412   client_ret = client_socket_->Read(
413       read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
414   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
415
416   server_ret = write_callback.GetResult(server_ret);
417   EXPECT_GT(server_ret, 0);
418   client_ret = read_callback.GetResult(client_ret);
419   ASSERT_GT(client_ret, 0);
420
421   read_buf->DidConsume(client_ret);
422   while (read_buf->BytesConsumed() < write_buf->size()) {
423     client_ret = client_socket_->Read(
424         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
425     EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
426     client_ret = read_callback.GetResult(client_ret);
427     ASSERT_GT(client_ret, 0);
428     read_buf->DidConsume(client_ret);
429   }
430   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
431   read_buf->SetOffset(0);
432   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
433
434   // Read then write.
435   write_buf = new StringIOBuffer("hello123");
436   server_ret = server_socket_->Read(
437       read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
438   EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
439   client_ret = client_socket_->Write(
440       write_buf.get(), write_buf->size(), write_callback.callback());
441   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
442
443   server_ret = read_callback.GetResult(server_ret);
444   ASSERT_GT(server_ret, 0);
445   client_ret = write_callback.GetResult(client_ret);
446   EXPECT_GT(client_ret, 0);
447
448   read_buf->DidConsume(server_ret);
449   while (read_buf->BytesConsumed() < write_buf->size()) {
450     server_ret = server_socket_->Read(
451         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
452     EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
453     server_ret = read_callback.GetResult(server_ret);
454     ASSERT_GT(server_ret, 0);
455     read_buf->DidConsume(server_ret);
456   }
457   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
458   read_buf->SetOffset(0);
459   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
460 }
461
462 // A regression test for bug 127822 (http://crbug.com/127822).
463 // If the server closes the connection after the handshake is finished,
464 // the client's Write() call should not cause an infinite loop.
465 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
466 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
467   Initialize();
468
469   TestCompletionCallback connect_callback;
470   TestCompletionCallback handshake_callback;
471
472   // Establish connection.
473   int client_ret = client_socket_->Connect(connect_callback.callback());
474   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
475
476   int server_ret = server_socket_->Handshake(handshake_callback.callback());
477   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
478
479   client_ret = connect_callback.GetResult(client_ret);
480   ASSERT_EQ(OK, client_ret);
481   server_ret = handshake_callback.GetResult(server_ret);
482   ASSERT_EQ(OK, server_ret);
483
484   scoped_refptr<StringIOBuffer> write_buf = new StringIOBuffer("testing123");
485
486   // The server closes the connection. The server needs to write some
487   // data first so that the client's Read() calls from the transport
488   // socket won't return ERR_IO_PENDING.  This ensures that the client
489   // will call Read() on the transport socket again.
490   TestCompletionCallback write_callback;
491
492   server_ret = server_socket_->Write(
493       write_buf.get(), write_buf->size(), write_callback.callback());
494   EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
495
496   server_ret = write_callback.GetResult(server_ret);
497   EXPECT_GT(server_ret, 0);
498
499   server_socket_->Disconnect();
500
501   // The client writes some data. This should not cause an infinite loop.
502   client_ret = client_socket_->Write(
503       write_buf.get(), write_buf->size(), write_callback.callback());
504   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
505
506   client_ret = write_callback.GetResult(client_ret);
507   EXPECT_GT(client_ret, 0);
508
509   base::MessageLoop::current()->PostDelayedTask(
510       FROM_HERE, base::MessageLoop::QuitClosure(),
511       base::TimeDelta::FromMilliseconds(10));
512   base::MessageLoop::current()->Run();
513 }
514
515 // This test executes ExportKeyingMaterial() on the client and server sockets,
516 // after connecting them, and verifies that the results match.
517 // This test will fail if False Start is enabled (see crbug.com/90208).
518 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
519   Initialize();
520
521   TestCompletionCallback connect_callback;
522   TestCompletionCallback handshake_callback;
523
524   int client_ret = client_socket_->Connect(connect_callback.callback());
525   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
526
527   int server_ret = server_socket_->Handshake(handshake_callback.callback());
528   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
529
530   if (client_ret == ERR_IO_PENDING) {
531     ASSERT_EQ(OK, connect_callback.WaitForResult());
532   }
533   if (server_ret == ERR_IO_PENDING) {
534     ASSERT_EQ(OK, handshake_callback.WaitForResult());
535   }
536
537   const int kKeyingMaterialSize = 32;
538   const char* kKeyingLabel = "EXPERIMENTAL-server-socket-test";
539   const char* kKeyingContext = "";
540   unsigned char server_out[kKeyingMaterialSize];
541   int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel,
542                                                 false, kKeyingContext,
543                                                 server_out, sizeof(server_out));
544   ASSERT_EQ(OK, rv);
545
546   unsigned char client_out[kKeyingMaterialSize];
547   rv = client_socket_->ExportKeyingMaterial(kKeyingLabel,
548                                             false, kKeyingContext,
549                                             client_out, sizeof(client_out));
550   ASSERT_EQ(OK, rv);
551   EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
552
553   const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad";
554   unsigned char client_bad[kKeyingMaterialSize];
555   rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad,
556                                             false, kKeyingContext,
557                                             client_bad, sizeof(client_bad));
558   ASSERT_EQ(rv, OK);
559   EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
560 }
561
562 }  // namespace net