1 # Copyright 2013 The Swarming Authors. All rights reserved.
2 # Use of this source code is governed under the Apache License, Version 2.0 that
3 # can be found in the LICENSE file.
5 """Classes and functions related to threading."""
18 class LockWithAssert(object):
19 """Wrapper around (non recursive) Lock that tracks its owner."""
22 self._lock = threading.Lock()
27 assert self._owner is None
28 self._owner = threading.current_thread()
30 def __exit__(self, _exc_type, _exec_value, _traceback):
31 self.assert_locked('Releasing unowned lock')
36 def assert_locked(self, msg=None):
37 """Asserts the lock is owned by running thread."""
38 assert self._owner == threading.current_thread(), msg
41 class ThreadPoolError(Exception):
42 """Base class for exceptions raised by ThreadPool."""
46 class ThreadPoolEmpty(ThreadPoolError):
47 """Trying to get task result from a thread pool with no pending tasks."""
51 class ThreadPoolClosed(ThreadPoolError):
52 """Trying to do something with a closed thread pool."""
56 class ThreadPool(object):
57 """Multithreaded worker pool with priority support.
59 When the priority of tasks match, it works in strict FIFO mode.
61 QUEUE_CLASS = Queue.PriorityQueue
63 def __init__(self, initial_threads, max_threads, queue_size, prefix=None):
64 """Immediately starts |initial_threads| threads.
67 initial_threads: Number of threads to start immediately. Can be 0 if it is
68 uncertain that threads will be needed.
69 max_threads: Maximum number of threads that will be started when all the
70 threads are busy working. Often the number of CPU cores.
71 queue_size: Maximum number of tasks to buffer in the queue. 0 for
72 unlimited queue. A non-zero value may make add_task()
74 prefix: Prefix to use for thread names. Pool's threads will be
75 named '<prefix>-<thread index>'.
77 prefix = prefix or 'tp-0x%0x' % id(self)
79 'New ThreadPool(%d, %d, %d): %s', initial_threads, max_threads,
81 assert initial_threads <= max_threads
82 assert max_threads <= 1024
84 self.tasks = self.QUEUE_CLASS(queue_size)
85 self._max_threads = max_threads
88 # Used to assign indexes to tasks.
89 self._num_of_added_tasks_lock = threading.Lock()
90 self._num_of_added_tasks = 0
92 # Lock that protected everything below (including conditional variable).
93 self._lock = threading.Lock()
95 # Condition 'bool(_outputs) or bool(_exceptions) or _pending_count == 0'.
96 self._outputs_exceptions_cond = threading.Condition(self._lock)
100 # Number of pending tasks (queued or being processed now).
101 self._pending_count = 0
105 # Number of threads that are waiting for new tasks.
107 # Number of threads already added to _workers, but not yet running the loop.
109 # True if close was called. Forbids adding new tasks.
110 self._is_closed = False
112 for _ in range(initial_threads):
115 def _add_worker(self):
116 """Adds one worker thread if there isn't too many. Thread-safe."""
118 if len(self._workers) >= self._max_threads or self._is_closed:
120 worker = threading.Thread(
121 name='%s-%d' % (self._prefix, len(self._workers)), target=self._run)
122 self._workers.append(worker)
124 logging.debug('Starting worker thread %s', worker.name)
129 def add_task(self, priority, func, *args, **kwargs):
130 """Adds a task, a function to be executed by a worker.
133 - priority: priority of the task versus others. Lower priority takes
135 - func: function to run. Can either return a return value to be added to the
136 output list or be a generator which can emit multiple values.
137 - args and kwargs: arguments to |func|. Note that if func mutates |args| or
138 |kwargs| and that the task is retried, see
139 AutoRetryThreadPool, the retry will use the mutated
143 Index of the item added, e.g. the total number of enqueued items up to
146 assert isinstance(priority, int)
147 assert callable(func)
150 raise ThreadPoolClosed('Can not add a task to a closed ThreadPool')
152 # Pending task count plus new task > number of available workers.
153 self.tasks.qsize() + 1 > self._ready + self._starting and
155 len(self._workers) < self._max_threads
157 self._pending_count += 1
158 with self._num_of_added_tasks_lock:
159 self._num_of_added_tasks += 1
160 index = self._num_of_added_tasks
161 self.tasks.put((priority, index, func, args, kwargs))
167 """Worker thread loop. Runs until a None task is queued."""
168 # Thread has started, adjust counters.
174 task = self.tasks.get()
182 _priority, _index, func, args, kwargs = task
183 if inspect.isgeneratorfunction(func):
184 for out in func(*args, **kwargs):
185 self._output_append(out)
187 out = func(*args, **kwargs)
188 self._output_append(out)
189 except Exception as e:
190 logging.warning('Caught exception: %s', e)
191 exc_info = sys.exc_info()
192 logging.info(''.join(traceback.format_tb(exc_info[2])))
193 with self._outputs_exceptions_cond:
194 self._exceptions.append(exc_info)
195 self._outputs_exceptions_cond.notifyAll()
198 # Mark thread as ready again, mark task as processed. Do it before
199 # waking up threads waiting on self.tasks.join(). Otherwise they might
200 # find ThreadPool still 'busy' and perform unnecessary wait on CV.
201 with self._outputs_exceptions_cond:
203 self._pending_count -= 1
204 if self._pending_count == 0:
205 self._outputs_exceptions_cond.notifyAll()
206 self.tasks.task_done()
207 except Exception as e:
208 # We need to catch and log this error here because this is the root
209 # function for the thread, nothing higher will catch the error.
210 logging.exception('Caught exception while marking task as done: %s',
213 def _output_append(self, out):
215 with self._outputs_exceptions_cond:
216 self._outputs.append(out)
217 self._outputs_exceptions_cond.notifyAll()
220 """Extracts all the results from each threads unordered.
222 Call repeatedly to extract all the exceptions if desired.
224 Note: will wait for all work items to be done before returning an exception.
225 To get an exception early, use get_one_result().
227 # TODO(maruel): Stop waiting as soon as an exception is caught.
229 with self._outputs_exceptions_cond:
231 e = self._exceptions.pop(0)
232 raise e[0], e[1], e[2]
237 def get_one_result(self):
238 """Returns the next item that was generated or raises an exception if one
242 ThreadPoolEmpty - no results available.
244 # Get first available result.
245 for result in self.iter_results():
247 # No results -> tasks queue is empty.
248 raise ThreadPoolEmpty('Task queue is empty')
250 def iter_results(self):
251 """Yields results as they appear until all tasks are processed."""
253 # Check for pending results.
255 self._on_iter_results_step()
256 with self._outputs_exceptions_cond:
258 e = self._exceptions.pop(0)
259 raise e[0], e[1], e[2]
261 # Remember the result to yield it outside of the lock.
262 result = self._outputs.pop(0)
264 # No pending tasks -> all tasks are done.
265 if not self._pending_count:
267 # Some task is queued, wait for its result to appear.
268 # Use non-None timeout so that process reacts to Ctrl+C and other
269 # signals, see http://bugs.python.org/issue8844.
270 self._outputs_exceptions_cond.wait(timeout=0.1)
275 """Closes all the threads."""
276 # Ensure no new threads can be started, self._workers is effectively
277 # a constant after that and can be accessed outside the lock.
280 raise ThreadPoolClosed('Can not close already closed ThreadPool')
281 self._is_closed = True
282 for _ in range(len(self._workers)):
283 # Enqueueing None causes the worker to stop.
285 for t in self._workers:
288 'Thread pool \'%s\' closed: spawned %d threads total',
289 self._prefix, len(self._workers))
292 """Empties the queue.
294 To be used when the pool should stop early, like when Ctrl-C was detected.
297 Number of tasks cancelled.
302 self.tasks.get_nowait()
303 self.tasks.task_done()
308 def _on_iter_results_step(self):
312 """Enables 'with' statement."""
315 def __exit__(self, _exc_type, _exc_value, _traceback):
316 """Enables 'with' statement."""
320 class AutoRetryThreadPool(ThreadPool):
321 """Automatically retries enqueued operations on exception."""
322 INTERNAL_PRIORITY_BITS = (1<<8) - 1
323 HIGH, MED, LOW = (1<<8, 2<<8, 3<<8)
325 def __init__(self, exceptions, retries, *args, **kwargs):
328 exceptions: list of exception classes that can be retried on.
329 retries: maximum number of retries to do.
331 assert exceptions and all(issubclass(e, Exception) for e in exceptions), (
333 assert 1 <= retries <= self.INTERNAL_PRIORITY_BITS
334 super(AutoRetryThreadPool, self).__init__(*args, **kwargs)
335 self._swallowed_exceptions = tuple(exceptions)
336 self._retries = retries
338 def add_task(self, priority, func, *args, **kwargs):
339 """Tasks added must not use the lower priority bits since they are reserved
342 assert (priority & self.INTERNAL_PRIORITY_BITS) == 0
343 return super(AutoRetryThreadPool, self).add_task(
352 def add_task_with_channel(self, channel, priority, func, *args, **kwargs):
353 """Tasks added must not use the lower priority bits since they are reserved
356 assert (priority & self.INTERNAL_PRIORITY_BITS) == 0
357 return super(AutoRetryThreadPool, self).add_task(
366 def _task_executer(self, priority, channel, func, *args, **kwargs):
367 """Wraps the function and automatically retry on exceptions."""
369 result = func(*args, **kwargs)
372 channel.send_result(result)
373 except self._swallowed_exceptions as e:
374 # Retry a few times, lowering the priority.
375 actual_retries = priority & self.INTERNAL_PRIORITY_BITS
376 if actual_retries < self._retries:
379 'Swallowed exception \'%s\'. Retrying at lower priority %X',
381 super(AutoRetryThreadPool, self).add_task(
392 channel.send_exception()
396 channel.send_exception()
399 class Progress(object):
400 """Prints progress and accepts updates thread-safely."""
401 def __init__(self, columns):
402 """Creates a Progress bar that will updates asynchronously from the worker
406 columns: list of tuple(name, initialvalue), defines both the number of
407 columns and their initial values.
410 len(c) == 2 and isinstance(c[0], str) and isinstance(c[1], int)
411 for c in columns), columns
412 # Members to be used exclusively in the primary thread.
413 self.use_cr_only = True
414 self.unfinished_commands = set()
415 self.start = time.time()
416 self._last_printed_line = ''
417 self._columns = [c[1] for c in columns]
418 self._columns_lookup = dict((c[0], i) for i, c in enumerate(columns))
419 # Setting it to True forces a print on the first print_update() call.
420 self._value_changed = True
422 # To be used in all threads.
423 self._queued_updates = Queue.Queue()
425 def update_item(self, name, raw=False, **kwargs):
426 """Queue information to print out.
429 name: string to print out to describe something that was completed.
430 raw: if True, prints the data without the header.
431 raw: if True, prints the data without the header.
432 <kwargs>: argument name is a name of a column. it's value is the increment
433 to the column, value is usually 0 or 1.
435 assert isinstance(name, str)
436 assert isinstance(raw, bool)
437 assert all(isinstance(v, int) for v in kwargs.itervalues())
438 args = [(self._columns_lookup[k], v) for k, v in kwargs.iteritems() if v]
439 self._queued_updates.put((name, raw, args))
441 def print_update(self):
442 """Prints the current status."""
443 # Flush all the logging output so it doesn't appear within this output.
444 for handler in logging.root.handlers:
450 name, raw, args = self._queued_updates.get_nowait()
455 self._columns[k] += v
456 self._value_changed = bool(args)
458 # Even if raw=True, there's nothing to print.
463 # Prints the data as-is.
464 self._last_printed_line = ''
465 sys.stdout.write('\n%s\n' % name.strip('\n'))
467 line, self._last_printed_line = self._gen_line(name)
468 sys.stdout.write(line)
470 if not got_one and self._value_changed:
471 # Make sure a line is printed in that case where statistics changes.
472 line, self._last_printed_line = self._gen_line('')
473 sys.stdout.write(line)
475 self._value_changed = False
477 # Ensure that all the output is flushed to prevent it from getting mixed
478 # with other output streams (like the logging streams).
481 if self.unfinished_commands:
482 logging.debug('Waiting for the following commands to finish:\n%s',
483 '\n'.join(self.unfinished_commands))
485 def _gen_line(self, name):
486 """Generates the line to be printed."""
487 next_line = ('[%s] %6.2fs %s') % (
488 self._render_columns(), time.time() - self.start, name)
489 # Fill it with whitespace only if self.use_cr_only is set.
491 if self.use_cr_only and self._last_printed_line:
494 suffix = ' ' * max(0, len(self._last_printed_line) - len(next_line))
497 return '%s%s%s' % (prefix, next_line, suffix), next_line
499 def _render_columns(self):
500 """Renders the columns."""
501 columns_as_str = map(str, self._columns)
502 max_len = max(map(len, columns_as_str))
503 return '/'.join(i.rjust(max_len) for i in columns_as_str)
506 class QueueWithProgress(Queue.PriorityQueue):
507 """Implements progress support in join()."""
508 def __init__(self, progress, *args, **kwargs):
509 Queue.PriorityQueue.__init__(self, *args, **kwargs)
510 self.progress = progress
513 """Contrary to Queue.task_done(), it wakes self.all_tasks_done at each task
516 with self.all_tasks_done:
518 unfinished = self.unfinished_tasks - 1
520 raise ValueError('task_done() called too many times')
521 self.unfinished_tasks = unfinished
522 # This is less efficient, because we want the Progress to be updated.
523 self.all_tasks_done.notify_all()
524 except Exception as e:
525 logging.exception('task_done threw an exception.\n%s', e)
528 """Wakes up all_tasks_done.
530 Unlike task_done(), do not substract one from self.unfinished_tasks.
532 # TODO(maruel): This is highly inefficient, since the listener is awaken
533 # twice; once per output, once per task. There should be no relationship
534 # between the number of output and the number of input task.
535 with self.all_tasks_done:
536 self.all_tasks_done.notify_all()
539 """Calls print_update() whenever possible."""
540 self.progress.print_update()
541 with self.all_tasks_done:
542 while self.unfinished_tasks:
543 self.progress.print_update()
544 # Use a short wait timeout so updates are printed in a timely manner.
545 # TODO(maruel): Find a way so Progress.queue and self.all_tasks_done
546 # share the same underlying event so no polling is necessary.
547 self.all_tasks_done.wait(0.1)
548 self.progress.print_update()
551 class ThreadPoolWithProgress(ThreadPool):
552 QUEUE_CLASS = QueueWithProgress
554 def __init__(self, progress, *args, **kwargs):
555 self.QUEUE_CLASS = functools.partial(self.QUEUE_CLASS, progress)
556 super(ThreadPoolWithProgress, self).__init__(*args, **kwargs)
558 def _output_append(self, out):
559 """Also wakes up the listener on new completed test_case."""
560 super(ThreadPoolWithProgress, self)._output_append(out)
563 def _on_iter_results_step(self):
564 self.tasks.progress.print_update()
567 class DeadlockDetector(object):
568 """Context manager that can detect deadlocks.
570 It will dump stack frames of all running threads if its 'ping' method isn't
574 with DeadlockDetector(timeout=60) as detector:
575 for item in some_work():
581 timeout - maximum allowed time between calls to 'ping'.
584 def __init__(self, timeout):
585 self.timeout = timeout
587 # Thread stop condition. Also lock for shared variables below.
588 self._stop_cv = threading.Condition()
589 self._stop_flag = False
590 # Time when 'ping' was called last time.
591 self._last_ping = None
592 # True if pings are coming on time.
596 """Starts internal watcher thread."""
597 assert self._thread is None
599 self._thread = threading.Thread(name='deadlock-detector', target=self._run)
600 self._thread.daemon = True
604 def __exit__(self, *_args):
605 """Stops internal watcher thread."""
606 assert self._thread is not None
608 self._stop_flag = True
609 self._stop_cv.notify()
612 self._stop_flag = False
615 """Notify detector that main thread is still running.
617 Should be called periodically to inform the detector that everything is
618 running as it should.
621 self._last_ping = time.time()
625 """Loop that watches for pings and dumps threads state if ping is late."""
627 while not self._stop_flag:
628 # Skipped deadline? Dump threads and switch to 'not alive' state.
629 if self._alive and time.time() > self._last_ping + self.timeout:
630 self.dump_threads(time.time() - self._last_ping, True)
635 # Wait until the moment we need to dump stack traces.
636 # Most probably some other thread will call 'ping' to move deadline
637 # further in time. We don't bother to wake up after each 'ping',
638 # only right before initial expected deadline.
639 self._stop_cv.wait(self._last_ping + self.timeout - time.time())
641 # Skipped some pings previously. Just periodically silently check
642 # for new pings with some arbitrary frequency.
643 self._stop_cv.wait(self.timeout * 0.1)
646 def dump_threads(timeout=None, skip_current_thread=False):
647 """Dumps stack frames of all running threads."""
648 all_threads = threading.enumerate()
649 current_thread_id = threading.current_thread().ident
651 # Collect tracebacks: thread name -> traceback string.
654 # pylint: disable=W0212
655 for thread_id, frame in sys._current_frames().iteritems():
656 # Don't dump deadlock detector's own thread, it's boring.
657 if thread_id == current_thread_id and skip_current_thread:
660 # Try to get more informative symbolic thread name.
662 for thread in all_threads:
663 if thread.ident == thread_id:
666 name += ' #%d' % (thread_id,)
667 tracebacks[name] = ''.join(traceback.format_stack(frame))
669 # Function to print a message. Makes it easier to change output destination.
671 logging.warning(msg.rstrip())
673 # Print tracebacks, sorting them by thread name. That way a thread pool's
674 # threads will be printed as one group.
675 output('=============== Potential deadlock detected ===============')
676 if timeout is not None:
677 output('No pings in last %d sec.' % (timeout,))
678 output('Dumping stack frames for all threads:')
679 for name in sorted(tracebacks):
680 output('Traceback for \'%s\':\n%s' % (name, tracebacks[name]))
681 output('===========================================================')
684 class TaskChannel(object):
685 """Queue of results of async task execution."""
687 class Timeout(Exception):
688 """Raised by 'pull' in case of timeout."""
694 self._queue = Queue.Queue()
696 def send_result(self, result):
697 """Enqueues a result of task execution."""
698 self._queue.put((self._ITEM_RESULT, result))
700 def send_exception(self, exc_info=None):
701 """Enqueue an exception raised by a task.
704 exc_info: If given, should be 3-tuple returned by sys.exc_info(),
705 default is current value of sys.exc_info(). Use default in
706 'except' blocks to capture currently processed exception.
708 exc_info = exc_info or sys.exc_info()
709 assert isinstance(exc_info, tuple) and len(exc_info) == 3
710 # Transparently passing Timeout will break 'pull' contract, since a caller
711 # has no way to figure out that's an exception from the task and not from
712 # 'pull' itself. Transform Timeout into generic RuntimeError with
714 if isinstance(exc_info[1], TaskChannel.Timeout):
717 RuntimeError('Task raised Timeout exception'),
719 self._queue.put((self._ITEM_EXCEPTION, exc_info))
721 def pull(self, timeout=None):
722 """Dequeues available result or exception.
725 timeout: if not None will block no longer than |timeout| seconds and will
726 raise TaskChannel.Timeout exception if no results are available.
729 Whatever task pushes to the queue by calling 'send_result'.
732 TaskChannel.Timeout: waiting longer than |timeout|.
733 Whatever exception task raises.
736 item_type, value = self._queue.get(timeout=timeout)
738 raise TaskChannel.Timeout()
739 if item_type == self._ITEM_RESULT:
741 if item_type == self._ITEM_EXCEPTION:
742 # 'value' is captured sys.exc_info() 3-tuple. Use extended raise syntax
743 # to preserve stack frame of original exception (that was raised in
745 assert isinstance(value, tuple) and len(value) == 3
746 raise value[0], value[1], value[2]
747 assert False, 'Impossible queue item type: %r' % item_type
749 def wrap_task(self, task):
750 """Decorator that makes a function push results into this channel."""
751 @functools.wraps(task)
752 def wrapped(*args, **kwargs):
754 self.send_result(task(*args, **kwargs))
756 self.send_exception()
760 def num_processors():
761 """Returns the number of processors.
763 Python on OSX 10.6 raises a NotImplementedError exception.
767 import multiprocessing
768 return multiprocessing.cpu_count()
769 except: # pylint: disable=W0702
772 return int(os.sysconf('SC_NPROCESSORS_ONLN')) # pylint: disable=E1101
774 # Some of the windows builders seem to get here.