1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
5 Tests for L{twisted.conch.ssh.agent}.
10 from twisted.trial import unittest
17 from twisted.test import iosim
20 import Crypto.Cipher.DES3
30 from twisted.conch.ssh import keys, agent
34 from twisted.conch.test import keydata
35 from twisted.conch.error import ConchError, MissingKeyStoreError
38 class StubFactory(object):
40 Mock factory that provides the keys attribute required by the
41 SSHAgentServerProtocol
48 class AgentTestBase(unittest.TestCase):
50 Tests for SSHAgentServer/Client.
53 skip = "iosim requires SSL, but SSL is not available"
54 elif agent is None or keys is None:
55 skip = "Cannot run without PyCrypto or PyASN1"
58 # wire up our client <-> server
59 self.client, self.server, self.pump = iosim.connectedServerAndClient(
60 agent.SSHAgentServer, agent.SSHAgentClient)
62 # the server's end of the protocol is stateful and we store it on the
63 # factory, for which we only need a mock
64 self.server.factory = StubFactory()
66 # pub/priv keys of each kind
67 self.rsaPrivate = keys.Key.fromString(keydata.privateRSA_openssh)
68 self.dsaPrivate = keys.Key.fromString(keydata.privateDSA_openssh)
70 self.rsaPublic = keys.Key.fromString(keydata.publicRSA_openssh)
71 self.dsaPublic = keys.Key.fromString(keydata.publicDSA_openssh)
75 class TestServerProtocolContractWithFactory(AgentTestBase):
77 The server protocol is stateful and so uses its factory to track state
78 across requests. This test asserts that the protocol raises if its factory
79 doesn't provide the necessary storage for that state.
81 def test_factorySuppliesKeyStorageForServerProtocol(self):
82 # need a message to send into the server
83 msg = struct.pack('!LB',1, agent.AGENTC_REQUEST_IDENTITIES)
84 del self.server.factory.__dict__['keys']
85 self.assertRaises(MissingKeyStoreError,
86 self.server.dataReceived, msg)
90 class TestUnimplementedVersionOneServer(AgentTestBase):
92 Tests for methods with no-op implementations on the server. We need these
93 for clients, such as openssh, that try v1 methods before going to v2.
95 Because the client doesn't expose these operations with nice method names,
96 we invoke sendRequest directly with an op code.
99 def test_agentc_REQUEST_RSA_IDENTITIES(self):
101 assert that we get the correct op code for an RSA identities request
103 d = self.client.sendRequest(agent.AGENTC_REQUEST_RSA_IDENTITIES, '')
107 agent.AGENT_RSA_IDENTITIES_ANSWER, ord(packet[0]))
108 return d.addCallback(_cb)
111 def test_agentc_REMOVE_RSA_IDENTITY(self):
113 assert that we get the correct op code for an RSA remove identity request
115 d = self.client.sendRequest(agent.AGENTC_REMOVE_RSA_IDENTITY, '')
117 return d.addCallback(self.assertEqual, '')
120 def test_agentc_REMOVE_ALL_RSA_IDENTITIES(self):
122 assert that we get the correct op code for an RSA remove all identities
125 d = self.client.sendRequest(agent.AGENTC_REMOVE_ALL_RSA_IDENTITIES, '')
127 return d.addCallback(self.assertEqual, '')
131 if agent is not None:
132 class CorruptServer(agent.SSHAgentServer):
134 A misbehaving server that returns bogus response op codes so that we can
135 verify that our callbacks that deal with these op codes handle such
138 def agentc_REQUEST_IDENTITIES(self, data):
139 self.sendResponse(254, '')
142 def agentc_SIGN_REQUEST(self, data):
143 self.sendResponse(254, '')
147 class TestClientWithBrokenServer(AgentTestBase):
149 verify error handling code in the client using a misbehaving server
153 AgentTestBase.setUp(self)
154 self.client, self.server, self.pump = iosim.connectedServerAndClient(
155 CorruptServer, agent.SSHAgentClient)
156 # the server's end of the protocol is stateful and we store it on the
157 # factory, for which we only need a mock
158 self.server.factory = StubFactory()
161 def test_signDataCallbackErrorHandling(self):
163 Assert that L{SSHAgentClient.signData} raises a ConchError
164 if we get a response from the server whose opcode doesn't match
165 the protocol for data signing requests.
167 d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
169 return self.assertFailure(d, ConchError)
172 def test_requestIdentitiesCallbackErrorHandling(self):
174 Assert that L{SSHAgentClient.requestIdentities} raises a ConchError
175 if we get a response from the server whose opcode doesn't match
176 the protocol for identity requests.
178 d = self.client.requestIdentities()
180 return self.assertFailure(d, ConchError)
184 class TestAgentKeyAddition(AgentTestBase):
186 Test adding different flavors of keys to an agent.
189 def test_addRSAIdentityNoComment(self):
191 L{SSHAgentClient.addIdentity} adds the private key it is called
192 with to the SSH agent server to which it is connected, associating
193 it with the comment it is called with.
195 This test asserts that ommitting the comment produces an
196 empty string for the comment on the server.
198 d = self.client.addIdentity(self.rsaPrivate.privateBlob())
201 serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
202 self.assertEqual(self.rsaPrivate, serverKey[0])
203 self.assertEqual('', serverKey[1])
204 return d.addCallback(_check)
207 def test_addDSAIdentityNoComment(self):
209 L{SSHAgentClient.addIdentity} adds the private key it is called
210 with to the SSH agent server to which it is connected, associating
211 it with the comment it is called with.
213 This test asserts that ommitting the comment produces an
214 empty string for the comment on the server.
216 d = self.client.addIdentity(self.dsaPrivate.privateBlob())
219 serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
220 self.assertEqual(self.dsaPrivate, serverKey[0])
221 self.assertEqual('', serverKey[1])
222 return d.addCallback(_check)
225 def test_addRSAIdentityWithComment(self):
227 L{SSHAgentClient.addIdentity} adds the private key it is called
228 with to the SSH agent server to which it is connected, associating
229 it with the comment it is called with.
231 This test asserts that the server receives/stores the comment
232 as sent by the client.
234 d = self.client.addIdentity(
235 self.rsaPrivate.privateBlob(), comment='My special key')
238 serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
239 self.assertEqual(self.rsaPrivate, serverKey[0])
240 self.assertEqual('My special key', serverKey[1])
241 return d.addCallback(_check)
244 def test_addDSAIdentityWithComment(self):
246 L{SSHAgentClient.addIdentity} adds the private key it is called
247 with to the SSH agent server to which it is connected, associating
248 it with the comment it is called with.
250 This test asserts that the server receives/stores the comment
251 as sent by the client.
253 d = self.client.addIdentity(
254 self.dsaPrivate.privateBlob(), comment='My special key')
257 serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
258 self.assertEqual(self.dsaPrivate, serverKey[0])
259 self.assertEqual('My special key', serverKey[1])
260 return d.addCallback(_check)
264 class TestAgentClientFailure(AgentTestBase):
265 def test_agentFailure(self):
267 verify that the client raises ConchError on AGENT_FAILURE
269 d = self.client.sendRequest(254, '')
271 return self.assertFailure(d, ConchError)
275 class TestAgentIdentityRequests(AgentTestBase):
277 Test operations against a server with identities already loaded.
281 AgentTestBase.setUp(self)
282 self.server.factory.keys[self.dsaPrivate.blob()] = (
283 self.dsaPrivate, 'a comment')
284 self.server.factory.keys[self.rsaPrivate.blob()] = (
285 self.rsaPrivate, 'another comment')
288 def test_signDataRSA(self):
290 Sign data with an RSA private key and then verify it with the public
293 d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
296 expected = self.rsaPrivate.sign("John Hancock")
297 self.assertEqual(expected, sig)
298 self.assertTrue(self.rsaPublic.verify(sig, "John Hancock"))
299 return d.addCallback(_check)
302 def test_signDataDSA(self):
304 Sign data with a DSA private key and then verify it with the public
307 d = self.client.signData(self.dsaPublic.blob(), "John Hancock")
310 # Cannot do this b/c DSA uses random numbers when signing
311 # expected = self.dsaPrivate.sign("John Hancock")
312 # self.assertEqual(expected, sig)
313 self.assertTrue(self.dsaPublic.verify(sig, "John Hancock"))
314 return d.addCallback(_check)
317 def test_signDataRSAErrbackOnUnknownBlob(self):
319 Assert that we get an errback if we try to sign data using a key that
322 del self.server.factory.keys[self.rsaPublic.blob()]
323 d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
325 return self.assertFailure(d, ConchError)
328 def test_requestIdentities(self):
330 Assert that we get all of the keys/comments that we add when we issue a
331 request for all identities.
333 d = self.client.requestIdentities()
337 expected[self.dsaPublic.blob()] = 'a comment'
338 expected[self.rsaPublic.blob()] = 'another comment'
342 received[keys.Key.fromString(k[0], type='blob').blob()] = k[1]
343 self.assertEqual(expected, received)
344 return d.addCallback(_check)
348 class TestAgentKeyRemoval(AgentTestBase):
350 Test support for removing keys in a remote server.
354 AgentTestBase.setUp(self)
355 self.server.factory.keys[self.dsaPrivate.blob()] = (
356 self.dsaPrivate, 'a comment')
357 self.server.factory.keys[self.rsaPrivate.blob()] = (
358 self.rsaPrivate, 'another comment')
361 def test_removeRSAIdentity(self):
363 Assert that we can remove an RSA identity.
365 # only need public key for this
366 d = self.client.removeIdentity(self.rsaPrivate.blob())
370 self.assertEqual(1, len(self.server.factory.keys))
371 self.assertIn(self.dsaPrivate.blob(), self.server.factory.keys)
372 self.assertNotIn(self.rsaPrivate.blob(), self.server.factory.keys)
373 return d.addCallback(_check)
376 def test_removeDSAIdentity(self):
378 Assert that we can remove a DSA identity.
380 # only need public key for this
381 d = self.client.removeIdentity(self.dsaPrivate.blob())
385 self.assertEqual(1, len(self.server.factory.keys))
386 self.assertIn(self.rsaPrivate.blob(), self.server.factory.keys)
387 return d.addCallback(_check)
390 def test_removeAllIdentities(self):
392 Assert that we can remove all identities.
394 d = self.client.removeAllIdentities()
398 self.assertEqual(0, len(self.server.factory.keys))
399 return d.addCallback(_check)