1 # -*- test-case-name: twisted.conch.test.test_insults -*-
2 # Copyright (c) Twisted Matrix Laboratories.
3 # See LICENSE for details.
5 from twisted.trial import unittest
6 from twisted.test.proto_helpers import StringTransport
8 from twisted.conch.insults.insults import ServerProtocol, ClientProtocol
9 from twisted.conch.insults.insults import CS_UK, CS_US, CS_DRAWING, CS_ALTERNATE, CS_ALTERNATE_SPECIAL
10 from twisted.conch.insults.insults import G0, G1
11 from twisted.conch.insults.insults import modes
13 def _getattr(mock, name):
14 return super(Mock, mock).__getattribute__(name)
16 def occurrences(mock):
17 return _getattr(mock, 'occurrences')
20 return _getattr(mock, 'methods')
22 def _append(mock, obj):
23 occurrences(mock).append(obj)
28 callReturnValue = default
30 def __init__(self, methods=None, callReturnValue=default):
32 @param methods: Mapping of names to return values
33 @param callReturnValue: object __call__ should return
38 self.methods = methods
39 if callReturnValue is not default:
40 self.callReturnValue = callReturnValue
42 def __call__(self, *a, **kw):
43 returnValue = _getattr(self, 'callReturnValue')
44 if returnValue is default:
46 # _getattr(self, 'occurrences').append(('__call__', returnValue, a, kw))
47 _append(self, ('__call__', returnValue, a, kw))
50 def __getattribute__(self, name):
51 methods = _getattr(self, 'methods')
53 attrValue = Mock(callReturnValue=methods[name])
56 # _getattr(self, 'occurrences').append((name, attrValue))
57 _append(self, (name, attrValue))
61 def assertCall(self, occurrence, methodName, expectedPositionalArgs=(),
62 expectedKeywordArgs={}):
63 attr, mock = occurrence
64 self.assertEqual(attr, methodName)
65 self.assertEqual(len(occurrences(mock)), 1)
66 [(call, result, args, kw)] = occurrences(mock)
67 self.assertEqual(call, "__call__")
68 self.assertEqual(args, expectedPositionalArgs)
69 self.assertEqual(kw, expectedKeywordArgs)
73 _byteGroupingTestTemplate = """\
74 def testByte%(groupName)s(self):
75 transport = StringTransport()
77 parser = self.protocolFactory(lambda: proto)
79 parser.makeConnection(transport)
81 bytes = self.TEST_BYTES
83 chunk = bytes[:%(bytesPer)d]
84 bytes = bytes[%(bytesPer)d:]
85 parser.dataReceived(chunk)
87 self.verifyResults(transport, proto, parser)
89 class ByteGroupingsMixin(MockMixin):
90 protocolFactory = None
92 for word, n in [('Pairs', 2), ('Triples', 3), ('Quads', 4), ('Quints', 5), ('Sexes', 6)]:
93 exec _byteGroupingTestTemplate % {'groupName': word, 'bytesPer': n}
96 def verifyResults(self, transport, proto, parser):
97 result = self.assertCall(occurrences(proto).pop(0), "makeConnection", (parser,))
98 self.assertEqual(occurrences(result), [])
100 del _byteGroupingTestTemplate
102 class ServerArrowKeys(ByteGroupingsMixin, unittest.TestCase):
103 protocolFactory = ServerProtocol
105 # All the arrow keys once
106 TEST_BYTES = '\x1b[A\x1b[B\x1b[C\x1b[D'
108 def verifyResults(self, transport, proto, parser):
109 ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
111 for arrow in (parser.UP_ARROW, parser.DOWN_ARROW,
112 parser.RIGHT_ARROW, parser.LEFT_ARROW):
113 result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (arrow, None))
114 self.assertEqual(occurrences(result), [])
115 self.failIf(occurrences(proto))
118 class PrintableCharacters(ByteGroupingsMixin, unittest.TestCase):
119 protocolFactory = ServerProtocol
121 # Some letters and digits, first on their own, then capitalized,
122 # then modified with alt
124 TEST_BYTES = 'abc123ABC!@#\x1ba\x1bb\x1bc\x1b1\x1b2\x1b3'
126 def verifyResults(self, transport, proto, parser):
127 ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
129 for char in 'abc123ABC!@#':
130 result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (char, None))
131 self.assertEqual(occurrences(result), [])
133 for char in 'abc123':
134 result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (char, parser.ALT))
135 self.assertEqual(occurrences(result), [])
137 occs = occurrences(proto)
138 self.failIf(occs, "%r should have been []" % (occs,))
140 class ServerFunctionKeys(ByteGroupingsMixin, unittest.TestCase):
141 """Test for parsing and dispatching function keys (F1 - F12)
143 protocolFactory = ServerProtocol
146 for bytes in ('OP', 'OQ', 'OR', 'OS', # F1 - F4
147 '15~', '17~', '18~', '19~', # F5 - F8
148 '20~', '21~', '23~', '24~'): # F9 - F12
149 byteList.append('\x1b[' + bytes)
150 TEST_BYTES = ''.join(byteList)
153 def verifyResults(self, transport, proto, parser):
154 ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
155 for funcNum in range(1, 13):
156 funcArg = getattr(parser, 'F%d' % (funcNum,))
157 result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (funcArg, None))
158 self.assertEqual(occurrences(result), [])
159 self.failIf(occurrences(proto))
161 class ClientCursorMovement(ByteGroupingsMixin, unittest.TestCase):
162 protocolFactory = ClientProtocol
168 # Move the cursor down two, right four, up one, left two, up one, left two
169 TEST_BYTES = d2 + r4 + u1 + l2 + u1 + l2
172 def verifyResults(self, transport, proto, parser):
173 ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
175 for (method, count) in [('Down', 2), ('Forward', 4), ('Up', 1),
176 ('Backward', 2), ('Up', 1), ('Backward', 2)]:
177 result = self.assertCall(occurrences(proto).pop(0), "cursor" + method, (count,))
178 self.assertEqual(occurrences(result), [])
179 self.failIf(occurrences(proto))
181 class ClientControlSequences(unittest.TestCase, MockMixin):
183 self.transport = StringTransport()
185 self.parser = ClientProtocol(lambda: self.proto)
186 self.parser.factory = self
187 self.parser.makeConnection(self.transport)
188 result = self.assertCall(occurrences(self.proto).pop(0), "makeConnection", (self.parser,))
189 self.failIf(occurrences(result))
191 def testSimpleCardinals(self):
192 self.parser.dataReceived(
193 ''.join([''.join(['\x1b[' + str(n) + ch for n in ('', 2, 20, 200)]) for ch in 'BACD']))
194 occs = occurrences(self.proto)
196 for meth in ("Down", "Up", "Forward", "Backward"):
197 for count in (1, 2, 20, 200):
198 result = self.assertCall(occs.pop(0), "cursor" + meth, (count,))
199 self.failIf(occurrences(result))
202 def testScrollRegion(self):
203 self.parser.dataReceived('\x1b[5;22r\x1b[r')
204 occs = occurrences(self.proto)
206 result = self.assertCall(occs.pop(0), "setScrollRegion", (5, 22))
207 self.failIf(occurrences(result))
209 result = self.assertCall(occs.pop(0), "setScrollRegion", (None, None))
210 self.failIf(occurrences(result))
213 def testHeightAndWidth(self):
214 self.parser.dataReceived("\x1b#3\x1b#4\x1b#5\x1b#6")
215 occs = occurrences(self.proto)
217 result = self.assertCall(occs.pop(0), "doubleHeightLine", (True,))
218 self.failIf(occurrences(result))
220 result = self.assertCall(occs.pop(0), "doubleHeightLine", (False,))
221 self.failIf(occurrences(result))
223 result = self.assertCall(occs.pop(0), "singleWidthLine")
224 self.failIf(occurrences(result))
226 result = self.assertCall(occs.pop(0), "doubleWidthLine")
227 self.failIf(occurrences(result))
230 def testCharacterSet(self):
231 self.parser.dataReceived(
232 ''.join([''.join(['\x1b' + g + n for n in 'AB012']) for g in '()']))
233 occs = occurrences(self.proto)
235 for which in (G0, G1):
236 for charset in (CS_UK, CS_US, CS_DRAWING, CS_ALTERNATE, CS_ALTERNATE_SPECIAL):
237 result = self.assertCall(occs.pop(0), "selectCharacterSet", (charset, which))
238 self.failIf(occurrences(result))
241 def testShifting(self):
242 self.parser.dataReceived("\x15\x14")
243 occs = occurrences(self.proto)
245 result = self.assertCall(occs.pop(0), "shiftIn")
246 self.failIf(occurrences(result))
248 result = self.assertCall(occs.pop(0), "shiftOut")
249 self.failIf(occurrences(result))
252 def testSingleShifts(self):
253 self.parser.dataReceived("\x1bN\x1bO")
254 occs = occurrences(self.proto)
256 result = self.assertCall(occs.pop(0), "singleShift2")
257 self.failIf(occurrences(result))
259 result = self.assertCall(occs.pop(0), "singleShift3")
260 self.failIf(occurrences(result))
263 def testKeypadMode(self):
264 self.parser.dataReceived("\x1b=\x1b>")
265 occs = occurrences(self.proto)
267 result = self.assertCall(occs.pop(0), "applicationKeypadMode")
268 self.failIf(occurrences(result))
270 result = self.assertCall(occs.pop(0), "numericKeypadMode")
271 self.failIf(occurrences(result))
274 def testCursor(self):
275 self.parser.dataReceived("\x1b7\x1b8")
276 occs = occurrences(self.proto)
278 result = self.assertCall(occs.pop(0), "saveCursor")
279 self.failIf(occurrences(result))
281 result = self.assertCall(occs.pop(0), "restoreCursor")
282 self.failIf(occurrences(result))
286 self.parser.dataReceived("\x1bc")
287 occs = occurrences(self.proto)
289 result = self.assertCall(occs.pop(0), "reset")
290 self.failIf(occurrences(result))
294 self.parser.dataReceived("\x1bD\x1bM\x1bE")
295 occs = occurrences(self.proto)
297 result = self.assertCall(occs.pop(0), "index")
298 self.failIf(occurrences(result))
300 result = self.assertCall(occs.pop(0), "reverseIndex")
301 self.failIf(occurrences(result))
303 result = self.assertCall(occs.pop(0), "nextLine")
304 self.failIf(occurrences(result))
308 self.parser.dataReceived(
309 "\x1b[" + ';'.join(map(str, [modes.KAM, modes.IRM, modes.LNM])) + "h")
310 self.parser.dataReceived(
311 "\x1b[" + ';'.join(map(str, [modes.KAM, modes.IRM, modes.LNM])) + "l")
312 occs = occurrences(self.proto)
314 result = self.assertCall(occs.pop(0), "setModes", ([modes.KAM, modes.IRM, modes.LNM],))
315 self.failIf(occurrences(result))
317 result = self.assertCall(occs.pop(0), "resetModes", ([modes.KAM, modes.IRM, modes.LNM],))
318 self.failIf(occurrences(result))
321 def testErasure(self):
322 self.parser.dataReceived(
323 "\x1b[K\x1b[1K\x1b[2K\x1b[J\x1b[1J\x1b[2J\x1b[3P")
324 occs = occurrences(self.proto)
326 for meth in ("eraseToLineEnd", "eraseToLineBeginning", "eraseLine",
327 "eraseToDisplayEnd", "eraseToDisplayBeginning",
329 result = self.assertCall(occs.pop(0), meth)
330 self.failIf(occurrences(result))
332 result = self.assertCall(occs.pop(0), "deleteCharacter", (3,))
333 self.failIf(occurrences(result))
336 def testLineDeletion(self):
337 self.parser.dataReceived("\x1b[M\x1b[3M")
338 occs = occurrences(self.proto)
341 result = self.assertCall(occs.pop(0), "deleteLine", (arg,))
342 self.failIf(occurrences(result))
345 def testLineInsertion(self):
346 self.parser.dataReceived("\x1b[L\x1b[3L")
347 occs = occurrences(self.proto)
350 result = self.assertCall(occs.pop(0), "insertLine", (arg,))
351 self.failIf(occurrences(result))
354 def testCursorPosition(self):
355 methods(self.proto)['reportCursorPosition'] = (6, 7)
356 self.parser.dataReceived("\x1b[6n")
357 self.assertEqual(self.transport.value(), "\x1b[7;8R")
358 occs = occurrences(self.proto)
360 result = self.assertCall(occs.pop(0), "reportCursorPosition")
361 # This isn't really an interesting assert, since it only tests that
362 # our mock setup is working right, but I'll include it anyway.
363 self.assertEqual(result, (6, 7))
366 def test_applicationDataBytes(self):
368 Contiguous non-control bytes are passed to a single call to the
369 C{write} method of the terminal to which the L{ClientProtocol} is
372 occs = occurrences(self.proto)
373 self.parser.dataReceived('a')
374 self.assertCall(occs.pop(0), "write", ("a",))
375 self.parser.dataReceived('bc')
376 self.assertCall(occs.pop(0), "write", ("bc",))
379 def _applicationDataTest(self, data, calls):
380 occs = occurrences(self.proto)
381 self.parser.dataReceived(data)
383 self.assertCall(occs.pop(0), *calls.pop(0))
384 self.assertFalse(occs, "No other calls should happen: %r" % (occs,))
387 def test_shiftInAfterApplicationData(self):
389 Application data bytes followed by a shift-in command are passed to a
390 call to C{write} before the terminal's C{shiftIn} method is called.
392 self._applicationDataTest(
398 def test_shiftOutAfterApplicationData(self):
400 Application data bytes followed by a shift-out command are passed to a
401 call to C{write} before the terminal's C{shiftOut} method is called.
403 self._applicationDataTest(
409 def test_cursorBackwardAfterApplicationData(self):
411 Application data bytes followed by a cursor-backward command are passed
412 to a call to C{write} before the terminal's C{cursorBackward} method is
415 self._applicationDataTest(
418 ("cursorBackward",)])
421 def test_escapeAfterApplicationData(self):
423 Application data bytes followed by an escape character are passed to a
424 call to C{write} before the terminal's handler method for the escape is
427 # Test a short escape
428 self._applicationDataTest(
434 self._applicationDataTest(
437 ("setModes", ([4],))])
439 # There's some other cases too, but they're all handled by the same
440 # codepaths as above.
444 class ServerProtocolOutputTests(unittest.TestCase):
446 Tests for the bytes L{ServerProtocol} writes to its transport when its
449 def test_nextLine(self):
451 L{ServerProtocol.nextLine} writes C{"\r\n"} to its transport.
453 # Why doesn't it write ESC E? Because ESC E is poorly supported. For
454 # example, gnome-terminal (many different versions) fails to scroll if
455 # it receives ESC E and the cursor is already on the last row.
456 protocol = ServerProtocol()
457 transport = StringTransport()
458 protocol.makeConnection(transport)
460 self.assertEqual(transport.value(), "\r\n")
464 class Deprecations(unittest.TestCase):
466 Tests to ensure deprecation of L{insults.colors} and L{insults.client}
469 def ensureDeprecated(self, message):
471 Ensures that the correct deprecation warning was issued.
473 warnings = self.flushWarnings()
474 self.assertIdentical(warnings[0]['category'], DeprecationWarning)
475 self.assertEqual(warnings[0]['message'], message)
476 self.assertEqual(len(warnings), 1)
479 def test_colors(self):
481 The L{insults.colors} module is deprecated
483 from twisted.conch.insults import colors
484 self.ensureDeprecated("twisted.conch.insults.colors was deprecated "
485 "in Twisted 10.1.0: Please use "
486 "twisted.conch.insults.helper instead.")
489 def test_client(self):
491 The L{insults.client} module is deprecated
493 from twisted.conch.insults import client
494 self.ensureDeprecated("twisted.conch.insults.client was deprecated "
495 "in Twisted 10.1.0: Please use "
496 "twisted.conch.insults.insults instead.")