2 # Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
3 # Use of this source code is governed by a BSD-style license that can be
4 # found in the LICENSE file.
6 """Contains functionality used to implement a partial mock."""
14 from chromite.lib import cros_build_lib
15 from chromite.lib import osutils
18 class Comparator(object):
19 """Base class for all comparators."""
22 """Match the comparator against an argument."""
23 raise NotImplementedError, 'method must be implemented by a subclass.'
25 def Equals(self, rhs):
26 """Returns whether rhs compares the same thing."""
27 return type(self) == type(rhs) and self.__dict__ == rhs.__dict__
29 def __eq__(self, rhs):
30 return self.Equals(rhs)
32 def __ne__(self, rhs):
33 return not self.Equals(rhs)
37 """Checks whether an item (or key) is in a list (or dict) parameter."""
39 def __init__(self, key):
43 key: Any thing that could be in a list or a key in a dict
45 Comparator.__init__(self)
50 return self._key in arg
55 return '<sequence or map containing %r>' % str(self._key)
58 class Regex(Comparator):
59 """Checks if a string matches a regular expression."""
61 def __init__(self, pattern, flags=0):
65 pattern: is the regular expression to search for
66 flags: passed to re.compile function as the second argument
68 Comparator.__init__(self)
69 self.pattern = pattern
71 self.regex = re.compile(pattern, flags=flags)
75 return self.regex.search(arg) is not None
80 s = '<regular expression %r' % self.regex.pattern
82 s += ', flags=%d' % self.regex.flags
87 class ListRegex(Regex):
88 """Checks if an iterable of strings matches a regular expression."""
92 if not isinstance(arg, basestring):
98 return self.regex.search(self._ProcessArg(arg)) is not None
103 class Ignore(Comparator):
104 """Used when we don't care about an argument of a method call."""
106 def Match(self, _arg):
113 def _RecursiveCompare(lhs, rhs):
114 """Compare parameter specs recursively.
117 lhs: Left Hand Side parameter spec to compare.
118 rhs: Right Hand Side parameter spec to compare.
119 equality: In the case of comparing Comparator objects, True means we call
120 the Equals() function. We call Match() if set to False (default).
122 if isinstance(lhs, Comparator):
123 return lhs.Match(rhs)
124 elif isinstance(lhs, (tuple, list)):
125 return (type(lhs) == type(rhs) and
126 len(lhs) == len(rhs) and
127 all(_RecursiveCompare(i, j) for i, j in zip(lhs, rhs)))
128 elif isinstance(lhs, dict):
129 return _RecursiveCompare(sorted(lhs.iteritems()), sorted(rhs.iteritems()))
134 def ListContains(small, big, strict=False):
135 """Looks for a sublist within a bigger list.
138 small: The sublist to search for.
139 big: The list to search in.
140 strict: If True, all items in list must be adjacent.
143 for i in xrange(len(big) - len(small) + 1):
144 if _RecursiveCompare(small, big[i:i + len(small)]):
149 for i in xrange(len(small)):
150 for j in xrange(j, len(big)):
151 if _RecursiveCompare(small[i], big[j]):
159 def DictContains(small, big):
160 """Looks for a subset within a dictionary.
163 small: The sub-dict to search for.
164 big: The dict to search in.
166 for k, v in small.iteritems():
167 if k not in big or not _RecursiveCompare(v, big[k]):
172 class MockedCallResults(object):
173 """Implements internal result specification for partial mocks.
175 Used with the PartialMock class.
177 Internal results are different from external results (return values,
178 side effects, exceptions, etc.) for functions. Internal results are
179 *used* by the partial mock to generate external results. Often internal
180 results represent the external results of the dependencies of the function
181 being partially mocked. Of course, the partial mock can just pass through
182 the internal results to become external results.
185 Params = collections.namedtuple('Params', ['args', 'kwargs'])
186 MockedCall = collections.namedtuple(
187 'MockedCall', ['params', 'strict', 'result', 'side_effect'])
189 def __init__(self, name):
193 name: The name given to the mock. Will be used in debug output.
196 self.mocked_calls = []
197 self.default_result, self.default_side_effect = None, None
200 def AssertArgs(args, kwargs):
201 """Verify arguments are of expected type."""
202 assert isinstance(args, (tuple))
204 assert isinstance(kwargs, dict)
206 def AddResultForParams(self, args, result, kwargs=None, side_effect=None,
208 """Record the internal results of a given partial mock call.
211 args: A list containing the positional args an invocation must have for
212 it to match the internal result. The list can contain instances of
213 meta-args (such as IgnoreArg, Regex, In, etc.). Positional argument
214 matching is always *strict*, meaning extra positional arguments in
215 the invocation are not allowed.
216 result: The internal result that will be matched for the command
217 invocation specified.
218 kwargs: A dictionary containing the keyword args an invocation must have
219 for it to match the internal result. The dictionary can contain
220 instances of meta-args (such as IgnoreArg, Regex, In, etc.). Keyword
221 argument matching is by default *strict*, but can be modified by the
223 side_effect: A functor that gets called every time a partially mocked
224 function is invoked. The arguments the partial mock is invoked with are
225 passed to the functor. This is similar to how side effects work for
227 strict: Specifies whether keyword are matched strictly. With strict
228 matching turned on, any keyword args a partial mock is invoked with that
229 are not specified in |kwargs| will cause the match to fail.
231 self.AssertArgs(args, kwargs)
235 params = self.Params(args=args, kwargs=kwargs)
236 dup, filtered = cros_build_lib.PredicateSplit(
237 lambda mc: mc.params == params, self.mocked_calls)
239 new = self.MockedCall(params=params, strict=strict, result=result,
240 side_effect=side_effect)
242 self.mocked_calls = filtered
245 logging.debug('%s: replacing mock for arguments %r:\n%r -> %r',
246 self.name, params, dup, new)
248 def SetDefaultResult(self, result, side_effect=None):
249 """Set the default result for an unmatched partial mock call.
252 result: See AddResultsForParams.
253 side_effect: See AddResultsForParams.
255 self.default_result, self.default_side_effect = result, side_effect
257 def LookupResult(self, args, kwargs=None, hook_args=None, hook_kwargs=None):
258 """For a given mocked function call lookup the recorded internal results.
261 args: A list containing positional args the function was called with.
262 kwargs: A dict containing keyword args the function was called with.
263 hook_args: A list of positional args to call the hook with.
264 hook_kwargs: A dict of key/value args to call the hook with.
267 The recorded result for the invocation.
270 AssertionError when the call is not mocked, or when there is more
271 than one mock that matches.
275 return _RecursiveCompare(mc.params, params)
277 return (DictContains(mc.params.kwargs, kwargs) and
278 _RecursiveCompare(mc.params.args, args))
280 self.AssertArgs(args, kwargs)
284 params = self.Params(args, kwargs)
285 matched, _ = cros_build_lib.PredicateSplit(filter_fn, self.mocked_calls)
287 raise AssertionError(
288 "%s: args %r matches more than one mock:\n%s"
289 % (self.name, params, '\n'.join([repr(c) for c in matched])))
291 side_effect, result = matched[0].side_effect, matched[0].result
292 elif (self.default_result, self.default_side_effect) != (None, None):
293 side_effect, result = self.default_side_effect, self.default_result
295 raise AssertionError("%s: %r not mocked!" % (self.name, params))
298 assert(hook_args is not None)
299 assert(hook_kwargs is not None)
300 hook_result = side_effect(*hook_args, **hook_kwargs)
301 if hook_result is not None:
306 class PartialMock(object):
307 """Provides functionality for partially mocking out a function or method.
309 Partial mocking is useful in cases where the side effects of a function or
310 method are complex, and so re-using the logic of the function with
311 *dependencies* mocked out is preferred over mocking out the entire function
312 and re-implementing the side effect (return value, state modification) logic
313 in the test. It is also useful for creating re-usable mocks.
319 def __init__(self, create_tempdir=False):
323 create_tempdir: If set to True, the partial mock will create its own
324 temporary directory when start() is called, and will set self.tempdir to
325 the path of the directory. The directory is deleted when stop() is
331 self.external_patchers = []
332 self.create_tempdir = create_tempdir
334 # Set when start() is called.
335 self._tempdir_obj = None
337 self.__saved_env__ = None
342 if not all([self.TARGET, self.ATTRS]) and any([self.TARGET, self.ATTRS]):
343 raise AssertionError('TARGET=%r but ATTRS=%r!'
344 % (self.TARGET, self.ATTRS))
346 if self.ATTRS is not None:
347 for attr in self.ATTRS:
348 self._results[attr] = MockedCallResults(attr)
353 def __exit__(self, exc_type, exc_value, traceback):
357 """Called at the beginning of start(). Child classes can override this.
359 If __init__ was called with |create_tempdir| set, then self.tempdir will
360 point to an existing temporary directory when this function is called.
364 """Called at the beginning of stop(). Child classes can override this.
366 If __init__ was called with |create_tempdir| set, then self.tempdir will
367 not be deleted until after this function returns.
370 def StartPatcher(self, patcher):
371 """PartialMock will stop the patcher when stop() is called."""
372 self.external_patchers.append(patcher)
373 return patcher.start()
375 def PatchObject(self, *args, **kwargs):
376 """Create and start a mock.patch.object().
378 stop() will be called automatically during tearDown.
380 return self.StartPatcher(mock.patch.object(*args, **kwargs))
383 if not all([self.TARGET, self.ATTRS]):
386 chunks = self.TARGET.rsplit('.', 1)
387 module = cros_build_lib.load_module(chunks[0])
389 cls = getattr(module, chunks[1])
390 for attr in self.ATTRS:
391 self.backup[attr] = getattr(cls, attr)
392 src_attr = '_target%s' % attr if attr.startswith('__') else attr
393 if hasattr(self.backup[attr], 'reset_mock'):
394 raise AssertionError(
395 'You are trying to nest mock contexts - this is currently '
396 'unsupported by PartialMock.')
397 if callable(self.backup[attr]):
398 patcher = mock.patch.object(cls, attr, autospec=True,
399 side_effect=getattr(self, src_attr))
401 patcher = mock.patch.object(cls, attr, getattr(self, src_attr))
402 self.patched[attr] = patcher.start()
403 self.patchers[attr] = patcher
408 """Activates the mock context."""
410 self.__saved_env__ = os.environ.copy()
412 if self.create_tempdir:
413 self._tempdir_obj = osutils.TempDir(set_global=True)
414 self.tempdir = self._tempdir_obj.tempdir
424 """Restores namespace to the unmocked state."""
426 if self.__saved_env__ is not None:
427 osutils.SetEnvironment(self.__saved_env__)
429 tasks = ([self.PreStop] + [p.stop for p in self.patchers.itervalues()] +
430 [p.stop for p in self.external_patchers])
431 if self._tempdir_obj is not None:
432 tasks += [self._tempdir_obj.Cleanup]
433 cros_build_lib.SafeRun(tasks)
436 self.tempdir, self._tempdir_obj = None, None
438 def UnMockAttr(self, attr):
439 """Unsetting the mock of an attribute/function."""
440 self.patchers.pop(attr).stop()
444 """Automatically set mock_attr based on class default.
446 This function decorator automatically sets the mock_attr keyword argument
447 based on the class default. The mock_attr specifies which mocked attribute
448 a given function is referring to.
450 Raises an AssertionError if mock_attr is left unspecified.
452 def new_f(self, *args, **kwargs):
453 mock_attr = kwargs.pop('mock_attr', None)
454 if mock_attr is None:
455 mock_attr = self.DEFAULT_ATTR
456 if self.DEFAULT_ATTR is None:
457 raise AssertionError(
458 'mock_attr not specified, and no default configured.')
459 return f(self, *args, mock_attr=mock_attr, **kwargs)
463 class PartialCmdMock(PartialMock):
464 """Base class for mocking functions that wrap command line functionality.
466 Implements mocking for functions that shell out. The internal results are
467 'returncode', 'output', 'error'.
470 CmdResult = collections.namedtuple(
471 'MockResult', ['returncode', 'output', 'error'])
476 def SetDefaultCmdResult(self, returncode=0, output='', error='',
477 side_effect=None, mock_attr=None):
478 """Specify the default command result if no command is matched.
481 returncode: See AddCmdResult.
482 output: See AddCmdResult.
483 error: See AddCmdResult.
484 side_effect: See MockedCallResults.AddResultForParams
485 mock_attr: Which attributes's mock is being referenced.
487 result = self.CmdResult(returncode, output, error)
488 self._results[mock_attr].SetDefaultResult(result, side_effect)
491 def AddCmdResult(self, cmd, returncode=0, output='', error='',
492 kwargs=None, strict=False, side_effect=None, mock_attr=None):
493 """Specify the result to simulate for a given command.
496 cmd: The command string or list to record a result for.
497 returncode: The returncode of the command (on the command line).
498 output: The stdout output of the command.
499 error: The stderr output of the command.
500 kwargs: Keyword arguments that the function needs to be invoked with.
501 strict: Defaults to False. See MockedCallResults.AddResultForParams.
502 side_effect: See MockedCallResults.AddResultForParams
503 mock_attr: Which attributes's mock is being referenced.
505 result = self.CmdResult(returncode, output, error)
506 self._results[mock_attr].AddResultForParams(
507 (cmd,), result, kwargs=kwargs, side_effect=side_effect, strict=strict)
510 def CommandContains(self, args, cmd_arg_index=-1, mock_attr=None, **kwargs):
511 """Verify that at least one command contains the specified args.
514 args: Set of expected command-line arguments.
515 cmd_arg_index: The index of the command list in the positional call_args.
516 Defaults to the last positional argument.
517 kwargs: Set of expected keyword arguments.
518 mock_attr: Which attributes's mock is being referenced.
520 for call_args, call_kwargs in self.patched[mock_attr].call_args_list:
521 if (ListContains(args, call_args[cmd_arg_index]) and
522 DictContains(kwargs, call_kwargs)):
527 def assertCommandContains(self, args=(), expected=True, mock_attr=None,
529 """Assert that RunCommand was called with the specified args.
531 This verifies that at least one of the RunCommand calls contains the
532 specified arguments on the command line.
535 args: Set of expected command-line arguments.
536 expected: If False, instead verify that none of the RunCommand calls
537 contained the specified arguments.
538 **kwargs: Set of expected keyword arguments.
539 mock_attr: Which attributes's mock is being referenced.
541 if bool(expected) != self.CommandContains(args, **kwargs):
543 msg = 'Expected to find %r in any of:\n%s'
545 msg = 'Expected to not find %r in any of:\n%s'
546 patched = self.patched[mock_attr]
547 cmds = '\n'.join(repr(x) for x in patched.call_args_list)
548 raise AssertionError(msg % (mock.call(args, **kwargs), cmds))
551 def assertCommandCalled(self, args=(), mock_attr=None, **kwargs):
552 """Assert that RunCommand was called with the specified args.
554 This verifies that at least one of the RunCommand calls exactly
555 matches the specified command line and misc-arguments.
558 args: Set of expected command-line arguments.
559 mock_attr: Which attributes's mock is being referenced.
560 **kwargs: Set of expected keyword arguments.
562 call = mock.call(args, **kwargs)
563 patched = self.patched[mock_attr]
565 for icall in patched.call_args_list:
569 cmds = '\n'.join(repr(x) for x in patched.call_args_list)
570 raise AssertionError('Expected to find %r in any of:\n%s' % (call, cmds))
574 def call_count(self, mock_attr=None):
575 """Return the number of times we've been called."""
576 return self.patched[mock_attr].call_count
580 def call_args_list(self, mock_attr=None):
581 """Return the list of args we've been called with."""
582 return self.patched[mock_attr].call_args_list