1 # Copyright 2014 Google Inc. All rights reserved.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
16 import multiprocessing
19 from typ.host import Host
22 def make_pool(host, jobs, callback, context, pre_fn, post_fn):
23 _validate_args(context, pre_fn, post_fn)
25 return _ProcessPool(host, jobs, callback, context, pre_fn, post_fn)
27 return _AsyncPool(host, jobs, callback, context, pre_fn, post_fn)
30 class _MessageType(object):
36 Interrupt = 'Interrupt'
38 values = [Request, Response, Close, Done, Error, Interrupt]
41 def _validate_args(context, pre_fn, post_fn):
43 _ = pickle.dumps(context)
44 except Exception as e:
45 raise ValueError('context passed to make_pool is not picklable: %s'
48 _ = pickle.dumps(pre_fn)
49 except pickle.PickleError:
50 raise ValueError('pre_fn passed to make_pool is not picklable')
52 _ = pickle.dumps(post_fn)
53 except pickle.PickleError:
54 raise ValueError('post_fn passed to make_pool is not picklable')
57 class _ProcessPool(object):
59 def __init__(self, host, jobs, callback, context, pre_fn, post_fn):
62 self.requests = multiprocessing.Queue()
63 self.responses = multiprocessing.Queue()
65 self.discarded_responses = []
68 for worker_num in range(1, jobs + 1):
69 w = multiprocessing.Process(target=_loop,
70 args=(self.requests, self.responses,
71 host.for_mp(), worker_num,
75 self.workers.append(w)
78 self.requests.put((_MessageType.Request, msg))
81 msg_type, resp = self.responses.get()
82 if msg_type == _MessageType.Error:
83 self._handle_error(resp)
84 elif msg_type == _MessageType.Interrupt:
85 raise KeyboardInterrupt
86 assert msg_type == _MessageType.Response
90 for _ in self.workers:
91 self.requests.put((_MessageType.Close, None))
95 # TODO: one would think that we could close self.requests in close(),
96 # above, and close self.responses below, but if we do, we get
97 # weird tracebacks in the daemon threads multiprocessing starts up.
98 # Instead, we have to hack the innards of multiprocessing. It
99 # seems likely that there's a bug somewhere, either in this module or
100 # in multiprocessing.
101 if self.host.is_python3: # pragma: python3
102 multiprocessing.queues.is_exiting = lambda: True
103 else: # pragma: python2
104 multiprocessing.util._exiting = True
107 # We must be aborting; terminate the workers rather than
108 # shutting down cleanly.
109 for w in self.workers:
117 for w in self.workers:
119 msg_type, resp = self.responses.get()
120 if msg_type == _MessageType.Error:
123 if msg_type == _MessageType.Interrupt:
126 if msg_type == _MessageType.Done:
127 final_responses.append(resp[1])
129 self.discarded_responses.append(resp)
131 for w in self.workers:
134 # TODO: See comment above at the beginning of the function for
135 # why this is commented out.
136 # self.responses.close()
139 self._handle_error(error)
141 raise KeyboardInterrupt
142 return final_responses
144 def _handle_error(self, msg):
145 worker_num, ex_str = msg
147 raise Exception("error from worker %d: %s" % (worker_num, ex_str))
150 # 'Too many arguments' pylint: disable=R0913
152 def _loop(requests, responses, host, worker_num,
153 callback, context, pre_fn, post_fn, should_loop=True):
154 host = host or Host()
156 context_after_pre = pre_fn(host, worker_num, context)
159 message_type, args = requests.get(block=True)
160 if message_type == _MessageType.Close:
161 responses.put((_MessageType.Done,
162 (worker_num, post_fn(context_after_pre))))
164 assert message_type == _MessageType.Request
165 resp = callback(context_after_pre, args)
166 responses.put((_MessageType.Response, resp))
167 keep_looping = should_loop
168 except KeyboardInterrupt as e:
169 responses.put((_MessageType.Interrupt, (worker_num, str(e))))
170 except Exception as e:
171 responses.put((_MessageType.Error, (worker_num, str(e))))
174 class _AsyncPool(object):
176 def __init__(self, host, jobs, callback, context, pre_fn, post_fn):
177 self.host = host or Host()
179 self.callback = callback
180 self.context = copy.deepcopy(context)
183 self.post_fn = post_fn
184 self.context_after_pre = pre_fn(self.host, 1, self.context)
185 self.final_context = None
188 self.msgs.append(msg)
191 return self.callback(self.context_after_pre, self.msgs.pop(0))
195 self.final_context = self.post_fn(self.context_after_pre)
200 return [self.final_context]