1 # -*- test-case-name: twisted.names.test.test_srvconnect -*-
2 # Copyright (c) Twisted Matrix Laboratories.
3 # See LICENSE for details.
7 from zope.interface import implements
9 from twisted.internet import error, interfaces
11 from twisted.names import client, dns
12 from twisted.names.error import DNSNameError
13 from twisted.python.compat import reduce
15 class _SRVConnector_ClientFactoryWrapper:
16 def __init__(self, connector, wrappedFactory):
17 self.__connector = connector
18 self.__wrappedFactory = wrappedFactory
20 def startedConnecting(self, connector):
21 self.__wrappedFactory.startedConnecting(self.__connector)
23 def clientConnectionFailed(self, connector, reason):
24 self.__connector.connectionFailed(reason)
26 def clientConnectionLost(self, connector, reason):
27 self.__connector.connectionLost(reason)
29 def __getattr__(self, key):
30 return getattr(self.__wrappedFactory, key)
33 """A connector that looks up DNS SRV records. See RFC2782."""
35 implements(interfaces.IConnector)
39 def __init__(self, reactor, service, domain, factory,
40 protocol='tcp', connectFuncName='connectTCP',
44 self.reactor = reactor
45 self.service = service
47 self.factory = factory
49 self.protocol = protocol
50 self.connectFuncName = connectFuncName
51 self.connectFuncArgs = connectFuncArgs
52 self.connectFuncKwArgs = connectFuncKwArgs
56 self.orderedServers = None # list of servers already used in this round
59 """Start connection to remote server."""
60 self.factory.doStart()
61 self.factory.startedConnecting(self)
64 if self.domain is None:
65 self.connectionFailed(error.DNSLookupError("Domain is not defined."))
67 d = client.lookupService('_%s._%s.%s' % (self.service,
70 d.addCallbacks(self._cbGotServers, self._ebGotServers)
71 d.addCallback(lambda x, self=self: self._reallyConnect())
72 d.addErrback(self.connectionFailed)
73 elif self.connector is None:
76 self.connector.connect()
78 def _ebGotServers(self, failure):
79 failure.trap(DNSNameError)
81 # Some DNS servers reply with NXDOMAIN when in fact there are
82 # just no SRV records for that domain. Act as if we just got an
83 # empty response and use fallback.
86 self.orderedServers = []
88 def _cbGotServers(self, (answers, auth, add)):
89 if len(answers) == 1 and answers[0].type == dns.SRV \
90 and answers[0].payload \
91 and answers[0].payload.target == dns.Name('.'):
92 # decidedly not available
93 raise error.DNSLookupError("Service %s not available for domain %s."
94 % (repr(self.service), repr(self.domain)))
97 self.orderedServers = []
99 if a.type != dns.SRV or not a.payload:
102 self.orderedServers.append((a.payload.priority, a.payload.weight,
103 str(a.payload.target), a.payload.port))
105 def _serverCmp(self, a, b):
107 return cmp(a[0], b[0])
109 return cmp(a[1], b[1])
111 def pickServer(self):
112 assert self.servers is not None
113 assert self.orderedServers is not None
115 if not self.servers and not self.orderedServers:
116 # no SRV record, fall back..
117 return self.domain, self.service
119 if not self.servers and self.orderedServers:
121 self.servers = self.orderedServers
122 self.orderedServers = []
126 self.servers.sort(self._serverCmp)
127 minPriority=self.servers[0][0]
129 weightIndex = zip(xrange(len(self.servers)), [x[1] for x in self.servers
130 if x[0]==minPriority])
131 weightSum = reduce(lambda x, y: (None, x[1]+y[1]), weightIndex, (None, 0))[1]
132 rand = random.randint(0, weightSum)
134 for index, weight in weightIndex:
137 chosen = self.servers[index]
138 del self.servers[index]
139 self.orderedServers.append(chosen)
141 p, w, host, port = chosen
144 raise RuntimeError, 'Impossible %s pickServer result.' % self.__class__.__name__
146 def _reallyConnect(self):
147 if self.stopAfterDNS:
151 self.host, self.port = self.pickServer()
152 assert self.host is not None, 'Must have a host to connect to.'
153 assert self.port is not None, 'Must have a port to connect to.'
155 connectFunc = getattr(self.reactor, self.connectFuncName)
156 self.connector=connectFunc(
157 self.host, self.port,
158 _SRVConnector_ClientFactoryWrapper(self, self.factory),
159 *self.connectFuncArgs, **self.connectFuncKwArgs)
161 def stopConnecting(self):
162 """Stop attempting to connect."""
164 self.connector.stopConnecting()
168 def disconnect(self):
169 """Disconnect whatever our are state is."""
170 if self.connector is not None:
171 self.connector.disconnect()
173 self.stopConnecting()
175 def getDestination(self):
176 assert self.connector
177 return self.connector.getDestination()
179 def connectionFailed(self, reason):
180 self.factory.clientConnectionFailed(self, reason)
181 self.factory.doStop()
183 def connectionLost(self, reason):
184 self.factory.clientConnectionLost(self, reason)
185 self.factory.doStop()