dnsresolv: compare returned name; "better" CNAME handling
authorH. Peter Anvin <hpa@zytor.com>
Mon, 24 Aug 2009 23:32:48 +0000 (16:32 -0700)
committerH. Peter Anvin <hpa@zytor.com>
Mon, 24 Aug 2009 23:32:48 +0000 (16:32 -0700)
Compare the returned name with the name we requested.  Handle CNAMEs
by overwriting the requested name with the CNAME, however, we really
should:

a) rescan the packet from the beginning;
b) send a new request if we get no A record for the CNAME.

Signed-off-by: H. Peter Anvin <hpa@zytor.com>
core/fs/pxe/dnsresolv.c

index 2d90cc7..b2c36b7 100644 (file)
@@ -3,6 +3,13 @@
 #include <core.h>
 #include "pxe.h"
 
+/* DNS CLASS values we care about */
+#define CLASS_IN       1
+
+/* DNS TYPE values we care about */
+#define TYPE_A         1
+#define TYPE_CNAME     5
+
 /*
  * The DNS header structure
  */
@@ -81,29 +88,61 @@ int dns_mangle(char **dst, const char *p)
     
 
 /*
- * Compare two sets of DNS labels, in _s1_ and _s2_; the one in _s1_
- * is allowed pointers relative to a packet in DNSRecvBuf.
+ * Compare two sets of DNS labels, in _s1_ and _s2_; the one in _s2_
+ * is allowed pointers relative to a packet in buf.
  *
  */
-static bool dns_compare(const char *s1, const char *s2)
+static bool dns_compare(const void *s1, const void *s2, const void *buf)
 {
-#if 0
+    const uint8_t *q = s1;
+    const uint8_t *p = s2;
+    unsigned int c0, c1;
+
     while (1) {
-        if (*s1 < 0xc0)
-            break;
-        s1 = DNSRecvBuf + (((*s1++ & 0x3f) << 8) | (*s1++));
+       c0 = p[0];
+        if (c0 >= 0xc0) {
+           /* Follow pointer */
+           c1 = p[1];
+           p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1;
+       } else if (c0) {
+           c0++;               /* Include the length byte */
+           if (memcmp(q, p, c0))
+               return false;
+           q += c0;
+           p += c0;
+       } else {
+           return *q == 0;
+       }
+    }
+}
+
+/*
+ * Copy a DNS label into a buffer, considering the possibility that we might
+ * have to follow pointers relative to "buf".
+ * Returns a pointer to the first free byte *after* the terminal null.
+ */
+static void *dns_copylabel(void *dst, const void *src, const void *buf)
+{
+    uint8_t *q = dst;
+    const uint8_t *p = src;
+    unsigned int c0, c1;
+    
+    while (1) {
+       c0 = p[0];
+        if (c0 >= 0xc0) {
+           /* Follow pointer */
+           c1 = p[1];
+           p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1;
+       } else if (c0) {
+           c0++;               /* Include the length byte */
+           memcpy(q, p, c0);
+           p += c0;
+           q += c0;
+       } else {
+           *q++ = 0;
+           return q;
+       }
     }
-    if (*s1 == 0)
-        return true;
-    else if (*s1++ != *s2++)
-        return false; /* not same */
-    else
-        return !strcmp(s1, s2);
-#else
-    (void)s1;
-    (void)s2;
-    return true;
-#endif
 }
 
 /*
@@ -169,8 +208,8 @@ uint32_t dns_resolv(const char *name)
     
     /* Fill the DNS query packet */
     query = (struct dnsquery *)p;
-    query->qtype  = htons(1);  /* QTYPE  = 1 = A */
-    query->qclass = htons(1);  /* QCLASS = 1 = IN */
+    query->qtype  = htons(TYPE_A);
+    query->qclass = htons(CLASS_IN);
     p += sizeof(struct dnsquery);
    
     /* Now send it to name server */
@@ -230,15 +269,31 @@ uint32_t dns_resolv(const char *name)
 
         /* Parse the replies */
         while (reps--) {
-            same = dns_compare(p, (char *)(DNSSendBuf + sizeof(struct dnshdr)));
+            same = dns_compare(DNSSendBuf + sizeof(struct dnshdr),
+                              p, DNSRecvBuf);
             p = dns_skiplabel(p);
             rr = (struct dnsrr *)p;
-            rd_len = htons(rr->rdlength);
-            if (same && rd_len == 4   &&
-                htons(rr->type) == 1  && /* TYPE  == A */
-                htons(rr->class) == 1 )  /* CLASS == IN */
-                return *(uint32_t *)rr->rdata;
-            
+            rd_len = ntohs(rr->rdlength);
+            if (same && ntohs(rr->class) == CLASS_IN) {
+               switch (ntohs(rr->type)) {
+               case TYPE_A:
+                   if (rd_len == 4)
+                       return *(uint32_t *)rr->rdata;
+                   break;
+               case TYPE_CNAME:
+                   dns_copylabel(DNSSendBuf + sizeof(struct dnshdr),
+                                 rr->rdata, DNSRecvBuf);
+                   /*
+                    * We should probably rescan the packet from the top
+                    * here, and technically we might have to send a whole
+                    * new request here...
+                    */
+                   break;
+               default:
+                   break;
+               }
+           }             
+
             /* not the one we want, try next */
             p += sizeof(struct dnsrr) + rd_len;
         }