Merge pull request #733 from longjon/pycaffe-tweaks
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Mon, 28 Jul 2014 21:14:03 +0000 (14:14 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Mon, 28 Jul 2014 21:14:03 +0000 (14:14 -0700)
pycaffe fixes

1  2 
python/caffe/pycaffe.py

@@@ -70,19 -57,20 +70,19 @@@ def _Net_forward(self, blobs=None, star
          # Set input according to defined shapes and make arrays single and
          # C-contiguous as Caffe expects.
          for in_, blob in kwargs.iteritems():
-             if blob.shape[0] != self.blobs[in_].num:
-                 raise Exception('Input is not batch sized')
              if blob.ndim != 4:
                  raise Exception('{} blob is not 4-d'.format(in_))
+             if blob.shape[0] != self.blobs[in_].num:
+                 raise Exception('Input is not batch sized')
              self.blobs[in_].data[...] = blob
  
 -    self._forward()
 +    self._forward(start_ind, end_ind)
  
      # Unpack blobs to extract
 -    outs = {out: self.blobs[out].data for out in set(self.outputs + blobs)}
 -    return outs
 +    return {out: self.blobs[out].data for out in outputs}
  
  
 -def _Net_backward(self, diffs=None, **kwargs):
 +def _Net_backward(self, diffs=None, start=None, end=None, **kwargs):
      """
      Backward pass: prepare diffs and run the net backward.
  
          # Set top diffs according to defined shapes and make arrays single and
          # C-contiguous as Caffe expects.
          for top, diff in kwargs.iteritems():
-             if diff.shape[0] != self.blobs[top].num:
-                 raise Exception('Diff is not batch sized')
              if diff.ndim != 4:
                  raise Exception('{} diff is not 4-d'.format(top))
+             if diff.shape[0] != self.blobs[top].num:
+                 raise Exception('Diff is not batch sized')
              self.blobs[top].diff[...] = diff
  
 -    self._backward()
 +    self._backward(start_ind, end_ind)
  
      # Unpack diffs to extract
 -    outs = {out: self.blobs[out].diff for out in set(self.inputs + diffs)}
 -    return outs
 +    return {out: self.blobs[out].diff for out in outputs}
  
  
  def _Net_forward_all(self, blobs=None, **kwargs):