Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / names / server.py
1 # -*- test-case-name: twisted.names.test.test_names -*-
2 # Copyright (c) Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 """
6 Async DNS server
7
8 Future plans:
9     - Better config file format maybe
10     - Make sure to differentiate between different classes
11     - notice truncation bit
12
13 Important: No additional processing is done on some of the record types.
14 This violates the most basic RFC and is just plain annoying
15 for resolvers to deal with.  Fix it.
16
17 @author: Jp Calderone
18 """
19
20 import time
21
22 from twisted.internet import protocol
23 from twisted.names import dns, resolve
24 from twisted.python import log
25
26
27 class DNSServerFactory(protocol.ServerFactory):
28     """
29     Server factory and tracker for L{DNSProtocol} connections.  This
30     class also provides records for responses to DNS queries.
31
32     @ivar connections: A list of all the connected L{DNSProtocol}
33         instances using this object as their controller.
34     @type connections: C{list} of L{DNSProtocol}
35     """
36
37     protocol = dns.DNSProtocol
38     cache = None
39
40     def __init__(self, authorities = None, caches = None, clients = None, verbose = 0):
41         resolvers = []
42         if authorities is not None:
43             resolvers.extend(authorities)
44         if caches is not None:
45             resolvers.extend(caches)
46         if clients is not None:
47             resolvers.extend(clients)
48
49         self.canRecurse = not not clients
50         self.resolver = resolve.ResolverChain(resolvers)
51         self.verbose = verbose
52         if caches:
53             self.cache = caches[-1]
54         self.connections = []
55
56
57     def buildProtocol(self, addr):
58         p = self.protocol(self)
59         p.factory = self
60         return p
61
62
63     def connectionMade(self, protocol):
64         """
65         Track a newly connected L{DNSProtocol}.
66         """
67         self.connections.append(protocol)
68
69
70     def connectionLost(self, protocol):
71         """
72         Stop tracking a no-longer connected L{DNSProtocol}.
73         """
74         self.connections.remove(protocol)
75
76
77     def sendReply(self, protocol, message, address):
78         if self.verbose > 1:
79             s = ' '.join([str(a.payload) for a in message.answers])
80             auth = ' '.join([str(a.payload) for a in message.authority])
81             add = ' '.join([str(a.payload) for a in message.additional])
82             if not s:
83                 log.msg("Replying with no answers")
84             else:
85                 log.msg("Answers are " + s)
86                 log.msg("Authority is " + auth)
87                 log.msg("Additional is " + add)
88
89         if address is None:
90             protocol.writeMessage(message)
91         else:
92             protocol.writeMessage(message, address)
93
94         if self.verbose > 1:
95             log.msg("Processed query in %0.3f seconds" % (time.time() - message.timeReceived))
96
97
98     def gotResolverResponse(self, (ans, auth, add), protocol, message, address):
99         message.rCode = dns.OK
100         message.answers = ans
101         for x in ans:
102             if x.isAuthoritative():
103                 message.auth = 1
104                 break
105         message.authority = auth
106         message.additional = add
107         self.sendReply(protocol, message, address)
108
109         l = len(ans) + len(auth) + len(add)
110         if self.verbose:
111             log.msg("Lookup found %d record%s" % (l, l != 1 and "s" or ""))
112
113         if self.cache and l:
114             self.cache.cacheResult(
115                 message.queries[0], (ans, auth, add)
116             )
117
118
119     def gotResolverError(self, failure, protocol, message, address):
120         if failure.check(dns.DomainError, dns.AuthoritativeDomainError):
121             message.rCode = dns.ENAME
122         else:
123             message.rCode = dns.ESERVER
124             log.err(failure)
125
126         self.sendReply(protocol, message, address)
127         if self.verbose:
128             log.msg("Lookup failed")
129
130
131     def handleQuery(self, message, protocol, address):
132         # Discard all but the first query!  HOO-AAH HOOOOO-AAAAH
133         # (no other servers implement multi-query messages, so we won't either)
134         query = message.queries[0]
135
136         return self.resolver.query(query).addCallback(
137             self.gotResolverResponse, protocol, message, address
138         ).addErrback(
139             self.gotResolverError, protocol, message, address
140         )
141
142
143     def handleInverseQuery(self, message, protocol, address):
144         message.rCode = dns.ENOTIMP
145         self.sendReply(protocol, message, address)
146         if self.verbose:
147             log.msg("Inverse query from %r" % (address,))
148
149
150     def handleStatus(self, message, protocol, address):
151         message.rCode = dns.ENOTIMP
152         self.sendReply(protocol, message, address)
153         if self.verbose:
154             log.msg("Status request from %r" % (address,))
155
156
157     def handleNotify(self, message, protocol, address):
158         message.rCode = dns.ENOTIMP
159         self.sendReply(protocol, message, address)
160         if self.verbose:
161             log.msg("Notify message from %r" % (address,))
162
163
164     def handleOther(self, message, protocol, address):
165         message.rCode = dns.ENOTIMP
166         self.sendReply(protocol, message, address)
167         if self.verbose:
168             log.msg("Unknown op code (%d) from %r" % (message.opCode, address))
169
170
171     def messageReceived(self, message, proto, address = None):
172         message.timeReceived = time.time()
173
174         if self.verbose:
175             if self.verbose > 1:
176                 s = ' '.join([str(q) for q in message.queries])
177             elif self.verbose > 0:
178                 s = ' '.join([dns.QUERY_TYPES.get(q.type, 'UNKNOWN') for q in message.queries])
179
180             if not len(s):
181                 log.msg("Empty query from %r" % ((address or proto.transport.getPeer()),))
182             else:
183                 log.msg("%s query from %r" % (s, address or proto.transport.getPeer()))
184
185         message.recAv = self.canRecurse
186         message.answer = 1
187
188         if not self.allowQuery(message, proto, address):
189             message.rCode = dns.EREFUSED
190             self.sendReply(proto, message, address)
191         elif message.opCode == dns.OP_QUERY:
192             self.handleQuery(message, proto, address)
193         elif message.opCode == dns.OP_INVERSE:
194             self.handleInverseQuery(message, proto, address)
195         elif message.opCode == dns.OP_STATUS:
196             self.handleStatus(message, proto, address)
197         elif message.opCode == dns.OP_NOTIFY:
198             self.handleNotify(message, proto, address)
199         else:
200             self.handleOther(message, proto, address)
201
202
203     def allowQuery(self, message, protocol, address):
204         # Allow anything but empty queries
205         return len(message.queries)