1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/ssl/default_channel_id_store.h"
8 #include "base/message_loop/message_loop.h"
9 #include "base/metrics/histogram.h"
10 #include "net/base/net_errors.h"
14 // --------------------------------------------------------------------------
16 class DefaultChannelIDStore::Task {
20 // Runs the task and invokes the client callback on the thread that
21 // originally constructed the task.
22 virtual void Run(DefaultChannelIDStore* store) = 0;
25 void InvokeCallback(base::Closure callback) const;
28 DefaultChannelIDStore::Task::~Task() {
31 void DefaultChannelIDStore::Task::InvokeCallback(
32 base::Closure callback) const {
33 if (!callback.is_null())
37 // --------------------------------------------------------------------------
39 class DefaultChannelIDStore::GetChannelIDTask
40 : public DefaultChannelIDStore::Task {
42 GetChannelIDTask(const std::string& server_identifier,
43 const GetChannelIDCallback& callback);
44 ~GetChannelIDTask() override;
45 void Run(DefaultChannelIDStore* store) override;
48 std::string server_identifier_;
49 GetChannelIDCallback callback_;
52 DefaultChannelIDStore::GetChannelIDTask::GetChannelIDTask(
53 const std::string& server_identifier,
54 const GetChannelIDCallback& callback)
55 : server_identifier_(server_identifier),
59 DefaultChannelIDStore::GetChannelIDTask::~GetChannelIDTask() {
62 void DefaultChannelIDStore::GetChannelIDTask::Run(
63 DefaultChannelIDStore* store) {
64 base::Time expiration_time;
65 std::string private_key_result;
66 std::string cert_result;
67 int err = store->GetChannelID(
68 server_identifier_, &expiration_time, &private_key_result,
69 &cert_result, GetChannelIDCallback());
70 DCHECK(err != ERR_IO_PENDING);
72 InvokeCallback(base::Bind(callback_, err, server_identifier_,
73 expiration_time, private_key_result, cert_result));
76 // --------------------------------------------------------------------------
78 class DefaultChannelIDStore::SetChannelIDTask
79 : public DefaultChannelIDStore::Task {
81 SetChannelIDTask(const std::string& server_identifier,
82 base::Time creation_time,
83 base::Time expiration_time,
84 const std::string& private_key,
85 const std::string& cert);
86 ~SetChannelIDTask() override;
87 void Run(DefaultChannelIDStore* store) override;
90 std::string server_identifier_;
91 base::Time creation_time_;
92 base::Time expiration_time_;
93 std::string private_key_;
97 DefaultChannelIDStore::SetChannelIDTask::SetChannelIDTask(
98 const std::string& server_identifier,
99 base::Time creation_time,
100 base::Time expiration_time,
101 const std::string& private_key,
102 const std::string& cert)
103 : server_identifier_(server_identifier),
104 creation_time_(creation_time),
105 expiration_time_(expiration_time),
106 private_key_(private_key),
110 DefaultChannelIDStore::SetChannelIDTask::~SetChannelIDTask() {
113 void DefaultChannelIDStore::SetChannelIDTask::Run(
114 DefaultChannelIDStore* store) {
115 store->SyncSetChannelID(server_identifier_, creation_time_,
116 expiration_time_, private_key_, cert_);
119 // --------------------------------------------------------------------------
120 // DeleteChannelIDTask
121 class DefaultChannelIDStore::DeleteChannelIDTask
122 : public DefaultChannelIDStore::Task {
124 DeleteChannelIDTask(const std::string& server_identifier,
125 const base::Closure& callback);
126 ~DeleteChannelIDTask() override;
127 void Run(DefaultChannelIDStore* store) override;
130 std::string server_identifier_;
131 base::Closure callback_;
134 DefaultChannelIDStore::DeleteChannelIDTask::
136 const std::string& server_identifier,
137 const base::Closure& callback)
138 : server_identifier_(server_identifier),
139 callback_(callback) {
142 DefaultChannelIDStore::DeleteChannelIDTask::
143 ~DeleteChannelIDTask() {
146 void DefaultChannelIDStore::DeleteChannelIDTask::Run(
147 DefaultChannelIDStore* store) {
148 store->SyncDeleteChannelID(server_identifier_);
150 InvokeCallback(callback_);
153 // --------------------------------------------------------------------------
154 // DeleteAllCreatedBetweenTask
155 class DefaultChannelIDStore::DeleteAllCreatedBetweenTask
156 : public DefaultChannelIDStore::Task {
158 DeleteAllCreatedBetweenTask(base::Time delete_begin,
159 base::Time delete_end,
160 const base::Closure& callback);
161 ~DeleteAllCreatedBetweenTask() override;
162 void Run(DefaultChannelIDStore* store) override;
165 base::Time delete_begin_;
166 base::Time delete_end_;
167 base::Closure callback_;
170 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
171 DeleteAllCreatedBetweenTask(
172 base::Time delete_begin,
173 base::Time delete_end,
174 const base::Closure& callback)
175 : delete_begin_(delete_begin),
176 delete_end_(delete_end),
177 callback_(callback) {
180 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
181 ~DeleteAllCreatedBetweenTask() {
184 void DefaultChannelIDStore::DeleteAllCreatedBetweenTask::Run(
185 DefaultChannelIDStore* store) {
186 store->SyncDeleteAllCreatedBetween(delete_begin_, delete_end_);
188 InvokeCallback(callback_);
191 // --------------------------------------------------------------------------
192 // GetAllChannelIDsTask
193 class DefaultChannelIDStore::GetAllChannelIDsTask
194 : public DefaultChannelIDStore::Task {
196 explicit GetAllChannelIDsTask(const GetChannelIDListCallback& callback);
197 ~GetAllChannelIDsTask() override;
198 void Run(DefaultChannelIDStore* store) override;
201 std::string server_identifier_;
202 GetChannelIDListCallback callback_;
205 DefaultChannelIDStore::GetAllChannelIDsTask::
206 GetAllChannelIDsTask(const GetChannelIDListCallback& callback)
207 : callback_(callback) {
210 DefaultChannelIDStore::GetAllChannelIDsTask::
211 ~GetAllChannelIDsTask() {
214 void DefaultChannelIDStore::GetAllChannelIDsTask::Run(
215 DefaultChannelIDStore* store) {
216 ChannelIDList cert_list;
217 store->SyncGetAllChannelIDs(&cert_list);
219 InvokeCallback(base::Bind(callback_, cert_list));
222 // --------------------------------------------------------------------------
223 // DefaultChannelIDStore
225 DefaultChannelIDStore::DefaultChannelIDStore(
226 PersistentStore* store)
227 : initialized_(false),
230 weak_ptr_factory_(this) {}
232 int DefaultChannelIDStore::GetChannelID(
233 const std::string& server_identifier,
234 base::Time* expiration_time,
235 std::string* private_key_result,
236 std::string* cert_result,
237 const GetChannelIDCallback& callback) {
238 DCHECK(CalledOnValidThread());
242 EnqueueTask(scoped_ptr<Task>(
243 new GetChannelIDTask(server_identifier, callback)));
244 return ERR_IO_PENDING;
247 ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
249 if (it == channel_ids_.end())
250 return ERR_FILE_NOT_FOUND;
252 ChannelID* channel_id = it->second;
253 *expiration_time = channel_id->expiration_time();
254 *private_key_result = channel_id->private_key();
255 *cert_result = channel_id->cert();
260 void DefaultChannelIDStore::SetChannelID(
261 const std::string& server_identifier,
262 base::Time creation_time,
263 base::Time expiration_time,
264 const std::string& private_key,
265 const std::string& cert) {
266 RunOrEnqueueTask(scoped_ptr<Task>(new SetChannelIDTask(
267 server_identifier, creation_time, expiration_time, private_key,
271 void DefaultChannelIDStore::DeleteChannelID(
272 const std::string& server_identifier,
273 const base::Closure& callback) {
274 RunOrEnqueueTask(scoped_ptr<Task>(
275 new DeleteChannelIDTask(server_identifier, callback)));
278 void DefaultChannelIDStore::DeleteAllCreatedBetween(
279 base::Time delete_begin,
280 base::Time delete_end,
281 const base::Closure& callback) {
282 RunOrEnqueueTask(scoped_ptr<Task>(
283 new DeleteAllCreatedBetweenTask(delete_begin, delete_end, callback)));
286 void DefaultChannelIDStore::DeleteAll(
287 const base::Closure& callback) {
288 DeleteAllCreatedBetween(base::Time(), base::Time(), callback);
291 void DefaultChannelIDStore::GetAllChannelIDs(
292 const GetChannelIDListCallback& callback) {
293 RunOrEnqueueTask(scoped_ptr<Task>(new GetAllChannelIDsTask(callback)));
296 int DefaultChannelIDStore::GetChannelIDCount() {
297 DCHECK(CalledOnValidThread());
299 return channel_ids_.size();
302 void DefaultChannelIDStore::SetForceKeepSessionState() {
303 DCHECK(CalledOnValidThread());
307 store_->SetForceKeepSessionState();
310 DefaultChannelIDStore::~DefaultChannelIDStore() {
314 void DefaultChannelIDStore::DeleteAllInMemory() {
315 DCHECK(CalledOnValidThread());
317 for (ChannelIDMap::iterator it = channel_ids_.begin();
318 it != channel_ids_.end(); ++it) {
321 channel_ids_.clear();
324 void DefaultChannelIDStore::InitStore() {
325 DCHECK(CalledOnValidThread());
326 DCHECK(store_.get()) << "Store must exist to initialize";
329 store_->Load(base::Bind(&DefaultChannelIDStore::OnLoaded,
330 weak_ptr_factory_.GetWeakPtr()));
333 void DefaultChannelIDStore::OnLoaded(
334 scoped_ptr<ScopedVector<ChannelID> > channel_ids) {
335 DCHECK(CalledOnValidThread());
337 for (std::vector<ChannelID*>::const_iterator it = channel_ids->begin();
338 it != channel_ids->end(); ++it) {
339 DCHECK(channel_ids_.find((*it)->server_identifier()) ==
341 channel_ids_[(*it)->server_identifier()] = *it;
343 channel_ids->weak_clear();
347 base::TimeDelta wait_time;
348 if (!waiting_tasks_.empty())
349 wait_time = base::TimeTicks::Now() - waiting_tasks_start_time_;
350 DVLOG(1) << "Task delay " << wait_time.InMilliseconds();
351 UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.TaskMaxWaitTime",
353 base::TimeDelta::FromMilliseconds(1),
354 base::TimeDelta::FromMinutes(1),
356 UMA_HISTOGRAM_COUNTS_100("DomainBoundCerts.TaskWaitCount",
357 waiting_tasks_.size());
360 for (ScopedVector<Task>::iterator i = waiting_tasks_.begin();
361 i != waiting_tasks_.end(); ++i)
363 waiting_tasks_.clear();
366 void DefaultChannelIDStore::SyncSetChannelID(
367 const std::string& server_identifier,
368 base::Time creation_time,
369 base::Time expiration_time,
370 const std::string& private_key,
371 const std::string& cert) {
372 DCHECK(CalledOnValidThread());
375 InternalDeleteChannelID(server_identifier);
376 InternalInsertChannelID(
379 server_identifier, creation_time, expiration_time, private_key,
383 void DefaultChannelIDStore::SyncDeleteChannelID(
384 const std::string& server_identifier) {
385 DCHECK(CalledOnValidThread());
387 InternalDeleteChannelID(server_identifier);
390 void DefaultChannelIDStore::SyncDeleteAllCreatedBetween(
391 base::Time delete_begin,
392 base::Time delete_end) {
393 DCHECK(CalledOnValidThread());
395 for (ChannelIDMap::iterator it = channel_ids_.begin();
396 it != channel_ids_.end();) {
397 ChannelIDMap::iterator cur = it;
399 ChannelID* channel_id = cur->second;
400 if ((delete_begin.is_null() ||
401 channel_id->creation_time() >= delete_begin) &&
402 (delete_end.is_null() || channel_id->creation_time() < delete_end)) {
404 store_->DeleteChannelID(*channel_id);
406 channel_ids_.erase(cur);
411 void DefaultChannelIDStore::SyncGetAllChannelIDs(
412 ChannelIDList* channel_id_list) {
413 DCHECK(CalledOnValidThread());
415 for (ChannelIDMap::iterator it = channel_ids_.begin();
416 it != channel_ids_.end(); ++it)
417 channel_id_list->push_back(*it->second);
420 void DefaultChannelIDStore::EnqueueTask(scoped_ptr<Task> task) {
421 DCHECK(CalledOnValidThread());
423 if (waiting_tasks_.empty())
424 waiting_tasks_start_time_ = base::TimeTicks::Now();
425 waiting_tasks_.push_back(task.release());
428 void DefaultChannelIDStore::RunOrEnqueueTask(scoped_ptr<Task> task) {
429 DCHECK(CalledOnValidThread());
433 EnqueueTask(task.Pass());
440 void DefaultChannelIDStore::InternalDeleteChannelID(
441 const std::string& server_identifier) {
442 DCHECK(CalledOnValidThread());
445 ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
446 if (it == channel_ids_.end())
447 return; // There is nothing to delete.
449 ChannelID* channel_id = it->second;
451 store_->DeleteChannelID(*channel_id);
452 channel_ids_.erase(it);
456 void DefaultChannelIDStore::InternalInsertChannelID(
457 const std::string& server_identifier,
458 ChannelID* channel_id) {
459 DCHECK(CalledOnValidThread());
463 store_->AddChannelID(*channel_id);
464 channel_ids_[server_identifier] = channel_id;
467 DefaultChannelIDStore::PersistentStore::PersistentStore() {}
469 DefaultChannelIDStore::PersistentStore::~PersistentStore() {}