4 // Include ares internal files for DNS protocol details
11 #ifdef HAVE_NETINET_TCP_H
12 #include <netinet/tcp.h>
21 #define BYTE_CAST (char *)
22 #define sclose(x) closesocket(x)
23 #define mkdir_(d, p) mkdir(d)
26 #define sclose(x) close(x)
27 #define mkdir_(d, p) mkdir(d, p)
36 unsigned long long LibraryTest::fails_ = 0;
37 std::map<size_t, int> LibraryTest::size_fails_;
39 void ProcessWork(ares_channel channel,
40 std::function<std::set<int>()> get_extrafds,
41 std::function<void(int)> process_extra) {
43 fd_set readers, writers;
46 // Retrieve the set of file descriptors that the library wants us to monitor.
49 nfds = ares_fds(channel, &readers, &writers);
50 if (nfds == 0) // no work left to do in the library
53 // Add in the extra FDs if present.
54 std::set<int> extrafds = get_extrafds();
55 for (int extrafd : extrafds) {
56 FD_SET(extrafd, &readers);
57 if (extrafd >= nfds) {
62 // Wait for activity or timeout.
64 tv.tv_usec = 100000; // 100ms
65 count = select(nfds, &readers, &writers, nullptr, &tv);
67 fprintf(stderr, "select() failed, errno %d\n", errno);
71 // Let the library process any activity.
72 ares_process(channel, &readers, &writers);
74 // Let the provided callback process any activity on the extra FD.
75 for (int extrafd : extrafds) {
76 if (FD_ISSET(extrafd, &readers)) {
77 process_extra(extrafd);
84 void LibraryTest::SetAllocFail(int nth) {
86 assert(nth <= (int)(8 * sizeof(fails_)));
87 fails_ |= (1LL << (nth - 1));
91 void LibraryTest::SetAllocSizeFail(size_t size) {
96 void LibraryTest::ClearFails() {
103 bool LibraryTest::ShouldAllocFail(size_t size) {
104 bool fail = (fails_ & 0x01);
106 if (size_fails_[size] > 0) {
114 void* LibraryTest::amalloc(size_t size) {
115 if (ShouldAllocFail(size)) {
116 if (verbose) std::cerr << "Failing malloc(" << size << ") request" << std::endl;
124 void* LibraryTest::arealloc(void *ptr, size_t size) {
125 if (ShouldAllocFail(size)) {
126 if (verbose) std::cerr << "Failing realloc(" << ptr << ", " << size << ") request" << std::endl;
129 return realloc(ptr, size);
134 void LibraryTest::afree(void *ptr) {
138 std::set<int> NoExtraFDs() {
139 return std::set<int>();
142 void DefaultChannelTest::Process() {
143 ProcessWork(channel_, NoExtraFDs, nullptr);
146 void DefaultChannelModeTest::Process() {
147 ProcessWork(channel_, NoExtraFDs, nullptr);
150 MockServer::MockServer(int family, int port, int tcpport)
151 : udpport_(port), tcpport_(tcpport ? tcpport : udpport_), qid_(-1) {
152 // Create a TCP socket to receive data on.
153 tcpfd_ = socket(family, SOCK_STREAM, 0);
154 EXPECT_NE(-1, tcpfd_);
156 setsockopt(tcpfd_, SOL_SOCKET, SO_REUSEADDR,
157 BYTE_CAST &optval , sizeof(int));
158 // Send TCP data right away.
159 setsockopt(tcpfd_, IPPROTO_TCP, TCP_NODELAY,
160 BYTE_CAST &optval , sizeof(int));
162 // Create a UDP socket to receive data on.
163 udpfd_ = socket(family, SOCK_DGRAM, 0);
164 EXPECT_NE(-1, udpfd_);
166 // Bind the sockets to the given port.
167 if (family == AF_INET) {
168 struct sockaddr_in addr;
169 memset(&addr, 0, sizeof(addr));
170 addr.sin_family = AF_INET;
171 addr.sin_addr.s_addr = htonl(INADDR_ANY);
172 addr.sin_port = htons(tcpport_);
173 int tcprc = bind(tcpfd_, (struct sockaddr*)&addr, sizeof(addr));
174 EXPECT_EQ(0, tcprc) << "Failed to bind AF_INET to TCP port " << tcpport_;
175 addr.sin_port = htons(udpport_);
176 int udprc = bind(udpfd_, (struct sockaddr*)&addr, sizeof(addr));
177 EXPECT_EQ(0, udprc) << "Failed to bind AF_INET to UDP port " << udpport_;
179 EXPECT_EQ(AF_INET6, family);
180 struct sockaddr_in6 addr;
181 memset(&addr, 0, sizeof(addr));
182 addr.sin6_family = AF_INET6;
183 memset(&addr.sin6_addr, 0, sizeof(addr.sin6_addr)); // in6addr_any
184 addr.sin6_port = htons(tcpport_);
185 int tcprc = bind(tcpfd_, (struct sockaddr*)&addr, sizeof(addr));
186 EXPECT_EQ(0, tcprc) << "Failed to bind AF_INET6 to TCP port " << tcpport_;
187 addr.sin6_port = htons(udpport_);
188 int udprc = bind(udpfd_, (struct sockaddr*)&addr, sizeof(addr));
189 EXPECT_EQ(0, udprc) << "Failed to bind AF_INET6 to UDP port " << udpport_;
191 if (verbose) std::cerr << "Configured "
192 << (family == AF_INET ? "IPv4" : "IPv6")
193 << " mock server with TCP socket " << tcpfd_
194 << " on port " << tcpport_
195 << " and UDP socket " << udpfd_
196 << " on port " << udpport_ << std::endl;
198 // For TCP, also need to listen for connections.
199 EXPECT_EQ(0, listen(tcpfd_, 5)) << "Failed to listen for TCP connections";
202 MockServer::~MockServer() {
203 for (int fd : connfds_) {
210 void MockServer::ProcessFD(int fd) {
211 if (fd != tcpfd_ && fd != udpfd_ && connfds_.find(fd) == connfds_.end()) {
212 // Not one of our FDs.
216 int connfd = accept(tcpfd_, NULL, NULL);
218 std::cerr << "Error accepting connection on fd " << fd << std::endl;
220 connfds_.insert(connfd);
225 // Activity on a data-bearing file descriptor.
226 struct sockaddr_storage addr;
227 socklen_t addrlen = sizeof(addr);
229 int len = recvfrom(fd, BYTE_CAST buffer, sizeof(buffer), 0,
230 (struct sockaddr *)&addr, &addrlen);
234 connfds_.erase(std::find(connfds_.begin(), connfds_.end(), fd));
239 std::cerr << "Packet too short (" << len << ")" << std::endl;
242 int tcplen = (data[0] << 8) + data[1];
246 std::cerr << "Warning: TCP length " << tcplen
247 << " doesn't match remaining data length " << len << std::endl;
251 // Assume the packet is a well-formed DNS request and extract the request
253 if (len < NS_HFIXEDSZ) {
254 std::cerr << "Packet too short (" << len << ")" << std::endl;
257 int qid = DNS_HEADER_QID(data);
258 if (DNS_HEADER_QR(data) != 0) {
259 std::cerr << "Not a request" << std::endl;
262 if (DNS_HEADER_OPCODE(data) != ns_o_query) {
263 std::cerr << "Not a query (opcode " << DNS_HEADER_OPCODE(data)
267 if (DNS_HEADER_QDCOUNT(data) != 1) {
268 std::cerr << "Unexpected question count (" << DNS_HEADER_QDCOUNT(data)
272 byte* question = data + 12;
275 char *name = nullptr;
277 ares_expand_name(question, data, len, &name, &enclen);
279 std::cerr << "Failed to retrieve name" << std::endl;
284 std::string namestr(name);
288 std::cerr << "Unexpected question size (" << qlen
289 << " bytes after name)" << std::endl;
292 if (DNS_QUESTION_CLASS(question) != ns_c_in) {
293 std::cerr << "Unexpected question class (" << DNS_QUESTION_CLASS(question)
297 int rrtype = DNS_QUESTION_TYPE(question);
300 std::vector<byte> req(data, data + len);
301 std::cerr << "received " << (fd == udpfd_ ? "UDP" : "TCP") << " request " << PacketToString(req)
302 << " on port " << (fd == udpfd_ ? udpport_ : tcpport_) << std::endl;
303 std::cerr << "ProcessRequest(" << qid << ", '" << namestr
304 << "', " << RRTypeToString(rrtype) << ")" << std::endl;
306 ProcessRequest(fd, &addr, addrlen, qid, namestr, rrtype);
309 std::set<int> MockServer::fds() const {
310 std::set<int> result = connfds_;
311 result.insert(tcpfd_);
312 result.insert(udpfd_);
316 void MockServer::ProcessRequest(int fd, struct sockaddr_storage* addr, int addrlen,
317 int qid, const std::string& name, int rrtype) {
318 // Before processing, let gMock know the request is happening.
319 OnRequest(name, rrtype);
321 if (reply_.size() == 0) {
325 // Make a local copy of the current pending reply.
326 std::vector<byte> reply = reply_;
329 // Use the explicitly specified query ID.
332 if (reply.size() >= 2) {
333 // Overwrite the query ID if space to do so.
334 reply[0] = (byte)((qid >> 8) & 0xff);
335 reply[1] = (byte)(qid & 0xff);
337 if (verbose) std::cerr << "sending reply " << PacketToString(reply)
338 << " on port " << ((fd == udpfd_) ? udpport_ : tcpport_) << std::endl;
340 // Prefix with 2-byte length if TCP.
342 int len = reply.size();
343 std::vector<byte> vlen = {(byte)((len & 0xFF00) >> 8), (byte)(len & 0xFF)};
344 reply.insert(reply.begin(), vlen.begin(), vlen.end());
345 // Also, don't bother with the destination address.
350 int rc = sendto(fd, BYTE_CAST reply.data(), reply.size(), 0,
351 (struct sockaddr *)addr, addrlen);
352 if (rc < static_cast<int>(reply.size())) {
353 std::cerr << "Failed to send full reply, rc=" << rc << std::endl;
358 MockChannelOptsTest::NiceMockServers MockChannelOptsTest::BuildServers(int count, int family, int base_port) {
359 NiceMockServers servers;
361 for (int ii = 0; ii < count; ii++) {
362 std::unique_ptr<NiceMockServer> server(new NiceMockServer(family, base_port + ii));
363 servers.push_back(std::move(server));
368 MockChannelOptsTest::MockChannelOptsTest(int count,
371 struct ares_options* givenopts,
373 : servers_(BuildServers(count, family, mock_port)),
374 server_(*servers_[0].get()), channel_(nullptr) {
375 // Set up channel options.
376 struct ares_options opts;
378 memcpy(&opts, givenopts, sizeof(opts));
380 memset(&opts, 0, sizeof(opts));
383 // Point the library at the first mock server by default (overridden below).
384 opts.udp_port = mock_port;
385 optmask |= ARES_OPT_UDP_PORT;
386 opts.tcp_port = mock_port;
387 optmask |= ARES_OPT_TCP_PORT;
389 // If not already overridden, set short-ish timeouts.
390 if (!(optmask & (ARES_OPT_TIMEOUTMS|ARES_OPT_TIMEOUT))) {
392 optmask |= ARES_OPT_TIMEOUTMS;
394 // If not already overridden, set 3 retries.
395 if (!(optmask & ARES_OPT_TRIES)) {
397 optmask |= ARES_OPT_TRIES;
399 // If not already overridden, set search domains.
400 const char *domains[3] = {"first.com", "second.org", "third.gov"};
401 if (!(optmask & ARES_OPT_DOMAINS)) {
403 opts.domains = (char**)domains;
404 optmask |= ARES_OPT_DOMAINS;
407 opts.flags |= ARES_FLAG_USEVC;
408 optmask |= ARES_OPT_FLAGS;
411 EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask));
412 EXPECT_NE(nullptr, channel_);
414 // Set up servers after construction so we can set individual ports
415 struct ares_addr_port_node* prev = nullptr;
416 struct ares_addr_port_node* first = nullptr;
417 for (const auto& server : servers_) {
418 struct ares_addr_port_node* node = (struct ares_addr_port_node*)malloc(sizeof(*node));
424 node->next = nullptr;
425 node->family = family;
426 node->udp_port = server->udpport();
427 node->tcp_port = server->tcpport();
428 if (family == AF_INET) {
429 node->addr.addr4.s_addr = htonl(0x7F000001);
431 memset(&node->addr.addr6, 0, sizeof(node->addr.addr6));
432 node->addr.addr6._S6_un._S6_u8[15] = 1;
436 EXPECT_EQ(ARES_SUCCESS, ares_set_servers_ports(channel_, first));
444 std::cerr << "Configured library with servers:";
445 std::vector<std::string> servers = GetNameServers(channel_);
446 for (const auto& server : servers) {
447 std::cerr << " " << server;
449 std::cerr << std::endl;
453 MockChannelOptsTest::~MockChannelOptsTest() {
455 ares_destroy(channel_);
460 std::set<int> MockChannelOptsTest::fds() const {
462 for (const auto& server : servers_) {
463 std::set<int> serverfds = server->fds();
464 fds.insert(serverfds.begin(), serverfds.end());
469 void MockChannelOptsTest::ProcessFD(int fd) {
470 for (auto& server : servers_) {
471 server->ProcessFD(fd);
475 void MockChannelOptsTest::Process() {
476 using namespace std::placeholders;
477 ProcessWork(channel_,
478 std::bind(&MockChannelOptsTest::fds, this),
479 std::bind(&MockChannelOptsTest::ProcessFD, this, _1));
482 std::ostream& operator<<(std::ostream& os, const HostResult& result) {
485 os << StatusToString(result.status_) << " " << result.host_;
487 os << "(incomplete)";
493 HostEnt::HostEnt(const struct hostent *hostent) : addrtype_(-1) {
497 name_ = hostent->h_name;
498 if (hostent->h_aliases) {
499 char** palias = hostent->h_aliases;
500 while (*palias != nullptr) {
501 aliases_.push_back(*palias);
505 addrtype_ = hostent->h_addrtype;
506 if (hostent->h_addr_list) {
507 char** paddr = hostent->h_addr_list;
508 while (*paddr != nullptr) {
509 std::string addr = AddressToString(*paddr, hostent->h_length);
510 addrs_.push_back(addr);
516 std::ostream& operator<<(std::ostream& os, const HostEnt& host) {
518 os << "'" << host.name_ << "' "
520 for (size_t ii = 0; ii < host.aliases_.size(); ii++) {
521 if (ii > 0) os << ", ";
522 os << host.aliases_[ii];
526 for (size_t ii = 0; ii < host.addrs_.size(); ii++) {
527 if (ii > 0) os << ", ";
528 os << host.addrs_[ii];
535 void HostCallback(void *data, int status, int timeouts,
536 struct hostent *hostent) {
537 EXPECT_NE(nullptr, data);
538 HostResult* result = reinterpret_cast<HostResult*>(data);
539 result->done_ = true;
540 result->status_ = status;
541 result->timeouts_ = timeouts;
542 result->host_ = HostEnt(hostent);
543 if (verbose) std::cerr << "HostCallback(" << *result << ")" << std::endl;
546 std::ostream& operator<<(std::ostream& os, const SearchResult& result) {
549 os << StatusToString(result.status_) << " " << PacketToString(result.data_);
551 os << "(incomplete)";
557 void SearchCallback(void *data, int status, int timeouts,
558 unsigned char *abuf, int alen) {
559 EXPECT_NE(nullptr, data);
560 SearchResult* result = reinterpret_cast<SearchResult*>(data);
561 result->done_ = true;
562 result->status_ = status;
563 result->timeouts_ = timeouts;
564 result->data_.assign(abuf, abuf + alen);
565 if (verbose) std::cerr << "SearchCallback(" << *result << ")" << std::endl;
568 std::ostream& operator<<(std::ostream& os, const NameInfoResult& result) {
571 os << StatusToString(result.status_) << " " << result.node_ << " " << result.service_;
573 os << "(incomplete)";
579 void NameInfoCallback(void *data, int status, int timeouts,
580 char *node, char *service) {
581 EXPECT_NE(nullptr, data);
582 NameInfoResult* result = reinterpret_cast<NameInfoResult*>(data);
583 result->done_ = true;
584 result->status_ = status;
585 result->timeouts_ = timeouts;
586 result->node_ = std::string(node ? node : "");
587 result->service_ = std::string(service ? service : "");
588 if (verbose) std::cerr << "NameInfoCallback(" << *result << ")" << std::endl;
591 std::vector<std::string> GetNameServers(ares_channel channel) {
592 struct ares_addr_port_node* servers = nullptr;
593 EXPECT_EQ(ARES_SUCCESS, ares_get_servers_ports(channel, &servers));
594 struct ares_addr_port_node* server = servers;
595 std::vector<std::string> results;
597 std::stringstream ss;
598 switch (server->family) {
600 ss << AddressToString((char*)&server->addr.addr4, 4);
603 if (server->udp_port != 0) {
606 ss << AddressToString((char*)&server->addr.addr6, 16);
607 if (server->udp_port != 0) {
612 results.push_back("<unknown family>");
615 if (server->udp_port != 0) {
616 ss << ":" << server->udp_port;
618 results.push_back(ss.str());
619 server = server->next;
621 if (servers) ares_free_data(servers);
625 TransientDir::TransientDir(const std::string& dirname) : dirname_(dirname) {
626 if (mkdir_(dirname_.c_str(), 0755) != 0) {
627 std::cerr << "Failed to create subdirectory '" << dirname_ << "'" << std::endl;
631 TransientDir::~TransientDir() {
632 rmdir(dirname_.c_str());
635 TransientFile::TransientFile(const std::string& filename,
636 const std::string& contents)
637 : filename_(filename) {
638 FILE *f = fopen(filename.c_str(), "w");
640 std::cerr << "Error: failed to create '" << filename << "'" << std::endl;
643 int rc = fwrite(contents.data(), 1, contents.size(), f);
644 if (rc != (int)contents.size()) {
645 std::cerr << "Error: failed to write contents of '" << filename << "'" << std::endl;
650 TransientFile::~TransientFile() {
651 unlink(filename_.c_str());
654 std::string TempNam(const char *dir, const char *prefix) {
655 char *p = tempnam(dir, prefix);
656 std::string result(p);
661 TempFile::TempFile(const std::string& contents)
662 : TransientFile(TempNam(nullptr, "ares"), contents) {