python forward() and backward() extract any blobs and diffs
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 15 May 2014 00:38:33 +0000 (17:38 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 15 May 2014 03:14:55 +0000 (20:14 -0700)
python/caffe/pycaffe.py

index 101deab..5275a07 100644 (file)
@@ -42,18 +42,22 @@ def _Net_params(self):
 Net.params = _Net_params
 
 
-def _Net_forward(self, **kwargs):
+def _Net_forward(self, blobs=None, **kwargs):
   """
   Forward pass: prepare inputs and run the net forward.
 
   Take
+    blobs: list of blobs to return in addition to output blobs.
     kwargs: Keys are input blob names and values are lists of inputs.
             Images must be (H x W x K) ndarrays.
             If None, input is taken from data layers by ForwardPrefilled().
 
   Give
-    outs: {output blob name: list of output blobs} dict.
+    outs: {blob name: list of blobs ndarrays} dict.
   """
+  if blobs is None:
+    blobs = []
+
   if not kwargs:
     # Carry out prefilled forward pass and unpack output.
     self.ForwardPrefilled()
@@ -68,9 +72,11 @@ def _Net_forward(self, **kwargs):
 
     self.Forward(in_blobs, out_blobs)
 
-  # Unpack output blobs
+  # Unpack blobs to extract
   outs = {}
-  for out, out_blob in zip(self.outputs, out_blobs):
+  out_blobs.extend([self.blobs[blob].data for blob in blobs])
+  out_blob_names = self.outputs + blobs
+  for out, out_blob in zip(out_blob_names, out_blobs):
     outs[out] = [out_blob[ix, :, :, :].squeeze()
                   for ix in range(out_blob.shape[0])]
   return outs
@@ -78,35 +84,43 @@ def _Net_forward(self, **kwargs):
 Net.forward = _Net_forward
 
 
-def _Net_backward(self, **kwargs):
+def _Net_backward(self, diffs=None, **kwargs):
   """
   Backward pass: prepare diffs and run the net backward.
 
   Take
+    diffs: list of diffs to return in addition to bottom diffs.
     kwargs: Keys are output blob names and values are lists of diffs.
-            If None, input is taken from data layers by BackwardPrefilled().
+            If None, top diffs are taken from loss by BackwardPrefilled().
 
   Give
-    bottom_diffs: {input blob name: list of diffs} dict.
+    outs: {blob name: list of diffs} dict.
   """
+  if diffs is None:
+    diffs = []
+
   if not kwargs:
+    # Carry out backward with forward loss diffs and unpack bottom diffs.
     self.BackwardPrefilled()
-    bottom_diffs = [self.blobs[in_].diff for in_ in self.inputs]
+    out_diffs = [self.blobs[in_].diff for in_ in self.inputs]
   else:
     # Create top and bottom diffs according to net defined shapes
     # and make arrays single and C-contiguous as Caffe expects.
     top_diffs = [np.ascontiguousarray(np.concatenate(kwargs[out]),
                                       dtype=np.float32) for out in self.outputs]
-    bottom_diffs = [np.empty(self.blobs[bottom].data.shape, dtype=np.float32)
-                    for bottom in self.inputs]
-    self.Backward(top_diffs, bottom_diffs)
-
-  # Unpack bottom diffs
-  bottom_diffs = {}
-  for bottom, bottom_diff in zip(self.inputs, bottom_diffs):
-    bottom_diffs[bottom] = [bottom_diff[ix, :, :, :].squeeze()
-                             for ix in range(bottom_diff.shape[0])]
-  return bottom_diffs
+    out_diffs = [np.empty(self.blobs[bottom].diff.shape, dtype=np.float32)
+                 for bottom in self.inputs]
+
+    self.Backward(top_diffs, out_diffs)
+
+  # Unpack diffs to extract
+  outs = {}
+  out_diffs.extend([self.blobs[diff].diff for diff in diffs])
+  out_diff_names = self.inputs + diffs
+  for out, out_diff in zip(out_diff_names, out_diffs):
+    outs[out] = [out_diff[ix, :, :, :].squeeze()
+                           for ix in range(out_diff.shape[0])]
+  return outs
 
 Net.backward = _Net_backward