Revert "Imported Upstream version 1.17.1"
[platform/upstream/c-ares.git] / test / ares-test.cc
1 #include "ares-test.h"
2 #include "dns-proto.h"
3
4 // Include ares internal files for DNS protocol details
5 #include "nameser.h"
6 #include "ares_dns.h"
7
8 #ifdef HAVE_NETDB_H
9 #include <netdb.h>
10 #endif
11 #ifdef HAVE_NETINET_TCP_H
12 #include <netinet/tcp.h>
13 #endif
14 #include <stdio.h>
15 #include <stdlib.h>
16
17 #include <functional>
18 #include <sstream>
19
20 #ifdef WIN32
21 #define BYTE_CAST (char *)
22 #define sclose(x) closesocket(x)
23 #define mkdir_(d, p) mkdir(d)
24 #else
25 #define BYTE_CAST
26 #define sclose(x) close(x)
27 #define mkdir_(d, p) mkdir(d, p)
28 #endif
29
30 namespace ares {
31 namespace test {
32
33 bool verbose = false;
34 int mock_port = 5300;
35
36 unsigned long long LibraryTest::fails_ = 0;
37 std::map<size_t, int> LibraryTest::size_fails_;
38
39 void ProcessWork(ares_channel channel,
40                  std::function<std::set<int>()> get_extrafds,
41                  std::function<void(int)> process_extra) {
42   int nfds, count;
43   fd_set readers, writers;
44   struct timeval tv;
45   while (true) {
46     // Retrieve the set of file descriptors that the library wants us to monitor.
47     FD_ZERO(&readers);
48     FD_ZERO(&writers);
49     nfds = ares_fds(channel, &readers, &writers);
50     if (nfds == 0)  // no work left to do in the library
51       return;
52
53     // Add in the extra FDs if present.
54     std::set<int> extrafds = get_extrafds();
55     for (int extrafd : extrafds) {
56       FD_SET(extrafd, &readers);
57       if (extrafd >= nfds) {
58         nfds = extrafd + 1;
59       }
60     }
61
62     // Wait for activity or timeout.
63     tv.tv_sec = 0;
64     tv.tv_usec = 100000;  // 100ms
65     count = select(nfds, &readers, &writers, nullptr, &tv);
66     if (count < 0) {
67       fprintf(stderr, "select() failed, errno %d\n", errno);
68       return;
69     }
70
71     // Let the library process any activity.
72     ares_process(channel, &readers, &writers);
73
74     // Let the provided callback process any activity on the extra FD.
75     for (int extrafd : extrafds) {
76       if (FD_ISSET(extrafd, &readers)) {
77         process_extra(extrafd);
78       }
79     }
80   }
81 }
82
83 // static
84 void LibraryTest::SetAllocFail(int nth) {
85   assert(nth > 0);
86   assert(nth <= (int)(8 * sizeof(fails_)));
87   fails_ |= (1LL << (nth - 1));
88 }
89
90 // static
91 void LibraryTest::SetAllocSizeFail(size_t size) {
92   size_fails_[size]++;
93 }
94
95 // static
96 void LibraryTest::ClearFails() {
97   fails_ = 0;
98   size_fails_.clear();
99 }
100
101
102 // static
103 bool LibraryTest::ShouldAllocFail(size_t size) {
104   bool fail = (fails_ & 0x01);
105   fails_ >>= 1;
106   if (size_fails_[size] > 0) {
107     size_fails_[size]--;
108     fail = true;
109   }
110   return fail;
111 }
112
113 // static
114 void* LibraryTest::amalloc(size_t size) {
115   if (ShouldAllocFail(size)) {
116     if (verbose) std::cerr << "Failing malloc(" << size << ") request" << std::endl;
117     return nullptr;
118   } else {
119     return malloc(size);
120   }
121 }
122
123 // static
124 void* LibraryTest::arealloc(void *ptr, size_t size) {
125   if (ShouldAllocFail(size)) {
126     if (verbose) std::cerr << "Failing realloc(" << ptr << ", " << size << ") request" << std::endl;
127     return nullptr;
128   } else {
129     return realloc(ptr, size);
130   }
131 }
132
133 // static
134 void LibraryTest::afree(void *ptr) {
135   free(ptr);
136 }
137
138 std::set<int> NoExtraFDs() {
139   return std::set<int>();
140 }
141
142 void DefaultChannelTest::Process() {
143   ProcessWork(channel_, NoExtraFDs, nullptr);
144 }
145
146 void DefaultChannelModeTest::Process() {
147   ProcessWork(channel_, NoExtraFDs, nullptr);
148 }
149
150 MockServer::MockServer(int family, int port, int tcpport)
151   : udpport_(port), tcpport_(tcpport ? tcpport : udpport_), qid_(-1) {
152   // Create a TCP socket to receive data on.
153   tcpfd_ = socket(family, SOCK_STREAM, 0);
154   EXPECT_NE(-1, tcpfd_);
155   int optval = 1;
156   setsockopt(tcpfd_, SOL_SOCKET, SO_REUSEADDR,
157              BYTE_CAST &optval , sizeof(int));
158   // Send TCP data right away.
159   setsockopt(tcpfd_, IPPROTO_TCP, TCP_NODELAY,
160              BYTE_CAST &optval , sizeof(int));
161
162   // Create a UDP socket to receive data on.
163   udpfd_ = socket(family, SOCK_DGRAM, 0);
164   EXPECT_NE(-1, udpfd_);
165
166   // Bind the sockets to the given port.
167   if (family == AF_INET) {
168     struct sockaddr_in addr;
169     memset(&addr, 0, sizeof(addr));
170     addr.sin_family = AF_INET;
171     addr.sin_addr.s_addr = htonl(INADDR_ANY);
172     addr.sin_port = htons(tcpport_);
173     int tcprc = bind(tcpfd_, (struct sockaddr*)&addr, sizeof(addr));
174     EXPECT_EQ(0, tcprc) << "Failed to bind AF_INET to TCP port " << tcpport_;
175     addr.sin_port = htons(udpport_);
176     int udprc = bind(udpfd_, (struct sockaddr*)&addr, sizeof(addr));
177     EXPECT_EQ(0, udprc) << "Failed to bind AF_INET to UDP port " << udpport_;
178   } else {
179     EXPECT_EQ(AF_INET6, family);
180     struct sockaddr_in6 addr;
181     memset(&addr, 0, sizeof(addr));
182     addr.sin6_family = AF_INET6;
183     memset(&addr.sin6_addr, 0, sizeof(addr.sin6_addr));  // in6addr_any
184     addr.sin6_port = htons(tcpport_);
185     int tcprc = bind(tcpfd_, (struct sockaddr*)&addr, sizeof(addr));
186     EXPECT_EQ(0, tcprc) << "Failed to bind AF_INET6 to TCP port " << tcpport_;
187     addr.sin6_port = htons(udpport_);
188     int udprc = bind(udpfd_, (struct sockaddr*)&addr, sizeof(addr));
189     EXPECT_EQ(0, udprc) << "Failed to bind AF_INET6 to UDP port " << udpport_;
190   }
191   if (verbose) std::cerr << "Configured "
192                          << (family == AF_INET ? "IPv4" : "IPv6")
193                          << " mock server with TCP socket " << tcpfd_
194                          << " on port " << tcpport_
195                          << " and UDP socket " << udpfd_
196                          << " on port " << udpport_ << std::endl;
197
198   // For TCP, also need to listen for connections.
199   EXPECT_EQ(0, listen(tcpfd_, 5)) << "Failed to listen for TCP connections";
200 }
201
202 MockServer::~MockServer() {
203   for (int fd : connfds_) {
204     sclose(fd);
205   }
206   sclose(tcpfd_);
207   sclose(udpfd_);
208 }
209
210 void MockServer::ProcessFD(int fd) {
211   if (fd != tcpfd_ && fd != udpfd_ && connfds_.find(fd) == connfds_.end()) {
212     // Not one of our FDs.
213     return;
214   }
215   if (fd == tcpfd_) {
216     int connfd = accept(tcpfd_, NULL, NULL);
217     if (connfd < 0) {
218       std::cerr << "Error accepting connection on fd " << fd << std::endl;
219     } else {
220       connfds_.insert(connfd);
221     }
222     return;
223   }
224
225   // Activity on a data-bearing file descriptor.
226   struct sockaddr_storage addr;
227   socklen_t addrlen = sizeof(addr);
228   byte buffer[2048];
229   int len = recvfrom(fd, BYTE_CAST buffer, sizeof(buffer), 0,
230                      (struct sockaddr *)&addr, &addrlen);
231   byte* data = buffer;
232   if (fd != udpfd_) {
233     if (len == 0) {
234       connfds_.erase(std::find(connfds_.begin(), connfds_.end(), fd));
235       sclose(fd);
236       return;
237     }
238     if (len < 2) {
239       std::cerr << "Packet too short (" << len << ")" << std::endl;
240       return;
241     }
242     int tcplen = (data[0] << 8) + data[1];
243     data += 2;
244     len -= 2;
245     if (tcplen != len) {
246       std::cerr << "Warning: TCP length " << tcplen
247                 << " doesn't match remaining data length " << len << std::endl;
248     }
249   }
250
251   // Assume the packet is a well-formed DNS request and extract the request
252   // details.
253   if (len < NS_HFIXEDSZ) {
254     std::cerr << "Packet too short (" << len << ")" << std::endl;
255     return;
256   }
257   int qid = DNS_HEADER_QID(data);
258   if (DNS_HEADER_QR(data) != 0) {
259     std::cerr << "Not a request" << std::endl;
260     return;
261   }
262   if (DNS_HEADER_OPCODE(data) != ns_o_query) {
263     std::cerr << "Not a query (opcode " << DNS_HEADER_OPCODE(data)
264               << ")" << std::endl;
265     return;
266   }
267   if (DNS_HEADER_QDCOUNT(data) != 1) {
268     std::cerr << "Unexpected question count (" << DNS_HEADER_QDCOUNT(data)
269               << ")" << std::endl;
270     return;
271   }
272   byte* question = data + 12;
273   int qlen = len - 12;
274
275   char *name = nullptr;
276   long enclen;
277   ares_expand_name(question, data, len, &name, &enclen);
278   if (!name) {
279     std::cerr << "Failed to retrieve name" << std::endl;
280     return;
281   }
282   qlen -= enclen;
283   question += enclen;
284   std::string namestr(name);
285   free(name);
286
287   if (qlen < 4) {
288     std::cerr << "Unexpected question size (" << qlen
289               << " bytes after name)" << std::endl;
290     return;
291   }
292   if (DNS_QUESTION_CLASS(question) != ns_c_in) {
293     std::cerr << "Unexpected question class (" << DNS_QUESTION_CLASS(question)
294               << ")" << std::endl;
295     return;
296   }
297   int rrtype = DNS_QUESTION_TYPE(question);
298
299   if (verbose) {
300     std::vector<byte> req(data, data + len);
301     std::cerr << "received " << (fd == udpfd_ ? "UDP" : "TCP") << " request " << PacketToString(req)
302               << " on port " << (fd == udpfd_ ? udpport_ : tcpport_) << std::endl;
303     std::cerr << "ProcessRequest(" << qid << ", '" << namestr
304               << "', " << RRTypeToString(rrtype) << ")" << std::endl;
305   }
306   ProcessRequest(fd, &addr, addrlen, qid, namestr, rrtype);
307 }
308
309 std::set<int> MockServer::fds() const {
310   std::set<int> result = connfds_;
311   result.insert(tcpfd_);
312   result.insert(udpfd_);
313   return result;
314 }
315
316 void MockServer::ProcessRequest(int fd, struct sockaddr_storage* addr, int addrlen,
317                                 int qid, const std::string& name, int rrtype) {
318   // Before processing, let gMock know the request is happening.
319   OnRequest(name, rrtype);
320
321   if (reply_.size() == 0) {
322     return;
323   }
324
325   // Make a local copy of the current pending reply.
326   std::vector<byte> reply = reply_;
327
328   if (qid_ >= 0) {
329     // Use the explicitly specified query ID.
330     qid = qid_;
331   }
332   if (reply.size() >=  2) {
333     // Overwrite the query ID if space to do so.
334     reply[0] = (byte)((qid >> 8) & 0xff);
335     reply[1] = (byte)(qid & 0xff);
336   }
337   if (verbose) std::cerr << "sending reply " << PacketToString(reply)
338                          << " on port " << ((fd == udpfd_) ? udpport_ : tcpport_) << std::endl;
339
340   // Prefix with 2-byte length if TCP.
341   if (fd != udpfd_) {
342     int len = reply.size();
343     std::vector<byte> vlen = {(byte)((len & 0xFF00) >> 8), (byte)(len & 0xFF)};
344     reply.insert(reply.begin(), vlen.begin(), vlen.end());
345     // Also, don't bother with the destination address.
346     addr = nullptr;
347     addrlen = 0;
348   }
349
350   int rc = sendto(fd, BYTE_CAST reply.data(), reply.size(), 0,
351                   (struct sockaddr *)addr, addrlen);
352   if (rc < static_cast<int>(reply.size())) {
353     std::cerr << "Failed to send full reply, rc=" << rc << std::endl;
354   }
355 }
356
357 // static
358 MockChannelOptsTest::NiceMockServers MockChannelOptsTest::BuildServers(int count, int family, int base_port) {
359   NiceMockServers servers;
360   assert(count > 0);
361   for (int ii = 0; ii < count; ii++) {
362     std::unique_ptr<NiceMockServer> server(new NiceMockServer(family, base_port + ii));
363     servers.push_back(std::move(server));
364   }
365   return servers;
366 }
367
368 MockChannelOptsTest::MockChannelOptsTest(int count,
369                                          int family,
370                                          bool force_tcp,
371                                          struct ares_options* givenopts,
372                                          int optmask)
373   : servers_(BuildServers(count, family, mock_port)),
374     server_(*servers_[0].get()), channel_(nullptr) {
375   // Set up channel options.
376   struct ares_options opts;
377   if (givenopts) {
378     memcpy(&opts, givenopts, sizeof(opts));
379   } else {
380     memset(&opts, 0, sizeof(opts));
381   }
382
383   // Point the library at the first mock server by default (overridden below).
384   opts.udp_port = mock_port;
385   optmask |= ARES_OPT_UDP_PORT;
386   opts.tcp_port = mock_port;
387   optmask |= ARES_OPT_TCP_PORT;
388
389   // If not already overridden, set short-ish timeouts.
390   if (!(optmask & (ARES_OPT_TIMEOUTMS|ARES_OPT_TIMEOUT))) {
391     opts.timeout = 1500;
392     optmask |= ARES_OPT_TIMEOUTMS;
393   }
394   // If not already overridden, set 3 retries.
395   if (!(optmask & ARES_OPT_TRIES)) {
396     opts.tries = 3;
397     optmask |= ARES_OPT_TRIES;
398   }
399   // If not already overridden, set search domains.
400   const char *domains[3] = {"first.com", "second.org", "third.gov"};
401   if (!(optmask & ARES_OPT_DOMAINS)) {
402     opts.ndomains = 3;
403     opts.domains = (char**)domains;
404     optmask |= ARES_OPT_DOMAINS;
405   }
406   if (force_tcp) {
407     opts.flags |= ARES_FLAG_USEVC;
408     optmask |= ARES_OPT_FLAGS;
409   }
410
411   EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
412   EXPECT_NE(nullptr, channel_);
413
414   // Set up servers after construction so we can set individual ports
415   struct ares_addr_port_node* prev = nullptr;
416   struct ares_addr_port_node* first = nullptr;
417   for (const auto& server : servers_) {
418     struct ares_addr_port_node* node = (struct ares_addr_port_node*)malloc(sizeof(*node));
419     if (prev) {
420       prev->next = node;
421     } else {
422       first = node;
423     }
424     node->next = nullptr;
425     node->family = family;
426     node->udp_port = server->udpport();
427     node->tcp_port = server->tcpport();
428     if (family == AF_INET) {
429       node->addr.addr4.s_addr = htonl(0x7F000001);
430     } else {
431       memset(&node->addr.addr6, 0, sizeof(node->addr.addr6));
432       node->addr.addr6._S6_un._S6_u8[15] = 1;
433     }
434     prev = node;
435   }
436   EXPECT_EQ(ARES_SUCCESS, ares_set_servers_ports(channel_, first));
437
438   while (first) {
439     prev = first;
440     first = first->next;
441     free(prev);
442   }
443   if (verbose) {
444     std::cerr << "Configured library with servers:";
445     std::vector<std::string> servers = GetNameServers(channel_);
446     for (const auto& server : servers) {
447       std::cerr << " " << server;
448     }
449     std::cerr << std::endl;
450   }
451 }
452
453 MockChannelOptsTest::~MockChannelOptsTest() {
454   if (channel_) {
455     ares_destroy(channel_);
456   }
457   channel_ = nullptr;
458 }
459
460 std::set<int> MockChannelOptsTest::fds() const {
461   std::set<int> fds;
462   for (const auto& server : servers_) {
463     std::set<int> serverfds = server->fds();
464     fds.insert(serverfds.begin(), serverfds.end());
465   }
466   return fds;
467 }
468
469 void MockChannelOptsTest::ProcessFD(int fd) {
470   for (auto& server : servers_) {
471     server->ProcessFD(fd);
472   }
473 }
474
475 void MockChannelOptsTest::Process() {
476   using namespace std::placeholders;
477   ProcessWork(channel_,
478               std::bind(&MockChannelOptsTest::fds, this),
479               std::bind(&MockChannelOptsTest::ProcessFD, this, _1));
480 }
481
482 std::ostream& operator<<(std::ostream& os, const HostResult& result) {
483   os << '{';
484   if (result.done_) {
485     os << StatusToString(result.status_) << " " << result.host_;
486   } else {
487     os << "(incomplete)";
488   }
489   os << '}';
490   return os;
491 }
492
493 HostEnt::HostEnt(const struct hostent *hostent) : addrtype_(-1) {
494   if (!hostent)
495     return;
496   if (hostent->h_name)
497     name_ = hostent->h_name;
498   if (hostent->h_aliases) {
499     char** palias = hostent->h_aliases;
500     while (*palias != nullptr) {
501       aliases_.push_back(*palias);
502       palias++;
503     }
504   }
505   addrtype_ = hostent->h_addrtype;
506   if (hostent->h_addr_list) {
507     char** paddr = hostent->h_addr_list;
508     while (*paddr != nullptr) {
509       std::string addr = AddressToString(*paddr, hostent->h_length);
510       addrs_.push_back(addr);
511       paddr++;
512     }
513   }
514 }
515
516 std::ostream& operator<<(std::ostream& os, const HostEnt& host) {
517   os << '{';
518   os << "'" << host.name_ << "' "
519      << "aliases=[";
520   for (size_t ii = 0; ii < host.aliases_.size(); ii++) {
521     if (ii > 0) os << ", ";
522     os << host.aliases_[ii];
523   }
524   os << "] ";
525   os << "addrs=[";
526   for (size_t ii = 0; ii < host.addrs_.size(); ii++) {
527     if (ii > 0) os << ", ";
528     os << host.addrs_[ii];
529   }
530   os << "]";
531   os << '}';
532   return os;
533 }
534
535 void HostCallback(void *data, int status, int timeouts,
536                   struct hostent *hostent) {
537   EXPECT_NE(nullptr, data);
538   HostResult* result = reinterpret_cast<HostResult*>(data);
539   result->done_ = true;
540   result->status_ = status;
541   result->timeouts_ = timeouts;
542   result->host_ = HostEnt(hostent);
543   if (verbose) std::cerr << "HostCallback(" << *result << ")" << std::endl;
544 }
545
546 std::ostream& operator<<(std::ostream& os, const SearchResult& result) {
547   os << '{';
548   if (result.done_) {
549     os << StatusToString(result.status_) << " " << PacketToString(result.data_);
550   } else {
551     os << "(incomplete)";
552   }
553   os << '}';
554   return os;
555 }
556
557 void SearchCallback(void *data, int status, int timeouts,
558                     unsigned char *abuf, int alen) {
559   EXPECT_NE(nullptr, data);
560   SearchResult* result = reinterpret_cast<SearchResult*>(data);
561   result->done_ = true;
562   result->status_ = status;
563   result->timeouts_ = timeouts;
564   result->data_.assign(abuf, abuf + alen);
565   if (verbose) std::cerr << "SearchCallback(" << *result << ")" << std::endl;
566 }
567
568 std::ostream& operator<<(std::ostream& os, const NameInfoResult& result) {
569   os << '{';
570   if (result.done_) {
571     os << StatusToString(result.status_) << " " << result.node_ << " " << result.service_;
572   } else {
573     os << "(incomplete)";
574   }
575   os << '}';
576   return os;
577 }
578
579 void NameInfoCallback(void *data, int status, int timeouts,
580                       char *node, char *service) {
581   EXPECT_NE(nullptr, data);
582   NameInfoResult* result = reinterpret_cast<NameInfoResult*>(data);
583   result->done_ = true;
584   result->status_ = status;
585   result->timeouts_ = timeouts;
586   result->node_ = std::string(node ? node : "");
587   result->service_ = std::string(service ? service : "");
588   if (verbose) std::cerr << "NameInfoCallback(" << *result << ")" << std::endl;
589 }
590
591 std::vector<std::string> GetNameServers(ares_channel channel) {
592   struct ares_addr_port_node* servers = nullptr;
593   EXPECT_EQ(ARES_SUCCESS, ares_get_servers_ports(channel, &servers));
594   struct ares_addr_port_node* server = servers;
595   std::vector<std::string> results;
596   while (server) {
597     std::stringstream ss;
598     switch (server->family) {
599     case AF_INET:
600       ss << AddressToString((char*)&server->addr.addr4, 4);
601       break;
602     case AF_INET6:
603       if (server->udp_port != 0) {
604         ss << '[';
605       }
606       ss << AddressToString((char*)&server->addr.addr6, 16);
607       if (server->udp_port != 0) {
608         ss << ']';
609       }
610       break;
611     default:
612       results.push_back("<unknown family>");
613       break;
614     }
615     if (server->udp_port != 0) {
616       ss << ":" << server->udp_port;
617     }
618     results.push_back(ss.str());
619     server = server->next;
620   }
621   if (servers) ares_free_data(servers);
622   return results;
623 }
624
625 TransientDir::TransientDir(const std::string& dirname) : dirname_(dirname) {
626   if (mkdir_(dirname_.c_str(), 0755) != 0) {
627     std::cerr << "Failed to create subdirectory '" << dirname_ << "'" << std::endl;
628   }
629 }
630
631 TransientDir::~TransientDir() {
632   rmdir(dirname_.c_str());
633 }
634
635 TransientFile::TransientFile(const std::string& filename,
636                              const std::string& contents)
637     : filename_(filename) {
638   FILE *f = fopen(filename.c_str(), "w");
639   if (f == nullptr) {
640     std::cerr << "Error: failed to create '" << filename << "'" << std::endl;
641     return;
642   }
643   int rc = fwrite(contents.data(), 1, contents.size(), f);
644   if (rc != (int)contents.size()) {
645     std::cerr << "Error: failed to write contents of '" << filename << "'" << std::endl;
646   }
647   fclose(f);
648 }
649
650 TransientFile::~TransientFile() {
651   unlink(filename_.c_str());
652 }
653
654 std::string TempNam(const char *dir, const char *prefix) {
655   char *p = tempnam(dir, prefix);
656   std::string result(p);
657   free(p);
658   return result;
659 }
660
661 TempFile::TempFile(const std::string& contents)
662   : TransientFile(TempNam(nullptr, "ares"), contents) {
663
664 }
665
666 }  // namespace test
667 }  // namespace ares