Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / conch / test / test_agent.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3
4 """
5 Tests for L{twisted.conch.ssh.agent}.
6 """
7
8 import struct
9
10 from twisted.trial import unittest
11
12 try:
13     import OpenSSL
14 except ImportError:
15     iosim = None
16 else:
17     from twisted.test import iosim
18
19 try:
20     import Crypto.Cipher.DES3
21 except ImportError:
22     Crypto = None
23
24 try:
25     import pyasn1
26 except ImportError:
27     pyasn1 = None
28
29 if Crypto and pyasn1:
30     from twisted.conch.ssh import keys, agent
31 else:
32     keys = agent = None
33
34 from twisted.conch.test import keydata
35 from twisted.conch.error import ConchError, MissingKeyStoreError
36
37
38 class StubFactory(object):
39     """
40     Mock factory that provides the keys attribute required by the
41     SSHAgentServerProtocol
42     """
43     def __init__(self):
44         self.keys = {}
45
46
47
48 class AgentTestBase(unittest.TestCase):
49     """
50     Tests for SSHAgentServer/Client.
51     """
52     if iosim is None:
53         skip = "iosim requires SSL, but SSL is not available"
54     elif agent is None or keys is None:
55         skip = "Cannot run without PyCrypto or PyASN1"
56
57     def setUp(self):
58         # wire up our client <-> server
59         self.client, self.server, self.pump = iosim.connectedServerAndClient(
60             agent.SSHAgentServer, agent.SSHAgentClient)
61
62         # the server's end of the protocol is stateful and we store it on the
63         # factory, for which we only need a mock
64         self.server.factory = StubFactory()
65
66         # pub/priv keys of each kind
67         self.rsaPrivate = keys.Key.fromString(keydata.privateRSA_openssh)
68         self.dsaPrivate = keys.Key.fromString(keydata.privateDSA_openssh)
69
70         self.rsaPublic = keys.Key.fromString(keydata.publicRSA_openssh)
71         self.dsaPublic = keys.Key.fromString(keydata.publicDSA_openssh)
72
73
74
75 class TestServerProtocolContractWithFactory(AgentTestBase):
76     """
77     The server protocol is stateful and so uses its factory to track state
78     across requests.  This test asserts that the protocol raises if its factory
79     doesn't provide the necessary storage for that state.
80     """
81     def test_factorySuppliesKeyStorageForServerProtocol(self):
82         # need a message to send into the server
83         msg = struct.pack('!LB',1, agent.AGENTC_REQUEST_IDENTITIES)
84         del self.server.factory.__dict__['keys']
85         self.assertRaises(MissingKeyStoreError,
86                           self.server.dataReceived, msg)
87
88
89
90 class TestUnimplementedVersionOneServer(AgentTestBase):
91     """
92     Tests for methods with no-op implementations on the server. We need these
93     for clients, such as openssh, that try v1 methods before going to v2.
94
95     Because the client doesn't expose these operations with nice method names,
96     we invoke sendRequest directly with an op code.
97     """
98
99     def test_agentc_REQUEST_RSA_IDENTITIES(self):
100         """
101         assert that we get the correct op code for an RSA identities request
102         """
103         d = self.client.sendRequest(agent.AGENTC_REQUEST_RSA_IDENTITIES, '')
104         self.pump.flush()
105         def _cb(packet):
106             self.assertEqual(
107                 agent.AGENT_RSA_IDENTITIES_ANSWER, ord(packet[0]))
108         return d.addCallback(_cb)
109
110
111     def test_agentc_REMOVE_RSA_IDENTITY(self):
112         """
113         assert that we get the correct op code for an RSA remove identity request
114         """
115         d = self.client.sendRequest(agent.AGENTC_REMOVE_RSA_IDENTITY, '')
116         self.pump.flush()
117         return d.addCallback(self.assertEqual, '')
118
119
120     def test_agentc_REMOVE_ALL_RSA_IDENTITIES(self):
121         """
122         assert that we get the correct op code for an RSA remove all identities
123         request.
124         """
125         d = self.client.sendRequest(agent.AGENTC_REMOVE_ALL_RSA_IDENTITIES, '')
126         self.pump.flush()
127         return d.addCallback(self.assertEqual, '')
128
129
130
131 if agent is not None:
132     class CorruptServer(agent.SSHAgentServer):
133         """
134         A misbehaving server that returns bogus response op codes so that we can
135         verify that our callbacks that deal with these op codes handle such
136         miscreants.
137         """
138         def agentc_REQUEST_IDENTITIES(self, data):
139             self.sendResponse(254, '')
140
141
142         def agentc_SIGN_REQUEST(self, data):
143             self.sendResponse(254, '')
144
145
146
147 class TestClientWithBrokenServer(AgentTestBase):
148     """
149     verify error handling code in the client using a misbehaving server
150     """
151
152     def setUp(self):
153         AgentTestBase.setUp(self)
154         self.client, self.server, self.pump = iosim.connectedServerAndClient(
155             CorruptServer, agent.SSHAgentClient)
156         # the server's end of the protocol is stateful and we store it on the
157         # factory, for which we only need a mock
158         self.server.factory = StubFactory()
159
160
161     def test_signDataCallbackErrorHandling(self):
162         """
163         Assert that L{SSHAgentClient.signData} raises a ConchError
164         if we get a response from the server whose opcode doesn't match
165         the protocol for data signing requests.
166         """
167         d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
168         self.pump.flush()
169         return self.assertFailure(d, ConchError)
170
171
172     def test_requestIdentitiesCallbackErrorHandling(self):
173         """
174         Assert that L{SSHAgentClient.requestIdentities} raises a ConchError
175         if we get a response from the server whose opcode doesn't match
176         the protocol for identity requests.
177         """
178         d = self.client.requestIdentities()
179         self.pump.flush()
180         return self.assertFailure(d, ConchError)
181
182
183
184 class TestAgentKeyAddition(AgentTestBase):
185     """
186     Test adding different flavors of keys to an agent.
187     """
188
189     def test_addRSAIdentityNoComment(self):
190         """
191         L{SSHAgentClient.addIdentity} adds the private key it is called
192         with to the SSH agent server to which it is connected, associating
193         it with the comment it is called with.
194
195         This test asserts that ommitting the comment produces an
196         empty string for the comment on the server.
197         """
198         d = self.client.addIdentity(self.rsaPrivate.privateBlob())
199         self.pump.flush()
200         def _check(ignored):
201             serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
202             self.assertEqual(self.rsaPrivate, serverKey[0])
203             self.assertEqual('', serverKey[1])
204         return d.addCallback(_check)
205
206
207     def test_addDSAIdentityNoComment(self):
208         """
209         L{SSHAgentClient.addIdentity} adds the private key it is called
210         with to the SSH agent server to which it is connected, associating
211         it with the comment it is called with.
212
213         This test asserts that ommitting the comment produces an
214         empty string for the comment on the server.
215         """
216         d = self.client.addIdentity(self.dsaPrivate.privateBlob())
217         self.pump.flush()
218         def _check(ignored):
219             serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
220             self.assertEqual(self.dsaPrivate, serverKey[0])
221             self.assertEqual('', serverKey[1])
222         return d.addCallback(_check)
223
224
225     def test_addRSAIdentityWithComment(self):
226         """
227         L{SSHAgentClient.addIdentity} adds the private key it is called
228         with to the SSH agent server to which it is connected, associating
229         it with the comment it is called with.
230
231         This test asserts that the server receives/stores the comment
232         as sent by the client.
233         """
234         d = self.client.addIdentity(
235             self.rsaPrivate.privateBlob(), comment='My special key')
236         self.pump.flush()
237         def _check(ignored):
238             serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
239             self.assertEqual(self.rsaPrivate, serverKey[0])
240             self.assertEqual('My special key', serverKey[1])
241         return d.addCallback(_check)
242
243
244     def test_addDSAIdentityWithComment(self):
245         """
246         L{SSHAgentClient.addIdentity} adds the private key it is called
247         with to the SSH agent server to which it is connected, associating
248         it with the comment it is called with.
249
250         This test asserts that the server receives/stores the comment
251         as sent by the client.
252         """
253         d = self.client.addIdentity(
254             self.dsaPrivate.privateBlob(), comment='My special key')
255         self.pump.flush()
256         def _check(ignored):
257             serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
258             self.assertEqual(self.dsaPrivate, serverKey[0])
259             self.assertEqual('My special key', serverKey[1])
260         return d.addCallback(_check)
261
262
263
264 class TestAgentClientFailure(AgentTestBase):
265     def test_agentFailure(self):
266         """
267         verify that the client raises ConchError on AGENT_FAILURE
268         """
269         d = self.client.sendRequest(254, '')
270         self.pump.flush()
271         return self.assertFailure(d, ConchError)
272
273
274
275 class TestAgentIdentityRequests(AgentTestBase):
276     """
277     Test operations against a server with identities already loaded.
278     """
279
280     def setUp(self):
281         AgentTestBase.setUp(self)
282         self.server.factory.keys[self.dsaPrivate.blob()] = (
283             self.dsaPrivate, 'a comment')
284         self.server.factory.keys[self.rsaPrivate.blob()] = (
285             self.rsaPrivate, 'another comment')
286
287
288     def test_signDataRSA(self):
289         """
290         Sign data with an RSA private key and then verify it with the public
291         key.
292         """
293         d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
294         self.pump.flush()
295         def _check(sig):
296             expected = self.rsaPrivate.sign("John Hancock")
297             self.assertEqual(expected, sig)
298             self.assertTrue(self.rsaPublic.verify(sig, "John Hancock"))
299         return d.addCallback(_check)
300
301
302     def test_signDataDSA(self):
303         """
304         Sign data with a DSA private key and then verify it with the public
305         key.
306         """
307         d = self.client.signData(self.dsaPublic.blob(), "John Hancock")
308         self.pump.flush()
309         def _check(sig):
310             # Cannot do this b/c DSA uses random numbers when signing
311             #   expected = self.dsaPrivate.sign("John Hancock")
312             #   self.assertEqual(expected, sig)
313             self.assertTrue(self.dsaPublic.verify(sig, "John Hancock"))
314         return d.addCallback(_check)
315
316
317     def test_signDataRSAErrbackOnUnknownBlob(self):
318         """
319         Assert that we get an errback if we try to sign data using a key that
320         wasn't added.
321         """
322         del self.server.factory.keys[self.rsaPublic.blob()]
323         d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
324         self.pump.flush()
325         return self.assertFailure(d, ConchError)
326
327
328     def test_requestIdentities(self):
329         """
330         Assert that we get all of the keys/comments that we add when we issue a
331         request for all identities.
332         """
333         d = self.client.requestIdentities()
334         self.pump.flush()
335         def _check(keyt):
336             expected = {}
337             expected[self.dsaPublic.blob()] = 'a comment'
338             expected[self.rsaPublic.blob()] = 'another comment'
339
340             received = {}
341             for k in keyt:
342                 received[keys.Key.fromString(k[0], type='blob').blob()] = k[1]
343             self.assertEqual(expected, received)
344         return d.addCallback(_check)
345
346
347
348 class TestAgentKeyRemoval(AgentTestBase):
349     """
350     Test support for removing keys in a remote server.
351     """
352
353     def setUp(self):
354         AgentTestBase.setUp(self)
355         self.server.factory.keys[self.dsaPrivate.blob()] = (
356             self.dsaPrivate, 'a comment')
357         self.server.factory.keys[self.rsaPrivate.blob()] = (
358             self.rsaPrivate, 'another comment')
359
360
361     def test_removeRSAIdentity(self):
362         """
363         Assert that we can remove an RSA identity.
364         """
365         # only need public key for this
366         d = self.client.removeIdentity(self.rsaPrivate.blob())
367         self.pump.flush()
368
369         def _check(ignored):
370             self.assertEqual(1, len(self.server.factory.keys))
371             self.assertIn(self.dsaPrivate.blob(), self.server.factory.keys)
372             self.assertNotIn(self.rsaPrivate.blob(), self.server.factory.keys)
373         return d.addCallback(_check)
374
375
376     def test_removeDSAIdentity(self):
377         """
378         Assert that we can remove a DSA identity.
379         """
380         # only need public key for this
381         d = self.client.removeIdentity(self.dsaPrivate.blob())
382         self.pump.flush()
383
384         def _check(ignored):
385             self.assertEqual(1, len(self.server.factory.keys))
386             self.assertIn(self.rsaPrivate.blob(), self.server.factory.keys)
387         return d.addCallback(_check)
388
389
390     def test_removeAllIdentities(self):
391         """
392         Assert that we can remove all identities.
393         """
394         d = self.client.removeAllIdentities()
395         self.pump.flush()
396
397         def _check(ignored):
398             self.assertEqual(0, len(self.server.factory.keys))
399         return d.addCallback(_check)