Update To 11.40.268.0
[platform/framework/web/crosswalk.git] / src / net / server / http_server_unittest.cc
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <algorithm>
6 #include <utility>
7 #include <vector>
8
9 #include "base/bind.h"
10 #include "base/bind_helpers.h"
11 #include "base/callback_helpers.h"
12 #include "base/compiler_specific.h"
13 #include "base/format_macros.h"
14 #include "base/memory/ref_counted.h"
15 #include "base/memory/scoped_ptr.h"
16 #include "base/memory/weak_ptr.h"
17 #include "base/message_loop/message_loop.h"
18 #include "base/message_loop/message_loop_proxy.h"
19 #include "base/run_loop.h"
20 #include "base/strings/string_split.h"
21 #include "base/strings/string_util.h"
22 #include "base/strings/stringprintf.h"
23 #include "base/time/time.h"
24 #include "net/base/address_list.h"
25 #include "net/base/io_buffer.h"
26 #include "net/base/ip_endpoint.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/net_log.h"
29 #include "net/base/net_util.h"
30 #include "net/base/test_completion_callback.h"
31 #include "net/http/http_response_headers.h"
32 #include "net/http/http_util.h"
33 #include "net/server/http_server.h"
34 #include "net/server/http_server_request_info.h"
35 #include "net/socket/tcp_client_socket.h"
36 #include "net/socket/tcp_server_socket.h"
37 #include "net/url_request/url_fetcher.h"
38 #include "net/url_request/url_fetcher_delegate.h"
39 #include "net/url_request/url_request_context.h"
40 #include "net/url_request/url_request_context_getter.h"
41 #include "net/url_request/url_request_test_util.h"
42 #include "testing/gtest/include/gtest/gtest.h"
43
44 namespace net {
45
46 namespace {
47
48 const int kMaxExpectedResponseLength = 2048;
49
50 void SetTimedOutAndQuitLoop(const base::WeakPtr<bool> timed_out,
51                             const base::Closure& quit_loop_func) {
52   if (timed_out) {
53     *timed_out = true;
54     quit_loop_func.Run();
55   }
56 }
57
58 bool RunLoopWithTimeout(base::RunLoop* run_loop) {
59   bool timed_out = false;
60   base::WeakPtrFactory<bool> timed_out_weak_factory(&timed_out);
61   base::MessageLoop::current()->PostDelayedTask(
62       FROM_HERE,
63       base::Bind(&SetTimedOutAndQuitLoop,
64                  timed_out_weak_factory.GetWeakPtr(),
65                  run_loop->QuitClosure()),
66       base::TimeDelta::FromSeconds(1));
67   run_loop->Run();
68   return !timed_out;
69 }
70
71 class TestHttpClient {
72  public:
73   TestHttpClient() : connect_result_(OK) {}
74
75   int ConnectAndWait(const IPEndPoint& address) {
76     AddressList addresses(address);
77     NetLog::Source source;
78     socket_.reset(new TCPClientSocket(addresses, NULL, source));
79
80     base::RunLoop run_loop;
81     connect_result_ = socket_->Connect(base::Bind(&TestHttpClient::OnConnect,
82                                                   base::Unretained(this),
83                                                   run_loop.QuitClosure()));
84     if (connect_result_ != OK && connect_result_ != ERR_IO_PENDING)
85       return connect_result_;
86
87     if (!RunLoopWithTimeout(&run_loop))
88       return ERR_TIMED_OUT;
89     return connect_result_;
90   }
91
92   void Send(const std::string& data) {
93     write_buffer_ =
94         new DrainableIOBuffer(new StringIOBuffer(data), data.length());
95     Write();
96   }
97
98   bool Read(std::string* message, int expected_bytes) {
99     int total_bytes_received = 0;
100     message->clear();
101     while (total_bytes_received < expected_bytes) {
102       net::TestCompletionCallback callback;
103       ReadInternal(callback.callback());
104       int bytes_received = callback.WaitForResult();
105       if (bytes_received <= 0)
106         return false;
107
108       total_bytes_received += bytes_received;
109       message->append(read_buffer_->data(), bytes_received);
110     }
111     return true;
112   }
113
114   bool ReadResponse(std::string* message) {
115     if (!Read(message, 1))
116       return false;
117     while (!IsCompleteResponse(*message)) {
118       std::string chunk;
119       if (!Read(&chunk, 1))
120         return false;
121       message->append(chunk);
122     }
123     return true;
124   }
125
126  private:
127   void OnConnect(const base::Closure& quit_loop, int result) {
128     connect_result_ = result;
129     quit_loop.Run();
130   }
131
132   void Write() {
133     int result = socket_->Write(
134         write_buffer_.get(),
135         write_buffer_->BytesRemaining(),
136         base::Bind(&TestHttpClient::OnWrite, base::Unretained(this)));
137     if (result != ERR_IO_PENDING)
138       OnWrite(result);
139   }
140
141   void OnWrite(int result) {
142     ASSERT_GT(result, 0);
143     write_buffer_->DidConsume(result);
144     if (write_buffer_->BytesRemaining())
145       Write();
146   }
147
148   void ReadInternal(const net::CompletionCallback& callback) {
149     read_buffer_ = new IOBufferWithSize(kMaxExpectedResponseLength);
150     int result =
151         socket_->Read(read_buffer_.get(), kMaxExpectedResponseLength, callback);
152     if (result != ERR_IO_PENDING)
153       callback.Run(result);
154   }
155
156   bool IsCompleteResponse(const std::string& response) {
157     // Check end of headers first.
158     int end_of_headers = HttpUtil::LocateEndOfHeaders(response.data(),
159                                                       response.size());
160     if (end_of_headers < 0)
161       return false;
162
163     // Return true if response has data equal to or more than content length.
164     int64 body_size = static_cast<int64>(response.size()) - end_of_headers;
165     DCHECK_LE(0, body_size);
166     scoped_refptr<HttpResponseHeaders> headers(new HttpResponseHeaders(
167         HttpUtil::AssembleRawHeaders(response.data(), end_of_headers)));
168     return body_size >= headers->GetContentLength();
169   }
170
171   scoped_refptr<IOBufferWithSize> read_buffer_;
172   scoped_refptr<DrainableIOBuffer> write_buffer_;
173   scoped_ptr<TCPClientSocket> socket_;
174   int connect_result_;
175 };
176
177 }  // namespace
178
179 class HttpServerTest : public testing::Test,
180                        public HttpServer::Delegate {
181  public:
182   HttpServerTest() : quit_after_request_count_(0) {}
183
184   void SetUp() override {
185     scoped_ptr<ServerSocket> server_socket(
186         new TCPServerSocket(NULL, net::NetLog::Source()));
187     server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1);
188     server_.reset(new HttpServer(server_socket.Pass(), this));
189     ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_));
190   }
191
192   void OnConnect(int connection_id) override {}
193
194   void OnHttpRequest(int connection_id,
195                      const HttpServerRequestInfo& info) override {
196     requests_.push_back(std::make_pair(info, connection_id));
197     if (requests_.size() == quit_after_request_count_)
198       run_loop_quit_func_.Run();
199   }
200
201   void OnWebSocketRequest(int connection_id,
202                           const HttpServerRequestInfo& info) override {
203     NOTREACHED();
204   }
205
206   void OnWebSocketMessage(int connection_id, const std::string& data) override {
207     NOTREACHED();
208   }
209
210   void OnClose(int connection_id) override {}
211
212   bool RunUntilRequestsReceived(size_t count) {
213     quit_after_request_count_ = count;
214     if (requests_.size() == count)
215       return true;
216
217     base::RunLoop run_loop;
218     run_loop_quit_func_ = run_loop.QuitClosure();
219     bool success = RunLoopWithTimeout(&run_loop);
220     run_loop_quit_func_.Reset();
221     return success;
222   }
223
224   HttpServerRequestInfo GetRequest(size_t request_index) {
225     return requests_[request_index].first;
226   }
227
228   int GetConnectionId(size_t request_index) {
229     return requests_[request_index].second;
230   }
231
232   void HandleAcceptResult(scoped_ptr<StreamSocket> socket) {
233     server_->accepted_socket_.reset(socket.release());
234     server_->HandleAcceptResult(OK);
235   }
236
237  protected:
238   scoped_ptr<HttpServer> server_;
239   IPEndPoint server_address_;
240   base::Closure run_loop_quit_func_;
241   std::vector<std::pair<HttpServerRequestInfo, int> > requests_;
242
243  private:
244   size_t quit_after_request_count_;
245 };
246
247 namespace {
248
249 class WebSocketTest : public HttpServerTest {
250   void OnHttpRequest(int connection_id,
251                      const HttpServerRequestInfo& info) override {
252     NOTREACHED();
253   }
254
255   void OnWebSocketRequest(int connection_id,
256                           const HttpServerRequestInfo& info) override {
257     HttpServerTest::OnHttpRequest(connection_id, info);
258   }
259
260   void OnWebSocketMessage(int connection_id, const std::string& data) override {
261   }
262 };
263
264 TEST_F(HttpServerTest, Request) {
265   TestHttpClient client;
266   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
267   client.Send("GET /test HTTP/1.1\r\n\r\n");
268   ASSERT_TRUE(RunUntilRequestsReceived(1));
269   ASSERT_EQ("GET", GetRequest(0).method);
270   ASSERT_EQ("/test", GetRequest(0).path);
271   ASSERT_EQ("", GetRequest(0).data);
272   ASSERT_EQ(0u, GetRequest(0).headers.size());
273   ASSERT_TRUE(StartsWithASCII(GetRequest(0).peer.ToString(),
274                               "127.0.0.1",
275                               true));
276 }
277
278 TEST_F(HttpServerTest, RequestWithHeaders) {
279   TestHttpClient client;
280   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
281   const char* kHeaders[][3] = {
282       {"Header", ": ", "1"},
283       {"HeaderWithNoWhitespace", ":", "1"},
284       {"HeaderWithWhitespace", "   :  \t   ", "1 1 1 \t  "},
285       {"HeaderWithColon", ": ", "1:1"},
286       {"EmptyHeader", ":", ""},
287       {"EmptyHeaderWithWhitespace", ":  \t  ", ""},
288       {"HeaderWithNonASCII", ":  ", "\xf7"},
289   };
290   std::string headers;
291   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
292     headers +=
293         std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
294   }
295
296   client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
297   ASSERT_TRUE(RunUntilRequestsReceived(1));
298   ASSERT_EQ("", GetRequest(0).data);
299
300   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
301     std::string field = base::StringToLowerASCII(std::string(kHeaders[i][0]));
302     std::string value = kHeaders[i][2];
303     ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field;
304     ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0];
305   }
306 }
307
308 TEST_F(HttpServerTest, RequestWithDuplicateHeaders) {
309   TestHttpClient client;
310   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
311   const char* kHeaders[][3] = {
312       {"FirstHeader", ": ", "1"},
313       {"DuplicateHeader", ": ", "2"},
314       {"MiddleHeader", ": ", "3"},
315       {"DuplicateHeader", ": ", "4"},
316       {"LastHeader", ": ", "5"},
317   };
318   std::string headers;
319   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
320     headers +=
321         std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
322   }
323
324   client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
325   ASSERT_TRUE(RunUntilRequestsReceived(1));
326   ASSERT_EQ("", GetRequest(0).data);
327
328   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
329     std::string field = base::StringToLowerASCII(std::string(kHeaders[i][0]));
330     std::string value = (field == "duplicateheader") ? "2,4" : kHeaders[i][2];
331     ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field;
332     ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0];
333   }
334 }
335
336 TEST_F(HttpServerTest, HasHeaderValueTest) {
337   TestHttpClient client;
338   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
339   const char* kHeaders[] = {
340       "Header: Abcd",
341       "HeaderWithNoWhitespace:E",
342       "HeaderWithWhitespace   :  \t   f \t  ",
343       "DuplicateHeader: g",
344       "HeaderWithComma: h, i ,j",
345       "DuplicateHeader: k",
346       "EmptyHeader:",
347       "EmptyHeaderWithWhitespace:  \t  ",
348       "HeaderWithNonASCII:  \xf7",
349   };
350   std::string headers;
351   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
352     headers += std::string(kHeaders[i]) + "\r\n";
353   }
354
355   client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
356   ASSERT_TRUE(RunUntilRequestsReceived(1));
357   ASSERT_EQ("", GetRequest(0).data);
358
359   ASSERT_TRUE(GetRequest(0).HasHeaderValue("header", "abcd"));
360   ASSERT_FALSE(GetRequest(0).HasHeaderValue("header", "bc"));
361   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnowhitespace", "e"));
362   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithwhitespace", "f"));
363   ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "g"));
364   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "h"));
365   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "i"));
366   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "j"));
367   ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "k"));
368   ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheader", "x"));
369   ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheaderwithwhitespace", "x"));
370   ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnonascii", "\xf7"));
371 }
372
373 TEST_F(HttpServerTest, RequestWithBody) {
374   TestHttpClient client;
375   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
376   std::string body = "a" + std::string(1 << 10, 'b') + "c";
377   client.Send(base::StringPrintf(
378       "GET /test HTTP/1.1\r\n"
379       "SomeHeader: 1\r\n"
380       "Content-Length: %" PRIuS "\r\n\r\n%s",
381       body.length(),
382       body.c_str()));
383   ASSERT_TRUE(RunUntilRequestsReceived(1));
384   ASSERT_EQ(2u, GetRequest(0).headers.size());
385   ASSERT_EQ(body.length(), GetRequest(0).data.length());
386   ASSERT_EQ('a', body[0]);
387   ASSERT_EQ('c', *body.rbegin());
388 }
389
390 TEST_F(WebSocketTest, RequestWebSocket) {
391   TestHttpClient client;
392   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
393   client.Send(
394       "GET /test HTTP/1.1\r\n"
395       "Upgrade: WebSocket\r\n"
396       "Connection: SomethingElse, Upgrade\r\n"
397       "Sec-WebSocket-Version: 8\r\n"
398       "Sec-WebSocket-Key: key\r\n"
399       "\r\n");
400   ASSERT_TRUE(RunUntilRequestsReceived(1));
401 }
402
403 TEST_F(HttpServerTest, RequestWithTooLargeBody) {
404   class TestURLFetcherDelegate : public URLFetcherDelegate {
405    public:
406     TestURLFetcherDelegate(const base::Closure& quit_loop_func)
407         : quit_loop_func_(quit_loop_func) {}
408     ~TestURLFetcherDelegate() override {}
409
410     void OnURLFetchComplete(const URLFetcher* source) override {
411       EXPECT_EQ(HTTP_INTERNAL_SERVER_ERROR, source->GetResponseCode());
412       quit_loop_func_.Run();
413     }
414
415    private:
416     base::Closure quit_loop_func_;
417   };
418
419   base::RunLoop run_loop;
420   TestURLFetcherDelegate delegate(run_loop.QuitClosure());
421
422   scoped_refptr<URLRequestContextGetter> request_context_getter(
423       new TestURLRequestContextGetter(base::MessageLoopProxy::current()));
424   scoped_ptr<URLFetcher> fetcher(
425       URLFetcher::Create(GURL(base::StringPrintf("http://127.0.0.1:%d/test",
426                                                  server_address_.port())),
427                          URLFetcher::GET,
428                          &delegate));
429   fetcher->SetRequestContext(request_context_getter.get());
430   fetcher->AddExtraRequestHeader(
431       base::StringPrintf("content-length:%d", 1 << 30));
432   fetcher->Start();
433
434   ASSERT_TRUE(RunLoopWithTimeout(&run_loop));
435   ASSERT_EQ(0u, requests_.size());
436 }
437
438 TEST_F(HttpServerTest, Send200) {
439   TestHttpClient client;
440   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
441   client.Send("GET /test HTTP/1.1\r\n\r\n");
442   ASSERT_TRUE(RunUntilRequestsReceived(1));
443   server_->Send200(GetConnectionId(0), "Response!", "text/plain");
444
445   std::string response;
446   ASSERT_TRUE(client.ReadResponse(&response));
447   ASSERT_TRUE(StartsWithASCII(response, "HTTP/1.1 200 OK", true));
448   ASSERT_TRUE(EndsWith(response, "Response!", true));
449 }
450
451 TEST_F(HttpServerTest, SendRaw) {
452   TestHttpClient client;
453   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
454   client.Send("GET /test HTTP/1.1\r\n\r\n");
455   ASSERT_TRUE(RunUntilRequestsReceived(1));
456   server_->SendRaw(GetConnectionId(0), "Raw Data ");
457   server_->SendRaw(GetConnectionId(0), "More Data");
458   server_->SendRaw(GetConnectionId(0), "Third Piece of Data");
459
460   const std::string expected_response("Raw Data More DataThird Piece of Data");
461   std::string response;
462   ASSERT_TRUE(client.Read(&response, expected_response.length()));
463   ASSERT_EQ(expected_response, response);
464 }
465
466 class MockStreamSocket : public StreamSocket {
467  public:
468   MockStreamSocket()
469       : connected_(true),
470         read_buf_(NULL),
471         read_buf_len_(0) {}
472
473   // StreamSocket
474   int Connect(const CompletionCallback& callback) override {
475     return ERR_NOT_IMPLEMENTED;
476   }
477   void Disconnect() override {
478     connected_ = false;
479     if (!read_callback_.is_null()) {
480       read_buf_ = NULL;
481       read_buf_len_ = 0;
482       base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED);
483     }
484   }
485   bool IsConnected() const override { return connected_; }
486   bool IsConnectedAndIdle() const override { return IsConnected(); }
487   int GetPeerAddress(IPEndPoint* address) const override {
488     return ERR_NOT_IMPLEMENTED;
489   }
490   int GetLocalAddress(IPEndPoint* address) const override {
491     return ERR_NOT_IMPLEMENTED;
492   }
493   const BoundNetLog& NetLog() const override { return net_log_; }
494   void SetSubresourceSpeculation() override {}
495   void SetOmniboxSpeculation() override {}
496   bool WasEverUsed() const override { return true; }
497   bool UsingTCPFastOpen() const override { return false; }
498   bool WasNpnNegotiated() const override { return false; }
499   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
500   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
501
502   // Socket
503   int Read(IOBuffer* buf,
504            int buf_len,
505            const CompletionCallback& callback) override {
506     if (!connected_) {
507       return ERR_SOCKET_NOT_CONNECTED;
508     }
509     if (pending_read_data_.empty()) {
510       read_buf_ = buf;
511       read_buf_len_ = buf_len;
512       read_callback_ = callback;
513       return ERR_IO_PENDING;
514     }
515     DCHECK_GT(buf_len, 0);
516     int read_len = std::min(static_cast<int>(pending_read_data_.size()),
517                             buf_len);
518     memcpy(buf->data(), pending_read_data_.data(), read_len);
519     pending_read_data_.erase(0, read_len);
520     return read_len;
521   }
522   int Write(IOBuffer* buf,
523             int buf_len,
524             const CompletionCallback& callback) override {
525     return ERR_NOT_IMPLEMENTED;
526   }
527   int SetReceiveBufferSize(int32 size) override { return ERR_NOT_IMPLEMENTED; }
528   int SetSendBufferSize(int32 size) override { return ERR_NOT_IMPLEMENTED; }
529
530   void DidRead(const char* data, int data_len) {
531     if (!read_buf_.get()) {
532       pending_read_data_.append(data, data_len);
533       return;
534     }
535     int read_len = std::min(data_len, read_buf_len_);
536     memcpy(read_buf_->data(), data, read_len);
537     pending_read_data_.assign(data + read_len, data_len - read_len);
538     read_buf_ = NULL;
539     read_buf_len_ = 0;
540     base::ResetAndReturn(&read_callback_).Run(read_len);
541   }
542
543  private:
544   ~MockStreamSocket() override {}
545
546   bool connected_;
547   scoped_refptr<IOBuffer> read_buf_;
548   int read_buf_len_;
549   CompletionCallback read_callback_;
550   std::string pending_read_data_;
551   BoundNetLog net_log_;
552
553   DISALLOW_COPY_AND_ASSIGN(MockStreamSocket);
554 };
555
556 TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
557   MockStreamSocket* socket = new MockStreamSocket();
558   HandleAcceptResult(make_scoped_ptr<StreamSocket>(socket));
559   std::string body("body");
560   std::string request_text = base::StringPrintf(
561       "GET /test HTTP/1.1\r\n"
562       "SomeHeader: 1\r\n"
563       "Content-Length: %" PRIuS "\r\n\r\n%s",
564       body.length(),
565       body.c_str());
566   socket->DidRead(request_text.c_str(), request_text.length() - 2);
567   ASSERT_EQ(0u, requests_.size());
568   socket->DidRead(request_text.c_str() + request_text.length() - 2, 2);
569   ASSERT_EQ(1u, requests_.size());
570   ASSERT_EQ(body, GetRequest(0).data);
571 }
572
573 TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
574   // The idea behind this test is that requests with or without bodies should
575   // not break parsing of the next request.
576   TestHttpClient client;
577   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
578   std::string body = "body";
579   client.Send(base::StringPrintf(
580       "GET /test HTTP/1.1\r\n"
581       "Content-Length: %" PRIuS "\r\n\r\n%s",
582       body.length(),
583       body.c_str()));
584   ASSERT_TRUE(RunUntilRequestsReceived(1));
585   ASSERT_EQ(body, GetRequest(0).data);
586
587   int client_connection_id = GetConnectionId(0);
588   server_->Send200(client_connection_id, "Content for /test", "text/plain");
589   std::string response1;
590   ASSERT_TRUE(client.ReadResponse(&response1));
591   ASSERT_TRUE(StartsWithASCII(response1, "HTTP/1.1 200 OK", true));
592   ASSERT_TRUE(EndsWith(response1, "Content for /test", true));
593
594   client.Send("GET /test2 HTTP/1.1\r\n\r\n");
595   ASSERT_TRUE(RunUntilRequestsReceived(2));
596   ASSERT_EQ("/test2", GetRequest(1).path);
597
598   ASSERT_EQ(client_connection_id, GetConnectionId(1));
599   server_->Send404(client_connection_id);
600   std::string response2;
601   ASSERT_TRUE(client.ReadResponse(&response2));
602   ASSERT_TRUE(StartsWithASCII(response2, "HTTP/1.1 404 Not Found", true));
603
604   client.Send("GET /test3 HTTP/1.1\r\n\r\n");
605   ASSERT_TRUE(RunUntilRequestsReceived(3));
606   ASSERT_EQ("/test3", GetRequest(2).path);
607
608   ASSERT_EQ(client_connection_id, GetConnectionId(2));
609   server_->Send200(client_connection_id, "Content for /test3", "text/plain");
610   std::string response3;
611   ASSERT_TRUE(client.ReadResponse(&response3));
612   ASSERT_TRUE(StartsWithASCII(response3, "HTTP/1.1 200 OK", true));
613   ASSERT_TRUE(EndsWith(response3, "Content for /test3", true));
614 }
615
616 class CloseOnConnectHttpServerTest : public HttpServerTest {
617  public:
618   void OnConnect(int connection_id) override {
619     connection_ids_.push_back(connection_id);
620     server_->Close(connection_id);
621   }
622
623  protected:
624   std::vector<int> connection_ids_;
625 };
626
627 TEST_F(CloseOnConnectHttpServerTest, ServerImmediatelyClosesConnection) {
628   TestHttpClient client;
629   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
630   client.Send("GET / HTTP/1.1\r\n\r\n");
631   ASSERT_FALSE(RunUntilRequestsReceived(1));
632   ASSERT_EQ(1ul, connection_ids_.size());
633   ASSERT_EQ(0ul, requests_.size());
634 }
635
636 }  // namespace
637
638 }  // namespace net