tizen 2.3.1 release
[external/protobuf.git] / python / mox.py
1 #!/usr/bin/python2.4
2 #
3 # Copyright 2008 Google Inc.
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #      http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 # This file is used for testing.  The original is at:
18 #   http://code.google.com/p/pymox/
19
20 """Mox, an object-mocking framework for Python.
21
22 Mox works in the record-replay-verify paradigm.  When you first create
23 a mock object, it is in record mode.  You then programmatically set
24 the expected behavior of the mock object (what methods are to be
25 called on it, with what parameters, what they should return, and in
26 what order).
27
28 Once you have set up the expected mock behavior, you put it in replay
29 mode.  Now the mock responds to method calls just as you told it to.
30 If an unexpected method (or an expected method with unexpected
31 parameters) is called, then an exception will be raised.
32
33 Once you are done interacting with the mock, you need to verify that
34 all the expected interactions occured.  (Maybe your code exited
35 prematurely without calling some cleanup method!)  The verify phase
36 ensures that every expected method was called; otherwise, an exception
37 will be raised.
38
39 Suggested usage / workflow:
40
41   # Create Mox factory
42   my_mox = Mox()
43
44   # Create a mock data access object
45   mock_dao = my_mox.CreateMock(DAOClass)
46
47   # Set up expected behavior
48   mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
49   mock_dao.DeletePerson(person)
50
51   # Put mocks in replay mode
52   my_mox.ReplayAll()
53
54   # Inject mock object and run test
55   controller.SetDao(mock_dao)
56   controller.DeletePersonById('1')
57
58   # Verify all methods were called as expected
59   my_mox.VerifyAll()
60 """
61
62 from collections import deque
63 import re
64 import types
65 import unittest
66
67 import stubout
68
69 class Error(AssertionError):
70   """Base exception for this module."""
71
72   pass
73
74
75 class ExpectedMethodCallsError(Error):
76   """Raised when Verify() is called before all expected methods have been called
77   """
78
79   def __init__(self, expected_methods):
80     """Init exception.
81
82     Args:
83       # expected_methods: A sequence of MockMethod objects that should have been
84       #   called.
85       expected_methods: [MockMethod]
86
87     Raises:
88       ValueError: if expected_methods contains no methods.
89     """
90
91     if not expected_methods:
92       raise ValueError("There must be at least one expected method")
93     Error.__init__(self)
94     self._expected_methods = expected_methods
95
96   def __str__(self):
97     calls = "\n".join(["%3d.  %s" % (i, m)
98                        for i, m in enumerate(self._expected_methods)])
99     return "Verify: Expected methods never called:\n%s" % (calls,)
100
101
102 class UnexpectedMethodCallError(Error):
103   """Raised when an unexpected method is called.
104
105   This can occur if a method is called with incorrect parameters, or out of the
106   specified order.
107   """
108
109   def __init__(self, unexpected_method, expected):
110     """Init exception.
111
112     Args:
113       # unexpected_method: MockMethod that was called but was not at the head of
114       #   the expected_method queue.
115       # expected: MockMethod or UnorderedGroup the method should have
116       #   been in.
117       unexpected_method: MockMethod
118       expected: MockMethod or UnorderedGroup
119     """
120
121     Error.__init__(self)
122     self._unexpected_method = unexpected_method
123     self._expected = expected
124
125   def __str__(self):
126     return "Unexpected method call: %s.  Expecting: %s" % \
127       (self._unexpected_method, self._expected)
128
129
130 class UnknownMethodCallError(Error):
131   """Raised if an unknown method is requested of the mock object."""
132
133   def __init__(self, unknown_method_name):
134     """Init exception.
135
136     Args:
137       # unknown_method_name: Method call that is not part of the mocked class's
138       #   public interface.
139       unknown_method_name: str
140     """
141
142     Error.__init__(self)
143     self._unknown_method_name = unknown_method_name
144
145   def __str__(self):
146     return "Method called is not a member of the object: %s" % \
147       self._unknown_method_name
148
149
150 class Mox(object):
151   """Mox: a factory for creating mock objects."""
152
153   # A list of types that should be stubbed out with MockObjects (as
154   # opposed to MockAnythings).
155   _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
156                       types.ObjectType, types.TypeType]
157
158   def __init__(self):
159     """Initialize a new Mox."""
160
161     self._mock_objects = []
162     self.stubs = stubout.StubOutForTesting()
163
164   def CreateMock(self, class_to_mock):
165     """Create a new mock object.
166
167     Args:
168       # class_to_mock: the class to be mocked
169       class_to_mock: class
170
171     Returns:
172       MockObject that can be used as the class_to_mock would be.
173     """
174
175     new_mock = MockObject(class_to_mock)
176     self._mock_objects.append(new_mock)
177     return new_mock
178
179   def CreateMockAnything(self):
180     """Create a mock that will accept any method calls.
181
182     This does not enforce an interface.
183     """
184
185     new_mock = MockAnything()
186     self._mock_objects.append(new_mock)
187     return new_mock
188
189   def ReplayAll(self):
190     """Set all mock objects to replay mode."""
191
192     for mock_obj in self._mock_objects:
193       mock_obj._Replay()
194
195
196   def VerifyAll(self):
197     """Call verify on all mock objects created."""
198
199     for mock_obj in self._mock_objects:
200       mock_obj._Verify()
201
202   def ResetAll(self):
203     """Call reset on all mock objects.  This does not unset stubs."""
204
205     for mock_obj in self._mock_objects:
206       mock_obj._Reset()
207
208   def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
209     """Replace a method, attribute, etc. with a Mock.
210
211     This will replace a class or module with a MockObject, and everything else
212     (method, function, etc) with a MockAnything.  This can be overridden to
213     always use a MockAnything by setting use_mock_anything to True.
214
215     Args:
216       obj: A Python object (class, module, instance, callable).
217       attr_name: str.  The name of the attribute to replace with a mock.
218       use_mock_anything: bool. True if a MockAnything should be used regardless
219         of the type of attribute.
220     """
221
222     attr_to_replace = getattr(obj, attr_name)
223     if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
224       stub = self.CreateMock(attr_to_replace)
225     else:
226       stub = self.CreateMockAnything()
227
228     self.stubs.Set(obj, attr_name, stub)
229
230   def UnsetStubs(self):
231     """Restore stubs to their original state."""
232
233     self.stubs.UnsetAll()
234
235 def Replay(*args):
236   """Put mocks into Replay mode.
237
238   Args:
239     # args is any number of mocks to put into replay mode.
240   """
241
242   for mock in args:
243     mock._Replay()
244
245
246 def Verify(*args):
247   """Verify mocks.
248
249   Args:
250     # args is any number of mocks to be verified.
251   """
252
253   for mock in args:
254     mock._Verify()
255
256
257 def Reset(*args):
258   """Reset mocks.
259
260   Args:
261     # args is any number of mocks to be reset.
262   """
263
264   for mock in args:
265     mock._Reset()
266
267
268 class MockAnything:
269   """A mock that can be used to mock anything.
270
271   This is helpful for mocking classes that do not provide a public interface.
272   """
273
274   def __init__(self):
275     """ """
276     self._Reset()
277
278   def __getattr__(self, method_name):
279     """Intercept method calls on this object.
280
281      A new MockMethod is returned that is aware of the MockAnything's
282      state (record or replay).  The call will be recorded or replayed
283      by the MockMethod's __call__.
284
285     Args:
286       # method name: the name of the method being called.
287       method_name: str
288
289     Returns:
290       A new MockMethod aware of MockAnything's state (record or replay).
291     """
292
293     return self._CreateMockMethod(method_name)
294
295   def _CreateMockMethod(self, method_name):
296     """Create a new mock method call and return it.
297
298     Args:
299       # method name: the name of the method being called.
300       method_name: str
301
302     Returns:
303       A new MockMethod aware of MockAnything's state (record or replay).
304     """
305
306     return MockMethod(method_name, self._expected_calls_queue,
307                       self._replay_mode)
308
309   def __nonzero__(self):
310     """Return 1 for nonzero so the mock can be used as a conditional."""
311
312     return 1
313
314   def __eq__(self, rhs):
315     """Provide custom logic to compare objects."""
316
317     return (isinstance(rhs, MockAnything) and
318             self._replay_mode == rhs._replay_mode and
319             self._expected_calls_queue == rhs._expected_calls_queue)
320
321   def __ne__(self, rhs):
322     """Provide custom logic to compare objects."""
323
324     return not self == rhs
325
326   def _Replay(self):
327     """Start replaying expected method calls."""
328
329     self._replay_mode = True
330
331   def _Verify(self):
332     """Verify that all of the expected calls have been made.
333
334     Raises:
335       ExpectedMethodCallsError: if there are still more method calls in the
336         expected queue.
337     """
338
339     # If the list of expected calls is not empty, raise an exception
340     if self._expected_calls_queue:
341       # The last MultipleTimesGroup is not popped from the queue.
342       if (len(self._expected_calls_queue) == 1 and
343           isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
344           self._expected_calls_queue[0].IsSatisfied()):
345         pass
346       else:
347         raise ExpectedMethodCallsError(self._expected_calls_queue)
348
349   def _Reset(self):
350     """Reset the state of this mock to record mode with an empty queue."""
351
352     # Maintain a list of method calls we are expecting
353     self._expected_calls_queue = deque()
354
355     # Make sure we are in setup mode, not replay mode
356     self._replay_mode = False
357
358
359 class MockObject(MockAnything, object):
360   """A mock object that simulates the public/protected interface of a class."""
361
362   def __init__(self, class_to_mock):
363     """Initialize a mock object.
364
365     This determines the methods and properties of the class and stores them.
366
367     Args:
368       # class_to_mock: class to be mocked
369       class_to_mock: class
370     """
371
372     # This is used to hack around the mixin/inheritance of MockAnything, which
373     # is not a proper object (it can be anything. :-)
374     MockAnything.__dict__['__init__'](self)
375
376     # Get a list of all the public and special methods we should mock.
377     self._known_methods = set()
378     self._known_vars = set()
379     self._class_to_mock = class_to_mock
380     for method in dir(class_to_mock):
381       if callable(getattr(class_to_mock, method)):
382         self._known_methods.add(method)
383       else:
384         self._known_vars.add(method)
385
386   def __getattr__(self, name):
387     """Intercept attribute request on this object.
388
389     If the attribute is a public class variable, it will be returned and not
390     recorded as a call.
391
392     If the attribute is not a variable, it is handled like a method
393     call. The method name is checked against the set of mockable
394     methods, and a new MockMethod is returned that is aware of the
395     MockObject's state (record or replay).  The call will be recorded
396     or replayed by the MockMethod's __call__.
397
398     Args:
399       # name: the name of the attribute being requested.
400       name: str
401
402     Returns:
403       Either a class variable or a new MockMethod that is aware of the state
404       of the mock (record or replay).
405
406     Raises:
407       UnknownMethodCallError if the MockObject does not mock the requested
408           method.
409     """
410
411     if name in self._known_vars:
412       return getattr(self._class_to_mock, name)
413
414     if name in self._known_methods:
415       return self._CreateMockMethod(name)
416
417     raise UnknownMethodCallError(name)
418
419   def __eq__(self, rhs):
420     """Provide custom logic to compare objects."""
421
422     return (isinstance(rhs, MockObject) and
423             self._class_to_mock == rhs._class_to_mock and
424             self._replay_mode == rhs._replay_mode and
425             self._expected_calls_queue == rhs._expected_calls_queue)
426
427   def __setitem__(self, key, value):
428     """Provide custom logic for mocking classes that support item assignment.
429
430     Args:
431       key: Key to set the value for.
432       value: Value to set.
433
434     Returns:
435       Expected return value in replay mode.  A MockMethod object for the
436       __setitem__ method that has already been called if not in replay mode.
437
438     Raises:
439       TypeError if the underlying class does not support item assignment.
440       UnexpectedMethodCallError if the object does not expect the call to
441         __setitem__.
442
443     """
444     setitem = self._class_to_mock.__dict__.get('__setitem__', None)
445
446     # Verify the class supports item assignment.
447     if setitem is None:
448       raise TypeError('object does not support item assignment')
449
450     # If we are in replay mode then simply call the mock __setitem__ method.
451     if self._replay_mode:
452       return MockMethod('__setitem__', self._expected_calls_queue,
453                         self._replay_mode)(key, value)
454
455
456     # Otherwise, create a mock method __setitem__.
457     return self._CreateMockMethod('__setitem__')(key, value)
458
459   def __getitem__(self, key):
460     """Provide custom logic for mocking classes that are subscriptable.
461
462     Args:
463       key: Key to return the value for.
464
465     Returns:
466       Expected return value in replay mode.  A MockMethod object for the
467       __getitem__ method that has already been called if not in replay mode.
468
469     Raises:
470       TypeError if the underlying class is not subscriptable.
471       UnexpectedMethodCallError if the object does not expect the call to
472         __setitem__.
473
474     """
475     getitem = self._class_to_mock.__dict__.get('__getitem__', None)
476
477     # Verify the class supports item assignment.
478     if getitem is None:
479       raise TypeError('unsubscriptable object')
480
481     # If we are in replay mode then simply call the mock __getitem__ method.
482     if self._replay_mode:
483       return MockMethod('__getitem__', self._expected_calls_queue,
484                         self._replay_mode)(key)
485
486
487     # Otherwise, create a mock method __getitem__.
488     return self._CreateMockMethod('__getitem__')(key)
489
490   def __call__(self, *params, **named_params):
491     """Provide custom logic for mocking classes that are callable."""
492
493     # Verify the class we are mocking is callable
494     callable = self._class_to_mock.__dict__.get('__call__', None)
495     if callable is None:
496       raise TypeError('Not callable')
497
498     # Because the call is happening directly on this object instead of a method,
499     # the call on the mock method is made right here
500     mock_method = self._CreateMockMethod('__call__')
501     return mock_method(*params, **named_params)
502
503   @property
504   def __class__(self):
505     """Return the class that is being mocked."""
506
507     return self._class_to_mock
508
509
510 class MockMethod(object):
511   """Callable mock method.
512
513   A MockMethod should act exactly like the method it mocks, accepting parameters
514   and returning a value, or throwing an exception (as specified).  When this
515   method is called, it can optionally verify whether the called method (name and
516   signature) matches the expected method.
517   """
518
519   def __init__(self, method_name, call_queue, replay_mode):
520     """Construct a new mock method.
521
522     Args:
523       # method_name: the name of the method
524       # call_queue: deque of calls, verify this call against the head, or add
525       #     this call to the queue.
526       # replay_mode: False if we are recording, True if we are verifying calls
527       #     against the call queue.
528       method_name: str
529       call_queue: list or deque
530       replay_mode: bool
531     """
532
533     self._name = method_name
534     self._call_queue = call_queue
535     if not isinstance(call_queue, deque):
536       self._call_queue = deque(self._call_queue)
537     self._replay_mode = replay_mode
538
539     self._params = None
540     self._named_params = None
541     self._return_value = None
542     self._exception = None
543     self._side_effects = None
544
545   def __call__(self, *params, **named_params):
546     """Log parameters and return the specified return value.
547
548     If the Mock(Anything/Object) associated with this call is in record mode,
549     this MockMethod will be pushed onto the expected call queue.  If the mock
550     is in replay mode, this will pop a MockMethod off the top of the queue and
551     verify this call is equal to the expected call.
552
553     Raises:
554       UnexpectedMethodCall if this call is supposed to match an expected method
555         call and it does not.
556     """
557
558     self._params = params
559     self._named_params = named_params
560
561     if not self._replay_mode:
562       self._call_queue.append(self)
563       return self
564
565     expected_method = self._VerifyMethodCall()
566
567     if expected_method._side_effects:
568       expected_method._side_effects(*params, **named_params)
569
570     if expected_method._exception:
571       raise expected_method._exception
572
573     return expected_method._return_value
574
575   def __getattr__(self, name):
576     """Raise an AttributeError with a helpful message."""
577
578     raise AttributeError('MockMethod has no attribute "%s". '
579         'Did you remember to put your mocks in replay mode?' % name)
580
581   def _PopNextMethod(self):
582     """Pop the next method from our call queue."""
583     try:
584       return self._call_queue.popleft()
585     except IndexError:
586       raise UnexpectedMethodCallError(self, None)
587
588   def _VerifyMethodCall(self):
589     """Verify the called method is expected.
590
591     This can be an ordered method, or part of an unordered set.
592
593     Returns:
594       The expected mock method.
595
596     Raises:
597       UnexpectedMethodCall if the method called was not expected.
598     """
599
600     expected = self._PopNextMethod()
601
602     # Loop here, because we might have a MethodGroup followed by another
603     # group.
604     while isinstance(expected, MethodGroup):
605       expected, method = expected.MethodCalled(self)
606       if method is not None:
607         return method
608
609     # This is a mock method, so just check equality.
610     if expected != self:
611       raise UnexpectedMethodCallError(self, expected)
612
613     return expected
614
615   def __str__(self):
616     params = ', '.join(
617         [repr(p) for p in self._params or []] +
618         ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
619     desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
620     return desc
621
622   def __eq__(self, rhs):
623     """Test whether this MockMethod is equivalent to another MockMethod.
624
625     Args:
626       # rhs: the right hand side of the test
627       rhs: MockMethod
628     """
629
630     return (isinstance(rhs, MockMethod) and
631             self._name == rhs._name and
632             self._params == rhs._params and
633             self._named_params == rhs._named_params)
634
635   def __ne__(self, rhs):
636     """Test whether this MockMethod is not equivalent to another MockMethod.
637
638     Args:
639       # rhs: the right hand side of the test
640       rhs: MockMethod
641     """
642
643     return not self == rhs
644
645   def GetPossibleGroup(self):
646     """Returns a possible group from the end of the call queue or None if no
647     other methods are on the stack.
648     """
649
650     # Remove this method from the tail of the queue so we can add it to a group.
651     this_method = self._call_queue.pop()
652     assert this_method == self
653
654     # Determine if the tail of the queue is a group, or just a regular ordered
655     # mock method.
656     group = None
657     try:
658       group = self._call_queue[-1]
659     except IndexError:
660       pass
661
662     return group
663
664   def _CheckAndCreateNewGroup(self, group_name, group_class):
665     """Checks if the last method (a possible group) is an instance of our
666     group_class. Adds the current method to this group or creates a new one.
667
668     Args:
669
670       group_name: the name of the group.
671       group_class: the class used to create instance of this new group
672     """
673     group = self.GetPossibleGroup()
674
675     # If this is a group, and it is the correct group, add the method.
676     if isinstance(group, group_class) and group.group_name() == group_name:
677       group.AddMethod(self)
678       return self
679
680     # Create a new group and add the method.
681     new_group = group_class(group_name)
682     new_group.AddMethod(self)
683     self._call_queue.append(new_group)
684     return self
685
686   def InAnyOrder(self, group_name="default"):
687     """Move this method into a group of unordered calls.
688
689     A group of unordered calls must be defined together, and must be executed
690     in full before the next expected method can be called.  There can be
691     multiple groups that are expected serially, if they are given
692     different group names.  The same group name can be reused if there is a
693     standard method call, or a group with a different name, spliced between
694     usages.
695
696     Args:
697       group_name: the name of the unordered group.
698
699     Returns:
700       self
701     """
702     return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
703
704   def MultipleTimes(self, group_name="default"):
705     """Move this method into group of calls which may be called multiple times.
706
707     A group of repeating calls must be defined together, and must be executed in
708     full before the next expected mehtod can be called.
709
710     Args:
711       group_name: the name of the unordered group.
712
713     Returns:
714       self
715     """
716     return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
717
718   def AndReturn(self, return_value):
719     """Set the value to return when this method is called.
720
721     Args:
722       # return_value can be anything.
723     """
724
725     self._return_value = return_value
726     return return_value
727
728   def AndRaise(self, exception):
729     """Set the exception to raise when this method is called.
730
731     Args:
732       # exception: the exception to raise when this method is called.
733       exception: Exception
734     """
735
736     self._exception = exception
737
738   def WithSideEffects(self, side_effects):
739     """Set the side effects that are simulated when this method is called.
740
741     Args:
742       side_effects: A callable which modifies the parameters or other relevant
743         state which a given test case depends on.
744
745     Returns:
746       Self for chaining with AndReturn and AndRaise.
747     """
748     self._side_effects = side_effects
749     return self
750
751 class Comparator:
752   """Base class for all Mox comparators.
753
754   A Comparator can be used as a parameter to a mocked method when the exact
755   value is not known.  For example, the code you are testing might build up a
756   long SQL string that is passed to your mock DAO. You're only interested that
757   the IN clause contains the proper primary keys, so you can set your mock
758   up as follows:
759
760   mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
761
762   Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
763
764   A Comparator may replace one or more parameters, for example:
765   # return at most 10 rows
766   mock_dao.RunQuery(StrContains('SELECT'), 10)
767
768   or
769
770   # Return some non-deterministic number of rows
771   mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
772   """
773
774   def equals(self, rhs):
775     """Special equals method that all comparators must implement.
776
777     Args:
778       rhs: any python object
779     """
780
781     raise NotImplementedError, 'method must be implemented by a subclass.'
782
783   def __eq__(self, rhs):
784     return self.equals(rhs)
785
786   def __ne__(self, rhs):
787     return not self.equals(rhs)
788
789
790 class IsA(Comparator):
791   """This class wraps a basic Python type or class.  It is used to verify
792   that a parameter is of the given type or class.
793
794   Example:
795   mock_dao.Connect(IsA(DbConnectInfo))
796   """
797
798   def __init__(self, class_name):
799     """Initialize IsA
800
801     Args:
802       class_name: basic python type or a class
803     """
804
805     self._class_name = class_name
806
807   def equals(self, rhs):
808     """Check to see if the RHS is an instance of class_name.
809
810     Args:
811       # rhs: the right hand side of the test
812       rhs: object
813
814     Returns:
815       bool
816     """
817
818     try:
819       return isinstance(rhs, self._class_name)
820     except TypeError:
821       # Check raw types if there was a type error.  This is helpful for
822       # things like cStringIO.StringIO.
823       return type(rhs) == type(self._class_name)
824
825   def __repr__(self):
826     return str(self._class_name)
827
828 class IsAlmost(Comparator):
829   """Comparison class used to check whether a parameter is nearly equal
830   to a given value.  Generally useful for floating point numbers.
831
832   Example mock_dao.SetTimeout((IsAlmost(3.9)))
833   """
834
835   def __init__(self, float_value, places=7):
836     """Initialize IsAlmost.
837
838     Args:
839       float_value: The value for making the comparison.
840       places: The number of decimal places to round to.
841     """
842
843     self._float_value = float_value
844     self._places = places
845
846   def equals(self, rhs):
847     """Check to see if RHS is almost equal to float_value
848
849     Args:
850       rhs: the value to compare to float_value
851
852     Returns:
853       bool
854     """
855
856     try:
857       return round(rhs-self._float_value, self._places) == 0
858     except TypeError:
859       # This is probably because either float_value or rhs is not a number.
860       return False
861
862   def __repr__(self):
863     return str(self._float_value)
864
865 class StrContains(Comparator):
866   """Comparison class used to check whether a substring exists in a
867   string parameter.  This can be useful in mocking a database with SQL
868   passed in as a string parameter, for example.
869
870   Example:
871   mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
872   """
873
874   def __init__(self, search_string):
875     """Initialize.
876
877     Args:
878       # search_string: the string you are searching for
879       search_string: str
880     """
881
882     self._search_string = search_string
883
884   def equals(self, rhs):
885     """Check to see if the search_string is contained in the rhs string.
886
887     Args:
888       # rhs: the right hand side of the test
889       rhs: object
890
891     Returns:
892       bool
893     """
894
895     try:
896       return rhs.find(self._search_string) > -1
897     except Exception:
898       return False
899
900   def __repr__(self):
901     return '<str containing \'%s\'>' % self._search_string
902
903
904 class Regex(Comparator):
905   """Checks if a string matches a regular expression.
906
907   This uses a given regular expression to determine equality.
908   """
909
910   def __init__(self, pattern, flags=0):
911     """Initialize.
912
913     Args:
914       # pattern is the regular expression to search for
915       pattern: str
916       # flags passed to re.compile function as the second argument
917       flags: int
918     """
919
920     self.regex = re.compile(pattern, flags=flags)
921
922   def equals(self, rhs):
923     """Check to see if rhs matches regular expression pattern.
924
925     Returns:
926       bool
927     """
928
929     return self.regex.search(rhs) is not None
930
931   def __repr__(self):
932     s = '<regular expression \'%s\'' % self.regex.pattern
933     if self.regex.flags:
934       s += ', flags=%d' % self.regex.flags
935     s += '>'
936     return s
937
938
939 class In(Comparator):
940   """Checks whether an item (or key) is in a list (or dict) parameter.
941
942   Example:
943   mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
944   """
945
946   def __init__(self, key):
947     """Initialize.
948
949     Args:
950       # key is any thing that could be in a list or a key in a dict
951     """
952
953     self._key = key
954
955   def equals(self, rhs):
956     """Check to see whether key is in rhs.
957
958     Args:
959       rhs: dict
960
961     Returns:
962       bool
963     """
964
965     return self._key in rhs
966
967   def __repr__(self):
968     return '<sequence or map containing \'%s\'>' % self._key
969
970
971 class ContainsKeyValue(Comparator):
972   """Checks whether a key/value pair is in a dict parameter.
973
974   Example:
975   mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
976   """
977
978   def __init__(self, key, value):
979     """Initialize.
980
981     Args:
982       # key: a key in a dict
983       # value: the corresponding value
984     """
985
986     self._key = key
987     self._value = value
988
989   def equals(self, rhs):
990     """Check whether the given key/value pair is in the rhs dict.
991
992     Returns:
993       bool
994     """
995
996     try:
997       return rhs[self._key] == self._value
998     except Exception:
999       return False
1000
1001   def __repr__(self):
1002     return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
1003
1004
1005 class SameElementsAs(Comparator):
1006   """Checks whether iterables contain the same elements (ignoring order).
1007
1008   Example:
1009   mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
1010   """
1011
1012   def __init__(self, expected_seq):
1013     """Initialize.
1014
1015     Args:
1016       expected_seq: a sequence
1017     """
1018
1019     self._expected_seq = expected_seq
1020
1021   def equals(self, actual_seq):
1022     """Check to see whether actual_seq has same elements as expected_seq.
1023
1024     Args:
1025       actual_seq: sequence
1026
1027     Returns:
1028       bool
1029     """
1030
1031     try:
1032       expected = dict([(element, None) for element in self._expected_seq])
1033       actual = dict([(element, None) for element in actual_seq])
1034     except TypeError:
1035       # Fall back to slower list-compare if any of the objects are unhashable.
1036       expected = list(self._expected_seq)
1037       actual = list(actual_seq)
1038       expected.sort()
1039       actual.sort()
1040     return expected == actual
1041
1042   def __repr__(self):
1043     return '<sequence with same elements as \'%s\'>' % self._expected_seq
1044
1045
1046 class And(Comparator):
1047   """Evaluates one or more Comparators on RHS and returns an AND of the results.
1048   """
1049
1050   def __init__(self, *args):
1051     """Initialize.
1052
1053     Args:
1054       *args: One or more Comparator
1055     """
1056
1057     self._comparators = args
1058
1059   def equals(self, rhs):
1060     """Checks whether all Comparators are equal to rhs.
1061
1062     Args:
1063       # rhs: can be anything
1064
1065     Returns:
1066       bool
1067     """
1068
1069     for comparator in self._comparators:
1070       if not comparator.equals(rhs):
1071         return False
1072
1073     return True
1074
1075   def __repr__(self):
1076     return '<AND %s>' % str(self._comparators)
1077
1078
1079 class Or(Comparator):
1080   """Evaluates one or more Comparators on RHS and returns an OR of the results.
1081   """
1082
1083   def __init__(self, *args):
1084     """Initialize.
1085
1086     Args:
1087       *args: One or more Mox comparators
1088     """
1089
1090     self._comparators = args
1091
1092   def equals(self, rhs):
1093     """Checks whether any Comparator is equal to rhs.
1094
1095     Args:
1096       # rhs: can be anything
1097
1098     Returns:
1099       bool
1100     """
1101
1102     for comparator in self._comparators:
1103       if comparator.equals(rhs):
1104         return True
1105
1106     return False
1107
1108   def __repr__(self):
1109     return '<OR %s>' % str(self._comparators)
1110
1111
1112 class Func(Comparator):
1113   """Call a function that should verify the parameter passed in is correct.
1114
1115   You may need the ability to perform more advanced operations on the parameter
1116   in order to validate it.  You can use this to have a callable validate any
1117   parameter. The callable should return either True or False.
1118
1119
1120   Example:
1121
1122   def myParamValidator(param):
1123     # Advanced logic here
1124     return True
1125
1126   mock_dao.DoSomething(Func(myParamValidator), true)
1127   """
1128
1129   def __init__(self, func):
1130     """Initialize.
1131
1132     Args:
1133       func: callable that takes one parameter and returns a bool
1134     """
1135
1136     self._func = func
1137
1138   def equals(self, rhs):
1139     """Test whether rhs passes the function test.
1140
1141     rhs is passed into func.
1142
1143     Args:
1144       rhs: any python object
1145
1146     Returns:
1147       the result of func(rhs)
1148     """
1149
1150     return self._func(rhs)
1151
1152   def __repr__(self):
1153     return str(self._func)
1154
1155
1156 class IgnoreArg(Comparator):
1157   """Ignore an argument.
1158
1159   This can be used when we don't care about an argument of a method call.
1160
1161   Example:
1162   # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
1163   mymock.CastMagic(3, IgnoreArg(), 'disappear')
1164   """
1165
1166   def equals(self, unused_rhs):
1167     """Ignores arguments and returns True.
1168
1169     Args:
1170       unused_rhs: any python object
1171
1172     Returns:
1173       always returns True
1174     """
1175
1176     return True
1177
1178   def __repr__(self):
1179     return '<IgnoreArg>'
1180
1181
1182 class MethodGroup(object):
1183   """Base class containing common behaviour for MethodGroups."""
1184
1185   def __init__(self, group_name):
1186     self._group_name = group_name
1187
1188   def group_name(self):
1189     return self._group_name
1190
1191   def __str__(self):
1192     return '<%s "%s">' % (self.__class__.__name__, self._group_name)
1193
1194   def AddMethod(self, mock_method):
1195     raise NotImplementedError
1196
1197   def MethodCalled(self, mock_method):
1198     raise NotImplementedError
1199
1200   def IsSatisfied(self):
1201     raise NotImplementedError
1202
1203 class UnorderedGroup(MethodGroup):
1204   """UnorderedGroup holds a set of method calls that may occur in any order.
1205
1206   This construct is helpful for non-deterministic events, such as iterating
1207   over the keys of a dict.
1208   """
1209
1210   def __init__(self, group_name):
1211     super(UnorderedGroup, self).__init__(group_name)
1212     self._methods = []
1213
1214   def AddMethod(self, mock_method):
1215     """Add a method to this group.
1216
1217     Args:
1218       mock_method: A mock method to be added to this group.
1219     """
1220
1221     self._methods.append(mock_method)
1222
1223   def MethodCalled(self, mock_method):
1224     """Remove a method call from the group.
1225
1226     If the method is not in the set, an UnexpectedMethodCallError will be
1227     raised.
1228
1229     Args:
1230       mock_method: a mock method that should be equal to a method in the group.
1231
1232     Returns:
1233       The mock method from the group
1234
1235     Raises:
1236       UnexpectedMethodCallError if the mock_method was not in the group.
1237     """
1238
1239     # Check to see if this method exists, and if so, remove it from the set
1240     # and return it.
1241     for method in self._methods:
1242       if method == mock_method:
1243         # Remove the called mock_method instead of the method in the group.
1244         # The called method will match any comparators when equality is checked
1245         # during removal.  The method in the group could pass a comparator to
1246         # another comparator during the equality check.
1247         self._methods.remove(mock_method)
1248
1249         # If this group is not empty, put it back at the head of the queue.
1250         if not self.IsSatisfied():
1251           mock_method._call_queue.appendleft(self)
1252
1253         return self, method
1254
1255     raise UnexpectedMethodCallError(mock_method, self)
1256
1257   def IsSatisfied(self):
1258     """Return True if there are not any methods in this group."""
1259
1260     return len(self._methods) == 0
1261
1262
1263 class MultipleTimesGroup(MethodGroup):
1264   """MultipleTimesGroup holds methods that may be called any number of times.
1265
1266   Note: Each method must be called at least once.
1267
1268   This is helpful, if you don't know or care how many times a method is called.
1269   """
1270
1271   def __init__(self, group_name):
1272     super(MultipleTimesGroup, self).__init__(group_name)
1273     self._methods = set()
1274     self._methods_called = set()
1275
1276   def AddMethod(self, mock_method):
1277     """Add a method to this group.
1278
1279     Args:
1280       mock_method: A mock method to be added to this group.
1281     """
1282
1283     self._methods.add(mock_method)
1284
1285   def MethodCalled(self, mock_method):
1286     """Remove a method call from the group.
1287
1288     If the method is not in the set, an UnexpectedMethodCallError will be
1289     raised.
1290
1291     Args:
1292       mock_method: a mock method that should be equal to a method in the group.
1293
1294     Returns:
1295       The mock method from the group
1296
1297     Raises:
1298       UnexpectedMethodCallError if the mock_method was not in the group.
1299     """
1300
1301     # Check to see if this method exists, and if so add it to the set of
1302     # called methods.
1303
1304     for method in self._methods:
1305       if method == mock_method:
1306         self._methods_called.add(mock_method)
1307         # Always put this group back on top of the queue, because we don't know
1308         # when we are done.
1309         mock_method._call_queue.appendleft(self)
1310         return self, method
1311
1312     if self.IsSatisfied():
1313       next_method = mock_method._PopNextMethod();
1314       return next_method, None
1315     else:
1316       raise UnexpectedMethodCallError(mock_method, self)
1317
1318   def IsSatisfied(self):
1319     """Return True if all methods in this group are called at least once."""
1320     # NOTE(psycho): We can't use the simple set difference here because we want
1321     # to match different parameters which are considered the same e.g. IsA(str)
1322     # and some string. This solution is O(n^2) but n should be small.
1323     tmp = self._methods.copy()
1324     for called in self._methods_called:
1325       for expected in tmp:
1326         if called == expected:
1327           tmp.remove(expected)
1328           if not tmp:
1329             return True
1330           break
1331     return False
1332
1333
1334 class MoxMetaTestBase(type):
1335   """Metaclass to add mox cleanup and verification to every test.
1336
1337   As the mox unit testing class is being constructed (MoxTestBase or a
1338   subclass), this metaclass will modify all test functions to call the
1339   CleanUpMox method of the test class after they finish. This means that
1340   unstubbing and verifying will happen for every test with no additional code,
1341   and any failures will result in test failures as opposed to errors.
1342   """
1343
1344   def __init__(cls, name, bases, d):
1345     type.__init__(cls, name, bases, d)
1346
1347     # also get all the attributes from the base classes to account
1348     # for a case when test class is not the immediate child of MoxTestBase
1349     for base in bases:
1350       for attr_name in dir(base):
1351         d[attr_name] = getattr(base, attr_name)
1352
1353     for func_name, func in d.items():
1354       if func_name.startswith('test') and callable(func):
1355         setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
1356
1357   @staticmethod
1358   def CleanUpTest(cls, func):
1359     """Adds Mox cleanup code to any MoxTestBase method.
1360
1361     Always unsets stubs after a test. Will verify all mocks for tests that
1362     otherwise pass.
1363
1364     Args:
1365       cls: MoxTestBase or subclass; the class whose test method we are altering.
1366       func: method; the method of the MoxTestBase test class we wish to alter.
1367
1368     Returns:
1369       The modified method.
1370     """
1371     def new_method(self, *args, **kwargs):
1372       mox_obj = getattr(self, 'mox', None)
1373       cleanup_mox = False
1374       if mox_obj and isinstance(mox_obj, Mox):
1375         cleanup_mox = True
1376       try:
1377         func(self, *args, **kwargs)
1378       finally:
1379         if cleanup_mox:
1380           mox_obj.UnsetStubs()
1381       if cleanup_mox:
1382         mox_obj.VerifyAll()
1383     new_method.__name__ = func.__name__
1384     new_method.__doc__ = func.__doc__
1385     new_method.__module__ = func.__module__
1386     return new_method
1387
1388
1389 class MoxTestBase(unittest.TestCase):
1390   """Convenience test class to make stubbing easier.
1391
1392   Sets up a "mox" attribute which is an instance of Mox - any mox tests will
1393   want this. Also automatically unsets any stubs and verifies that all mock
1394   methods have been called at the end of each test, eliminating boilerplate
1395   code.
1396   """
1397
1398   __metaclass__ = MoxMetaTestBase
1399
1400   def setUp(self):
1401     self.mox = Mox()