Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / names / srvconnect.py
1 # -*- test-case-name: twisted.names.test.test_srvconnect -*-
2 # Copyright (c) Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 import random
6
7 from zope.interface import implements
8
9 from twisted.internet import error, interfaces
10
11 from twisted.names import client, dns
12 from twisted.names.error import DNSNameError
13 from twisted.python.compat import reduce
14
15 class _SRVConnector_ClientFactoryWrapper:
16     def __init__(self, connector, wrappedFactory):
17         self.__connector = connector
18         self.__wrappedFactory = wrappedFactory
19
20     def startedConnecting(self, connector):
21         self.__wrappedFactory.startedConnecting(self.__connector)
22
23     def clientConnectionFailed(self, connector, reason):
24         self.__connector.connectionFailed(reason)
25
26     def clientConnectionLost(self, connector, reason):
27         self.__connector.connectionLost(reason)
28
29     def __getattr__(self, key):
30         return getattr(self.__wrappedFactory, key)
31
32 class SRVConnector:
33     """A connector that looks up DNS SRV records. See RFC2782."""
34
35     implements(interfaces.IConnector)
36
37     stopAfterDNS=0
38
39     def __init__(self, reactor, service, domain, factory,
40                  protocol='tcp', connectFuncName='connectTCP',
41                  connectFuncArgs=(),
42                  connectFuncKwArgs={},
43                  ):
44         self.reactor = reactor
45         self.service = service
46         self.domain = domain
47         self.factory = factory
48
49         self.protocol = protocol
50         self.connectFuncName = connectFuncName
51         self.connectFuncArgs = connectFuncArgs
52         self.connectFuncKwArgs = connectFuncKwArgs
53
54         self.connector = None
55         self.servers = None
56         self.orderedServers = None # list of servers already used in this round
57
58     def connect(self):
59         """Start connection to remote server."""
60         self.factory.doStart()
61         self.factory.startedConnecting(self)
62
63         if not self.servers:
64             if self.domain is None:
65                 self.connectionFailed(error.DNSLookupError("Domain is not defined."))
66                 return
67             d = client.lookupService('_%s._%s.%s' % (self.service,
68                                                      self.protocol,
69                                                      self.domain))
70             d.addCallbacks(self._cbGotServers, self._ebGotServers)
71             d.addCallback(lambda x, self=self: self._reallyConnect())
72             d.addErrback(self.connectionFailed)
73         elif self.connector is None:
74             self._reallyConnect()
75         else:
76             self.connector.connect()
77
78     def _ebGotServers(self, failure):
79         failure.trap(DNSNameError)
80
81         # Some DNS servers reply with NXDOMAIN when in fact there are
82         # just no SRV records for that domain. Act as if we just got an
83         # empty response and use fallback.
84
85         self.servers = []
86         self.orderedServers = []
87
88     def _cbGotServers(self, (answers, auth, add)):
89         if len(answers) == 1 and answers[0].type == dns.SRV \
90                              and answers[0].payload \
91                              and answers[0].payload.target == dns.Name('.'):
92             # decidedly not available
93             raise error.DNSLookupError("Service %s not available for domain %s."
94                                        % (repr(self.service), repr(self.domain)))
95
96         self.servers = []
97         self.orderedServers = []
98         for a in answers:
99             if a.type != dns.SRV or not a.payload:
100                 continue
101
102             self.orderedServers.append((a.payload.priority, a.payload.weight,
103                                         str(a.payload.target), a.payload.port))
104
105     def _serverCmp(self, a, b):
106         if a[0]!=b[0]:
107             return cmp(a[0], b[0])
108         else:
109             return cmp(a[1], b[1])
110
111     def pickServer(self):
112         assert self.servers is not None
113         assert self.orderedServers is not None
114
115         if not self.servers and not self.orderedServers:
116             # no SRV record, fall back..
117             return self.domain, self.service
118
119         if not self.servers and self.orderedServers:
120             # start new round
121             self.servers = self.orderedServers
122             self.orderedServers = []
123
124         assert self.servers
125
126         self.servers.sort(self._serverCmp)
127         minPriority=self.servers[0][0]
128
129         weightIndex = zip(xrange(len(self.servers)), [x[1] for x in self.servers
130                                                       if x[0]==minPriority])
131         weightSum = reduce(lambda x, y: (None, x[1]+y[1]), weightIndex, (None, 0))[1]
132         rand = random.randint(0, weightSum)
133
134         for index, weight in weightIndex:
135             weightSum -= weight
136             if weightSum <= 0:
137                 chosen = self.servers[index]
138                 del self.servers[index]
139                 self.orderedServers.append(chosen)
140
141                 p, w, host, port = chosen
142                 return host, port
143
144         raise RuntimeError, 'Impossible %s pickServer result.' % self.__class__.__name__
145
146     def _reallyConnect(self):
147         if self.stopAfterDNS:
148             self.stopAfterDNS=0
149             return
150
151         self.host, self.port = self.pickServer()
152         assert self.host is not None, 'Must have a host to connect to.'
153         assert self.port is not None, 'Must have a port to connect to.'
154
155         connectFunc = getattr(self.reactor, self.connectFuncName)
156         self.connector=connectFunc(
157             self.host, self.port,
158             _SRVConnector_ClientFactoryWrapper(self, self.factory),
159             *self.connectFuncArgs, **self.connectFuncKwArgs)
160
161     def stopConnecting(self):
162         """Stop attempting to connect."""
163         if self.connector:
164             self.connector.stopConnecting()
165         else:
166             self.stopAfterDNS=1
167
168     def disconnect(self):
169         """Disconnect whatever our are state is."""
170         if self.connector is not None:
171             self.connector.disconnect()
172         else:
173             self.stopConnecting()
174
175     def getDestination(self):
176         assert self.connector
177         return self.connector.getDestination()
178
179     def connectionFailed(self, reason):
180         self.factory.clientConnectionFailed(self, reason)
181         self.factory.doStop()
182
183     def connectionLost(self, reason):
184         self.factory.clientConnectionLost(self, reason)
185         self.factory.doStop()
186