1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
10 #include "base/at_exit.h"
11 #include "base/bind.h"
12 #include "base/callback.h"
13 #include "base/command_line.h"
14 #include "base/files/file_path.h"
15 #include "base/lazy_instance.h"
16 #include "base/logging.h"
17 #include "base/memory/scoped_ptr.h"
18 #include "base/message_loop/message_loop.h"
19 #include "base/run_loop.h"
20 #include "base/strings/string_number_conversions.h"
21 #include "base/strings/string_split.h"
22 #include "base/strings/string_util.h"
23 #include "base/strings/stringprintf.h"
24 #include "base/synchronization/waitable_event.h"
25 #include "base/threading/thread.h"
26 #include "base/threading/thread_local.h"
27 #include "chrome/test/chromedriver/logging.h"
28 #include "chrome/test/chromedriver/net/port_server.h"
29 #include "chrome/test/chromedriver/server/http_handler.h"
30 #include "chrome/test/chromedriver/version.h"
31 #include "net/base/ip_endpoint.h"
32 #include "net/base/net_errors.h"
33 #include "net/server/http_server.h"
34 #include "net/server/http_server_request_info.h"
35 #include "net/server/http_server_response_info.h"
36 #include "net/socket/tcp_server_socket.h"
40 const char* kLocalHostAddress = "127.0.0.1";
41 const int kBufferSize = 100 * 1024 * 1024; // 100 MB
43 typedef base::Callback<
44 void(const net::HttpServerRequestInfo&, const HttpResponseSenderFunc&)>
45 HttpRequestHandlerFunc;
47 class HttpServer : public net::HttpServer::Delegate {
49 explicit HttpServer(const HttpRequestHandlerFunc& handle_request_func)
50 : handle_request_func_(handle_request_func),
51 weak_factory_(this) {}
53 virtual ~HttpServer() {}
55 bool Start(int port, bool allow_remote) {
56 std::string binding_ip = kLocalHostAddress;
58 binding_ip = "0.0.0.0";
59 scoped_ptr<net::ServerSocket> server_socket(
60 new net::TCPServerSocket(NULL, net::NetLog::Source()));
61 server_socket->ListenWithAddressAndPort(binding_ip, port, 1);
62 server_.reset(new net::HttpServer(server_socket.Pass(), this));
63 net::IPEndPoint address;
64 return server_->GetLocalAddress(&address) == net::OK;
67 // Overridden from net::HttpServer::Delegate:
68 virtual void OnConnect(int connection_id) OVERRIDE {
69 server_->SetSendBufferSize(connection_id, kBufferSize);
70 server_->SetReceiveBufferSize(connection_id, kBufferSize);
72 virtual void OnHttpRequest(int connection_id,
73 const net::HttpServerRequestInfo& info) OVERRIDE {
74 handle_request_func_.Run(
76 base::Bind(&HttpServer::OnResponse,
77 weak_factory_.GetWeakPtr(),
80 virtual void OnWebSocketRequest(
82 const net::HttpServerRequestInfo& info) OVERRIDE {}
83 virtual void OnWebSocketMessage(int connection_id,
84 const std::string& data) OVERRIDE {}
85 virtual void OnClose(int connection_id) OVERRIDE {}
88 void OnResponse(int connection_id,
89 scoped_ptr<net::HttpServerResponseInfo> response) {
90 // Don't support keep-alive, since there's no way to detect if the
91 // client is HTTP/1.0. In such cases, the client may hang waiting for
92 // the connection to close (e.g., python 2.7 urllib).
93 response->AddHeader("Connection", "close");
94 server_->SendResponse(connection_id, *response);
95 // Don't need to call server_->Close(), since SendResponse() will handle
99 HttpRequestHandlerFunc handle_request_func_;
100 scoped_ptr<net::HttpServer> server_;
101 base::WeakPtrFactory<HttpServer> weak_factory_; // Should be last.
104 void SendResponseOnCmdThread(
105 const scoped_refptr<base::SingleThreadTaskRunner>& io_task_runner,
106 const HttpResponseSenderFunc& send_response_on_io_func,
107 scoped_ptr<net::HttpServerResponseInfo> response) {
108 io_task_runner->PostTask(
109 FROM_HERE, base::Bind(send_response_on_io_func, base::Passed(&response)));
112 void HandleRequestOnCmdThread(
113 HttpHandler* handler,
114 const std::vector<std::string>& whitelisted_ips,
115 const net::HttpServerRequestInfo& request,
116 const HttpResponseSenderFunc& send_response_func) {
117 if (!whitelisted_ips.empty()) {
118 std::string peer_address = request.peer.ToStringWithoutPort();
119 if (peer_address != kLocalHostAddress &&
120 std::find(whitelisted_ips.begin(), whitelisted_ips.end(),
121 peer_address) == whitelisted_ips.end()) {
122 LOG(WARNING) << "unauthorized access from " << request.peer.ToString();
123 scoped_ptr<net::HttpServerResponseInfo> response(
124 new net::HttpServerResponseInfo(net::HTTP_UNAUTHORIZED));
125 response->SetBody("Unauthorized access", "text/plain");
126 send_response_func.Run(response.Pass());
131 handler->Handle(request, send_response_func);
134 void HandleRequestOnIOThread(
135 const scoped_refptr<base::SingleThreadTaskRunner>& cmd_task_runner,
136 const HttpRequestHandlerFunc& handle_request_on_cmd_func,
137 const net::HttpServerRequestInfo& request,
138 const HttpResponseSenderFunc& send_response_func) {
139 cmd_task_runner->PostTask(
141 base::Bind(handle_request_on_cmd_func,
143 base::Bind(&SendResponseOnCmdThread,
144 base::MessageLoopProxy::current(),
145 send_response_func)));
148 base::LazyInstance<base::ThreadLocalPointer<HttpServer> >
149 lazy_tls_server = LAZY_INSTANCE_INITIALIZER;
151 void StopServerOnIOThread() {
152 // Note, |server| may be NULL.
153 HttpServer* server = lazy_tls_server.Pointer()->Get();
154 lazy_tls_server.Pointer()->Set(NULL);
158 void StartServerOnIOThread(int port,
160 const HttpRequestHandlerFunc& handle_request_func) {
161 scoped_ptr<HttpServer> temp_server(new HttpServer(handle_request_func));
162 if (!temp_server->Start(port, allow_remote)) {
163 printf("Port not available. Exiting...\n");
166 lazy_tls_server.Pointer()->Set(temp_server.release());
169 void RunServer(int port,
171 const std::vector<std::string>& whitelisted_ips,
172 const std::string& url_base,
174 scoped_ptr<PortServer> port_server) {
175 base::Thread io_thread("ChromeDriver IO");
176 CHECK(io_thread.StartWithOptions(
177 base::Thread::Options(base::MessageLoop::TYPE_IO, 0)));
179 base::MessageLoop cmd_loop;
180 base::RunLoop cmd_run_loop;
181 HttpHandler handler(cmd_run_loop.QuitClosure(),
182 io_thread.message_loop_proxy(),
186 HttpRequestHandlerFunc handle_request_func =
187 base::Bind(&HandleRequestOnCmdThread, &handler, whitelisted_ips);
189 io_thread.message_loop()
190 ->PostTask(FROM_HERE,
191 base::Bind(&StartServerOnIOThread,
194 base::Bind(&HandleRequestOnIOThread,
195 cmd_loop.message_loop_proxy(),
196 handle_request_func)));
197 // Run the command loop. This loop is quit after the response for a shutdown
198 // request is posted to the IO loop. After the command loop quits, a task
199 // is posted to the IO loop to stop the server. Lastly, the IO thread is
200 // destroyed, which waits until all pending tasks have been completed.
201 // This assumes the response is sent synchronously as part of the IO task.
203 io_thread.message_loop()
204 ->PostTask(FROM_HERE, base::Bind(&StopServerOnIOThread));
209 int main(int argc, char *argv[]) {
210 CommandLine::Init(argc, argv);
212 base::AtExitManager at_exit;
213 CommandLine* cmd_line = CommandLine::ForCurrentProcess();
215 #if defined(OS_LINUX)
216 // Select the locale from the environment by passing an empty string instead
217 // of the default "C" locale. This is particularly needed for the keycode
218 // conversion code to work.
219 setlocale(LC_ALL, "");
222 // Parse command line flags.
225 bool allow_remote = false;
226 std::vector<std::string> whitelisted_ips;
227 std::string url_base;
228 scoped_ptr<PortServer> port_server;
229 if (cmd_line->HasSwitch("h") || cmd_line->HasSwitch("help")) {
231 const char* kOptionAndDescriptions[] = {
232 "port=PORT", "port to listen on",
233 "adb-port=PORT", "adb server port",
234 "log-path=FILE", "write server log to file instead of stderr, "
235 "increases log level to INFO",
236 "verbose", "log verbosely",
237 "version", "print the version number and exit",
238 "silent", "log nothing",
239 "url-base", "base URL path prefix for commands, e.g. wd/url",
240 "port-server", "address of server to contact for reserving a port",
241 "whitelisted-ips", "comma-separated whitelist of remote IPv4 addresses "
242 "which are allowed to connect to ChromeDriver",
244 for (size_t i = 0; i < arraysize(kOptionAndDescriptions) - 1; i += 2) {
245 options += base::StringPrintf(
247 kOptionAndDescriptions[i], kOptionAndDescriptions[i + 1]);
249 printf("Usage: %s [OPTIONS]\n\nOptions\n%s", argv[0], options.c_str());
252 if (cmd_line->HasSwitch("v") || cmd_line->HasSwitch("version")) {
253 printf("ChromeDriver %s\n", kChromeDriverVersion);
256 if (cmd_line->HasSwitch("port")) {
257 if (!base::StringToInt(cmd_line->GetSwitchValueASCII("port"), &port)) {
258 printf("Invalid port. Exiting...\n");
262 if (cmd_line->HasSwitch("adb-port")) {
263 if (!base::StringToInt(cmd_line->GetSwitchValueASCII("adb-port"),
265 printf("Invalid adb-port. Exiting...\n");
269 if (cmd_line->HasSwitch("port-server")) {
270 #if defined(OS_LINUX)
271 std::string address = cmd_line->GetSwitchValueASCII("port-server");
272 if (address.empty() || address[0] != '@') {
273 printf("Invalid port-server. Exiting...\n");
277 // First character of path is \0 to use Linux's abstract namespace.
279 path += address.substr(1);
280 port_server.reset(new PortServer(path));
282 printf("Warning: port-server not implemented for this platform.\n");
285 if (cmd_line->HasSwitch("url-base"))
286 url_base = cmd_line->GetSwitchValueASCII("url-base");
287 if (url_base.empty() || url_base[0] != '/')
288 url_base = "/" + url_base;
289 if (url_base[url_base.length() - 1] != '/')
290 url_base = url_base + "/";
291 if (cmd_line->HasSwitch("whitelisted-ips")) {
293 std::string whitelist = cmd_line->GetSwitchValueASCII("whitelisted-ips");
294 base::SplitString(whitelist, ',', &whitelisted_ips);
296 if (!cmd_line->HasSwitch("silent")) {
298 "Starting ChromeDriver (v%s) on port %d\n", kChromeDriverVersion, port);
300 printf("Only local connections are allowed.\n");
301 } else if (!whitelisted_ips.empty()) {
302 printf("Remote connections are allowed by a whitelist (%s).\n",
303 cmd_line->GetSwitchValueASCII("whitelisted-ips").c_str());
305 printf("All remote connections are allowed. Use a whitelist instead!\n");
310 if (!InitLogging()) {
311 printf("Unable to initialize logging. Exiting...\n");
314 RunServer(port, allow_remote, whitelisted_ips,
315 url_base, adb_port, port_server.Pass());