Upstream version 10.39.225.0
[platform/framework/web/crosswalk.git] / src / net / server / http_server_unittest.cc
index 467bde4..216cb03 100644 (file)
@@ -2,11 +2,13 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
+#include <algorithm>
 #include <utility>
 #include <vector>
 
 #include "base/bind.h"
 #include "base/bind_helpers.h"
+#include "base/callback_helpers.h"
 #include "base/compiler_specific.h"
 #include "base/format_macros.h"
 #include "base/memory/ref_counted.h"
 #include "net/base/ip_endpoint.h"
 #include "net/base/net_errors.h"
 #include "net/base/net_log.h"
+#include "net/base/net_util.h"
 #include "net/base/test_completion_callback.h"
+#include "net/http/http_response_headers.h"
+#include "net/http/http_util.h"
 #include "net/server/http_server.h"
 #include "net/server/http_server_request_info.h"
 #include "net/socket/tcp_client_socket.h"
-#include "net/socket/tcp_listen_socket.h"
+#include "net/socket/tcp_server_socket.h"
 #include "net/url_request/url_fetcher.h"
 #include "net/url_request/url_fetcher_delegate.h"
 #include "net/url_request/url_request_context.h"
@@ -90,10 +95,6 @@ class TestHttpClient {
     Write();
   }
 
-  bool Read(std::string* message) {
-    return Read(message, 1);
-  }
-
   bool Read(std::string* message, int expected_bytes) {
     int total_bytes_received = 0;
     message->clear();
@@ -110,6 +111,18 @@ class TestHttpClient {
     return true;
   }
 
+  bool ReadResponse(std::string* message) {
+    if (!Read(message, 1))
+      return false;
+    while (!IsCompleteResponse(*message)) {
+      std::string chunk;
+      if (!Read(&chunk, 1))
+        return false;
+      message->append(chunk);
+    }
+    return true;
+  }
+
  private:
   void OnConnect(const base::Closure& quit_loop, int result) {
     connect_result_ = result;
@@ -134,13 +147,27 @@ class TestHttpClient {
 
   void ReadInternal(const net::CompletionCallback& callback) {
     read_buffer_ = new IOBufferWithSize(kMaxExpectedResponseLength);
-    int result = socket_->Read(read_buffer_,
-                               kMaxExpectedResponseLength,
-                               callback);
+    int result =
+        socket_->Read(read_buffer_.get(), kMaxExpectedResponseLength, callback);
     if (result != ERR_IO_PENDING)
       callback.Run(result);
   }
 
+  bool IsCompleteResponse(const std::string& response) {
+    // Check end of headers first.
+    int end_of_headers = HttpUtil::LocateEndOfHeaders(response.data(),
+                                                      response.size());
+    if (end_of_headers < 0)
+      return false;
+
+    // Return true if response has data equal to or more than content length.
+    int64 body_size = static_cast<int64>(response.size()) - end_of_headers;
+    DCHECK_LE(0, body_size);
+    scoped_refptr<HttpResponseHeaders> headers(new HttpResponseHeaders(
+        HttpUtil::AssembleRawHeaders(response.data(), end_of_headers)));
+    return body_size >= headers->GetContentLength();
+  }
+
   scoped_refptr<IOBufferWithSize> read_buffer_;
   scoped_refptr<DrainableIOBuffer> write_buffer_;
   scoped_ptr<TCPClientSocket> socket_;
@@ -155,11 +182,15 @@ class HttpServerTest : public testing::Test,
   HttpServerTest() : quit_after_request_count_(0) {}
 
   virtual void SetUp() OVERRIDE {
-    TCPListenSocketFactory socket_factory("127.0.0.1", 0);
-    server_ = new HttpServer(socket_factory, this);
+    scoped_ptr<ServerSocket> server_socket(
+        new TCPServerSocket(NULL, net::NetLog::Source()));
+    server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1);
+    server_.reset(new HttpServer(server_socket.Pass(), this));
     ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_));
   }
 
+  virtual void OnConnect(int connection_id) OVERRIDE {}
+
   virtual void OnHttpRequest(int connection_id,
                              const HttpServerRequestInfo& info) OVERRIDE {
     requests_.push_back(std::make_pair(info, connection_id));
@@ -199,8 +230,13 @@ class HttpServerTest : public testing::Test,
     return requests_[request_index].second;
   }
 
+  void HandleAcceptResult(scoped_ptr<StreamSocket> socket) {
+    server_->accepted_socket_.reset(socket.release());
+    server_->HandleAcceptResult(OK);
+  }
+
  protected:
-  scoped_refptr<HttpServer> server_;
+  scoped_ptr<HttpServer> server_;
   IPEndPoint server_address_;
   base::Closure run_loop_quit_func_;
   std::vector<std::pair<HttpServerRequestInfo, int> > requests_;
@@ -209,6 +245,8 @@ class HttpServerTest : public testing::Test,
   size_t quit_after_request_count_;
 };
 
+namespace {
+
 class WebSocketTest : public HttpServerTest {
   virtual void OnHttpRequest(int connection_id,
                              const HttpServerRequestInfo& info) OVERRIDE {
@@ -407,7 +445,7 @@ TEST_F(HttpServerTest, Send200) {
   server_->Send200(GetConnectionId(0), "Response!", "text/plain");
 
   std::string response;
-  ASSERT_TRUE(client.Read(&response));
+  ASSERT_TRUE(client.ReadResponse(&response));
   ASSERT_TRUE(StartsWithASCII(response, "HTTP/1.1 200 OK", true));
   ASSERT_TRUE(EndsWith(response, "Response!", true));
 }
@@ -427,25 +465,103 @@ TEST_F(HttpServerTest, SendRaw) {
   ASSERT_EQ(expected_response, response);
 }
 
-namespace {
-
-class MockStreamListenSocket : public StreamListenSocket {
+class MockStreamSocket : public StreamSocket {
  public:
-  MockStreamListenSocket(StreamListenSocket::Delegate* delegate)
-      : StreamListenSocket(kInvalidSocket, delegate) {}
+  MockStreamSocket()
+      : connected_(true),
+        read_buf_(NULL),
+        read_buf_len_(0) {}
+
+  // StreamSocket
+  virtual int Connect(const CompletionCallback& callback) OVERRIDE {
+    return ERR_NOT_IMPLEMENTED;
+  }
+  virtual void Disconnect() OVERRIDE {
+    connected_ = false;
+    if (!read_callback_.is_null()) {
+      read_buf_ = NULL;
+      read_buf_len_ = 0;
+      base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED);
+    }
+  }
+  virtual bool IsConnected() const OVERRIDE { return connected_; }
+  virtual bool IsConnectedAndIdle() const OVERRIDE { return IsConnected(); }
+  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
+    return ERR_NOT_IMPLEMENTED;
+  }
+  virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
+    return ERR_NOT_IMPLEMENTED;
+  }
+  virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; }
+  virtual void SetSubresourceSpeculation() OVERRIDE {}
+  virtual void SetOmniboxSpeculation() OVERRIDE {}
+  virtual bool WasEverUsed() const OVERRIDE { return true; }
+  virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
+  virtual bool WasNpnNegotiated() const OVERRIDE { return false; }
+  virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
+    return kProtoUnknown;
+  }
+  virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; }
 
-  virtual void Accept() OVERRIDE { NOTREACHED(); }
+  // Socket
+  virtual int Read(IOBuffer* buf, int buf_len,
+                   const CompletionCallback& callback) OVERRIDE {
+    if (!connected_) {
+      return ERR_SOCKET_NOT_CONNECTED;
+    }
+    if (pending_read_data_.empty()) {
+      read_buf_ = buf;
+      read_buf_len_ = buf_len;
+      read_callback_ = callback;
+      return ERR_IO_PENDING;
+    }
+    DCHECK_GT(buf_len, 0);
+    int read_len = std::min(static_cast<int>(pending_read_data_.size()),
+                            buf_len);
+    memcpy(buf->data(), pending_read_data_.data(), read_len);
+    pending_read_data_.erase(0, read_len);
+    return read_len;
+  }
+  virtual int Write(IOBuffer* buf, int buf_len,
+                    const CompletionCallback& callback) OVERRIDE {
+    return ERR_NOT_IMPLEMENTED;
+  }
+  virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
+    return ERR_NOT_IMPLEMENTED;
+  }
+  virtual int SetSendBufferSize(int32 size) OVERRIDE {
+    return ERR_NOT_IMPLEMENTED;
+  }
+
+  void DidRead(const char* data, int data_len) {
+    if (!read_buf_.get()) {
+      pending_read_data_.append(data, data_len);
+      return;
+    }
+    int read_len = std::min(data_len, read_buf_len_);
+    memcpy(read_buf_->data(), data, read_len);
+    pending_read_data_.assign(data + read_len, data_len - read_len);
+    read_buf_ = NULL;
+    read_buf_len_ = 0;
+    base::ResetAndReturn(&read_callback_).Run(read_len);
+  }
 
  private:
-  virtual ~MockStreamListenSocket() {}
-};
+  virtual ~MockStreamSocket() {}
 
-}  // namespace
+  bool connected_;
+  scoped_refptr<IOBuffer> read_buf_;
+  int read_buf_len_;
+  CompletionCallback read_callback_;
+  std::string pending_read_data_;
+  BoundNetLog net_log_;
+
+  DISALLOW_COPY_AND_ASSIGN(MockStreamSocket);
+};
 
 TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
-  StreamListenSocket* socket =
-      new MockStreamListenSocket(server_.get());
-  server_->DidAccept(NULL, make_scoped_ptr(socket));
+  MockStreamSocket* socket = new MockStreamSocket();
+  HandleAcceptResult(make_scoped_ptr<StreamSocket>(socket));
   std::string body("body");
   std::string request_text = base::StringPrintf(
       "GET /test HTTP/1.1\r\n"
@@ -453,9 +569,9 @@ TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
       "Content-Length: %" PRIuS "\r\n\r\n%s",
       body.length(),
       body.c_str());
-  server_->DidRead(socket, request_text.c_str(), request_text.length() - 2);
+  socket->DidRead(request_text.c_str(), request_text.length() - 2);
   ASSERT_EQ(0u, requests_.size());
-  server_->DidRead(socket, request_text.c_str() + request_text.length() - 2, 2);
+  socket->DidRead(request_text.c_str() + request_text.length() - 2, 2);
   ASSERT_EQ(1u, requests_.size());
   ASSERT_EQ(body, GetRequest(0).data);
 }
@@ -477,7 +593,7 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
   int client_connection_id = GetConnectionId(0);
   server_->Send200(client_connection_id, "Content for /test", "text/plain");
   std::string response1;
-  ASSERT_TRUE(client.Read(&response1));
+  ASSERT_TRUE(client.ReadResponse(&response1));
   ASSERT_TRUE(StartsWithASCII(response1, "HTTP/1.1 200 OK", true));
   ASSERT_TRUE(EndsWith(response1, "Content for /test", true));
 
@@ -488,7 +604,7 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
   ASSERT_EQ(client_connection_id, GetConnectionId(1));
   server_->Send404(client_connection_id);
   std::string response2;
-  ASSERT_TRUE(client.Read(&response2));
+  ASSERT_TRUE(client.ReadResponse(&response2));
   ASSERT_TRUE(StartsWithASCII(response2, "HTTP/1.1 404 Not Found", true));
 
   client.Send("GET /test3 HTTP/1.1\r\n\r\n");
@@ -498,9 +614,31 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
   ASSERT_EQ(client_connection_id, GetConnectionId(2));
   server_->Send200(client_connection_id, "Content for /test3", "text/plain");
   std::string response3;
-  ASSERT_TRUE(client.Read(&response3));
+  ASSERT_TRUE(client.ReadResponse(&response3));
   ASSERT_TRUE(StartsWithASCII(response3, "HTTP/1.1 200 OK", true));
   ASSERT_TRUE(EndsWith(response3, "Content for /test3", true));
 }
 
+class CloseOnConnectHttpServerTest : public HttpServerTest {
+ public:
+  virtual void OnConnect(int connection_id) OVERRIDE {
+    connection_ids_.push_back(connection_id);
+    server_->Close(connection_id);
+  }
+
+ protected:
+  std::vector<int> connection_ids_;
+};
+
+TEST_F(CloseOnConnectHttpServerTest, ServerImmediatelyClosesConnection) {
+  TestHttpClient client;
+  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
+  client.Send("GET / HTTP/1.1\r\n\r\n");
+  ASSERT_FALSE(RunUntilRequestsReceived(1));
+  ASSERT_EQ(1ul, connection_ids_.size());
+  ASSERT_EQ(0ul, requests_.size());
+}
+
+}  // namespace
+
 }  // namespace net