pycaffe Net.forward() helper
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Wed, 14 May 2014 20:39:06 +0000 (13:39 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Wed, 14 May 2014 20:44:02 +0000 (13:44 -0700)
Do forward pass by prefilled or packaging input + output blobs and
returning a {output blob name: output list} dict.

python/caffe/pycaffe.py

index a7bc278..4053815 100644 (file)
@@ -42,6 +42,42 @@ def _Net_params(self):
 Net.params = _Net_params
 
 
+def _Net_forward(self, **kwargs):
+  """
+  Forward pass: prepare inputs and run the net forward.
+
+  Take
+    kwargs: Keys are input blob names and values are lists of inputs.
+            Images must be (H x W x K) ndarrays.
+            If None, input is taken from data layers by ForwardPrefilled().
+
+  Give
+    out: {output blob name: list of output blobs} dict.
+  """
+  outs = {}
+  if not kwargs:
+    # Carry out prefilled forward pass and unpack output.
+    self.ForwardPrefilled()
+    out_blobs = [self.blobs[out].data for out in self.outputs]
+  else:
+    # Create input and output blobs according to net defined shapes
+    # and make arrays single and C-contiguous as Caffe expects.
+    in_blobs = [np.ascontiguousarray(np.concatenate(kwargs[in_]),
+                                     dtype=np.float32) for in_ in self.inputs]
+    out_blobs = [np.empty(self.blobs[out].data.shape, dtype=np.float32)
+                 for out in self.outputs]
+
+    self.Forward(in_blobs, out_blobs)
+
+  # Unpack output blobs
+  for out, out_blob in zip(self.outputs, out_blobs):
+    outs[out] = [out_blob[ix, :, :, :].squeeze()
+                  for ix in range(out_blob.shape[0])]
+  return outs
+
+Net.forward = _Net_forward
+
+
 def _Net_set_mean(self, input_, mean_f, mode='image'):
   """
   Set the mean to subtract for data centering.