Update release_notes.md
[platform/upstream/caffeonacl.git] / python / detect.py
1 #!/usr/bin/env python
2 """
3 detector.py is an out-of-the-box windowed detector
4 callable from the command line.
5
6 By default it configures and runs the Caffe reference ImageNet model.
7 Note that this model was trained for image classification and not detection,
8 and finetuning for detection can be expected to improve results.
9
10 The selective_search_ijcv_with_python code required for the selective search
11 proposal mode is available at
12     https://github.com/sergeyk/selective_search_ijcv_with_python
13
14 TODO:
15 - batch up image filenames as well: don't want to load all of them into memory
16 - come up with a batching scheme that preserved order / keeps a unique ID
17 """
18 import numpy as np
19 import pandas as pd
20 import os
21 import argparse
22 import time
23
24 import caffe
25
26 CROP_MODES = ['list', 'selective_search']
27 COORD_COLS = ['ymin', 'xmin', 'ymax', 'xmax']
28
29
30 def main(argv):
31     pycaffe_dir = os.path.dirname(__file__)
32
33     parser = argparse.ArgumentParser()
34     # Required arguments: input and output.
35     parser.add_argument(
36         "input_file",
37         help="Input txt/csv filename. If .txt, must be list of filenames.\
38         If .csv, must be comma-separated file with header\
39         'filename, xmin, ymin, xmax, ymax'"
40     )
41     parser.add_argument(
42         "output_file",
43         help="Output h5/csv filename. Format depends on extension."
44     )
45     # Optional arguments.
46     parser.add_argument(
47         "--model_def",
48         default=os.path.join(pycaffe_dir,
49                 "../models/bvlc_reference_caffenet/deploy.prototxt"),
50         help="Model definition file."
51     )
52     parser.add_argument(
53         "--pretrained_model",
54         default=os.path.join(pycaffe_dir,
55                 "../models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel"),
56         help="Trained model weights file."
57     )
58     parser.add_argument(
59         "--crop_mode",
60         default="selective_search",
61         choices=CROP_MODES,
62         help="How to generate windows for detection."
63     )
64     parser.add_argument(
65         "--gpu",
66         action='store_true',
67         help="Switch for gpu computation."
68     )
69     parser.add_argument(
70         "--mean_file",
71         default=os.path.join(pycaffe_dir,
72                              'caffe/imagenet/ilsvrc_2012_mean.npy'),
73         help="Data set image mean of H x W x K dimensions (numpy array). " +
74              "Set to '' for no mean subtraction."
75     )
76     parser.add_argument(
77         "--input_scale",
78         type=float,
79         help="Multiply input features by this scale to finish preprocessing."
80     )
81     parser.add_argument(
82         "--raw_scale",
83         type=float,
84         default=255.0,
85         help="Multiply raw input by this scale before preprocessing."
86     )
87     parser.add_argument(
88         "--channel_swap",
89         default='2,1,0',
90         help="Order to permute input channels. The default converts " +
91              "RGB -> BGR since BGR is the Caffe default by way of OpenCV."
92
93     )
94     parser.add_argument(
95         "--context_pad",
96         type=int,
97         default='16',
98         help="Amount of surrounding context to collect in input window."
99     )
100     args = parser.parse_args()
101
102     mean, channel_swap = None, None
103     if args.mean_file:
104         mean = np.load(args.mean_file)
105         if mean.shape[1:] != (1, 1):
106             mean = mean.mean(1).mean(1)
107     if args.channel_swap:
108         channel_swap = [int(s) for s in args.channel_swap.split(',')]
109
110     if args.gpu:
111         caffe.set_mode_gpu()
112         print("GPU mode")
113     else:
114         caffe.set_mode_cpu()
115         print("CPU mode")
116
117     # Make detector.
118     detector = caffe.Detector(args.model_def, args.pretrained_model, mean=mean,
119             input_scale=args.input_scale, raw_scale=args.raw_scale,
120             channel_swap=channel_swap,
121             context_pad=args.context_pad)
122
123     # Load input.
124     t = time.time()
125     print("Loading input...")
126     if args.input_file.lower().endswith('txt'):
127         with open(args.input_file) as f:
128             inputs = [_.strip() for _ in f.readlines()]
129     elif args.input_file.lower().endswith('csv'):
130         inputs = pd.read_csv(args.input_file, sep=',', dtype={'filename': str})
131         inputs.set_index('filename', inplace=True)
132     else:
133         raise Exception("Unknown input file type: not in txt or csv.")
134
135     # Detect.
136     if args.crop_mode == 'list':
137         # Unpack sequence of (image filename, windows).
138         images_windows = [
139             (ix, inputs.iloc[np.where(inputs.index == ix)][COORD_COLS].values)
140             for ix in inputs.index.unique()
141         ]
142         detections = detector.detect_windows(images_windows)
143     else:
144         detections = detector.detect_selective_search(inputs)
145     print("Processed {} windows in {:.3f} s.".format(len(detections),
146                                                      time.time() - t))
147
148     # Collect into dataframe with labeled fields.
149     df = pd.DataFrame(detections)
150     df.set_index('filename', inplace=True)
151     df[COORD_COLS] = pd.DataFrame(
152         data=np.vstack(df['window']), index=df.index, columns=COORD_COLS)
153     del(df['window'])
154
155     # Save results.
156     t = time.time()
157     if args.output_file.lower().endswith('csv'):
158         # csv
159         # Enumerate the class probabilities.
160         class_cols = ['class{}'.format(x) for x in range(NUM_OUTPUT)]
161         df[class_cols] = pd.DataFrame(
162             data=np.vstack(df['feat']), index=df.index, columns=class_cols)
163         df.to_csv(args.output_file, cols=COORD_COLS + class_cols)
164     else:
165         # h5
166         df.to_hdf(args.output_file, 'df', mode='w')
167     print("Saved to {} in {:.3f} s.".format(args.output_file,
168                                             time.time() - t))
169
170
171 if __name__ == "__main__":
172     import sys
173     main(sys.argv)