1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome/browser/policy/cloud/test_request_interceptor.h"
10 #include "base/bind.h"
11 #include "base/bind_helpers.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "base/run_loop.h"
14 #include "base/sequenced_task_runner.h"
15 #include "content/public/browser/browser_thread.h"
16 #include "net/base/net_errors.h"
17 #include "net/base/upload_bytes_element_reader.h"
18 #include "net/base/upload_data_stream.h"
19 #include "net/base/upload_element_reader.h"
20 #include "net/test/url_request/url_request_mock_http_job.h"
21 #include "net/url_request/url_request_error_job.h"
22 #include "net/url_request/url_request_filter.h"
23 #include "net/url_request/url_request_interceptor.h"
24 #include "net/url_request/url_request_test_job.h"
27 namespace em = enterprise_management;
33 // Helper callback for jobs that should fail with a network |error|.
34 net::URLRequestJob* ErrorJobCallback(int error,
35 net::URLRequest* request,
36 net::NetworkDelegate* network_delegate) {
37 return new net::URLRequestErrorJob(request, network_delegate, error);
40 // Helper callback for jobs that should fail with a 400 HTTP error.
41 net::URLRequestJob* BadRequestJobCallback(
42 net::URLRequest* request,
43 net::NetworkDelegate* network_delegate) {
44 static const char kBadHeaders[] =
45 "HTTP/1.1 400 Bad request\0"
46 "Content-type: application/protobuf\0"
48 std::string headers(kBadHeaders, arraysize(kBadHeaders));
49 return new net::URLRequestTestJob(
50 request, network_delegate, headers, std::string(), true);
53 net::URLRequestJob* FileJobCallback(const base::FilePath& file_path,
54 net::URLRequest* request,
55 net::NetworkDelegate* network_delegate) {
56 return new net::URLRequestMockHTTPJob(
60 content::BrowserThread::GetBlockingPool()
61 ->GetTaskRunnerWithShutdownBehavior(
62 base::SequencedWorkerPool::SKIP_ON_SHUTDOWN));
65 // Parses the upload data in |request| into |request_msg|, and validates the
66 // request. The query string in the URL must contain the |expected_type| for
67 // the "request" parameter. Returns true if all checks succeeded, and the
68 // request data has been parsed into |request_msg|.
69 bool ValidRequest(net::URLRequest* request,
70 const std::string& expected_type,
71 em::DeviceManagementRequest* request_msg) {
72 if (request->method() != "POST")
74 std::string spec = request->url().spec();
75 if (spec.find("request=" + expected_type) == std::string::npos)
78 // This assumes that the payload data was set from a single string. In that
79 // case the UploadDataStream has a single UploadBytesElementReader with the
81 const net::UploadDataStream* stream = request->get_upload();
84 const ScopedVector<net::UploadElementReader>* readers =
85 stream->GetElementReaders();
86 if (!readers || readers->size() != 1u)
88 const net::UploadBytesElementReader* reader = (*readers)[0]->AsBytesReader();
91 std::string data(reader->bytes(), reader->length());
92 if (!request_msg->ParseFromString(data))
98 // Helper callback for register jobs that should suceed. Validates the request
99 // parameters and returns an appropriate response job. If |expect_reregister|
100 // is true then the reregister flag must be set in the DeviceRegisterRequest
102 net::URLRequestJob* RegisterJobCallback(
103 em::DeviceRegisterRequest::Type expected_type,
104 bool expect_reregister,
105 net::URLRequest* request,
106 net::NetworkDelegate* network_delegate) {
107 em::DeviceManagementRequest request_msg;
108 if (!ValidRequest(request, "register", &request_msg))
109 return BadRequestJobCallback(request, network_delegate);
111 if (!request_msg.has_register_request() ||
112 request_msg.has_unregister_request() ||
113 request_msg.has_policy_request() ||
114 request_msg.has_device_status_report_request() ||
115 request_msg.has_session_status_report_request() ||
116 request_msg.has_auto_enrollment_request()) {
117 return BadRequestJobCallback(request, network_delegate);
120 const em::DeviceRegisterRequest& register_request =
121 request_msg.register_request();
122 if (expect_reregister &&
123 (!register_request.has_reregister() || !register_request.reregister())) {
124 return BadRequestJobCallback(request, network_delegate);
125 } else if (!expect_reregister &&
126 register_request.has_reregister() &&
127 register_request.reregister()) {
128 return BadRequestJobCallback(request, network_delegate);
131 if (!register_request.has_type() || register_request.type() != expected_type)
132 return BadRequestJobCallback(request, network_delegate);
134 em::DeviceManagementResponse response;
135 em::DeviceRegisterResponse* register_response =
136 response.mutable_register_response();
137 register_response->set_device_management_token("s3cr3t70k3n");
139 response.SerializeToString(&data);
141 static const char kGoodHeaders[] =
143 "Content-type: application/protobuf\0"
145 std::string headers(kGoodHeaders, arraysize(kGoodHeaders));
146 return new net::URLRequestTestJob(
147 request, network_delegate, headers, data, true);
150 void RegisterHttpInterceptor(
151 const std::string& hostname,
152 scoped_ptr<net::URLRequestInterceptor> interceptor) {
153 net::URLRequestFilter::GetInstance()->AddHostnameInterceptor(
154 "http", hostname, interceptor.Pass());
159 class TestRequestInterceptor::Delegate : public net::URLRequestInterceptor {
161 Delegate(const std::string& hostname,
162 scoped_refptr<base::SequencedTaskRunner> io_task_runner);
163 ~Delegate() override;
165 // net::URLRequestInterceptor implementation:
166 net::URLRequestJob* MaybeInterceptRequest(
167 net::URLRequest* request,
168 net::NetworkDelegate* network_delegate) const override;
170 void GetPendingSize(size_t* pending_size) const;
171 void PushJobCallback(const JobCallback& callback);
174 const std::string hostname_;
175 scoped_refptr<base::SequencedTaskRunner> io_task_runner_;
177 // The queue of pending callbacks. 'mutable' because MaybeCreateJob() is a
178 // const method; it can't reenter though, because it runs exclusively on
180 mutable std::queue<JobCallback> pending_job_callbacks_;
183 TestRequestInterceptor::Delegate::Delegate(
184 const std::string& hostname,
185 scoped_refptr<base::SequencedTaskRunner> io_task_runner)
186 : hostname_(hostname), io_task_runner_(io_task_runner) {}
188 TestRequestInterceptor::Delegate::~Delegate() {}
190 net::URLRequestJob* TestRequestInterceptor::Delegate::MaybeInterceptRequest(
191 net::URLRequest* request,
192 net::NetworkDelegate* network_delegate) const {
193 CHECK(io_task_runner_->RunsTasksOnCurrentThread());
195 if (request->url().host() != hostname_) {
196 // Reject requests to other servers.
197 return ErrorJobCallback(
198 net::ERR_CONNECTION_REFUSED, request, network_delegate);
201 if (pending_job_callbacks_.empty()) {
202 // Reject dmserver requests by default.
203 return BadRequestJobCallback(request, network_delegate);
206 JobCallback callback = pending_job_callbacks_.front();
207 pending_job_callbacks_.pop();
208 return callback.Run(request, network_delegate);
211 void TestRequestInterceptor::Delegate::GetPendingSize(
212 size_t* pending_size) const {
213 CHECK(io_task_runner_->RunsTasksOnCurrentThread());
214 *pending_size = pending_job_callbacks_.size();
217 void TestRequestInterceptor::Delegate::PushJobCallback(
218 const JobCallback& callback) {
219 CHECK(io_task_runner_->RunsTasksOnCurrentThread());
220 pending_job_callbacks_.push(callback);
223 TestRequestInterceptor::TestRequestInterceptor(const std::string& hostname,
224 scoped_refptr<base::SequencedTaskRunner> io_task_runner)
225 : hostname_(hostname),
226 io_task_runner_(io_task_runner) {
227 delegate_ = new Delegate(hostname_, io_task_runner_);
228 scoped_ptr<net::URLRequestInterceptor> interceptor(delegate_);
230 base::Bind(&RegisterHttpInterceptor, hostname_,
231 base::Passed(&interceptor)));
234 TestRequestInterceptor::~TestRequestInterceptor() {
235 // RemoveHostnameHandler() destroys the |delegate_|, which is owned by
236 // the URLRequestFilter.
239 base::Bind(&net::URLRequestFilter::RemoveHostnameHandler,
240 base::Unretained(net::URLRequestFilter::GetInstance()),
244 size_t TestRequestInterceptor::GetPendingSize() {
245 size_t pending_size = std::numeric_limits<size_t>::max();
246 PostToIOAndWait(base::Bind(&Delegate::GetPendingSize,
247 base::Unretained(delegate_),
252 void TestRequestInterceptor::PushJobCallback(const JobCallback& callback) {
253 PostToIOAndWait(base::Bind(&Delegate::PushJobCallback,
254 base::Unretained(delegate_),
259 TestRequestInterceptor::JobCallback TestRequestInterceptor::ErrorJob(
261 return base::Bind(&ErrorJobCallback, error);
265 TestRequestInterceptor::JobCallback TestRequestInterceptor::BadRequestJob() {
266 return base::Bind(&BadRequestJobCallback);
270 TestRequestInterceptor::JobCallback TestRequestInterceptor::RegisterJob(
271 em::DeviceRegisterRequest::Type expected_type,
272 bool expect_reregister) {
273 return base::Bind(&RegisterJobCallback, expected_type, expect_reregister);
277 TestRequestInterceptor::JobCallback TestRequestInterceptor::FileJob(
278 const base::FilePath& file_path) {
279 return base::Bind(&FileJobCallback, file_path);
282 void TestRequestInterceptor::PostToIOAndWait(const base::Closure& task) {
283 io_task_runner_->PostTask(FROM_HERE, task);
284 base::RunLoop run_loop;
285 io_task_runner_->PostTask(
288 base::IgnoreResult(&base::MessageLoopProxy::PostTask),
289 base::MessageLoopProxy::current(),
291 run_loop.QuitClosure()));
295 } // namespace policy