Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / names / authority.py
1 # -*- test-case-name: twisted.names.test.test_names -*-
2 # Copyright (c) Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 """
6 Authoritative resolvers.
7 """
8
9 import os
10 import time
11
12 from twisted.names import dns
13 from twisted.internet import defer
14 from twisted.python import failure
15
16 import common
17
18 def getSerial(filename = '/tmp/twisted-names.serial'):
19     """Return a monotonically increasing (across program runs) integer.
20
21     State is stored in the given file.  If it does not exist, it is
22     created with rw-/---/--- permissions.
23     """
24     serial = time.strftime('%Y%m%d')
25
26     o = os.umask(0177)
27     try:
28         if not os.path.exists(filename):
29             f = file(filename, 'w')
30             f.write(serial + ' 0')
31             f.close()
32     finally:
33         os.umask(o)
34
35     serialFile = file(filename, 'r')
36     lastSerial, ID = serialFile.readline().split()
37     ID = (lastSerial == serial) and (int(ID) + 1) or 0
38     serialFile.close()
39     serialFile = file(filename, 'w')
40     serialFile.write('%s %d' % (serial, ID))
41     serialFile.close()
42     serial = serial + ('%02d' % (ID,))
43     return serial
44
45
46 #class LookupCacherMixin(object):
47 #    _cache = None
48 #
49 #    def _lookup(self, name, cls, type, timeout = 10):
50 #        if not self._cache:
51 #            self._cache = {}
52 #            self._meth = super(LookupCacherMixin, self)._lookup
53 #
54 #        if self._cache.has_key((name, cls, type)):
55 #            return self._cache[(name, cls, type)]
56 #        else:
57 #            r = self._meth(name, cls, type, timeout)
58 #            self._cache[(name, cls, type)] = r
59 #            return r
60
61
62 class FileAuthority(common.ResolverBase):
63     """An Authority that is loaded from a file."""
64
65     soa = None
66     records = None
67
68     def __init__(self, filename):
69         common.ResolverBase.__init__(self)
70         self.loadFile(filename)
71         self._cache = {}
72
73
74     def __setstate__(self, state):
75         self.__dict__ = state
76 #        print 'setstate ', self.soa
77
78     def _lookup(self, name, cls, type, timeout = None):
79         cnames = []
80         results = []
81         authority = []
82         additional = []
83         default_ttl = max(self.soa[1].minimum, self.soa[1].expire)
84
85         domain_records = self.records.get(name.lower())
86
87         if domain_records:
88             for record in domain_records:
89                 if record.ttl is not None:
90                     ttl = record.ttl
91                 else:
92                     ttl = default_ttl
93
94                 if record.TYPE == dns.NS and name.lower() != self.soa[0].lower():
95                     # NS record belong to a child zone: this is a referral.  As
96                     # NS records are authoritative in the child zone, ours here
97                     # are not.  RFC 2181, section 6.1.
98                     authority.append(
99                         dns.RRHeader(name, record.TYPE, dns.IN, ttl, record, auth=False)
100                     )
101                 elif record.TYPE == type or type == dns.ALL_RECORDS:
102                     results.append(
103                         dns.RRHeader(name, record.TYPE, dns.IN, ttl, record, auth=True)
104                     )
105                 if record.TYPE == dns.CNAME:
106                     cnames.append(
107                         dns.RRHeader(name, record.TYPE, dns.IN, ttl, record, auth=True)
108                     )
109             if not results:
110                 results = cnames
111
112             for record in results + authority:
113                 section = {dns.NS: additional, dns.CNAME: results, dns.MX: additional}.get(record.type)
114                 if section is not None:
115                     n = str(record.payload.name)
116                     for rec in self.records.get(n.lower(), ()):
117                         if rec.TYPE == dns.A:
118                             section.append(
119                                 dns.RRHeader(n, dns.A, dns.IN, rec.ttl or default_ttl, rec, auth=True)
120                             )
121
122             if not results and not authority:
123                 # Empty response. Include SOA record to allow clients to cache
124                 # this response.  RFC 1034, sections 3.7 and 4.3.4, and RFC 2181
125                 # section 7.1.
126                 authority.append(
127                     dns.RRHeader(self.soa[0], dns.SOA, dns.IN, ttl, self.soa[1], auth=True)
128                     )
129             return defer.succeed((results, authority, additional))
130         else:
131             if name.lower().endswith(self.soa[0].lower()):
132                 # We are the authority and we didn't find it.  Goodbye.
133                 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
134             return defer.fail(failure.Failure(dns.DomainError(name)))
135
136
137     def lookupZone(self, name, timeout = 10):
138         if self.soa[0].lower() == name.lower():
139             # Wee hee hee hooo yea
140             default_ttl = max(self.soa[1].minimum, self.soa[1].expire)
141             if self.soa[1].ttl is not None:
142                 soa_ttl = self.soa[1].ttl
143             else:
144                 soa_ttl = default_ttl
145             results = [dns.RRHeader(self.soa[0], dns.SOA, dns.IN, soa_ttl, self.soa[1], auth=True)]
146             for (k, r) in self.records.items():
147                 for rec in r:
148                     if rec.ttl is not None:
149                         ttl = rec.ttl
150                     else:
151                         ttl = default_ttl
152                     if rec.TYPE != dns.SOA:
153                         results.append(dns.RRHeader(k, rec.TYPE, dns.IN, ttl, rec, auth=True))
154             results.append(results[0])
155             return defer.succeed((results, (), ()))
156         return defer.fail(failure.Failure(dns.DomainError(name)))
157
158     def _cbAllRecords(self, results):
159         ans, auth, add = [], [], []
160         for res in results:
161             if res[0]:
162                 ans.extend(res[1][0])
163                 auth.extend(res[1][1])
164                 add.extend(res[1][2])
165         return ans, auth, add
166
167
168 class PySourceAuthority(FileAuthority):
169     """A FileAuthority that is built up from Python source code."""
170
171     def loadFile(self, filename):
172         g, l = self.setupConfigNamespace(), {}
173         execfile(filename, g, l)
174         if not l.has_key('zone'):
175             raise ValueError, "No zone defined in " + filename
176
177         self.records = {}
178         for rr in l['zone']:
179             if isinstance(rr[1], dns.Record_SOA):
180                 self.soa = rr
181             self.records.setdefault(rr[0].lower(), []).append(rr[1])
182
183
184     def wrapRecord(self, type):
185         return lambda name, *arg, **kw: (name, type(*arg, **kw))
186
187
188     def setupConfigNamespace(self):
189         r = {}
190         items = dns.__dict__.iterkeys()
191         for record in [x for x in items if x.startswith('Record_')]:
192             type = getattr(dns, record)
193             f = self.wrapRecord(type)
194             r[record[len('Record_'):]] = f
195         return r
196
197
198 class BindAuthority(FileAuthority):
199     """An Authority that loads BIND configuration files"""
200
201     def loadFile(self, filename):
202         self.origin = os.path.basename(filename) + '.' # XXX - this might suck
203         lines = open(filename).readlines()
204         lines = self.stripComments(lines)
205         lines = self.collapseContinuations(lines)
206         self.parseLines(lines)
207
208
209     def stripComments(self, lines):
210         return [
211             a.find(';') == -1 and a or a[:a.find(';')] for a in [
212                 b.strip() for b in lines
213             ]
214         ]
215
216
217     def collapseContinuations(self, lines):
218         L = []
219         state = 0
220         for line in lines:
221             if state == 0:
222                 if line.find('(') == -1:
223                     L.append(line)
224                 else:
225                     L.append(line[:line.find('(')])
226                     state = 1
227             else:
228                 if line.find(')') != -1:
229                     L[-1] += ' ' + line[:line.find(')')]
230                     state = 0
231                 else:
232                     L[-1] += ' ' + line
233         lines = L
234         L = []
235         for line in lines:
236             L.append(line.split())
237         return filter(None, L)
238
239
240     def parseLines(self, lines):
241         TTL = 60 * 60 * 3
242         ORIGIN = self.origin
243
244         self.records = {}
245
246         for (line, index) in zip(lines, range(len(lines))):
247             if line[0] == '$TTL':
248                 TTL = dns.str2time(line[1])
249             elif line[0] == '$ORIGIN':
250                 ORIGIN = line[1]
251             elif line[0] == '$INCLUDE': # XXX - oh, fuck me
252                 raise NotImplementedError('$INCLUDE directive not implemented')
253             elif line[0] == '$GENERATE':
254                 raise NotImplementedError('$GENERATE directive not implemented')
255             else:
256                 self.parseRecordLine(ORIGIN, TTL, line)
257
258
259     def addRecord(self, owner, ttl, type, domain, cls, rdata):
260         if not domain.endswith('.'):
261             domain = domain + '.' + owner
262         else:
263             domain = domain[:-1]
264         f = getattr(self, 'class_%s' % cls, None)
265         if f:
266             f(ttl, type, domain, rdata)
267         else:
268             raise NotImplementedError, "Record class %r not supported" % cls
269
270
271     def class_IN(self, ttl, type, domain, rdata):
272         record = getattr(dns, 'Record_%s' % type, None)
273         if record:
274             r = record(*rdata)
275             r.ttl = ttl
276             self.records.setdefault(domain.lower(), []).append(r)
277
278             print 'Adding IN Record', domain, ttl, r
279             if type == 'SOA':
280                 self.soa = (domain, r)
281         else:
282             raise NotImplementedError, "Record type %r not supported" % type
283
284
285     #
286     # This file ends here.  Read no further.
287     #
288     def parseRecordLine(self, origin, ttl, line):
289         MARKERS = dns.QUERY_CLASSES.values() + dns.QUERY_TYPES.values()
290         cls = 'IN'
291         owner = origin
292
293         if line[0] == '@':
294             line = line[1:]
295             owner = origin
296 #            print 'default owner'
297         elif not line[0].isdigit() and line[0] not in MARKERS:
298             owner = line[0]
299             line = line[1:]
300 #            print 'owner is ', owner
301
302         if line[0].isdigit() or line[0] in MARKERS:
303             domain = owner
304             owner = origin
305 #            print 'woops, owner is ', owner, ' domain is ', domain
306         else:
307             domain = line[0]
308             line = line[1:]
309 #            print 'domain is ', domain
310
311         if line[0] in dns.QUERY_CLASSES.values():
312             cls = line[0]
313             line = line[1:]
314 #            print 'cls is ', cls
315             if line[0].isdigit():
316                 ttl = int(line[0])
317                 line = line[1:]
318 #                print 'ttl is ', ttl
319         elif line[0].isdigit():
320             ttl = int(line[0])
321             line = line[1:]
322 #            print 'ttl is ', ttl
323             if line[0] in dns.QUERY_CLASSES.values():
324                 cls = line[0]
325                 line = line[1:]
326 #                print 'cls is ', cls
327
328         type = line[0]
329 #        print 'type is ', type
330         rdata = line[1:]
331 #        print 'rdata is ', rdata
332
333         self.addRecord(owner, ttl, type, domain, cls, rdata)