2 # Copyright 2013 The Swarming Authors. All rights reserved.
3 # Use of this source code is governed under the Apache License, Version 2.0 that
4 # can be found in the LICENSE file.
6 # Lambda may not be necessary.
7 # pylint: disable=W0108
19 ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
20 sys.path.insert(0, ROOT_DIR)
22 from utils import threading_utils
25 def timeout(max_running_time):
26 """Test method decorator that fails the test if it executes longer
27 than |max_running_time| seconds.
29 It exists to terminate tests in case of deadlocks. There's a high chance that
30 process is broken after such timeout (due to hanging deadlocked threads that
31 can own some shared resources). But failing early (maybe not in a cleanest
32 way) due to timeout is generally better than hanging indefinitely.
34 |max_running_time| should be an order of magnitude (or even two orders) larger
35 than the expected run time of the test to compensate for slow machine, high
36 CPU utilization by some other processes, etc.
40 Noop on windows (since win32 doesn't support signal.setitimer).
42 if sys.platform == 'win32':
43 return lambda method: method
45 def decorator(method):
46 @functools.wraps(method)
47 def wrapper(self, *args, **kwargs):
48 signal.signal(signal.SIGALRM, lambda *_args: self.fail('Timeout'))
49 signal.setitimer(signal.ITIMER_REAL, max_running_time)
51 return method(self, *args, **kwargs)
53 signal.signal(signal.SIGALRM, signal.SIG_DFL)
54 signal.setitimer(signal.ITIMER_REAL, 0)
60 class ThreadPoolTest(unittest.TestCase):
64 # Append custom assert messages to default ones (works with python >= 2.7).
68 def sleep_task(duration=0.01):
69 """Returns function that sleeps |duration| sec and returns its argument."""
75 def retrying_sleep_task(self, duration=0.01):
76 """Returns function that adds sleep_task to the thread pool."""
78 self.thread_pool.add_task(0, self.sleep_task(duration), arg)
83 """Returns function that returns None."""
84 return lambda _arg: None
87 super(ThreadPoolTest, self).setUp()
88 self.thread_pool = threading_utils.ThreadPool(
89 self.MIN_THREADS, self.MAX_THREADS, 0)
93 super(ThreadPoolTest, self).tearDown()
94 self.thread_pool.close()
96 def get_results_via_join(self, _expected):
97 return self.thread_pool.join()
99 def get_results_via_get_one_result(self, expected):
100 return [self.thread_pool.get_one_result() for _ in expected]
102 def get_results_via_iter_results(self, _expected):
103 return list(self.thread_pool.iter_results())
105 def run_results_test(self, task, results_getter, args=None, expected=None):
106 """Template function for tests checking that pool returns all results.
108 Will add multiple instances of |task| to the thread pool, then call
109 |results_getter| to get back all results and compare them to expected ones.
111 args = range(0, 100) if args is None else args
112 expected = args if expected is None else expected
113 msg = 'Using \'%s\' to get results.' % (results_getter.__name__,)
116 self.thread_pool.add_task(0, task, i)
117 results = results_getter(expected)
119 # Check that got all results back (exact same set, no duplicates).
120 self.assertEqual(set(expected), set(results), msg)
121 self.assertEqual(len(expected), len(results), msg)
123 # Queue is empty, result request should fail.
124 with self.assertRaises(threading_utils.ThreadPoolEmpty):
125 self.thread_pool.get_one_result()
128 def test_get_one_result_ok(self):
129 self.thread_pool.add_task(0, lambda: 'OK')
130 self.assertEqual(self.thread_pool.get_one_result(), 'OK')
133 def test_get_one_result_fail(self):
134 # No tasks added -> get_one_result raises an exception.
135 with self.assertRaises(threading_utils.ThreadPoolEmpty):
136 self.thread_pool.get_one_result()
140 self.run_results_test(self.sleep_task(),
141 self.get_results_via_join)
144 def test_get_one_result(self):
145 self.run_results_test(self.sleep_task(),
146 self.get_results_via_get_one_result)
149 def test_iter_results(self):
150 self.run_results_test(self.sleep_task(),
151 self.get_results_via_iter_results)
154 def test_retry_and_join(self):
155 self.run_results_test(self.retrying_sleep_task(),
156 self.get_results_via_join)
159 def test_retry_and_get_one_result(self):
160 self.run_results_test(self.retrying_sleep_task(),
161 self.get_results_via_get_one_result)
164 def test_retry_and_iter_results(self):
165 self.run_results_test(self.retrying_sleep_task(),
166 self.get_results_via_iter_results)
169 def test_none_task_and_join(self):
170 self.run_results_test(self.none_task(),
171 self.get_results_via_join,
175 def test_none_task_and_get_one_result(self):
176 self.thread_pool.add_task(0, self.none_task(), 0)
177 with self.assertRaises(threading_utils.ThreadPoolEmpty):
178 self.thread_pool.get_one_result()
181 def test_none_task_and_and_iter_results(self):
182 self.run_results_test(self.none_task(),
183 self.get_results_via_iter_results,
187 def test_generator_task(self):
191 # Generator that yields [i * MULTIPLIER, i * MULTIPLIER + COUNT).
192 def generator_task(i):
193 for j in xrange(COUNT):
195 yield i * MULTIPLIER + j
197 # Arguments for tasks and expected results.
199 expected = [i * MULTIPLIER + j for i in args for j in xrange(COUNT)]
201 # Test all possible ways to pull results from the thread pool.
202 getters = (self.get_results_via_join,
203 self.get_results_via_iter_results,
204 self.get_results_via_get_one_result,)
205 for results_getter in getters:
206 self.run_results_test(generator_task, results_getter, args, expected)
209 def test_concurrent_iter_results(self):
210 def poller_proc(result):
211 result.extend(self.thread_pool.iter_results())
215 self.thread_pool.add_task(0, self.sleep_task(), i)
217 # Start a bunch of threads, all calling iter_results in parallel.
219 for _ in xrange(0, 4):
221 poller = threading.Thread(target=poller_proc, args=(result,))
223 pollers.append((poller, result))
225 # Collects results from all polling threads.
227 for poller, results in pollers:
229 all_results.extend(results)
231 # Check that got all results back (exact same set, no duplicates).
232 self.assertEqual(set(args), set(all_results))
233 self.assertEqual(len(args), len(all_results))
236 def test_adding_tasks_after_close(self):
237 pool = threading_utils.ThreadPool(1, 1, 0)
238 pool.add_task(0, lambda: None)
240 with self.assertRaises(threading_utils.ThreadPoolClosed):
241 pool.add_task(0, lambda: None)
244 def test_double_close(self):
245 pool = threading_utils.ThreadPool(1, 1, 0)
247 with self.assertRaises(threading_utils.ThreadPoolClosed):
250 def test_priority(self):
251 # Verifies that a lower priority is run first.
252 with threading_utils.ThreadPool(1, 1, 0) as pool:
253 lock = threading.Lock()
255 def wait_and_return(x):
263 pool.add_task(0, wait_and_return, 'a')
264 pool.add_task(2, return_x, 'b')
265 pool.add_task(1, return_x, 'c')
268 self.assertEqual(['a', 'c', 'b'], actual)
271 def test_abort(self):
272 # Trigger a ridiculous amount of tasks, and abort the remaining.
273 with threading_utils.ThreadPool(2, 2, 0) as pool:
274 # Allow 10 tasks to run initially.
275 sem = threading.Semaphore(10)
277 def grab_and_return(x):
282 pool.add_task(0, grab_and_return, i)
284 # Running at 11 would hang.
285 results = [pool.get_one_result() for _ in xrange(10)]
286 # At that point, there's 10 completed tasks and 2 tasks hanging, 88
288 self.assertEqual(88, pool.abort())
289 # Calling .join() before these 2 .release() would hang.
292 results.extend(pool.join())
293 # The results *may* be out of order. Even if the calls are processed
294 # strictly in FIFO mode, a thread may preempt another one when returning the
296 self.assertEqual(range(12), sorted(results))
299 class AutoRetryThreadPoolTest(unittest.TestCase):
300 def test_bad_class(self):
301 exceptions = [AutoRetryThreadPoolTest]
302 with self.assertRaises(AssertionError):
303 threading_utils.AutoRetryThreadPool(exceptions, 1, 0, 1, 0)
305 def test_no_exception(self):
306 with self.assertRaises(AssertionError):
307 threading_utils.AutoRetryThreadPool([], 1, 0, 1, 0)
309 def test_bad_retry(self):
310 exceptions = [IOError]
311 with self.assertRaises(AssertionError):
312 threading_utils.AutoRetryThreadPool(exceptions, 256, 0, 1, 0)
314 def test_bad_priority(self):
315 exceptions = [IOError]
316 with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool:
317 pool.add_task(0, lambda x: x, 0)
318 pool.add_task(256, lambda x: x, 0)
319 pool.add_task(512, lambda x: x, 0)
320 with self.assertRaises(AssertionError):
321 pool.add_task(1, lambda x: x, 0)
322 with self.assertRaises(AssertionError):
323 pool.add_task(255, lambda x: x, 0)
325 def test_priority(self):
326 # Verifies that a lower priority is run first.
327 exceptions = [IOError]
328 with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool:
329 lock = threading.Lock()
331 def wait_and_return(x):
339 pool.add_task(pool.HIGH, wait_and_return, 'a')
340 pool.add_task(pool.LOW, return_x, 'b')
341 pool.add_task(pool.MED, return_x, 'c')
344 self.assertEqual(['a', 'c', 'b'], actual)
346 def test_retry_inherited(self):
347 # Exception class inheritance works.
348 class CustomException(IOError):
351 def throw(to_throw, x):
354 raise to_throw.pop(0)
356 with threading_utils.AutoRetryThreadPool([IOError], 1, 1, 1, 0) as pool:
357 pool.add_task(pool.MED, throw, [CustomException('a')], 'yay')
359 self.assertEqual(['yay'], actual)
360 self.assertEqual(['yay', 'yay'], ran)
362 def test_retry_2_times(self):
363 exceptions = [IOError, OSError]
364 to_throw = [OSError('a'), IOError('b')]
367 raise to_throw.pop(0)
369 with threading_utils.AutoRetryThreadPool(exceptions, 2, 1, 1, 0) as pool:
370 pool.add_task(pool.MED, throw, 'yay')
372 self.assertEqual(['yay'], actual)
374 def test_retry_too_many_times(self):
375 exceptions = [IOError, OSError]
376 to_throw = [OSError('a'), IOError('b')]
379 raise to_throw.pop(0)
381 with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool:
382 pool.add_task(pool.MED, throw, 'yay')
383 with self.assertRaises(IOError):
386 def test_retry_mutation_1(self):
387 # This is to warn that mutable arguments WILL be mutated.
388 def throw(to_throw, x):
390 raise to_throw.pop(0)
392 exceptions = [IOError, OSError]
393 with threading_utils.AutoRetryThreadPool(exceptions, 1, 1, 1, 0) as pool:
394 pool.add_task(pool.MED, throw, [OSError('a'), IOError('b')], 'yay')
395 with self.assertRaises(IOError):
398 def test_retry_mutation_2(self):
399 # This is to warn that mutable arguments WILL be mutated.
400 def throw(to_throw, x):
402 raise to_throw.pop(0)
404 exceptions = [IOError, OSError]
405 with threading_utils.AutoRetryThreadPool(exceptions, 2, 1, 1, 0) as pool:
406 pool.add_task(pool.MED, throw, [OSError('a'), IOError('b')], 'yay')
408 self.assertEqual(['yay'], actual)
410 def test_retry_interleaved(self):
411 # Verifies that retries are interleaved. This is important, we don't want a
412 # retried task to take all the pool during retries.
413 exceptions = [IOError, OSError]
414 lock = threading.Lock()
416 with threading_utils.AutoRetryThreadPool(exceptions, 2, 1, 1, 0) as pool:
417 def lock_and_throw(to_throw, x):
421 raise to_throw.pop(0)
425 pool.MED, lock_and_throw, [OSError('a'), IOError('b')], 'A')
427 pool.MED, lock_and_throw, [OSError('a'), IOError('b')], 'B')
430 self.assertEqual(['A', 'B'], actual)
431 # Retries are properly interleaved:
432 self.assertEqual(['A', 'B', 'A', 'B', 'A', 'B'], ran)
434 def test_add_task_with_channel_success(self):
435 with threading_utils.AutoRetryThreadPool([OSError], 2, 1, 1, 0) as pool:
436 channel = threading_utils.TaskChannel()
437 pool.add_task_with_channel(channel, 0, lambda: 0)
438 self.assertEqual(0, channel.pull())
440 def test_add_task_with_channel_fatal_error(self):
441 with threading_utils.AutoRetryThreadPool([OSError], 2, 1, 1, 0) as pool:
442 channel = threading_utils.TaskChannel()
445 pool.add_task_with_channel(channel, 0, throw, ValueError())
446 with self.assertRaises(ValueError):
449 def test_add_task_with_channel_retryable_error(self):
450 with threading_utils.AutoRetryThreadPool([OSError], 2, 1, 1, 0) as pool:
451 channel = threading_utils.TaskChannel()
454 pool.add_task_with_channel(channel, 0, throw, OSError())
455 with self.assertRaises(OSError):
458 def test_add_task_with_channel_captures_stack_trace(self):
459 with threading_utils.AutoRetryThreadPool([OSError], 2, 1, 1, 0) as pool:
460 channel = threading_utils.TaskChannel()
462 def function_with_some_unusual_name():
464 function_with_some_unusual_name()
465 pool.add_task_with_channel(channel, 0, throw, OSError())
470 exc_traceback = traceback.format_exc()
471 self.assertIn('function_with_some_unusual_name', exc_traceback)
474 class FakeProgress(object):
480 class WorkerPoolTest(unittest.TestCase):
481 def test_normal(self):
482 mapper = lambda value: -value
483 progress = FakeProgress()
484 with threading_utils.ThreadPoolWithProgress(progress, 8, 8, 0) as pool:
486 pool.add_task(0, mapper, i)
487 results = pool.join()
488 self.assertEqual(range(-31, 1), sorted(results))
490 def test_exception(self):
491 class FearsomeException(Exception):
494 raise FearsomeException(value)
497 progress = FakeProgress()
498 with threading_utils.ThreadPoolWithProgress(progress, 8, 8, 0) as pool:
499 pool.add_task(0, mapper, 0)
503 except FearsomeException:
504 self.assertEqual(True, task_added)
507 class TaskChannelTest(unittest.TestCase):
508 def test_passes_simple_value(self):
509 with threading_utils.ThreadPool(1, 1, 0) as tp:
510 channel = threading_utils.TaskChannel()
511 tp.add_task(0, lambda: channel.send_result(0))
512 self.assertEqual(0, channel.pull())
514 def test_passes_exception_value(self):
515 with threading_utils.ThreadPool(1, 1, 0) as tp:
516 channel = threading_utils.TaskChannel()
517 tp.add_task(0, lambda: channel.send_result(Exception()))
518 self.assertTrue(isinstance(channel.pull(), Exception))
520 def test_wrap_task_passes_simple_value(self):
521 with threading_utils.ThreadPool(1, 1, 0) as tp:
522 channel = threading_utils.TaskChannel()
523 tp.add_task(0, channel.wrap_task(lambda: 0))
524 self.assertEqual(0, channel.pull())
526 def test_wrap_task_passes_exception_value(self):
527 with threading_utils.ThreadPool(1, 1, 0) as tp:
528 channel = threading_utils.TaskChannel()
529 tp.add_task(0, channel.wrap_task(lambda: Exception()))
530 self.assertTrue(isinstance(channel.pull(), Exception))
532 def test_send_exception_raises_exception(self):
533 class CustomError(Exception):
535 with threading_utils.ThreadPool(1, 1, 0) as tp:
536 channel = threading_utils.TaskChannel()
537 exc_info = (CustomError, CustomError(), None)
538 tp.add_task(0, lambda: channel.send_exception(exc_info))
539 with self.assertRaises(CustomError):
542 def test_wrap_task_raises_exception(self):
543 class CustomError(Exception):
545 with threading_utils.ThreadPool(1, 1, 0) as tp:
546 channel = threading_utils.TaskChannel()
549 tp.add_task(0, channel.wrap_task(task_func))
550 with self.assertRaises(CustomError):
553 def test_wrap_task_exception_captures_stack_trace(self):
554 class CustomError(Exception):
556 with threading_utils.ThreadPool(1, 1, 0) as tp:
557 channel = threading_utils.TaskChannel()
559 def function_with_some_unusual_name():
561 function_with_some_unusual_name()
562 tp.add_task(0, channel.wrap_task(task_func))
567 exc_traceback = traceback.format_exc()
568 self.assertIn('function_with_some_unusual_name', exc_traceback)
570 def test_pull_timeout(self):
571 with threading_utils.ThreadPool(1, 1, 0) as tp:
572 channel = threading_utils.TaskChannel()
574 # This test ultimately relies on the condition variable primitive
575 # provided by pthreads. There's no easy way to mock time for it.
576 # Increase this duration if the test is flaky.
579 tp.add_task(0, channel.wrap_task(task_func))
580 with self.assertRaises(threading_utils.TaskChannel.Timeout):
581 channel.pull(timeout=0.001)
582 self.assertEqual(123, channel.pull())
584 def test_timeout_exception_from_task(self):
585 with threading_utils.ThreadPool(1, 1, 0) as tp:
586 channel = threading_utils.TaskChannel()
588 raise threading_utils.TaskChannel.Timeout()
589 tp.add_task(0, channel.wrap_task(task_func))
590 # 'Timeout' raised by task gets transformed into 'RuntimeError'.
591 with self.assertRaises(RuntimeError):
595 if __name__ == '__main__':
596 VERBOSE = '-v' in sys.argv
597 logging.basicConfig(level=logging.DEBUG if VERBOSE else logging.ERROR)