automagically set detection batch size from network
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sun, 26 Jan 2014 04:53:00 +0000 (20:53 -0800)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sun, 26 Jan 2014 04:53:00 +0000 (20:53 -0800)
python/caffe/detection/detector.py

index 2b71336..2651538 100644 (file)
@@ -36,10 +36,12 @@ IMAGE_CENTER = None
 IMAGE_MEAN = None
 CROPPED_IMAGE_MEAN = None
 
+BATCH_SIZE = None
 NUM_OUTPUT = None
 
 CROP_MODES = ['center_only', 'corners', 'selective_search']
 
+
 def load_image(filename):
   """
   Input:
@@ -181,7 +183,7 @@ def _assemble_images_selective_search(image_fnames):
   return images_df
 
 
-def assemble_batches(image_fnames, crop_mode='center_only', batch_size=10):
+def assemble_batches(image_fnames, crop_mode='center_only'):
   """
   Assemble DataFrame of image crops for feature computation.
 
@@ -195,7 +197,7 @@ def assemble_batches(image_fnames, crop_mode='center_only', batch_size=10):
         image, and take each enclosing subwindow.
 
   Output:
-    df_batches: list of DataFrames, each one of batch_size rows.
+    df_batches: list of DataFrames, each one of BATCH_SIZE rows.
       Each row has 'image', 'filename', and 'window' info.
       Column 'image' contains (X x 3 x 227 x 227) ndarrays.
       Column 'filename' contains source filenames.
@@ -216,23 +218,23 @@ def assemble_batches(image_fnames, crop_mode='center_only', batch_size=10):
   else:
     raise Exception("Unknown mode: not in {}".format(CROP_MODES))
 
-  # Make sure the DataFrame has a multiple of batch_size rows:
+  # Make sure the DataFrame has a multiple of BATCH_SIZE rows:
   # just fill the extra rows with NaN filenames and all-zero images.
   N = images_df.shape[0]
-  remainder = N % batch_size
+  remainder = N % BATCH_SIZE
   if remainder > 0:
     zero_image = np.zeros_like(images_df['image'].iloc[0])
     remainder_df = pd.DataFrame([{
       'filename': None,
       'image': zero_image,
       'window': [0, 0, 0, 0]
-    }] * (batch_size - remainder))
+    }] * (BATCH_SIZE - remainder))
     images_df = images_df.append(remainder_df)
     N = images_df.shape[0]
 
-  # Split into batches of batch_size.
-  ind = np.arange(N) / batch_size
-  df_batches = [images_df[ind == i] for i in range(N / batch_size)]
+  # Split into batches of BATCH_SIZE.
+  ind = np.arange(N) / BATCH_SIZE
+  df_batches = [images_df[ind == i] for i in range(N / BATCH_SIZE)]
   return df_batches
 
 
@@ -254,7 +256,7 @@ def compute_feats(images_df):
 
 def config(model_def, pretrained_model, gpu, image_dim, image_mean_file):
   global IMAGE_DIM, CROPPED_DIM, IMAGE_CENTER, IMAGE_MEAN, CROPPED_IMAGE_MEAN
-  global NET, NUM_OUTPUT
+  global NET, BATCH_SIZE, NUM_OUTPUT
 
   # Initialize network by loading model definition and weights.
   t = time.time()
@@ -273,11 +275,11 @@ def config(model_def, pretrained_model, gpu, image_dim, image_mean_file):
     # Load the data set mean file
   IMAGE_MEAN = np.load(image_mean_file)
 
-
   CROPPED_IMAGE_MEAN = IMAGE_MEAN[IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM,
                                   IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM,
                                   :]
-  NUM_OUTPUT = NET.blobs()[-1].channels # number of output classes
+  BATCH_SIZE = NET.blobs()[0].num  # network batch size
+  NUM_OUTPUT = NET.blobs()[-1].channels  # number of output classes
 
 
 if __name__ == "__main__":
@@ -293,8 +295,6 @@ if __name__ == "__main__":
   gflags.DEFINE_string(
     "images_file", "", "Image filenames file.")
   gflags.DEFINE_string(
-    "batch_size", 10, "Number of image crops to let through in one go")
-  gflags.DEFINE_string(
     "output_file", "", "Output DataFrame HDF5 filename.")
   gflags.DEFINE_string(
     "images_dim", 256, "Canonical dimension of (square) images.")
@@ -305,7 +305,6 @@ if __name__ == "__main__":
   FLAGS = gflags.FLAGS
   FLAGS(sys.argv)
 
-
   # Configure network, input, output
   config(FLAGS.model_def, FLAGS.pretrained_model, FLAGS.gpu, FLAGS.images_dim,
          FLAGS.images_mean_file)
@@ -315,8 +314,7 @@ if __name__ == "__main__":
   print('Assembling batches...')
   with open(FLAGS.images_file) as f:
     image_fnames = [_.strip() for _ in f.readlines()]
-  image_batches = assemble_batches(image_fnames, FLAGS.crop_mode,
-                                   FLAGS.batch_size)
+  image_batches = assemble_batches(image_fnames, FLAGS.crop_mode)
   print('{} batches assembled in {:.3f} s'.format(len(image_batches),
                                                   time.time() - t))