1 # Copyright 2013 The Swarming Authors. All rights reserved.
2 # Use of this source code is governed under the Apache License, Version 2.0 that
3 # can be found in the LICENSE file.
5 """Classes and functions related to threading."""
18 # Priorities for tasks in AutoRetryThreadPool, particular values are important.
19 PRIORITY_HIGH = 1 << 8
24 class LockWithAssert(object):
25 """Wrapper around (non recursive) Lock that tracks its owner."""
28 self._lock = threading.Lock()
33 assert self._owner is None
34 self._owner = threading.current_thread()
36 def __exit__(self, _exc_type, _exec_value, _traceback):
37 self.assert_locked('Releasing unowned lock')
42 def assert_locked(self, msg=None):
43 """Asserts the lock is owned by running thread."""
44 assert self._owner == threading.current_thread(), msg
47 class ThreadPoolError(Exception):
48 """Base class for exceptions raised by ThreadPool."""
52 class ThreadPoolEmpty(ThreadPoolError):
53 """Trying to get task result from a thread pool with no pending tasks."""
57 class ThreadPoolClosed(ThreadPoolError):
58 """Trying to do something with a closed thread pool."""
62 class ThreadPool(object):
63 """Multithreaded worker pool with priority support.
65 When the priority of tasks match, it works in strict FIFO mode.
67 QUEUE_CLASS = Queue.PriorityQueue
69 def __init__(self, initial_threads, max_threads, queue_size, prefix=None):
70 """Immediately starts |initial_threads| threads.
73 initial_threads: Number of threads to start immediately. Can be 0 if it is
74 uncertain that threads will be needed.
75 max_threads: Maximum number of threads that will be started when all the
76 threads are busy working. Often the number of CPU cores.
77 queue_size: Maximum number of tasks to buffer in the queue. 0 for
78 unlimited queue. A non-zero value may make add_task()
80 prefix: Prefix to use for thread names. Pool's threads will be
81 named '<prefix>-<thread index>'.
83 prefix = prefix or 'tp-0x%0x' % id(self)
85 'New ThreadPool(%d, %d, %d): %s', initial_threads, max_threads,
87 assert initial_threads <= max_threads
88 assert max_threads <= 1024
90 self.tasks = self.QUEUE_CLASS(queue_size)
91 self._max_threads = max_threads
94 # Used to assign indexes to tasks.
95 self._num_of_added_tasks_lock = threading.Lock()
96 self._num_of_added_tasks = 0
98 # Lock that protected everything below (including conditional variable).
99 self._lock = threading.Lock()
101 # Condition 'bool(_outputs) or bool(_exceptions) or _pending_count == 0'.
102 self._outputs_exceptions_cond = threading.Condition(self._lock)
104 self._exceptions = []
106 # Number of pending tasks (queued or being processed now).
107 self._pending_count = 0
111 # Number of threads that are waiting for new tasks.
113 # Number of threads already added to _workers, but not yet running the loop.
115 # True if close was called. Forbids adding new tasks.
116 self._is_closed = False
118 for _ in range(initial_threads):
121 def _add_worker(self):
122 """Adds one worker thread if there isn't too many. Thread-safe."""
124 if len(self._workers) >= self._max_threads or self._is_closed:
126 worker = threading.Thread(
127 name='%s-%d' % (self._prefix, len(self._workers)), target=self._run)
128 self._workers.append(worker)
130 logging.debug('Starting worker thread %s', worker.name)
135 def add_task(self, priority, func, *args, **kwargs):
136 """Adds a task, a function to be executed by a worker.
139 - priority: priority of the task versus others. Lower priority takes
141 - func: function to run. Can either return a return value to be added to the
142 output list or be a generator which can emit multiple values.
143 - args and kwargs: arguments to |func|. Note that if func mutates |args| or
144 |kwargs| and that the task is retried, see
145 AutoRetryThreadPool, the retry will use the mutated
149 Index of the item added, e.g. the total number of enqueued items up to
152 assert isinstance(priority, int)
153 assert callable(func)
156 raise ThreadPoolClosed('Can not add a task to a closed ThreadPool')
158 # Pending task count plus new task > number of available workers.
159 self.tasks.qsize() + 1 > self._ready + self._starting and
161 len(self._workers) < self._max_threads
163 self._pending_count += 1
164 with self._num_of_added_tasks_lock:
165 self._num_of_added_tasks += 1
166 index = self._num_of_added_tasks
167 self.tasks.put((priority, index, func, args, kwargs))
173 """Worker thread loop. Runs until a None task is queued."""
174 # Thread has started, adjust counters.
180 task = self.tasks.get()
188 _priority, _index, func, args, kwargs = task
189 if inspect.isgeneratorfunction(func):
190 for out in func(*args, **kwargs):
191 self._output_append(out)
193 out = func(*args, **kwargs)
194 self._output_append(out)
195 except Exception as e:
196 logging.warning('Caught exception: %s', e)
197 exc_info = sys.exc_info()
198 logging.info(''.join(traceback.format_tb(exc_info[2])))
199 with self._outputs_exceptions_cond:
200 self._exceptions.append(exc_info)
201 self._outputs_exceptions_cond.notifyAll()
204 # Mark thread as ready again, mark task as processed. Do it before
205 # waking up threads waiting on self.tasks.join(). Otherwise they might
206 # find ThreadPool still 'busy' and perform unnecessary wait on CV.
207 with self._outputs_exceptions_cond:
209 self._pending_count -= 1
210 if self._pending_count == 0:
211 self._outputs_exceptions_cond.notifyAll()
212 self.tasks.task_done()
213 except Exception as e:
214 # We need to catch and log this error here because this is the root
215 # function for the thread, nothing higher will catch the error.
216 logging.exception('Caught exception while marking task as done: %s',
219 def _output_append(self, out):
221 with self._outputs_exceptions_cond:
222 self._outputs.append(out)
223 self._outputs_exceptions_cond.notifyAll()
226 """Extracts all the results from each threads unordered.
228 Call repeatedly to extract all the exceptions if desired.
230 Note: will wait for all work items to be done before returning an exception.
231 To get an exception early, use get_one_result().
233 # TODO(maruel): Stop waiting as soon as an exception is caught.
235 with self._outputs_exceptions_cond:
237 e = self._exceptions.pop(0)
238 raise e[0], e[1], e[2]
243 def get_one_result(self):
244 """Returns the next item that was generated or raises an exception if one
248 ThreadPoolEmpty - no results available.
250 # Get first available result.
251 for result in self.iter_results():
253 # No results -> tasks queue is empty.
254 raise ThreadPoolEmpty('Task queue is empty')
256 def iter_results(self):
257 """Yields results as they appear until all tasks are processed."""
259 # Check for pending results.
261 self._on_iter_results_step()
262 with self._outputs_exceptions_cond:
264 e = self._exceptions.pop(0)
265 raise e[0], e[1], e[2]
267 # Remember the result to yield it outside of the lock.
268 result = self._outputs.pop(0)
270 # No pending tasks -> all tasks are done.
271 if not self._pending_count:
273 # Some task is queued, wait for its result to appear.
274 # Use non-None timeout so that process reacts to Ctrl+C and other
275 # signals, see http://bugs.python.org/issue8844.
276 self._outputs_exceptions_cond.wait(timeout=0.1)
281 """Closes all the threads."""
282 # Ensure no new threads can be started, self._workers is effectively
283 # a constant after that and can be accessed outside the lock.
286 raise ThreadPoolClosed('Can not close already closed ThreadPool')
287 self._is_closed = True
288 for _ in range(len(self._workers)):
289 # Enqueueing None causes the worker to stop.
291 for t in self._workers:
294 'Thread pool \'%s\' closed: spawned %d threads total',
295 self._prefix, len(self._workers))
298 """Empties the queue.
300 To be used when the pool should stop early, like when Ctrl-C was detected.
303 Number of tasks cancelled.
308 self.tasks.get_nowait()
309 self.tasks.task_done()
314 def _on_iter_results_step(self):
318 """Enables 'with' statement."""
321 def __exit__(self, _exc_type, _exc_value, _traceback):
322 """Enables 'with' statement."""
326 class AutoRetryThreadPool(ThreadPool):
327 """Automatically retries enqueued operations on exception."""
328 # See also PRIORITY_* module-level constants.
329 INTERNAL_PRIORITY_BITS = (1<<8) - 1
331 def __init__(self, exceptions, retries, *args, **kwargs):
334 exceptions: list of exception classes that can be retried on.
335 retries: maximum number of retries to do.
337 assert exceptions and all(issubclass(e, Exception) for e in exceptions), (
339 assert 1 <= retries <= self.INTERNAL_PRIORITY_BITS
340 super(AutoRetryThreadPool, self).__init__(*args, **kwargs)
341 self._swallowed_exceptions = tuple(exceptions)
342 self._retries = retries
344 def add_task(self, priority, func, *args, **kwargs):
345 """Tasks added must not use the lower priority bits since they are reserved
348 assert (priority & self.INTERNAL_PRIORITY_BITS) == 0
349 return super(AutoRetryThreadPool, self).add_task(
358 def add_task_with_channel(self, channel, priority, func, *args, **kwargs):
359 """Tasks added must not use the lower priority bits since they are reserved
362 assert (priority & self.INTERNAL_PRIORITY_BITS) == 0
363 return super(AutoRetryThreadPool, self).add_task(
372 def _task_executer(self, priority, channel, func, *args, **kwargs):
373 """Wraps the function and automatically retry on exceptions."""
375 result = func(*args, **kwargs)
378 channel.send_result(result)
379 except self._swallowed_exceptions as e:
380 # Retry a few times, lowering the priority.
381 actual_retries = priority & self.INTERNAL_PRIORITY_BITS
382 if actual_retries < self._retries:
385 'Swallowed exception \'%s\'. Retrying at lower priority %X',
387 super(AutoRetryThreadPool, self).add_task(
398 channel.send_exception()
402 channel.send_exception()
405 class IOAutoRetryThreadPool(AutoRetryThreadPool):
406 """Thread pool that automatically retries on IOError.
408 Supposed to be used for IO bound tasks, and thus default maximum number of
409 worker threads is independent of number of CPU cores.
411 # Initial and maximum number of worker threads.
417 super(IOAutoRetryThreadPool, self).__init__(
420 self.INITIAL_WORKERS,
426 class Progress(object):
427 """Prints progress and accepts updates thread-safely."""
428 def __init__(self, columns):
429 """Creates a Progress bar that will updates asynchronously from the worker
433 columns: list of tuple(name, initialvalue), defines both the number of
434 columns and their initial values.
437 len(c) == 2 and isinstance(c[0], str) and isinstance(c[1], int)
438 for c in columns), columns
439 # Members to be used exclusively in the primary thread.
440 self.use_cr_only = True
441 self.unfinished_commands = set()
442 self.start = time.time()
443 self._last_printed_line = ''
444 self._columns = [c[1] for c in columns]
445 self._columns_lookup = dict((c[0], i) for i, c in enumerate(columns))
446 # Setting it to True forces a print on the first print_update() call.
447 self._value_changed = True
449 # To be used in all threads.
450 self._queued_updates = Queue.Queue()
452 def update_item(self, name, raw=False, **kwargs):
453 """Queue information to print out.
456 name: string to print out to describe something that was completed.
457 raw: if True, prints the data without the header.
458 raw: if True, prints the data without the header.
459 <kwargs>: argument name is a name of a column. it's value is the increment
460 to the column, value is usually 0 or 1.
462 assert isinstance(name, str)
463 assert isinstance(raw, bool)
464 assert all(isinstance(v, int) for v in kwargs.itervalues())
465 args = [(self._columns_lookup[k], v) for k, v in kwargs.iteritems() if v]
466 self._queued_updates.put((name, raw, args))
468 def print_update(self):
469 """Prints the current status."""
470 # Flush all the logging output so it doesn't appear within this output.
471 for handler in logging.root.handlers:
477 name, raw, args = self._queued_updates.get_nowait()
482 self._columns[k] += v
483 self._value_changed = bool(args)
485 # Even if raw=True, there's nothing to print.
490 # Prints the data as-is.
491 self._last_printed_line = ''
492 sys.stdout.write('\n%s\n' % name.strip('\n'))
494 line, self._last_printed_line = self._gen_line(name)
495 sys.stdout.write(line)
497 if not got_one and self._value_changed:
498 # Make sure a line is printed in that case where statistics changes.
499 line, self._last_printed_line = self._gen_line('')
500 sys.stdout.write(line)
502 self._value_changed = False
504 # Ensure that all the output is flushed to prevent it from getting mixed
505 # with other output streams (like the logging streams).
508 if self.unfinished_commands:
509 logging.debug('Waiting for the following commands to finish:\n%s',
510 '\n'.join(self.unfinished_commands))
512 def _gen_line(self, name):
513 """Generates the line to be printed."""
514 next_line = ('[%s] %6.2fs %s') % (
515 self._render_columns(), time.time() - self.start, name)
516 # Fill it with whitespace only if self.use_cr_only is set.
518 if self.use_cr_only and self._last_printed_line:
521 suffix = ' ' * max(0, len(self._last_printed_line) - len(next_line))
524 return '%s%s%s' % (prefix, next_line, suffix), next_line
526 def _render_columns(self):
527 """Renders the columns."""
528 columns_as_str = map(str, self._columns)
529 max_len = max(map(len, columns_as_str))
530 return '/'.join(i.rjust(max_len) for i in columns_as_str)
533 class QueueWithProgress(Queue.PriorityQueue):
534 """Implements progress support in join()."""
535 def __init__(self, progress, *args, **kwargs):
536 Queue.PriorityQueue.__init__(self, *args, **kwargs)
537 self.progress = progress
540 """Contrary to Queue.task_done(), it wakes self.all_tasks_done at each task
543 with self.all_tasks_done:
545 unfinished = self.unfinished_tasks - 1
547 raise ValueError('task_done() called too many times')
548 self.unfinished_tasks = unfinished
549 # This is less efficient, because we want the Progress to be updated.
550 self.all_tasks_done.notify_all()
551 except Exception as e:
552 logging.exception('task_done threw an exception.\n%s', e)
555 """Wakes up all_tasks_done.
557 Unlike task_done(), do not substract one from self.unfinished_tasks.
559 # TODO(maruel): This is highly inefficient, since the listener is awaken
560 # twice; once per output, once per task. There should be no relationship
561 # between the number of output and the number of input task.
562 with self.all_tasks_done:
563 self.all_tasks_done.notify_all()
566 """Calls print_update() whenever possible."""
567 self.progress.print_update()
568 with self.all_tasks_done:
569 while self.unfinished_tasks:
570 self.progress.print_update()
571 # Use a short wait timeout so updates are printed in a timely manner.
572 # TODO(maruel): Find a way so Progress.queue and self.all_tasks_done
573 # share the same underlying event so no polling is necessary.
574 self.all_tasks_done.wait(0.1)
575 self.progress.print_update()
578 class ThreadPoolWithProgress(ThreadPool):
579 QUEUE_CLASS = QueueWithProgress
581 def __init__(self, progress, *args, **kwargs):
582 self.QUEUE_CLASS = functools.partial(self.QUEUE_CLASS, progress)
583 super(ThreadPoolWithProgress, self).__init__(*args, **kwargs)
585 def _output_append(self, out):
586 """Also wakes up the listener on new completed test_case."""
587 super(ThreadPoolWithProgress, self)._output_append(out)
590 def _on_iter_results_step(self):
591 self.tasks.progress.print_update()
594 class DeadlockDetector(object):
595 """Context manager that can detect deadlocks.
597 It will dump stack frames of all running threads if its 'ping' method isn't
601 with DeadlockDetector(timeout=60) as detector:
602 for item in some_work():
608 timeout - maximum allowed time between calls to 'ping'.
611 def __init__(self, timeout):
612 self.timeout = timeout
614 # Thread stop condition. Also lock for shared variables below.
615 self._stop_cv = threading.Condition()
616 self._stop_flag = False
617 # Time when 'ping' was called last time.
618 self._last_ping = None
619 # True if pings are coming on time.
623 """Starts internal watcher thread."""
624 assert self._thread is None
626 self._thread = threading.Thread(name='deadlock-detector', target=self._run)
627 self._thread.daemon = True
631 def __exit__(self, *_args):
632 """Stops internal watcher thread."""
633 assert self._thread is not None
635 self._stop_flag = True
636 self._stop_cv.notify()
639 self._stop_flag = False
642 """Notify detector that main thread is still running.
644 Should be called periodically to inform the detector that everything is
645 running as it should.
648 self._last_ping = time.time()
652 """Loop that watches for pings and dumps threads state if ping is late."""
654 while not self._stop_flag:
655 # Skipped deadline? Dump threads and switch to 'not alive' state.
656 if self._alive and time.time() > self._last_ping + self.timeout:
657 self.dump_threads(time.time() - self._last_ping, True)
662 # Wait until the moment we need to dump stack traces.
663 # Most probably some other thread will call 'ping' to move deadline
664 # further in time. We don't bother to wake up after each 'ping',
665 # only right before initial expected deadline.
666 self._stop_cv.wait(self._last_ping + self.timeout - time.time())
668 # Skipped some pings previously. Just periodically silently check
669 # for new pings with some arbitrary frequency.
670 self._stop_cv.wait(self.timeout * 0.1)
673 def dump_threads(timeout=None, skip_current_thread=False):
674 """Dumps stack frames of all running threads."""
675 all_threads = threading.enumerate()
676 current_thread_id = threading.current_thread().ident
678 # Collect tracebacks: thread name -> traceback string.
681 # pylint: disable=W0212
682 for thread_id, frame in sys._current_frames().iteritems():
683 # Don't dump deadlock detector's own thread, it's boring.
684 if thread_id == current_thread_id and skip_current_thread:
687 # Try to get more informative symbolic thread name.
689 for thread in all_threads:
690 if thread.ident == thread_id:
693 name += ' #%d' % (thread_id,)
694 tracebacks[name] = ''.join(traceback.format_stack(frame))
696 # Function to print a message. Makes it easier to change output destination.
698 logging.warning(msg.rstrip())
700 # Print tracebacks, sorting them by thread name. That way a thread pool's
701 # threads will be printed as one group.
702 output('=============== Potential deadlock detected ===============')
703 if timeout is not None:
704 output('No pings in last %d sec.' % (timeout,))
705 output('Dumping stack frames for all threads:')
706 for name in sorted(tracebacks):
707 output('Traceback for \'%s\':\n%s' % (name, tracebacks[name]))
708 output('===========================================================')
711 class TaskChannel(object):
712 """Queue of results of async task execution."""
714 class Timeout(Exception):
715 """Raised by 'pull' in case of timeout."""
721 self._queue = Queue.Queue()
723 def send_result(self, result):
724 """Enqueues a result of task execution."""
725 self._queue.put((self._ITEM_RESULT, result))
727 def send_exception(self, exc_info=None):
728 """Enqueue an exception raised by a task.
731 exc_info: If given, should be 3-tuple returned by sys.exc_info(),
732 default is current value of sys.exc_info(). Use default in
733 'except' blocks to capture currently processed exception.
735 exc_info = exc_info or sys.exc_info()
736 assert isinstance(exc_info, tuple) and len(exc_info) == 3
737 # Transparently passing Timeout will break 'pull' contract, since a caller
738 # has no way to figure out that's an exception from the task and not from
739 # 'pull' itself. Transform Timeout into generic RuntimeError with
741 if isinstance(exc_info[1], TaskChannel.Timeout):
744 RuntimeError('Task raised Timeout exception'),
746 self._queue.put((self._ITEM_EXCEPTION, exc_info))
748 def pull(self, timeout=None):
749 """Dequeues available result or exception.
752 timeout: if not None will block no longer than |timeout| seconds and will
753 raise TaskChannel.Timeout exception if no results are available.
756 Whatever task pushes to the queue by calling 'send_result'.
759 TaskChannel.Timeout: waiting longer than |timeout|.
760 Whatever exception task raises.
763 item_type, value = self._queue.get(timeout=timeout)
765 raise TaskChannel.Timeout()
766 if item_type == self._ITEM_RESULT:
768 if item_type == self._ITEM_EXCEPTION:
769 # 'value' is captured sys.exc_info() 3-tuple. Use extended raise syntax
770 # to preserve stack frame of original exception (that was raised in
772 assert isinstance(value, tuple) and len(value) == 3
773 raise value[0], value[1], value[2]
774 assert False, 'Impossible queue item type: %r' % item_type
776 def wrap_task(self, task):
777 """Decorator that makes a function push results into this channel."""
778 @functools.wraps(task)
779 def wrapped(*args, **kwargs):
781 self.send_result(task(*args, **kwargs))
783 self.send_exception()
787 def num_processors():
788 """Returns the number of processors.
790 Python on OSX 10.6 raises a NotImplementedError exception.
794 import multiprocessing
795 return multiprocessing.cpu_count()
796 except: # pylint: disable=W0702
799 return int(os.sysconf('SC_NPROCESSORS_ONLN')) # pylint: disable=E1101
801 # Some of the windows builders seem to get here.
805 def enum_processes_win():
806 """Returns all processes on the system that are accessible to this process.
809 Win32_Process COM objects. See
810 http://msdn.microsoft.com/library/aa394372.aspx for more details.
812 import win32com.client # pylint: disable=F0401
813 wmi_service = win32com.client.Dispatch('WbemScripting.SWbemLocator')
814 wbem = wmi_service.ConnectServer('.', 'root\\cimv2')
816 proc for proc in wbem.ExecQuery('SELECT * FROM Win32_Process')
817 if proc.ExecutablePath
821 def filter_processes_dir_win(processes, root_dir):
822 """Returns all processes which has their main executable located inside
825 root_dir = root_dir.lower()
827 proc for proc in processes
828 if proc.ExecutablePath.lower().startswith(root_dir)
832 def filter_processes_tree_win(processes):
833 """Returns all the processes under the current process."""
835 processes = {p.ProcessId: p for p in processes}
836 root_pid = os.getpid()
837 out = {root_pid: processes[root_pid]}
842 p.ProcessId for p in processes.itervalues()
843 if p.ParentProcessId == pid)
847 out.update((p, processes[p]) for p in found)