2 Wrap the internal caffe C++ module (_caffe.so) with a clean, Pythonic
6 from collections import OrderedDict
8 from itertools import izip_longest
10 from itertools import zip_longest as izip_longest
13 from ._caffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, \
14 RMSPropSolver, AdaDeltaSolver, AdamSolver, NCCL, Timer
19 # We directly update methods from Net here (rather than using composition or
20 # inheritance) so that nets created by caffe (e.g., by SGDSolver) will
21 # automatically have the improved interface.
27 An OrderedDict (bottom to top, i.e., input to output) of network
30 if not hasattr(self, '_blobs_dict'):
31 self._blobs_dict = OrderedDict(zip(self._blob_names, self._blobs))
32 return self._blobs_dict
36 def _Net_blob_loss_weights(self):
38 An OrderedDict (bottom to top, i.e., input to output) of network
39 blob loss weights indexed by name
41 if not hasattr(self, '_blobs_loss_weights_dict'):
42 self._blob_loss_weights_dict = OrderedDict(zip(self._blob_names,
43 self._blob_loss_weights))
44 return self._blob_loss_weights_dict
47 def _Net_layer_dict(self):
49 An OrderedDict (bottom to top, i.e., input to output) of network
50 layers indexed by name
52 if not hasattr(self, '_layer_dict'):
53 self._layer_dict = OrderedDict(zip(self._layer_names, self.layers))
54 return self._layer_dict
58 def _Net_params(self):
60 An OrderedDict (bottom to top, i.e., input to output) of network
61 parameters indexed by name; each is a list of multiple blobs (e.g.,
64 if not hasattr(self, '_params_dict'):
65 self._params_dict = OrderedDict([(name, lr.blobs)
67 self._layer_names, self.layers)
68 if len(lr.blobs) > 0])
69 return self._params_dict
73 def _Net_inputs(self):
74 if not hasattr(self, '_input_list'):
75 keys = list(self.blobs.keys())
76 self._input_list = [keys[i] for i in self._inputs]
77 return self._input_list
81 def _Net_outputs(self):
82 if not hasattr(self, '_output_list'):
83 keys = list(self.blobs.keys())
84 self._output_list = [keys[i] for i in self._outputs]
85 return self._output_list
88 def _Net_forward(self, blobs=None, start=None, end=None, **kwargs):
90 Forward pass: prepare inputs and run the net forward.
94 blobs : list of blobs to return in addition to output blobs.
95 kwargs : Keys are input blob names and values are blob ndarrays.
96 For formatting inputs for Caffe, see Net.preprocess().
97 If None, input is taken from data layers.
98 start : optional name of layer at which to begin the forward pass
99 end : optional name of layer at which to finish the forward pass
104 outs : {blob name: blob ndarray} dict.
109 if start is not None:
110 start_ind = list(self._layer_names).index(start)
115 end_ind = list(self._layer_names).index(end)
116 outputs = set([end] + blobs)
118 end_ind = len(self.layers) - 1
119 outputs = set(self.outputs + blobs)
122 if set(kwargs.keys()) != set(self.inputs):
123 raise Exception('Input blob arguments do not match net inputs.')
124 # Set input according to defined shapes and make arrays single and
125 # C-contiguous as Caffe expects.
126 for in_, blob in six.iteritems(kwargs):
127 if blob.shape[0] != self.blobs[in_].shape[0]:
128 raise Exception('Input is not batch sized')
129 self.blobs[in_].data[...] = blob
131 self._forward(start_ind, end_ind)
133 # Unpack blobs to extract
134 return {out: self.blobs[out].data for out in outputs}
137 def _Net_backward(self, diffs=None, start=None, end=None, **kwargs):
139 Backward pass: prepare diffs and run the net backward.
143 diffs : list of diffs to return in addition to bottom diffs.
144 kwargs : Keys are output blob names and values are diff ndarrays.
145 If None, top diffs are taken from forward loss.
146 start : optional name of layer at which to begin the backward pass
147 end : optional name of layer at which to finish the backward pass
152 outs: {blob name: diff ndarray} dict.
157 if start is not None:
158 start_ind = list(self._layer_names).index(start)
160 start_ind = len(self.layers) - 1
163 end_ind = list(self._layer_names).index(end)
164 outputs = set([end] + diffs)
167 outputs = set(self.inputs + diffs)
170 if set(kwargs.keys()) != set(self.outputs):
171 raise Exception('Top diff arguments do not match net outputs.')
172 # Set top diffs according to defined shapes and make arrays single and
173 # C-contiguous as Caffe expects.
174 for top, diff in six.iteritems(kwargs):
175 if diff.shape[0] != self.blobs[top].shape[0]:
176 raise Exception('Diff is not batch sized')
177 self.blobs[top].diff[...] = diff
179 self._backward(start_ind, end_ind)
181 # Unpack diffs to extract
182 return {out: self.blobs[out].diff for out in outputs}
185 def _Net_forward_all(self, blobs=None, **kwargs):
187 Run net forward in batches.
191 blobs : list of blobs to extract as in forward()
192 kwargs : Keys are input blob names and values are blob ndarrays.
197 all_outs : {blob name: list of blobs} dict.
199 # Collect outputs from batches
200 all_outs = {out: [] for out in set(self.outputs + (blobs or []))}
201 for batch in self._batch(kwargs):
202 outs = self.forward(blobs=blobs, **batch)
203 for out, out_blob in six.iteritems(outs):
204 all_outs[out].extend(out_blob.copy())
205 # Package in ndarray.
207 all_outs[out] = np.asarray(all_outs[out])
209 pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs)))
212 all_outs[out] = all_outs[out][:-pad]
216 def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs):
218 Run net forward + backward in batches.
222 blobs: list of blobs to extract as in forward()
223 diffs: list of diffs to extract as in backward()
224 kwargs: Keys are input (for forward) and output (for backward) blob names
225 and values are ndarrays. Refer to forward() and backward().
226 Prefilled variants are called for lack of input or output blobs.
230 all_blobs: {blob name: blob ndarray} dict.
231 all_diffs: {blob name: diff ndarray} dict.
233 # Batch blobs and diffs.
234 all_outs = {out: [] for out in set(self.outputs + (blobs or []))}
235 all_diffs = {diff: [] for diff in set(self.inputs + (diffs or []))}
236 forward_batches = self._batch({in_: kwargs[in_]
237 for in_ in self.inputs if in_ in kwargs})
238 backward_batches = self._batch({out: kwargs[out]
239 for out in self.outputs if out in kwargs})
240 # Collect outputs from batches (and heed lack of forward/backward batches).
241 for fb, bb in izip_longest(forward_batches, backward_batches, fillvalue={}):
242 batch_blobs = self.forward(blobs=blobs, **fb)
243 batch_diffs = self.backward(diffs=diffs, **bb)
244 for out, out_blobs in six.iteritems(batch_blobs):
245 all_outs[out].extend(out_blobs.copy())
246 for diff, out_diffs in six.iteritems(batch_diffs):
247 all_diffs[diff].extend(out_diffs.copy())
248 # Package in ndarray.
249 for out, diff in zip(all_outs, all_diffs):
250 all_outs[out] = np.asarray(all_outs[out])
251 all_diffs[diff] = np.asarray(all_diffs[diff])
252 # Discard padding at the end and package in ndarray.
253 pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs)))
255 for out, diff in zip(all_outs, all_diffs):
256 all_outs[out] = all_outs[out][:-pad]
257 all_diffs[diff] = all_diffs[diff][:-pad]
258 return all_outs, all_diffs
261 def _Net_set_input_arrays(self, data, labels):
263 Set input arrays of the in-memory MemoryDataLayer.
264 (Note: this is only for networks declared with the memory data layer.)
267 labels = np.ascontiguousarray(labels[:, np.newaxis, np.newaxis,
269 return self._set_input_arrays(data, labels)
272 def _Net_batch(self, blobs):
274 Batch blob lists according to net's batch size.
278 blobs: Keys blob names and values are lists of blobs (of any length).
279 Naturally, all the lists should have the same length.
283 batch: {blob name: list of blobs} dict for a single batch.
285 num = len(six.next(six.itervalues(blobs)))
286 batch_size = six.next(six.itervalues(self.blobs)).shape[0]
287 remainder = num % batch_size
288 num_batches = num // batch_size
290 # Yield full batches.
291 for b in range(num_batches):
293 yield {name: blobs[name][i:i + batch_size] for name in blobs}
295 # Yield last padded batch, if any.
299 padding = np.zeros((batch_size - remainder,)
300 + blobs[name].shape[1:])
301 padded_batch[name] = np.concatenate([blobs[name][-remainder:],
305 def _Net_get_id_name(func, field):
307 Generic property that maps func to the layer names into an OrderedDict.
309 Used for top_names and bottom_names.
313 func: function id -> [id]
314 field: implementation field name (cache)
318 A one-parameter function that can be set as a property.
321 def get_id_name(self):
322 if not hasattr(self, field):
323 id_to_name = list(self.blobs)
324 res = OrderedDict([(self._layer_names[i],
325 [id_to_name[j] for j in func(self, i)])
326 for i in range(len(self.layers))])
327 setattr(self, field, res)
328 return getattr(self, field)
331 # Attach methods to Net.
332 Net.blobs = _Net_blobs
333 Net.blob_loss_weights = _Net_blob_loss_weights
334 Net.layer_dict = _Net_layer_dict
335 Net.params = _Net_params
336 Net.forward = _Net_forward
337 Net.backward = _Net_backward
338 Net.forward_all = _Net_forward_all
339 Net.forward_backward_all = _Net_forward_backward_all
340 Net.set_input_arrays = _Net_set_input_arrays
341 Net._batch = _Net_batch
342 Net.inputs = _Net_inputs
343 Net.outputs = _Net_outputs
344 Net.top_names = _Net_get_id_name(Net._top_ids, "_top_names")
345 Net.bottom_names = _Net_get_id_name(Net._bottom_ids, "_bottom_names")