Upstream version 8.36.161.0
[platform/framework/web/crosswalk.git] / src / third_party / chromite / lib / partial_mock.py
1 #!/usr/bin/python
2 # Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
3 # Use of this source code is governed by a BSD-style license that can be
4 # found in the LICENSE file.
5
6 """Contains functionality used to implement a partial mock."""
7
8 import collections
9 import logging
10 import mock
11 import os
12 import re
13
14 from chromite.lib import cros_build_lib
15 from chromite.lib import osutils
16
17
18 class Comparator(object):
19   """Base class for all comparators."""
20
21   def Match(self, arg):
22     """Match the comparator against an argument."""
23     raise NotImplementedError, 'method must be implemented by a subclass.'
24
25   def Equals(self, rhs):
26     """Returns whether rhs compares the same thing."""
27     return type(self) == type(rhs) and self.__dict__ == rhs.__dict__
28
29   def __eq__(self, rhs):
30     return self.Equals(rhs)
31
32   def __ne__(self, rhs):
33     return not self.Equals(rhs)
34
35
36 class In(Comparator):
37   """Checks whether an item (or key) is in a list (or dict) parameter."""
38
39   def __init__(self, key):
40     """Initialize.
41
42     Args:
43       key: Any thing that could be in a list or a key in a dict
44     """
45     Comparator.__init__(self)
46     self._key = key
47
48   def Match(self, arg):
49     try:
50       return self._key in arg
51     except TypeError:
52       return False
53
54   def __repr__(self):
55     return '<sequence or map containing %r>' % str(self._key)
56
57
58 class Regex(Comparator):
59   """Checks if a string matches a regular expression."""
60
61   def __init__(self, pattern, flags=0):
62     """Initialize.
63
64     Args:
65       pattern: is the regular expression to search for
66       flags: passed to re.compile function as the second argument
67     """
68     Comparator.__init__(self)
69     self.pattern = pattern
70     self.flags = flags
71     self.regex = re.compile(pattern, flags=flags)
72
73   def Match(self, arg):
74     try:
75       return self.regex.search(arg) is not None
76     except TypeError:
77       return False
78
79   def __repr__(self):
80     s = '<regular expression %r' % self.regex.pattern
81     if self.regex.flags:
82       s += ', flags=%d' % self.regex.flags
83     s += '>'
84     return s
85
86
87 class ListRegex(Regex):
88   """Checks if an iterable of strings matches a regular expression."""
89
90   @staticmethod
91   def _ProcessArg(arg):
92     if not isinstance(arg, basestring):
93       return ' '.join(arg)
94     return arg
95
96   def Match(self, arg):
97     try:
98       return self.regex.search(self._ProcessArg(arg)) is not None
99     except TypeError:
100       return False
101
102
103 class Ignore(Comparator):
104   """Used when we don't care about an argument of a method call."""
105
106   def Match(self, _arg):
107     return True
108
109   def __repr__(self):
110     return '<IgnoreArg>'
111
112
113 def _RecursiveCompare(lhs, rhs):
114   """Compare parameter specs recursively.
115
116   Args:
117     lhs: Left Hand Side parameter spec to compare.
118     rhs: Right Hand Side parameter spec to compare.
119     equality: In the case of comparing Comparator objects, True means we call
120       the Equals() function.  We call Match() if set to False (default).
121   """
122   if isinstance(lhs, Comparator):
123     return lhs.Match(rhs)
124   elif isinstance(lhs, (tuple, list)):
125     return (type(lhs) == type(rhs) and
126             len(lhs) == len(rhs) and
127             all(_RecursiveCompare(i, j) for i, j in zip(lhs, rhs)))
128   elif isinstance(lhs, dict):
129     return _RecursiveCompare(sorted(lhs.iteritems()), sorted(rhs.iteritems()))
130   else:
131     return lhs == rhs
132
133
134 def ListContains(small, big, strict=False):
135   """Looks for a sublist within a bigger list.
136
137   Args:
138     small: The sublist to search for.
139     big: The list to search in.
140     strict: If True, all items in list must be adjacent.
141   """
142   if strict:
143     for i in xrange(len(big) - len(small) + 1):
144       if _RecursiveCompare(small, big[i:i + len(small)]):
145         return True
146     return False
147   else:
148     j = 0
149     for i in xrange(len(small)):
150       for j in xrange(j, len(big)):
151         if _RecursiveCompare(small[i], big[j]):
152           j += 1
153           break
154       else:
155         return False
156     return True
157
158
159 def DictContains(small, big):
160   """Looks for a subset within a dictionary.
161
162   Args:
163     small: The sub-dict to search for.
164     big: The dict to search in.
165   """
166   for k, v in small.iteritems():
167     if k not in big or not _RecursiveCompare(v, big[k]):
168       return False
169   return True
170
171
172 class MockedCallResults(object):
173   """Implements internal result specification for partial mocks.
174
175   Used with the PartialMock class.
176
177   Internal results are different from external results (return values,
178   side effects, exceptions, etc.) for functions.  Internal results are
179   *used* by the partial mock to generate external results.  Often internal
180   results represent the external results of the dependencies of the function
181   being partially mocked.  Of course, the partial mock can just pass through
182   the internal results to become external results.
183   """
184
185   Params = collections.namedtuple('Params', ['args', 'kwargs'])
186   MockedCall = collections.namedtuple(
187       'MockedCall', ['params', 'strict', 'result', 'side_effect'])
188
189   def __init__(self, name):
190     """Initialize.
191
192     Args:
193       name: The name given to the mock.  Will be used in debug output.
194     """
195     self.name = name
196     self.mocked_calls = []
197     self.default_result, self.default_side_effect = None, None
198
199   @staticmethod
200   def AssertArgs(args, kwargs):
201     """Verify arguments are of expected type."""
202     assert isinstance(args, (tuple))
203     if kwargs:
204       assert isinstance(kwargs, dict)
205
206   def AddResultForParams(self, args, result, kwargs=None, side_effect=None,
207                          strict=True):
208     """Record the internal results of a given partial mock call.
209
210     Args:
211       args: A list containing the positional args an invocation must have for
212         it to match the internal result.  The list can contain instances of
213         meta-args (such as IgnoreArg, Regex, In, etc.).  Positional argument
214         matching is always *strict*, meaning extra positional arguments in
215         the invocation are not allowed.
216       result: The internal result that will be matched for the command
217         invocation specified.
218       kwargs: A dictionary containing the keyword args an invocation must have
219         for it to match the internal result.  The dictionary can contain
220         instances of meta-args (such as IgnoreArg, Regex, In, etc.).  Keyword
221         argument matching is by default *strict*, but can be modified by the
222         |strict| argument.
223       side_effect: A functor that gets called every time a partially mocked
224         function is invoked.  The arguments the partial mock is invoked with are
225         passed to the functor.  This is similar to how side effects work for
226         mocks.
227       strict: Specifies whether keyword are matched strictly.  With strict
228         matching turned on, any keyword args a partial mock is invoked with that
229         are not specified in |kwargs| will cause the match to fail.
230     """
231     self.AssertArgs(args, kwargs)
232     if kwargs is None:
233       kwargs = {}
234
235     params = self.Params(args=args, kwargs=kwargs)
236     dup, filtered = cros_build_lib.PredicateSplit(
237         lambda mc: mc.params == params, self.mocked_calls)
238
239     new = self.MockedCall(params=params, strict=strict, result=result,
240                           side_effect=side_effect)
241     filtered.append(new)
242     self.mocked_calls = filtered
243
244     if dup:
245       logging.debug('%s: replacing mock for arguments %r:\n%r -> %r',
246                     self.name, params, dup, new)
247
248   def SetDefaultResult(self, result, side_effect=None):
249     """Set the default result for an unmatched partial mock call.
250
251     Args:
252       result: See AddResultsForParams.
253       side_effect: See AddResultsForParams.
254     """
255     self.default_result, self.default_side_effect = result, side_effect
256
257   def LookupResult(self, args, kwargs=None, hook_args=None, hook_kwargs=None):
258     """For a given mocked function call lookup the recorded internal results.
259
260     Args:
261       args: A list containing positional args the function was called with.
262       kwargs: A dict containing keyword args the function was called with.
263       hook_args: A list of positional args to call the hook with.
264       hook_kwargs: A dict of key/value args to call the hook with.
265
266     Returns:
267       The recorded result for the invocation.
268
269     Raises:
270       AssertionError when the call is not mocked, or when there is more
271       than one mock that matches.
272     """
273     def filter_fn(mc):
274       if mc.strict:
275         return _RecursiveCompare(mc.params, params)
276
277       return (DictContains(mc.params.kwargs, kwargs) and
278               _RecursiveCompare(mc.params.args, args))
279
280     self.AssertArgs(args, kwargs)
281     if kwargs is None:
282       kwargs = {}
283
284     params = self.Params(args, kwargs)
285     matched, _ = cros_build_lib.PredicateSplit(filter_fn, self.mocked_calls)
286     if len(matched) > 1:
287       raise AssertionError(
288           "%s: args %r matches more than one mock:\n%s"
289           % (self.name, params, '\n'.join([repr(c) for c in matched])))
290     elif matched:
291       side_effect, result = matched[0].side_effect, matched[0].result
292     elif (self.default_result, self.default_side_effect) != (None, None):
293       side_effect, result = self.default_side_effect, self.default_result
294     else:
295       raise AssertionError("%s: %r not mocked!" % (self.name, params))
296
297     if side_effect:
298       assert(hook_args is not None)
299       assert(hook_kwargs is not None)
300       hook_result = side_effect(*hook_args, **hook_kwargs)
301       if hook_result is not None:
302         return hook_result
303     return result
304
305
306 class PartialMock(object):
307   """Provides functionality for partially mocking out a function or method.
308
309   Partial mocking is useful in cases where the side effects of a function or
310   method are complex, and so re-using the logic of the function with
311   *dependencies* mocked out is preferred over mocking out the entire function
312   and re-implementing the side effect (return value, state modification) logic
313   in the test.  It is also useful for creating re-usable mocks.
314   """
315
316   TARGET = None
317   ATTRS = None
318
319   def __init__(self, create_tempdir=False):
320     """Initialize.
321
322     Args:
323       create_tempdir: If set to True, the partial mock will create its own
324         temporary directory when start() is called, and will set self.tempdir to
325         the path of the directory.  The directory is deleted when stop() is
326         called.
327     """
328     self.backup = {}
329     self.patchers = {}
330     self.patched = {}
331     self.external_patchers = []
332     self.create_tempdir = create_tempdir
333
334     # Set when start() is called.
335     self._tempdir_obj = None
336     self.tempdir = None
337     self.__saved_env__ = None
338     self.started = False
339
340     self._results = {}
341
342     if not all([self.TARGET, self.ATTRS]) and any([self.TARGET, self.ATTRS]):
343       raise AssertionError('TARGET=%r but ATTRS=%r!'
344                            % (self.TARGET, self.ATTRS))
345
346     if self.ATTRS is not None:
347       for attr in self.ATTRS:
348         self._results[attr] = MockedCallResults(attr)
349
350   def __enter__(self):
351     return self.start()
352
353   def __exit__(self, exc_type, exc_value, traceback):
354     self.stop()
355
356   def PreStart(self):
357     """Called at the beginning of start(). Child classes can override this.
358
359     If __init__ was called with |create_tempdir| set, then self.tempdir will
360     point to an existing temporary directory when this function is called.
361     """
362
363   def PreStop(self):
364     """Called at the beginning of stop().  Child classes can override this.
365
366     If __init__ was called with |create_tempdir| set, then self.tempdir will
367     not be deleted until after this function returns.
368     """
369
370   def StartPatcher(self, patcher):
371     """PartialMock will stop the patcher when stop() is called."""
372     self.external_patchers.append(patcher)
373     return patcher.start()
374
375   def PatchObject(self, *args, **kwargs):
376     """Create and start a mock.patch.object().
377
378     stop() will be called automatically during tearDown.
379     """
380     return self.StartPatcher(mock.patch.object(*args, **kwargs))
381
382   def _start(self):
383     if not all([self.TARGET, self.ATTRS]):
384       return
385
386     chunks = self.TARGET.rsplit('.', 1)
387     module = cros_build_lib.load_module(chunks[0])
388
389     cls = getattr(module, chunks[1])
390     for attr in self.ATTRS:
391       self.backup[attr] = getattr(cls, attr)
392       src_attr = '_target%s' % attr if attr.startswith('__') else attr
393       if hasattr(self.backup[attr], 'reset_mock'):
394         raise AssertionError(
395             'You are trying to nest mock contexts - this is currently '
396             'unsupported by PartialMock.')
397       if callable(self.backup[attr]):
398         patcher = mock.patch.object(cls, attr, autospec=True,
399                                     side_effect=getattr(self, src_attr))
400       else:
401         patcher = mock.patch.object(cls, attr, getattr(self, src_attr))
402       self.patched[attr] = patcher.start()
403       self.patchers[attr] = patcher
404
405     return self
406
407   def start(self):
408     """Activates the mock context."""
409     try:
410       self.__saved_env__ = os.environ.copy()
411       self.tempdir = None
412       if self.create_tempdir:
413         self._tempdir_obj = osutils.TempDir(set_global=True)
414         self.tempdir = self._tempdir_obj.tempdir
415
416       self.started = True
417       self.PreStart()
418       return self._start()
419     except:
420       self.stop()
421       raise
422
423   def stop(self):
424     """Restores namespace to the unmocked state."""
425     try:
426       if self.__saved_env__ is not None:
427         osutils.SetEnvironment(self.__saved_env__)
428
429       tasks = ([self.PreStop] + [p.stop for p in self.patchers.itervalues()] +
430                [p.stop for p in self.external_patchers])
431       if self._tempdir_obj is not None:
432         tasks += [self._tempdir_obj.Cleanup]
433       cros_build_lib.SafeRun(tasks)
434     finally:
435       self.started = False
436       self.tempdir, self._tempdir_obj = None, None
437
438   def UnMockAttr(self, attr):
439     """Unsetting the mock of an attribute/function."""
440     self.patchers.pop(attr).stop()
441
442
443 def CheckAttr(f):
444   """Automatically set mock_attr based on class default.
445
446   This function decorator automatically sets the mock_attr keyword argument
447   based on the class default. The mock_attr specifies which mocked attribute
448   a given function is referring to.
449
450   Raises an AssertionError if mock_attr is left unspecified.
451   """
452   def new_f(self, *args, **kwargs):
453     mock_attr = kwargs.pop('mock_attr', None)
454     if mock_attr is None:
455       mock_attr = self.DEFAULT_ATTR
456       if self.DEFAULT_ATTR is None:
457         raise AssertionError(
458             'mock_attr not specified, and no default configured.')
459     return f(self, *args, mock_attr=mock_attr, **kwargs)
460   return new_f
461
462
463 class PartialCmdMock(PartialMock):
464   """Base class for mocking functions that wrap command line functionality.
465
466   Implements mocking for functions that shell out.  The internal results are
467   'returncode', 'output', 'error'.
468   """
469
470   CmdResult = collections.namedtuple(
471     'MockResult', ['returncode', 'output', 'error'])
472
473   DEFAULT_ATTR = None
474
475   @CheckAttr
476   def SetDefaultCmdResult(self, returncode=0, output='', error='',
477                           side_effect=None, mock_attr=None):
478     """Specify the default command result if no command is matched.
479
480     Args:
481       returncode: See AddCmdResult.
482       output: See AddCmdResult.
483       error: See AddCmdResult.
484       side_effect: See MockedCallResults.AddResultForParams
485       mock_attr: Which attributes's mock is being referenced.
486     """
487     result = self.CmdResult(returncode, output, error)
488     self._results[mock_attr].SetDefaultResult(result, side_effect)
489
490   @CheckAttr
491   def AddCmdResult(self, cmd, returncode=0, output='', error='',
492                    kwargs=None, strict=False, side_effect=None, mock_attr=None):
493     """Specify the result to simulate for a given command.
494
495     Args:
496       cmd: The command string or list to record a result for.
497       returncode: The returncode of the command (on the command line).
498       output: The stdout output of the command.
499       error: The stderr output of the command.
500       kwargs: Keyword arguments that the function needs to be invoked with.
501       strict: Defaults to False.  See MockedCallResults.AddResultForParams.
502       side_effect: See MockedCallResults.AddResultForParams
503       mock_attr: Which attributes's mock is being referenced.
504     """
505     result = self.CmdResult(returncode, output, error)
506     self._results[mock_attr].AddResultForParams(
507         (cmd,), result, kwargs=kwargs, side_effect=side_effect, strict=strict)
508
509   @CheckAttr
510   def CommandContains(self, args, cmd_arg_index=-1, mock_attr=None, **kwargs):
511     """Verify that at least one command contains the specified args.
512
513     Args:
514       args: Set of expected command-line arguments.
515       cmd_arg_index: The index of the command list in the positional call_args.
516         Defaults to the last positional argument.
517       kwargs: Set of expected keyword arguments.
518       mock_attr: Which attributes's mock is being referenced.
519     """
520     for call_args, call_kwargs in self.patched[mock_attr].call_args_list:
521       if (ListContains(args, call_args[cmd_arg_index]) and
522           DictContains(kwargs, call_kwargs)):
523         return True
524     return False
525
526   @CheckAttr
527   def assertCommandContains(self, args=(), expected=True, mock_attr=None,
528                             **kwargs):
529     """Assert that RunCommand was called with the specified args.
530
531     This verifies that at least one of the RunCommand calls contains the
532     specified arguments on the command line.
533
534     Args:
535       args: Set of expected command-line arguments.
536       expected: If False, instead verify that none of the RunCommand calls
537           contained the specified arguments.
538       **kwargs: Set of expected keyword arguments.
539       mock_attr: Which attributes's mock is being referenced.
540     """
541     if bool(expected) != self.CommandContains(args, **kwargs):
542       if expected:
543         msg = 'Expected to find %r in any of:\n%s'
544       else:
545         msg = 'Expected to not find %r in any of:\n%s'
546       patched = self.patched[mock_attr]
547       cmds = '\n'.join(repr(x) for x in patched.call_args_list)
548       raise AssertionError(msg % (mock.call(args, **kwargs), cmds))
549
550   @CheckAttr
551   def assertCommandCalled(self, args=(), mock_attr=None, **kwargs):
552     """Assert that RunCommand was called with the specified args.
553
554     This verifies that at least one of the RunCommand calls exactly
555     matches the specified command line and misc-arguments.
556
557     Args:
558       args: Set of expected command-line arguments.
559       mock_attr: Which attributes's mock is being referenced.
560       **kwargs: Set of expected keyword arguments.
561     """
562     call = mock.call(args, **kwargs)
563     patched = self.patched[mock_attr]
564
565     for icall in patched.call_args_list:
566       if call == icall:
567         return
568
569     cmds = '\n'.join(repr(x) for x in patched.call_args_list)
570     raise AssertionError('Expected to find %r in any of:\n%s' % (call, cmds))
571
572   @property
573   @CheckAttr
574   def call_count(self, mock_attr=None):
575     """Return the number of times we've been called."""
576     return self.patched[mock_attr].call_count
577
578   @property
579   @CheckAttr
580   def call_args_list(self, mock_attr=None):
581     """Return the list of args we've been called with."""
582     return self.patched[mock_attr].call_args_list