Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / python / test / test_util.py
1 # -*- test-case-name: twisted.python.test.test_util
2 # Copyright (c) Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 """
6 Tests for L{twisted.python.util}.
7 """
8
9 import os.path, sys
10 import shutil, errno
11 try:
12     import pwd, grp
13 except ImportError:
14     pwd = grp = None
15
16 from twisted.trial import unittest
17
18 from twisted.python import util
19 from twisted.internet import reactor
20 from twisted.internet.interfaces import IReactorProcess
21 from twisted.internet.protocol import ProcessProtocol
22 from twisted.internet.defer import Deferred
23 from twisted.internet.error import ProcessDone
24
25 from twisted.test.test_process import MockOS
26
27
28
29 class UtilTestCase(unittest.TestCase):
30
31     def testUniq(self):
32         l = ["a", 1, "ab", "a", 3, 4, 1, 2, 2, 4, 6]
33         self.assertEqual(util.uniquify(l), ["a", 1, "ab", 3, 4, 2, 6])
34
35     def testRaises(self):
36         self.failUnless(util.raises(ZeroDivisionError, divmod, 1, 0))
37         self.failIf(util.raises(ZeroDivisionError, divmod, 0, 1))
38
39         try:
40             util.raises(TypeError, divmod, 1, 0)
41         except ZeroDivisionError:
42             pass
43         else:
44             raise unittest.FailTest, "util.raises didn't raise when it should have"
45
46     def testUninterruptably(self):
47         def f(a, b):
48             self.calls += 1
49             exc = self.exceptions.pop()
50             if exc is not None:
51                 raise exc(errno.EINTR, "Interrupted system call!")
52             return a + b
53
54         self.exceptions = [None]
55         self.calls = 0
56         self.assertEqual(util.untilConcludes(f, 1, 2), 3)
57         self.assertEqual(self.calls, 1)
58
59         self.exceptions = [None, OSError, IOError]
60         self.calls = 0
61         self.assertEqual(util.untilConcludes(f, 2, 3), 5)
62         self.assertEqual(self.calls, 3)
63
64     def testNameToLabel(self):
65         """
66         Test the various kinds of inputs L{nameToLabel} supports.
67         """
68         nameData = [
69             ('f', 'F'),
70             ('fo', 'Fo'),
71             ('foo', 'Foo'),
72             ('fooBar', 'Foo Bar'),
73             ('fooBarBaz', 'Foo Bar Baz'),
74             ]
75         for inp, out in nameData:
76             got = util.nameToLabel(inp)
77             self.assertEqual(
78                 got, out,
79                 "nameToLabel(%r) == %r != %r" % (inp, got, out))
80
81
82     def test_uidFromNumericString(self):
83         """
84         When L{uidFromString} is called with a base-ten string representation
85         of an integer, it returns the integer.
86         """
87         self.assertEqual(util.uidFromString("100"), 100)
88
89
90     def test_uidFromUsernameString(self):
91         """
92         When L{uidFromString} is called with a base-ten string representation
93         of an integer, it returns the integer.
94         """
95         pwent = pwd.getpwuid(os.getuid())
96         self.assertEqual(util.uidFromString(pwent.pw_name), pwent.pw_uid)
97     if pwd is None:
98         test_uidFromUsernameString.skip = (
99             "Username/UID conversion requires the pwd module.")
100
101
102     def test_gidFromNumericString(self):
103         """
104         When L{gidFromString} is called with a base-ten string representation
105         of an integer, it returns the integer.
106         """
107         self.assertEqual(util.gidFromString("100"), 100)
108
109
110     def test_gidFromGroupnameString(self):
111         """
112         When L{gidFromString} is called with a base-ten string representation
113         of an integer, it returns the integer.
114         """
115         grent = grp.getgrgid(os.getgid())
116         self.assertEqual(util.gidFromString(grent.gr_name), grent.gr_gid)
117     if grp is None:
118         test_gidFromGroupnameString.skip = (
119             "Group Name/GID conversion requires the grp module.")
120
121
122
123 class SwitchUIDTest(unittest.TestCase):
124     """
125     Tests for L{util.switchUID}.
126     """
127
128     if getattr(os, "getuid", None) is None:
129         skip = "getuid/setuid not available"
130
131
132     def setUp(self):
133         self.mockos = MockOS()
134         self.patch(util, "os", self.mockos)
135         self.patch(util, "initgroups", self.initgroups)
136         self.initgroupsCalls = []
137
138
139     def initgroups(self, uid, gid):
140         """
141         Save L{util.initgroups} calls in C{self.initgroupsCalls}.
142         """
143         self.initgroupsCalls.append((uid, gid))
144
145
146     def test_uid(self):
147         """
148         L{util.switchUID} calls L{util.initgroups} and then C{os.setuid} with
149         the given uid.
150         """
151         util.switchUID(12000, None)
152         self.assertEqual(self.initgroupsCalls, [(12000, None)])
153         self.assertEqual(self.mockos.actions, [("setuid", 12000)])
154
155
156     def test_euid(self):
157         """
158         L{util.switchUID} calls L{util.initgroups} and then C{os.seteuid} with
159         the given uid if the C{euid} parameter is set to C{True}.
160         """
161         util.switchUID(12000, None, True)
162         self.assertEqual(self.initgroupsCalls, [(12000, None)])
163         self.assertEqual(self.mockos.seteuidCalls, [12000])
164
165
166     def test_currentUID(self):
167         """
168         If the current uid is the same as the uid passed to L{util.switchUID},
169         then initgroups does not get called, but a warning is issued.
170         """
171         uid = self.mockos.getuid()
172         util.switchUID(uid, None)
173         self.assertEqual(self.initgroupsCalls, [])
174         self.assertEqual(self.mockos.actions, [])
175         warnings = self.flushWarnings([util.switchUID])
176         self.assertEqual(len(warnings), 1)
177         self.assertIn('tried to drop privileges and setuid %i' % uid, 
178                       warnings[0]['message'])
179         self.assertIn('but uid is already %i' % uid, warnings[0]['message'])
180
181
182     def test_currentEUID(self):
183         """
184         If the current euid is the same as the euid passed to L{util.switchUID},
185         then initgroups does not get called, but a warning is issued.
186         """
187         euid = self.mockos.geteuid()
188         util.switchUID(euid, None, True)
189         self.assertEqual(self.initgroupsCalls, [])
190         self.assertEqual(self.mockos.seteuidCalls, [])
191         warnings = self.flushWarnings([util.switchUID])
192         self.assertEqual(len(warnings), 1)
193         self.assertIn('tried to drop privileges and seteuid %i' % euid, 
194                       warnings[0]['message'])
195         self.assertIn('but euid is already %i' % euid, warnings[0]['message'])
196
197
198
199 class TestMergeFunctionMetadata(unittest.TestCase):
200     """
201     Tests for L{mergeFunctionMetadata}.
202     """
203
204     def test_mergedFunctionBehavesLikeMergeTarget(self):
205         """
206         After merging C{foo}'s data into C{bar}, the returned function behaves
207         as if it is C{bar}.
208         """
209         foo_object = object()
210         bar_object = object()
211
212         def foo():
213             return foo_object
214
215         def bar(x, y, (a, b), c=10, *d, **e):
216             return bar_object
217
218         baz = util.mergeFunctionMetadata(foo, bar)
219         self.assertIdentical(baz(1, 2, (3, 4), quux=10), bar_object)
220
221
222     def test_moduleIsMerged(self):
223         """
224         Merging C{foo} into C{bar} returns a function with C{foo}'s
225         C{__module__}.
226         """
227         def foo():
228             pass
229
230         def bar():
231             pass
232         bar.__module__ = 'somewhere.else'
233
234         baz = util.mergeFunctionMetadata(foo, bar)
235         self.assertEqual(baz.__module__, foo.__module__)
236
237
238     def test_docstringIsMerged(self):
239         """
240         Merging C{foo} into C{bar} returns a function with C{foo}'s docstring.
241         """
242
243         def foo():
244             """
245             This is foo.
246             """
247
248         def bar():
249             """
250             This is bar.
251             """
252
253         baz = util.mergeFunctionMetadata(foo, bar)
254         self.assertEqual(baz.__doc__, foo.__doc__)
255
256
257     def test_nameIsMerged(self):
258         """
259         Merging C{foo} into C{bar} returns a function with C{foo}'s name.
260         """
261
262         def foo():
263             pass
264
265         def bar():
266             pass
267
268         baz = util.mergeFunctionMetadata(foo, bar)
269         self.assertEqual(baz.__name__, foo.__name__)
270
271
272     def test_instanceDictionaryIsMerged(self):
273         """
274         Merging C{foo} into C{bar} returns a function with C{bar}'s
275         dictionary, updated by C{foo}'s.
276         """
277
278         def foo():
279             pass
280         foo.a = 1
281         foo.b = 2
282
283         def bar():
284             pass
285         bar.b = 3
286         bar.c = 4
287
288         baz = util.mergeFunctionMetadata(foo, bar)
289         self.assertEqual(foo.a, baz.a)
290         self.assertEqual(foo.b, baz.b)
291         self.assertEqual(bar.c, baz.c)
292
293
294
295 class OrderedDictTest(unittest.TestCase):
296     def testOrderedDict(self):
297         d = util.OrderedDict()
298         d['a'] = 'b'
299         d['b'] = 'a'
300         d[3] = 12
301         d[1234] = 4321
302         self.assertEqual(repr(d), "{'a': 'b', 'b': 'a', 3: 12, 1234: 4321}")
303         self.assertEqual(d.values(), ['b', 'a', 12, 4321])
304         del d[3]
305         self.assertEqual(repr(d), "{'a': 'b', 'b': 'a', 1234: 4321}")
306         self.assertEqual(d, {'a': 'b', 'b': 'a', 1234:4321})
307         self.assertEqual(d.keys(), ['a', 'b', 1234])
308         self.assertEqual(list(d.iteritems()),
309                           [('a', 'b'), ('b','a'), (1234, 4321)])
310         item = d.popitem()
311         self.assertEqual(item, (1234, 4321))
312
313     def testInitialization(self):
314         d = util.OrderedDict({'monkey': 'ook',
315                               'apple': 'red'})
316         self.failUnless(d._order)
317
318         d = util.OrderedDict(((1,1),(3,3),(2,2),(0,0)))
319         self.assertEqual(repr(d), "{1: 1, 3: 3, 2: 2, 0: 0}")
320
321 class InsensitiveDictTest(unittest.TestCase):
322     def testPreserve(self):
323         InsensitiveDict=util.InsensitiveDict
324         dct=InsensitiveDict({'Foo':'bar', 1:2, 'fnz':{1:2}}, preserve=1)
325         self.assertEqual(dct['fnz'], {1:2})
326         self.assertEqual(dct['foo'], 'bar')
327         self.assertEqual(dct.copy(), dct)
328         self.assertEqual(dct['foo'], dct.get('Foo'))
329         assert 1 in dct and 'foo' in dct
330         self.assertEqual(eval(repr(dct)), dct)
331         keys=['Foo', 'fnz', 1]
332         for x in keys:
333             assert x in dct.keys()
334             assert (x, dct[x]) in dct.items()
335         self.assertEqual(len(keys), len(dct))
336         del dct[1]
337         del dct['foo']
338
339     def testNoPreserve(self):
340         InsensitiveDict=util.InsensitiveDict
341         dct=InsensitiveDict({'Foo':'bar', 1:2, 'fnz':{1:2}}, preserve=0)
342         keys=['foo', 'fnz', 1]
343         for x in keys:
344             assert x in dct.keys()
345             assert (x, dct[x]) in dct.items()
346         self.assertEqual(len(keys), len(dct))
347         del dct[1]
348         del dct['foo']
349
350
351
352
353 class PasswordTestingProcessProtocol(ProcessProtocol):
354     """
355     Write the string C{"secret\n"} to a subprocess and then collect all of
356     its output and fire a Deferred with it when the process ends.
357     """
358     def connectionMade(self):
359         self.output = []
360         self.transport.write('secret\n')
361
362     def childDataReceived(self, fd, output):
363         self.output.append((fd, output))
364
365     def processEnded(self, reason):
366         self.finished.callback((reason, self.output))
367
368
369 class GetPasswordTest(unittest.TestCase):
370     if not IReactorProcess.providedBy(reactor):
371         skip = "Process support required to test getPassword"
372
373     def test_stdin(self):
374         """
375         Making sure getPassword accepts a password from standard input by
376         running a child process which uses getPassword to read in a string
377         which it then writes it out again.  Write a string to the child
378         process and then read one and make sure it is the right string.
379         """
380         p = PasswordTestingProcessProtocol()
381         p.finished = Deferred()
382         reactor.spawnProcess(
383             p,
384             sys.executable,
385             [sys.executable,
386              '-c',
387              ('import sys\n'
388              'from twisted.python.util import getPassword\n'
389               'sys.stdout.write(getPassword())\n'
390               'sys.stdout.flush()\n')],
391             env={'PYTHONPATH': os.pathsep.join(sys.path)})
392
393         def processFinished((reason, output)):
394             reason.trap(ProcessDone)
395             self.assertIn((1, 'secret'), output)
396
397         return p.finished.addCallback(processFinished)
398
399
400
401 class SearchUpwardsTest(unittest.TestCase):
402     def testSearchupwards(self):
403         os.makedirs('searchupwards/a/b/c')
404         file('searchupwards/foo.txt', 'w').close()
405         file('searchupwards/a/foo.txt', 'w').close()
406         file('searchupwards/a/b/c/foo.txt', 'w').close()
407         os.mkdir('searchupwards/bar')
408         os.mkdir('searchupwards/bam')
409         os.mkdir('searchupwards/a/bar')
410         os.mkdir('searchupwards/a/b/bam')
411         actual=util.searchupwards('searchupwards/a/b/c',
412                                   files=['foo.txt'],
413                                   dirs=['bar', 'bam'])
414         expected=os.path.abspath('searchupwards') + os.sep
415         self.assertEqual(actual, expected)
416         shutil.rmtree('searchupwards')
417         actual=util.searchupwards('searchupwards/a/b/c',
418                                   files=['foo.txt'],
419                                   dirs=['bar', 'bam'])
420         expected=None
421         self.assertEqual(actual, expected)
422
423
424
425 class IntervalDifferentialTestCase(unittest.TestCase):
426     def testDefault(self):
427         d = iter(util.IntervalDifferential([], 10))
428         for i in range(100):
429             self.assertEqual(d.next(), (10, None))
430
431     def testSingle(self):
432         d = iter(util.IntervalDifferential([5], 10))
433         for i in range(100):
434             self.assertEqual(d.next(), (5, 0))
435
436     def testPair(self):
437         d = iter(util.IntervalDifferential([5, 7], 10))
438         for i in range(100):
439             self.assertEqual(d.next(), (5, 0))
440             self.assertEqual(d.next(), (2, 1))
441             self.assertEqual(d.next(), (3, 0))
442             self.assertEqual(d.next(), (4, 1))
443             self.assertEqual(d.next(), (1, 0))
444             self.assertEqual(d.next(), (5, 0))
445             self.assertEqual(d.next(), (1, 1))
446             self.assertEqual(d.next(), (4, 0))
447             self.assertEqual(d.next(), (3, 1))
448             self.assertEqual(d.next(), (2, 0))
449             self.assertEqual(d.next(), (5, 0))
450             self.assertEqual(d.next(), (0, 1))
451
452     def testTriple(self):
453         d = iter(util.IntervalDifferential([2, 4, 5], 10))
454         for i in range(100):
455             self.assertEqual(d.next(), (2, 0))
456             self.assertEqual(d.next(), (2, 0))
457             self.assertEqual(d.next(), (0, 1))
458             self.assertEqual(d.next(), (1, 2))
459             self.assertEqual(d.next(), (1, 0))
460             self.assertEqual(d.next(), (2, 0))
461             self.assertEqual(d.next(), (0, 1))
462             self.assertEqual(d.next(), (2, 0))
463             self.assertEqual(d.next(), (0, 2))
464             self.assertEqual(d.next(), (2, 0))
465             self.assertEqual(d.next(), (0, 1))
466             self.assertEqual(d.next(), (2, 0))
467             self.assertEqual(d.next(), (1, 2))
468             self.assertEqual(d.next(), (1, 0))
469             self.assertEqual(d.next(), (0, 1))
470             self.assertEqual(d.next(), (2, 0))
471             self.assertEqual(d.next(), (2, 0))
472             self.assertEqual(d.next(), (0, 1))
473             self.assertEqual(d.next(), (0, 2))
474
475     def testInsert(self):
476         d = iter(util.IntervalDifferential([], 10))
477         self.assertEqual(d.next(), (10, None))
478         d.addInterval(3)
479         self.assertEqual(d.next(), (3, 0))
480         self.assertEqual(d.next(), (3, 0))
481         d.addInterval(6)
482         self.assertEqual(d.next(), (3, 0))
483         self.assertEqual(d.next(), (3, 0))
484         self.assertEqual(d.next(), (0, 1))
485         self.assertEqual(d.next(), (3, 0))
486         self.assertEqual(d.next(), (3, 0))
487         self.assertEqual(d.next(), (0, 1))
488
489     def testRemove(self):
490         d = iter(util.IntervalDifferential([3, 5], 10))
491         self.assertEqual(d.next(), (3, 0))
492         self.assertEqual(d.next(), (2, 1))
493         self.assertEqual(d.next(), (1, 0))
494         d.removeInterval(3)
495         self.assertEqual(d.next(), (4, 0))
496         self.assertEqual(d.next(), (5, 0))
497         d.removeInterval(5)
498         self.assertEqual(d.next(), (10, None))
499         self.assertRaises(ValueError, d.removeInterval, 10)
500
501
502
503 class Record(util.FancyEqMixin):
504     """
505     Trivial user of L{FancyEqMixin} used by tests.
506     """
507     compareAttributes = ('a', 'b')
508
509     def __init__(self, a, b):
510         self.a = a
511         self.b = b
512
513
514
515 class DifferentRecord(util.FancyEqMixin):
516     """
517     Trivial user of L{FancyEqMixin} which is not related to L{Record}.
518     """
519     compareAttributes = ('a', 'b')
520
521     def __init__(self, a, b):
522         self.a = a
523         self.b = b
524
525
526
527 class DerivedRecord(Record):
528     """
529     A class with an inheritance relationship to L{Record}.
530     """
531
532
533
534 class EqualToEverything(object):
535     """
536     A class the instances of which consider themselves equal to everything.
537     """
538     def __eq__(self, other):
539         return True
540
541
542     def __ne__(self, other):
543         return False
544
545
546
547 class EqualToNothing(object):
548     """
549     A class the instances of which consider themselves equal to nothing.
550     """
551     def __eq__(self, other):
552         return False
553
554
555     def __ne__(self, other):
556         return True
557
558
559
560 class EqualityTests(unittest.TestCase):
561     """
562     Tests for L{FancyEqMixin}.
563     """
564     def test_identity(self):
565         """
566         Instances of a class which mixes in L{FancyEqMixin} but which
567         defines no comparison attributes compare by identity.
568         """
569         class Empty(util.FancyEqMixin):
570             pass
571
572         self.assertFalse(Empty() == Empty())
573         self.assertTrue(Empty() != Empty())
574         empty = Empty()
575         self.assertTrue(empty == empty)
576         self.assertFalse(empty != empty)
577
578
579     def test_equality(self):
580         """
581         Instances of a class which mixes in L{FancyEqMixin} should compare
582         equal if all of their attributes compare equal.  They should not
583         compare equal if any of their attributes do not compare equal.
584         """
585         self.assertTrue(Record(1, 2) == Record(1, 2))
586         self.assertFalse(Record(1, 2) == Record(1, 3))
587         self.assertFalse(Record(1, 2) == Record(2, 2))
588         self.assertFalse(Record(1, 2) == Record(3, 4))
589
590
591     def test_unequality(self):
592         """
593         Unequality between instances of a particular L{record} should be
594         defined as the negation of equality.
595         """
596         self.assertFalse(Record(1, 2) != Record(1, 2))
597         self.assertTrue(Record(1, 2) != Record(1, 3))
598         self.assertTrue(Record(1, 2) != Record(2, 2))
599         self.assertTrue(Record(1, 2) != Record(3, 4))
600
601
602     def test_differentClassesEquality(self):
603         """
604         Instances of different classes which mix in L{FancyEqMixin} should not
605         compare equal.
606         """
607         self.assertFalse(Record(1, 2) == DifferentRecord(1, 2))
608
609
610     def test_differentClassesInequality(self):
611         """
612         Instances of different classes which mix in L{FancyEqMixin} should
613         compare unequal.
614         """
615         self.assertTrue(Record(1, 2) != DifferentRecord(1, 2))
616
617
618     def test_inheritedClassesEquality(self):
619         """
620         An instance of a class which derives from a class which mixes in
621         L{FancyEqMixin} should compare equal to an instance of the base class
622         if and only if all of their attributes compare equal.
623         """
624         self.assertTrue(Record(1, 2) == DerivedRecord(1, 2))
625         self.assertFalse(Record(1, 2) == DerivedRecord(1, 3))
626         self.assertFalse(Record(1, 2) == DerivedRecord(2, 2))
627         self.assertFalse(Record(1, 2) == DerivedRecord(3, 4))
628
629
630     def test_inheritedClassesInequality(self):
631         """
632         An instance of a class which derives from a class which mixes in
633         L{FancyEqMixin} should compare unequal to an instance of the base
634         class if any of their attributes compare unequal.
635         """
636         self.assertFalse(Record(1, 2) != DerivedRecord(1, 2))
637         self.assertTrue(Record(1, 2) != DerivedRecord(1, 3))
638         self.assertTrue(Record(1, 2) != DerivedRecord(2, 2))
639         self.assertTrue(Record(1, 2) != DerivedRecord(3, 4))
640
641
642     def test_rightHandArgumentImplementsEquality(self):
643         """
644         The right-hand argument to the equality operator is given a chance
645         to determine the result of the operation if it is of a type
646         unrelated to the L{FancyEqMixin}-based instance on the left-hand
647         side.
648         """
649         self.assertTrue(Record(1, 2) == EqualToEverything())
650         self.assertFalse(Record(1, 2) == EqualToNothing())
651
652
653     def test_rightHandArgumentImplementsUnequality(self):
654         """
655         The right-hand argument to the non-equality operator is given a
656         chance to determine the result of the operation if it is of a type
657         unrelated to the L{FancyEqMixin}-based instance on the left-hand
658         side.
659         """
660         self.assertFalse(Record(1, 2) != EqualToEverything())
661         self.assertTrue(Record(1, 2) != EqualToNothing())
662
663
664
665 class RunAsEffectiveUserTests(unittest.TestCase):
666     """
667     Test for the L{util.runAsEffectiveUser} function.
668     """
669
670     if getattr(os, "geteuid", None) is None:
671         skip = "geteuid/seteuid not available"
672
673     def setUp(self):
674         self.mockos = MockOS()
675         self.patch(os, "geteuid", self.mockos.geteuid)
676         self.patch(os, "getegid", self.mockos.getegid)
677         self.patch(os, "seteuid", self.mockos.seteuid)
678         self.patch(os, "setegid", self.mockos.setegid)
679
680
681     def _securedFunction(self, startUID, startGID, wantUID, wantGID):
682         """
683         Check if wanted UID/GID matched start or saved ones.
684         """
685         self.assertTrue(wantUID == startUID or
686                         wantUID == self.mockos.seteuidCalls[-1])
687         self.assertTrue(wantGID == startGID or
688                         wantGID == self.mockos.setegidCalls[-1])
689
690
691     def test_forwardResult(self):
692         """
693         L{util.runAsEffectiveUser} forwards the result obtained by calling the
694         given function
695         """
696         result = util.runAsEffectiveUser(0, 0, lambda: 1)
697         self.assertEqual(result, 1)
698
699
700     def test_takeParameters(self):
701         """
702         L{util.runAsEffectiveUser} pass the given parameters to the given
703         function.
704         """
705         result = util.runAsEffectiveUser(0, 0, lambda x: 2*x, 3)
706         self.assertEqual(result, 6)
707
708
709     def test_takesKeyworkArguments(self):
710         """
711         L{util.runAsEffectiveUser} pass the keyword parameters to the given
712         function.
713         """
714         result = util.runAsEffectiveUser(0, 0, lambda x, y=1, z=1: x*y*z, 2, z=3)
715         self.assertEqual(result, 6)
716
717
718     def _testUIDGIDSwitch(self, startUID, startGID, wantUID, wantGID,
719                           expectedUIDSwitches, expectedGIDSwitches):
720         """
721         Helper method checking the calls to C{os.seteuid} and C{os.setegid}
722         made by L{util.runAsEffectiveUser}, when switching from startUID to
723         wantUID and from startGID to wantGID.
724         """
725         self.mockos.euid = startUID
726         self.mockos.egid = startGID
727         util.runAsEffectiveUser(
728             wantUID, wantGID,
729             self._securedFunction, startUID, startGID, wantUID, wantGID)
730         self.assertEqual(self.mockos.seteuidCalls, expectedUIDSwitches)
731         self.assertEqual(self.mockos.setegidCalls, expectedGIDSwitches)
732         self.mockos.seteuidCalls = []
733         self.mockos.setegidCalls = []
734
735
736     def test_root(self):
737         """
738         Check UID/GID switches when current effective UID is root.
739         """
740         self._testUIDGIDSwitch(0, 0, 0, 0, [], [])
741         self._testUIDGIDSwitch(0, 0, 1, 0, [1, 0], [])
742         self._testUIDGIDSwitch(0, 0, 0, 1, [], [1, 0])
743         self._testUIDGIDSwitch(0, 0, 1, 1, [1, 0], [1, 0])
744
745
746     def test_UID(self):
747         """
748         Check UID/GID switches when current effective UID is non-root.
749         """
750         self._testUIDGIDSwitch(1, 0, 0, 0, [0, 1], [])
751         self._testUIDGIDSwitch(1, 0, 1, 0, [], [])
752         self._testUIDGIDSwitch(1, 0, 1, 1, [0, 1, 0, 1], [1, 0])
753         self._testUIDGIDSwitch(1, 0, 2, 1, [0, 2, 0, 1], [1, 0])
754
755
756     def test_GID(self):
757         """
758         Check UID/GID switches when current effective GID is non-root.
759         """
760         self._testUIDGIDSwitch(0, 1, 0, 0, [], [0, 1])
761         self._testUIDGIDSwitch(0, 1, 0, 1, [], [])
762         self._testUIDGIDSwitch(0, 1, 1, 1, [1, 0], [])
763         self._testUIDGIDSwitch(0, 1, 1, 2, [1, 0], [2, 1])
764
765
766     def test_UIDGID(self):
767         """
768         Check UID/GID switches when current effective UID/GID is non-root.
769         """
770         self._testUIDGIDSwitch(1, 1, 0, 0, [0, 1], [0, 1])
771         self._testUIDGIDSwitch(1, 1, 0, 1, [0, 1], [])
772         self._testUIDGIDSwitch(1, 1, 1, 0, [0, 1, 0, 1], [0, 1])
773         self._testUIDGIDSwitch(1, 1, 1, 1, [], [])
774         self._testUIDGIDSwitch(1, 1, 2, 1, [0, 2, 0, 1], [])
775         self._testUIDGIDSwitch(1, 1, 1, 2, [0, 1, 0, 1], [2, 1])
776         self._testUIDGIDSwitch(1, 1, 2, 2, [0, 2, 0, 1], [2, 1])
777
778
779
780 class UnsignedIDTests(unittest.TestCase):
781     """
782     Tests for L{util.unsignedID} and L{util.setIDFunction}.
783     """
784     def setUp(self):
785         """
786         Save the value of L{util._idFunction} and arrange for it to be restored
787         after the test runs.
788         """
789         self.addCleanup(setattr, util, '_idFunction', util._idFunction)
790
791
792     def test_setIDFunction(self):
793         """
794         L{util.setIDFunction} returns the last value passed to it.
795         """
796         value = object()
797         previous = util.setIDFunction(value)
798         result = util.setIDFunction(previous)
799         self.assertIdentical(value, result)
800
801
802     def test_unsignedID(self):
803         """
804         L{util.unsignedID} uses the function passed to L{util.setIDFunction} to
805         determine the unique integer id of an object and then adjusts it to be
806         positive if necessary.
807         """
808         foo = object()
809         bar = object()
810
811         # A fake object identity mapping
812         objects = {foo: 17, bar: -73}
813         def fakeId(obj):
814             return objects[obj]
815
816         util.setIDFunction(fakeId)
817
818         self.assertEqual(util.unsignedID(foo), 17)
819         self.assertEqual(util.unsignedID(bar), (sys.maxint + 1) * 2 - 73)
820
821
822     def test_defaultIDFunction(self):
823         """
824         L{util.unsignedID} uses the built in L{id} by default.
825         """
826         obj = object()
827         idValue = id(obj)
828         if idValue < 0:
829             idValue += (sys.maxint + 1) * 2
830
831         self.assertEqual(util.unsignedID(obj), idValue)
832
833
834
835 class InitGroupsTests(unittest.TestCase):
836     """
837     Tests for L{util.initgroups}.
838     """
839
840     if pwd is None:
841         skip = "pwd not available"
842
843
844     def setUp(self):
845         self.addCleanup(setattr, util, "_c_initgroups", util._c_initgroups)
846         self.addCleanup(setattr, util, "setgroups", util.setgroups)
847
848
849     def test_initgroupsForceC(self):
850         """
851         If we fake the presence of the C extension, it's called instead of the
852         Python implementation.
853         """
854         calls = []
855         util._c_initgroups = lambda x, y: calls.append((x, y))
856         setgroupsCalls = []
857         util.setgroups = calls.append
858
859         util.initgroups(os.getuid(), 4)
860         self.assertEqual(calls, [(pwd.getpwuid(os.getuid())[0], 4)])
861         self.assertFalse(setgroupsCalls)
862
863
864     def test_initgroupsForcePython(self):
865         """
866         If we fake the absence of the C extension, the Python implementation is
867         called instead, calling C{os.setgroups}.
868         """
869         util._c_initgroups = None
870         calls = []
871         util.setgroups = calls.append
872         util.initgroups(os.getuid(), os.getgid())
873         # Something should be in the calls, we don't really care what
874         self.assertTrue(calls)
875
876
877     def test_initgroupsInC(self):
878         """
879         If the C extension is present, it's called instead of the Python
880         version.  We check that by making sure C{os.setgroups} is not called.
881         """
882         calls = []
883         util.setgroups = calls.append
884         try:
885             util.initgroups(os.getuid(), os.getgid())
886         except OSError:
887             pass
888         self.assertFalse(calls)
889
890
891     if util._c_initgroups is None:
892         test_initgroupsInC.skip = "C initgroups not available"