Upstream version 5.34.104.0
[platform/framework/web/crosswalk.git] / src / chrome / utility / local_discovery / service_discovery_client_impl.cc
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <utility>
6
7 #include "base/logging.h"
8 #include "base/memory/singleton.h"
9 #include "base/message_loop/message_loop_proxy.h"
10 #include "base/stl_util.h"
11 #include "chrome/utility/local_discovery/service_discovery_client_impl.h"
12 #include "net/dns/dns_protocol.h"
13 #include "net/dns/record_rdata.h"
14
15 namespace local_discovery {
16
17 namespace {
18 // TODO(noamsml): Make this configurable through the LocalDomainResolver
19 // interface.
20 const int kLocalDomainSecondAddressTimeoutMs = 100;
21
22 const int kInitialRequeryTimeSeconds = 1;
23 const int kMaxRequeryTimeSeconds = 2; // Time for last requery
24 }
25
26 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
27     net::MDnsClient* mdns_client) : mdns_client_(mdns_client) {
28 }
29
30 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
31 }
32
33 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher(
34     const std::string& service_type,
35     const ServiceWatcher::UpdatedCallback& callback) {
36   return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl(
37       service_type, callback, mdns_client_));
38 }
39
40 scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver(
41     const std::string& service_name,
42     const ServiceResolver::ResolveCompleteCallback& callback) {
43   return scoped_ptr<ServiceResolver>(new ServiceResolverImpl(
44       service_name, callback, mdns_client_));
45 }
46
47 scoped_ptr<LocalDomainResolver>
48 ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
49       const std::string& domain,
50       net::AddressFamily address_family,
51       const LocalDomainResolver::IPAddressCallback& callback) {
52   return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl(
53       domain, address_family, callback, mdns_client_));
54 }
55
56 ServiceWatcherImpl::ServiceWatcherImpl(
57     const std::string& service_type,
58     const ServiceWatcher::UpdatedCallback& callback,
59     net::MDnsClient* mdns_client)
60     : service_type_(service_type), callback_(callback), started_(false),
61       actively_refresh_services_(false), mdns_client_(mdns_client) {
62 }
63
64 void ServiceWatcherImpl::Start() {
65   DCHECK(!started_);
66   listener_ = mdns_client_->CreateListener(
67       net::dns_protocol::kTypePTR, service_type_, this);
68   started_ = listener_->Start();
69   if (started_)
70     ReadCachedServices();
71 }
72
73 ServiceWatcherImpl::~ServiceWatcherImpl() {
74 }
75
76 void ServiceWatcherImpl::DiscoverNewServices(bool force_update) {
77   DCHECK(started_);
78   if (force_update)
79     services_.clear();
80   SendQuery(kInitialRequeryTimeSeconds, force_update);
81 }
82
83 void ServiceWatcherImpl::SetActivelyRefreshServices(
84     bool actively_refresh_services) {
85   DCHECK(started_);
86   actively_refresh_services_ = actively_refresh_services;
87
88   for (ServiceListenersMap::iterator i = services_.begin();
89        i != services_.end(); i++) {
90     i->second->SetActiveRefresh(actively_refresh_services);
91   }
92 }
93
94 void ServiceWatcherImpl::ReadCachedServices() {
95   DCHECK(started_);
96   CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
97                     &transaction_cache_);
98 }
99
100 bool ServiceWatcherImpl::CreateTransaction(
101     bool network, bool cache, bool force_refresh,
102     scoped_ptr<net::MDnsTransaction>* transaction) {
103   int transaction_flags = 0;
104   if (network)
105     transaction_flags |= net::MDnsTransaction::QUERY_NETWORK;
106
107   if (cache)
108     transaction_flags |= net::MDnsTransaction::QUERY_CACHE;
109
110   // TODO(noamsml): Add flag for force_refresh when supported.
111
112   if (transaction_flags) {
113     *transaction = mdns_client_->CreateTransaction(
114         net::dns_protocol::kTypePTR, service_type_, transaction_flags,
115         base::Bind(&ServiceWatcherImpl::OnTransactionResponse,
116                    base::Unretained(this), transaction));
117     return (*transaction)->Start();
118   }
119
120   return true;
121 }
122
123 std::string ServiceWatcherImpl::GetServiceType() const {
124   return listener_->GetName();
125 }
126
127 void ServiceWatcherImpl::OnRecordUpdate(
128     net::MDnsListener::UpdateType update,
129     const net::RecordParsed* record) {
130   DCHECK(started_);
131   if (record->type() == net::dns_protocol::kTypePTR) {
132     DCHECK(record->name() == GetServiceType());
133     const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
134
135     switch (update) {
136       case net::MDnsListener::RECORD_ADDED:
137         AddService(rdata->ptrdomain());
138         break;
139       case net::MDnsListener::RECORD_CHANGED:
140         NOTREACHED();
141         break;
142       case net::MDnsListener::RECORD_REMOVED:
143         RemovePTR(rdata->ptrdomain());
144         break;
145     }
146   } else {
147     DCHECK(record->type() == net::dns_protocol::kTypeSRV ||
148            record->type() == net::dns_protocol::kTypeTXT);
149     DCHECK(services_.find(record->name()) != services_.end());
150
151     if (record->type() == net::dns_protocol::kTypeSRV) {
152       if (update == net::MDnsListener::RECORD_REMOVED) {
153         RemoveSRV(record->name());
154       } else if (update == net::MDnsListener::RECORD_ADDED) {
155         AddSRV(record->name());
156       }
157     }
158
159     // If this is the first time we see an SRV record, do not send
160     // an UPDATE_CHANGED.
161     if (record->type() != net::dns_protocol::kTypeSRV ||
162         update != net::MDnsListener::RECORD_ADDED) {
163       DeferUpdate(UPDATE_CHANGED, record->name());
164     }
165   }
166 }
167
168 void ServiceWatcherImpl::OnCachePurged() {
169   // Not yet implemented.
170 }
171
172 void ServiceWatcherImpl::OnTransactionResponse(
173     scoped_ptr<net::MDnsTransaction>* transaction,
174     net::MDnsTransaction::Result result,
175     const net::RecordParsed* record) {
176   DCHECK(started_);
177   if (result == net::MDnsTransaction::RESULT_RECORD) {
178     const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
179     DCHECK(rdata);
180     AddService(rdata->ptrdomain());
181   } else if (result == net::MDnsTransaction::RESULT_DONE) {
182     transaction->reset();
183   }
184
185   // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
186   // record for PTR records on any name.
187 }
188
189 ServiceWatcherImpl::ServiceListeners::ServiceListeners(
190     const std::string& service_name,
191     ServiceWatcherImpl* watcher,
192     net::MDnsClient* mdns_client)
193     : service_name_(service_name), mdns_client_(mdns_client),
194       update_pending_(false), has_ptr_(true), has_srv_(false) {
195   srv_listener_ = mdns_client->CreateListener(
196       net::dns_protocol::kTypeSRV, service_name, watcher);
197   txt_listener_ = mdns_client->CreateListener(
198       net::dns_protocol::kTypeTXT, service_name, watcher);
199 }
200
201 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
202 }
203
204 bool ServiceWatcherImpl::ServiceListeners::Start() {
205   if (!srv_listener_->Start())
206     return false;
207   return txt_listener_->Start();
208 }
209
210 void ServiceWatcherImpl::ServiceListeners::SetActiveRefresh(
211     bool active_refresh) {
212   srv_listener_->SetActiveRefresh(active_refresh);
213
214   if (active_refresh && !has_srv_) {
215     DCHECK(has_ptr_);
216     srv_transaction_ = mdns_client_->CreateTransaction(
217         net::dns_protocol::kTypeSRV, service_name_,
218         net::MDnsTransaction::SINGLE_RESULT |
219         net::MDnsTransaction::QUERY_CACHE | net::MDnsTransaction::QUERY_NETWORK,
220         base::Bind(&ServiceWatcherImpl::ServiceListeners::OnSRVRecord,
221                    base::Unretained(this)));
222     srv_transaction_->Start();
223   } else if (!active_refresh) {
224     srv_transaction_.reset();
225   }
226 }
227
228 void ServiceWatcherImpl::ServiceListeners::OnSRVRecord(
229     net::MDnsTransaction::Result result,
230     const net::RecordParsed* record) {
231   set_has_srv(record != NULL);
232 }
233
234 void ServiceWatcherImpl::ServiceListeners::set_has_srv(bool has_srv) {
235   has_srv_ = has_srv;
236
237   srv_transaction_.reset();
238 }
239
240 void ServiceWatcherImpl::AddService(const std::string& service) {
241   DCHECK(started_);
242   std::pair<ServiceListenersMap::iterator, bool> found = services_.insert(
243       make_pair(service, linked_ptr<ServiceListeners>(NULL)));
244
245   if (found.second) {  // Newly inserted.
246     found.first->second = linked_ptr<ServiceListeners>(
247         new ServiceListeners(service, this, mdns_client_));
248     bool success = found.first->second->Start();
249     found.first->second->SetActiveRefresh(actively_refresh_services_);
250     DeferUpdate(UPDATE_ADDED, service);
251
252     DCHECK(success);
253   }
254
255   found.first->second->set_has_ptr(true);
256 }
257
258 void ServiceWatcherImpl::AddSRV(const std::string& service) {
259   DCHECK(started_);
260
261   ServiceListenersMap::iterator found = services_.find(service);
262   if (found != services_.end()) {
263     found->second->set_has_srv(true);
264   }
265 }
266
267 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type,
268                                      const std::string& service_name) {
269   ServiceListenersMap::iterator found = services_.find(service_name);
270
271   if (found != services_.end() && !found->second->update_pending()) {
272     found->second->set_update_pending(true);
273     base::MessageLoop::current()->PostTask(
274         FROM_HERE,
275         base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(),
276                    update_type, service_name));
277   }
278 }
279
280 void ServiceWatcherImpl::DeliverDeferredUpdate(
281     ServiceWatcher::UpdateType update_type, const std::string& service_name) {
282   ServiceListenersMap::iterator found = services_.find(service_name);
283
284   if (found != services_.end()) {
285     found->second->set_update_pending(false);
286     if (!callback_.is_null())
287       callback_.Run(update_type, service_name);
288   }
289 }
290
291 void ServiceWatcherImpl::RemovePTR(const std::string& service) {
292   DCHECK(started_);
293
294   ServiceListenersMap::iterator found = services_.find(service);
295   if (found != services_.end()) {
296     found->second->set_has_ptr(false);
297
298     if (!found->second->has_ptr_or_srv()) {
299       services_.erase(found);
300       if (!callback_.is_null())
301         callback_.Run(UPDATE_REMOVED, service);
302     }
303   }
304 }
305
306 void ServiceWatcherImpl::RemoveSRV(const std::string& service) {
307   DCHECK(started_);
308
309   ServiceListenersMap::iterator found = services_.find(service);
310   if (found != services_.end()) {
311     found->second->set_has_srv(false);
312
313     if (!found->second->has_ptr_or_srv()) {
314       services_.erase(found);
315       if (!callback_.is_null())
316         callback_.Run(UPDATE_REMOVED, service);
317     }
318   }
319 }
320
321 void ServiceWatcherImpl::OnNsecRecord(const std::string& name,
322                                       unsigned rrtype) {
323   // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
324   // on any name.
325 }
326
327 void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) {
328   if (timeout_seconds <= kMaxRequeryTimeSeconds) {
329     base::MessageLoop::current()->PostDelayedTask(
330         FROM_HERE,
331         base::Bind(&ServiceWatcherImpl::SendQuery,
332                    AsWeakPtr(),
333                    timeout_seconds * 2 /*next_timeout_seconds*/,
334                    false /*force_update*/),
335         base::TimeDelta::FromSeconds(timeout_seconds));
336   }
337 }
338
339 void ServiceWatcherImpl::SendQuery(int next_timeout_seconds,
340                                    bool force_update) {
341   CreateTransaction(true /*network*/, false /*cache*/, force_update,
342                     &transaction_network_);
343   ScheduleQuery(next_timeout_seconds);
344 }
345
346 ServiceResolverImpl::ServiceResolverImpl(
347     const std::string& service_name,
348     const ResolveCompleteCallback& callback,
349     net::MDnsClient* mdns_client)
350     : service_name_(service_name), callback_(callback),
351       metadata_resolved_(false), address_resolved_(false),
352       mdns_client_(mdns_client) {
353 }
354
355 void ServiceResolverImpl::StartResolving() {
356   address_resolved_ = false;
357   metadata_resolved_ = false;
358   service_staging_ = ServiceDescription();
359   service_staging_.service_name = service_name_;
360
361   if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
362     ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT);
363   }
364 }
365
366 ServiceResolverImpl::~ServiceResolverImpl() {
367 }
368
369 bool ServiceResolverImpl::CreateTxtTransaction() {
370   txt_transaction_ = mdns_client_->CreateTransaction(
371       net::dns_protocol::kTypeTXT, service_name_,
372       net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
373       net::MDnsTransaction::QUERY_NETWORK,
374       base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse,
375                  AsWeakPtr()));
376   return txt_transaction_->Start();
377 }
378
379 // TODO(noamsml): quick-resolve for AAAA records.  Since A records tend to be in
380 void ServiceResolverImpl::CreateATransaction() {
381   a_transaction_ = mdns_client_->CreateTransaction(
382       net::dns_protocol::kTypeA,
383       service_staging_.address.host(),
384       net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE,
385       base::Bind(&ServiceResolverImpl::ARecordTransactionResponse,
386                  AsWeakPtr()));
387   a_transaction_->Start();
388 }
389
390 bool ServiceResolverImpl::CreateSrvTransaction() {
391   srv_transaction_ = mdns_client_->CreateTransaction(
392       net::dns_protocol::kTypeSRV, service_name_,
393       net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
394       net::MDnsTransaction::QUERY_NETWORK,
395       base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse,
396                  AsWeakPtr()));
397   return srv_transaction_->Start();
398 }
399
400 std::string ServiceResolverImpl::GetName() const {
401   return service_name_;
402 }
403
404 void ServiceResolverImpl::SrvRecordTransactionResponse(
405     net::MDnsTransaction::Result status, const net::RecordParsed* record) {
406   srv_transaction_.reset();
407   if (status == net::MDnsTransaction::RESULT_RECORD) {
408     DCHECK(record);
409     service_staging_.address = RecordToAddress(record);
410     service_staging_.last_seen = record->time_created();
411     CreateATransaction();
412   } else {
413     ServiceNotFound(MDnsStatusToRequestStatus(status));
414   }
415 }
416
417 void ServiceResolverImpl::TxtRecordTransactionResponse(
418     net::MDnsTransaction::Result status, const net::RecordParsed* record) {
419   txt_transaction_.reset();
420   if (status == net::MDnsTransaction::RESULT_RECORD) {
421     DCHECK(record);
422     service_staging_.metadata = RecordToMetadata(record);
423   } else {
424     service_staging_.metadata = std::vector<std::string>();
425   }
426
427   metadata_resolved_ = true;
428   AlertCallbackIfReady();
429 }
430
431 void ServiceResolverImpl::ARecordTransactionResponse(
432     net::MDnsTransaction::Result status, const net::RecordParsed* record) {
433   a_transaction_.reset();
434
435   if (status == net::MDnsTransaction::RESULT_RECORD) {
436     DCHECK(record);
437     service_staging_.ip_address = RecordToIPAddress(record);
438   } else {
439     service_staging_.ip_address = net::IPAddressNumber();
440   }
441
442   address_resolved_ = true;
443   AlertCallbackIfReady();
444 }
445
446 void ServiceResolverImpl::AlertCallbackIfReady() {
447   if (metadata_resolved_ && address_resolved_) {
448     txt_transaction_.reset();
449     srv_transaction_.reset();
450     a_transaction_.reset();
451     if (!callback_.is_null())
452       callback_.Run(STATUS_SUCCESS, service_staging_);
453   }
454 }
455
456 void ServiceResolverImpl::ServiceNotFound(
457     ServiceResolver::RequestStatus status) {
458   txt_transaction_.reset();
459   srv_transaction_.reset();
460   a_transaction_.reset();
461   if (!callback_.is_null())
462     callback_.Run(status, ServiceDescription());
463 }
464
465 ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus(
466     net::MDnsTransaction::Result status) const {
467   switch (status) {
468     case net::MDnsTransaction::RESULT_RECORD:
469       return ServiceResolver::STATUS_SUCCESS;
470     case net::MDnsTransaction::RESULT_NO_RESULTS:
471       return ServiceResolver::STATUS_REQUEST_TIMEOUT;
472     case net::MDnsTransaction::RESULT_NSEC:
473       return ServiceResolver::STATUS_KNOWN_NONEXISTENT;
474     case net::MDnsTransaction::RESULT_DONE:  // Pass through.
475     default:
476       NOTREACHED();
477       return ServiceResolver::STATUS_REQUEST_TIMEOUT;
478   }
479 }
480
481 const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata(
482     const net::RecordParsed* record) const {
483   DCHECK(record->type() == net::dns_protocol::kTypeTXT);
484   const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>();
485   DCHECK(txt_rdata);
486   return txt_rdata->texts();
487 }
488
489 net::HostPortPair ServiceResolverImpl::RecordToAddress(
490     const net::RecordParsed* record) const {
491   DCHECK(record->type() == net::dns_protocol::kTypeSRV);
492   const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>();
493   DCHECK(srv_rdata);
494   return net::HostPortPair(srv_rdata->target(), srv_rdata->port());
495 }
496
497 const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress(
498     const net::RecordParsed* record) const {
499   DCHECK(record->type() == net::dns_protocol::kTypeA);
500   const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>();
501   DCHECK(a_rdata);
502   return a_rdata->address();
503 }
504
505 LocalDomainResolverImpl::LocalDomainResolverImpl(
506     const std::string& domain,
507     net::AddressFamily address_family,
508     const IPAddressCallback& callback,
509     net::MDnsClient* mdns_client)
510     : domain_(domain), address_family_(address_family), callback_(callback),
511       transactions_finished_(0), mdns_client_(mdns_client) {
512 }
513
514 LocalDomainResolverImpl::~LocalDomainResolverImpl() {
515   timeout_callback_.Cancel();
516 }
517
518 void LocalDomainResolverImpl::Start() {
519   if (address_family_ == net::ADDRESS_FAMILY_IPV4 ||
520       address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
521     transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA);
522     transaction_a_->Start();
523   }
524
525   if (address_family_ == net::ADDRESS_FAMILY_IPV6 ||
526       address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
527     transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA);
528     transaction_aaaa_->Start();
529   }
530 }
531
532 scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction(
533     uint16 type) {
534   return mdns_client_->CreateTransaction(
535       type, domain_, net::MDnsTransaction::SINGLE_RESULT |
536                      net::MDnsTransaction::QUERY_CACHE |
537                      net::MDnsTransaction::QUERY_NETWORK,
538       base::Bind(&LocalDomainResolverImpl::OnTransactionComplete,
539                  base::Unretained(this)));
540 }
541
542 void LocalDomainResolverImpl::OnTransactionComplete(
543     net::MDnsTransaction::Result result, const net::RecordParsed* record) {
544   transactions_finished_++;
545
546   if (result == net::MDnsTransaction::RESULT_RECORD) {
547     if (record->type() == net::dns_protocol::kTypeA) {
548       const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>();
549       address_ipv4_ = rdata->address();
550     } else {
551       DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type());
552       const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>();
553       address_ipv6_ = rdata->address();
554     }
555   }
556
557   if (transactions_finished_ == 1 &&
558       address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
559     timeout_callback_.Reset(base::Bind(
560         &LocalDomainResolverImpl::SendResolvedAddresses,
561         base::Unretained(this)));
562
563     base::MessageLoop::current()->PostDelayedTask(
564         FROM_HERE,
565         timeout_callback_.callback(),
566         base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs));
567   } else if (transactions_finished_ == 2
568       || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) {
569     SendResolvedAddresses();
570   }
571 }
572
573 bool LocalDomainResolverImpl::IsSuccess() {
574   return !address_ipv4_.empty() || !address_ipv6_.empty();
575 }
576
577 void LocalDomainResolverImpl::SendResolvedAddresses() {
578   transaction_a_.reset();
579   transaction_aaaa_.reset();
580   timeout_callback_.Cancel();
581   callback_.Run(IsSuccess(), address_ipv4_, address_ipv6_);
582 }
583
584 }  // namespace local_discovery