1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "base/memory/weak_ptr.h"
6 #include "base/run_loop.h"
7 #include "chrome/utility/local_discovery/service_discovery_client_impl.h"
8 #include "net/base/net_errors.h"
9 #include "net/dns/dns_protocol.h"
10 #include "net/dns/mdns_client_impl.h"
11 #include "net/dns/mock_mdns_socket_factory.h"
12 #include "testing/gmock/include/gmock/gmock.h"
13 #include "testing/gtest/include/gtest/gtest.h"
16 using ::testing::Invoke;
17 using ::testing::StrictMock;
18 using ::testing::NiceMock;
19 using ::testing::Mock;
20 using ::testing::SaveArg;
21 using ::testing::SetArgPointee;
22 using ::testing::Return;
23 using ::testing::Exactly;
25 namespace local_discovery {
29 const uint8 kSamplePacketPTR[] = {
31 0x00, 0x00, // ID is zeroed out
32 0x81, 0x80, // Standard query response, RA, no error
33 0x00, 0x00, // No questions (for simplicity)
34 0x00, 0x01, // 1 RR (answers)
35 0x00, 0x00, // 0 authority RRs
36 0x00, 0x00, // 0 additional RRs
38 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
39 0x04, '_', 't', 'c', 'p',
40 0x05, 'l', 'o', 'c', 'a', 'l',
42 0x00, 0x0c, // TYPE is PTR.
43 0x00, 0x01, // CLASS is IN.
44 0x00, 0x00, // TTL (4 bytes) is 1 second.
46 0x00, 0x08, // RDLENGTH is 8 bytes.
47 0x05, 'h', 'e', 'l', 'l', 'o',
51 const uint8 kSamplePacketSRV[] = {
53 0x00, 0x00, // ID is zeroed out
54 0x81, 0x80, // Standard query response, RA, no error
55 0x00, 0x00, // No questions (for simplicity)
56 0x00, 0x01, // 1 RR (answers)
57 0x00, 0x00, // 0 authority RRs
58 0x00, 0x00, // 0 additional RRs
60 0x05, 'h', 'e', 'l', 'l', 'o',
61 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
62 0x04, '_', 't', 'c', 'p',
63 0x05, 'l', 'o', 'c', 'a', 'l',
65 0x00, 0x21, // TYPE is SRV.
66 0x00, 0x01, // CLASS is IN.
67 0x00, 0x00, // TTL (4 bytes) is 1 second.
69 0x00, 0x15, // RDLENGTH is 21 bytes.
72 0x22, 0xb8, // port 8888
73 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
74 0x05, 'l', 'o', 'c', 'a', 'l',
78 const uint8 kSamplePacketTXT[] = {
80 0x00, 0x00, // ID is zeroed out
81 0x81, 0x80, // Standard query response, RA, no error
82 0x00, 0x00, // No questions (for simplicity)
83 0x00, 0x01, // 1 RR (answers)
84 0x00, 0x00, // 0 authority RRs
85 0x00, 0x00, // 0 additional RRs
87 0x05, 'h', 'e', 'l', 'l', 'o',
88 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
89 0x04, '_', 't', 'c', 'p',
90 0x05, 'l', 'o', 'c', 'a', 'l',
92 0x00, 0x10, // TYPE is PTR.
93 0x00, 0x01, // CLASS is IN.
94 0x00, 0x00, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
96 0x00, 0x06, // RDLENGTH is 21 bytes.
97 0x05, 'h', 'e', 'l', 'l', 'o'
100 const uint8 kSamplePacketSRVA[] = {
102 0x00, 0x00, // ID is zeroed out
103 0x81, 0x80, // Standard query response, RA, no error
104 0x00, 0x00, // No questions (for simplicity)
105 0x00, 0x02, // 2 RR (answers)
106 0x00, 0x00, // 0 authority RRs
107 0x00, 0x00, // 0 additional RRs
109 0x05, 'h', 'e', 'l', 'l', 'o',
110 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
111 0x04, '_', 't', 'c', 'p',
112 0x05, 'l', 'o', 'c', 'a', 'l',
114 0x00, 0x21, // TYPE is SRV.
115 0x00, 0x01, // CLASS is IN.
116 0x00, 0x00, // TTL (4 bytes) is 16 seconds.
118 0x00, 0x15, // RDLENGTH is 21 bytes.
121 0x22, 0xb8, // port 8888
122 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
123 0x05, 'l', 'o', 'c', 'a', 'l',
126 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
127 0x05, 'l', 'o', 'c', 'a', 'l',
129 0x00, 0x01, // TYPE is A.
130 0x00, 0x01, // CLASS is IN.
131 0x00, 0x00, // TTL (4 bytes) is 16 seconds.
133 0x00, 0x04, // RDLENGTH is 4 bytes.
138 const uint8 kSamplePacketPTR2[] = {
140 0x00, 0x00, // ID is zeroed out
141 0x81, 0x80, // Standard query response, RA, no error
142 0x00, 0x00, // No questions (for simplicity)
143 0x00, 0x02, // 2 RR (answers)
144 0x00, 0x00, // 0 authority RRs
145 0x00, 0x00, // 0 additional RRs
147 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
148 0x04, '_', 't', 'c', 'p',
149 0x05, 'l', 'o', 'c', 'a', 'l',
151 0x00, 0x0c, // TYPE is PTR.
152 0x00, 0x01, // CLASS is IN.
153 0x02, 0x00, // TTL (4 bytes) is 1 second.
155 0x00, 0x08, // RDLENGTH is 8 bytes.
156 0x05, 'g', 'd', 'b', 'y', 'e',
159 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
160 0x04, '_', 't', 'c', 'p',
161 0x05, 'l', 'o', 'c', 'a', 'l',
163 0x00, 0x0c, // TYPE is PTR.
164 0x00, 0x01, // CLASS is IN.
165 0x02, 0x00, // TTL (4 bytes) is 1 second.
167 0x00, 0x08, // RDLENGTH is 8 bytes.
168 0x05, 'h', 'e', 'l', 'l', 'o',
172 class MockServiceWatcherClient {
174 MOCK_METHOD2(OnServiceUpdated,
175 void(ServiceWatcher::UpdateType, const std::string&));
177 ServiceWatcher::UpdatedCallback GetCallback() {
178 return base::Bind(&MockServiceWatcherClient::OnServiceUpdated,
179 base::Unretained(this));
183 class ServiceDiscoveryTest : public ::testing::Test {
185 ServiceDiscoveryTest()
186 : socket_factory_(new net::MockMDnsSocketFactory),
188 scoped_ptr<net::MDnsConnection::SocketFactory>(
190 service_discovery_client_(&mdns_client_) {
191 mdns_client_.StartListening();
194 virtual ~ServiceDiscoveryTest() {
198 void RunFor(base::TimeDelta time_period) {
199 base::CancelableCallback<void()> callback(base::Bind(
200 &ServiceDiscoveryTest::Stop, base::Unretained(this)));
201 base::MessageLoop::current()->PostDelayedTask(
202 FROM_HERE, callback.callback(), time_period);
204 base::MessageLoop::current()->Run();
209 base::MessageLoop::current()->Quit();
212 net::MockMDnsSocketFactory* socket_factory_;
213 net::MDnsClientImpl mdns_client_;
214 ServiceDiscoveryClientImpl service_discovery_client_;
215 base::MessageLoop loop_;
218 TEST_F(ServiceDiscoveryTest, AddRemoveService) {
219 StrictMock<MockServiceWatcherClient> delegate;
221 scoped_ptr<ServiceWatcher> watcher(
222 service_discovery_client_.CreateServiceWatcher(
223 "_privet._tcp.local", delegate.GetCallback()));
227 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
228 "hello._privet._tcp.local"))
231 socket_factory_->SimulateReceive(
232 kSamplePacketPTR, sizeof(kSamplePacketPTR));
234 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED,
235 "hello._privet._tcp.local"))
238 RunFor(base::TimeDelta::FromSeconds(2));
241 TEST_F(ServiceDiscoveryTest, DiscoverNewServices) {
242 StrictMock<MockServiceWatcherClient> delegate;
244 scoped_ptr<ServiceWatcher> watcher(
245 service_discovery_client_.CreateServiceWatcher(
246 "_privet._tcp.local", delegate.GetCallback()));
250 EXPECT_CALL(*socket_factory_, OnSendTo(_))
253 watcher->DiscoverNewServices(false);
256 TEST_F(ServiceDiscoveryTest, ReadCachedServices) {
257 socket_factory_->SimulateReceive(
258 kSamplePacketPTR, sizeof(kSamplePacketPTR));
260 StrictMock<MockServiceWatcherClient> delegate;
262 scoped_ptr<ServiceWatcher> watcher(
263 service_discovery_client_.CreateServiceWatcher(
264 "_privet._tcp.local", delegate.GetCallback()));
268 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
269 "hello._privet._tcp.local"))
272 base::MessageLoop::current()->RunUntilIdle();
276 TEST_F(ServiceDiscoveryTest, ReadCachedServicesMultiple) {
277 socket_factory_->SimulateReceive(
278 kSamplePacketPTR2, sizeof(kSamplePacketPTR2));
280 StrictMock<MockServiceWatcherClient> delegate;
281 scoped_ptr<ServiceWatcher> watcher =
282 service_discovery_client_.CreateServiceWatcher(
283 "_privet._tcp.local", delegate.GetCallback());
287 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
288 "hello._privet._tcp.local"))
291 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
292 "gdbye._privet._tcp.local"))
295 base::MessageLoop::current()->RunUntilIdle();
299 TEST_F(ServiceDiscoveryTest, OnServiceChanged) {
300 StrictMock<MockServiceWatcherClient> delegate;
301 scoped_ptr<ServiceWatcher> watcher(
302 service_discovery_client_.CreateServiceWatcher(
303 "_privet._tcp.local", delegate.GetCallback()));
307 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
308 "hello._privet._tcp.local"))
311 socket_factory_->SimulateReceive(
312 kSamplePacketPTR, sizeof(kSamplePacketPTR));
314 base::MessageLoop::current()->RunUntilIdle();
316 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
317 "hello._privet._tcp.local"))
320 socket_factory_->SimulateReceive(
321 kSamplePacketSRV, sizeof(kSamplePacketSRV));
323 socket_factory_->SimulateReceive(
324 kSamplePacketTXT, sizeof(kSamplePacketTXT));
326 base::MessageLoop::current()->RunUntilIdle();
329 TEST_F(ServiceDiscoveryTest, SinglePacket) {
330 StrictMock<MockServiceWatcherClient> delegate;
331 scoped_ptr<ServiceWatcher> watcher(
332 service_discovery_client_.CreateServiceWatcher(
333 "_privet._tcp.local", delegate.GetCallback()));
337 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
338 "hello._privet._tcp.local"))
341 socket_factory_->SimulateReceive(
342 kSamplePacketPTR, sizeof(kSamplePacketPTR));
344 // Reset the "already updated" flag.
345 base::MessageLoop::current()->RunUntilIdle();
347 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
348 "hello._privet._tcp.local"))
351 socket_factory_->SimulateReceive(
352 kSamplePacketSRV, sizeof(kSamplePacketSRV));
354 socket_factory_->SimulateReceive(
355 kSamplePacketTXT, sizeof(kSamplePacketTXT));
357 base::MessageLoop::current()->RunUntilIdle();
360 class ServiceResolverTest : public ServiceDiscoveryTest {
362 ServiceResolverTest() {
363 metadata_expected_.push_back("hello");
364 address_expected_ = net::HostPortPair("myhello.local", 8888);
365 ip_address_expected_.push_back(1);
366 ip_address_expected_.push_back(2);
367 ip_address_expected_.push_back(3);
368 ip_address_expected_.push_back(4);
371 ~ServiceResolverTest() {
375 resolver_ = service_discovery_client_.CreateServiceResolver(
376 "hello._privet._tcp.local",
377 base::Bind(&ServiceResolverTest::OnFinishedResolving,
378 base::Unretained(this)));
381 void OnFinishedResolving(ServiceResolver::RequestStatus request_status,
382 const ServiceDescription& service_description) {
383 OnFinishedResolvingInternal(request_status,
384 service_description.address.ToString(),
385 service_description.metadata,
386 service_description.ip_address);
389 MOCK_METHOD4(OnFinishedResolvingInternal,
390 void(ServiceResolver::RequestStatus,
392 const std::vector<std::string>&,
393 const net::IPAddressNumber&));
396 scoped_ptr<ServiceResolver> resolver_;
397 net::IPAddressNumber ip_address_;
398 net::HostPortPair address_expected_;
399 std::vector<std::string> metadata_expected_;
400 net::IPAddressNumber ip_address_expected_;
403 TEST_F(ServiceResolverTest, TxtAndSrvButNoA) {
404 EXPECT_CALL(*socket_factory_, OnSendTo(_))
407 resolver_->StartResolving();
409 socket_factory_->SimulateReceive(
410 kSamplePacketSRV, sizeof(kSamplePacketSRV));
412 base::MessageLoop::current()->RunUntilIdle();
415 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
416 address_expected_.ToString(),
418 net::IPAddressNumber()));
420 socket_factory_->SimulateReceive(
421 kSamplePacketTXT, sizeof(kSamplePacketTXT));
424 TEST_F(ServiceResolverTest, TxtSrvAndA) {
425 EXPECT_CALL(*socket_factory_, OnSendTo(_))
428 resolver_->StartResolving();
431 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
432 address_expected_.ToString(),
434 ip_address_expected_));
436 socket_factory_->SimulateReceive(
437 kSamplePacketTXT, sizeof(kSamplePacketTXT));
439 socket_factory_->SimulateReceive(
440 kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
443 TEST_F(ServiceResolverTest, JustSrv) {
444 EXPECT_CALL(*socket_factory_, OnSendTo(_))
447 resolver_->StartResolving();
450 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
451 address_expected_.ToString(),
452 std::vector<std::string>(),
453 ip_address_expected_));
455 socket_factory_->SimulateReceive(
456 kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
458 // TODO(noamsml): When NSEC record support is added, change this to use an
460 RunFor(base::TimeDelta::FromSeconds(4));
463 TEST_F(ServiceResolverTest, WithNothing) {
464 EXPECT_CALL(*socket_factory_, OnSendTo(_))
467 resolver_->StartResolving();
469 EXPECT_CALL(*this, OnFinishedResolvingInternal(
470 ServiceResolver::STATUS_REQUEST_TIMEOUT, _, _, _));
472 // TODO(noamsml): When NSEC record support is added, change this to use an
474 RunFor(base::TimeDelta::FromSeconds(4));
479 } // namespace local_discovery