Fixed calc_output_scale with NO_OUTPUT_SCALE flag set.
[profile/ivi/opencv.git] / samples / gpu / hog.cpp
1 #include <iostream>
2 #include <fstream>
3 #include <string>
4 #include <sstream>
5 #include <iomanip>
6 #include <stdexcept>
7 #include <opencv2/core/utility.hpp>
8 #include "opencv2/cuda.hpp"
9 #include "opencv2/highgui.hpp"
10 #include "opencv2/objdetect.hpp"
11 #include "opencv2/imgproc.hpp"
12
13 using namespace std;
14 using namespace cv;
15
16 bool help_showed = false;
17
18 class Args
19 {
20 public:
21     Args();
22     static Args read(int argc, char** argv);
23
24     string src;
25     bool src_is_video;
26     bool src_is_camera;
27     int camera_id;
28
29     bool write_video;
30     string dst_video;
31     double dst_video_fps;
32
33     bool make_gray;
34
35     bool resize_src;
36     int width, height;
37
38     double scale;
39     int nlevels;
40     int gr_threshold;
41
42     double hit_threshold;
43     bool hit_threshold_auto;
44
45     int win_width;
46     int win_stride_width, win_stride_height;
47
48     bool gamma_corr;
49 };
50
51
52 class App
53 {
54 public:
55     App(const Args& s);
56     void run();
57
58     void handleKey(char key);
59
60     void hogWorkBegin();
61     void hogWorkEnd();
62     string hogWorkFps() const;
63
64     void workBegin();
65     void workEnd();
66     string workFps() const;
67
68     string message() const;
69
70 private:
71     App operator=(App&);
72
73     Args args;
74     bool running;
75
76     bool use_gpu;
77     bool make_gray;
78     double scale;
79     int gr_threshold;
80     int nlevels;
81     double hit_threshold;
82     bool gamma_corr;
83
84     int64 hog_work_begin;
85     double hog_work_fps;
86
87     int64 work_begin;
88     double work_fps;
89 };
90
91 static void printHelp()
92 {
93     cout << "Histogram of Oriented Gradients descriptor and detector sample.\n"
94          << "\nUsage: hog_gpu\n"
95          << "  (<image>|--video <vide>|--camera <camera_id>) # frames source\n"
96          << "  [--make_gray <true/false>] # convert image to gray one or not\n"
97          << "  [--resize_src <true/false>] # do resize of the source image or not\n"
98          << "  [--width <int>] # resized image width\n"
99          << "  [--height <int>] # resized image height\n"
100          << "  [--hit_threshold <double>] # classifying plane distance threshold (0.0 usually)\n"
101          << "  [--scale <double>] # HOG window scale factor\n"
102          << "  [--nlevels <int>] # max number of HOG window scales\n"
103          << "  [--win_width <int>] # width of the window (48 or 64)\n"
104          << "  [--win_stride_width <int>] # distance by OX axis between neighbour wins\n"
105          << "  [--win_stride_height <int>] # distance by OY axis between neighbour wins\n"
106          << "  [--gr_threshold <int>] # merging similar rects constant\n"
107          << "  [--gamma_correct <int>] # do gamma correction or not\n"
108          << "  [--write_video <bool>] # write video or not\n"
109          << "  [--dst_video <path>] # output video path\n"
110          << "  [--dst_video_fps <double>] # output video fps\n";
111     help_showed = true;
112 }
113
114 int main(int argc, char** argv)
115 {
116     try
117     {
118         if (argc < 2)
119             printHelp();
120         Args args = Args::read(argc, argv);
121         if (help_showed)
122             return -1;
123         App app(args);
124         app.run();
125     }
126     catch (const Exception& e) { return cout << "error: "  << e.what() << endl, 1; }
127     catch (const exception& e) { return cout << "error: "  << e.what() << endl, 1; }
128     catch(...) { return cout << "unknown exception" << endl, 1; }
129     return 0;
130 }
131
132
133 Args::Args()
134 {
135     src_is_video = false;
136     src_is_camera = false;
137     camera_id = 0;
138
139     write_video = false;
140     dst_video_fps = 24.;
141
142     make_gray = false;
143
144     resize_src = false;
145     width = 640;
146     height = 480;
147
148     scale = 1.05;
149     nlevels = 13;
150     gr_threshold = 8;
151     hit_threshold = 1.4;
152     hit_threshold_auto = true;
153
154     win_width = 48;
155     win_stride_width = 8;
156     win_stride_height = 8;
157
158     gamma_corr = true;
159 }
160
161
162 Args Args::read(int argc, char** argv)
163 {
164     Args args;
165     for (int i = 1; i < argc; i++)
166     {
167         if (string(argv[i]) == "--make_gray") args.make_gray = (string(argv[++i]) == "true");
168         else if (string(argv[i]) == "--resize_src") args.resize_src = (string(argv[++i]) == "true");
169         else if (string(argv[i]) == "--width") args.width = atoi(argv[++i]);
170         else if (string(argv[i]) == "--height") args.height = atoi(argv[++i]);
171         else if (string(argv[i]) == "--hit_threshold")
172         {
173             args.hit_threshold = atof(argv[++i]);
174             args.hit_threshold_auto = false;
175         }
176         else if (string(argv[i]) == "--scale") args.scale = atof(argv[++i]);
177         else if (string(argv[i]) == "--nlevels") args.nlevels = atoi(argv[++i]);
178         else if (string(argv[i]) == "--win_width") args.win_width = atoi(argv[++i]);
179         else if (string(argv[i]) == "--win_stride_width") args.win_stride_width = atoi(argv[++i]);
180         else if (string(argv[i]) == "--win_stride_height") args.win_stride_height = atoi(argv[++i]);
181         else if (string(argv[i]) == "--gr_threshold") args.gr_threshold = atoi(argv[++i]);
182         else if (string(argv[i]) == "--gamma_correct") args.gamma_corr = (string(argv[++i]) == "true");
183         else if (string(argv[i]) == "--write_video") args.write_video = (string(argv[++i]) == "true");
184         else if (string(argv[i]) == "--dst_video") args.dst_video = argv[++i];
185         else if (string(argv[i]) == "--dst_video_fps") args.dst_video_fps = atof(argv[++i]);
186         else if (string(argv[i]) == "--help") printHelp();
187         else if (string(argv[i]) == "--video") { args.src = argv[++i]; args.src_is_video = true; }
188         else if (string(argv[i]) == "--camera") { args.camera_id = atoi(argv[++i]); args.src_is_camera = true; }
189         else if (args.src.empty()) args.src = argv[i];
190         else throw runtime_error((string("unknown key: ") + argv[i]));
191     }
192     return args;
193 }
194
195
196 App::App(const Args& s)
197 {
198     cv::cuda::printShortCudaDeviceInfo(cv::cuda::getDevice());
199
200     args = s;
201     cout << "\nControls:\n"
202          << "\tESC - exit\n"
203          << "\tm - change mode GPU <-> CPU\n"
204          << "\tg - convert image to gray or not\n"
205          << "\t1/q - increase/decrease HOG scale\n"
206          << "\t2/w - increase/decrease levels count\n"
207          << "\t3/e - increase/decrease HOG group threshold\n"
208          << "\t4/r - increase/decrease hit threshold\n"
209          << endl;
210
211     use_gpu = true;
212     make_gray = args.make_gray;
213     scale = args.scale;
214     gr_threshold = args.gr_threshold;
215     nlevels = args.nlevels;
216
217     if (args.hit_threshold_auto)
218         args.hit_threshold = args.win_width == 48 ? 1.4 : 0.;
219     hit_threshold = args.hit_threshold;
220
221     gamma_corr = args.gamma_corr;
222
223     if (args.win_width != 64 && args.win_width != 48)
224         args.win_width = 64;
225
226     cout << "Scale: " << scale << endl;
227     if (args.resize_src)
228         cout << "Resized source: (" << args.width << ", " << args.height << ")\n";
229     cout << "Group threshold: " << gr_threshold << endl;
230     cout << "Levels number: " << nlevels << endl;
231     cout << "Win width: " << args.win_width << endl;
232     cout << "Win stride: (" << args.win_stride_width << ", " << args.win_stride_height << ")\n";
233     cout << "Hit threshold: " << hit_threshold << endl;
234     cout << "Gamma correction: " << gamma_corr << endl;
235     cout << endl;
236 }
237
238
239 void App::run()
240 {
241     running = true;
242     cv::VideoWriter video_writer;
243
244     Size win_size(args.win_width, args.win_width * 2); //(64, 128) or (48, 96)
245     Size win_stride(args.win_stride_width, args.win_stride_height);
246
247     // Create HOG descriptors and detectors here
248     vector<float> detector;
249     if (win_size == Size(64, 128))
250         detector = cv::cuda::HOGDescriptor::getPeopleDetector64x128();
251     else
252         detector = cv::cuda::HOGDescriptor::getPeopleDetector48x96();
253
254     cv::cuda::HOGDescriptor gpu_hog(win_size, Size(16, 16), Size(8, 8), Size(8, 8), 9,
255                                    cv::cuda::HOGDescriptor::DEFAULT_WIN_SIGMA, 0.2, gamma_corr,
256                                    cv::cuda::HOGDescriptor::DEFAULT_NLEVELS);
257     cv::HOGDescriptor cpu_hog(win_size, Size(16, 16), Size(8, 8), Size(8, 8), 9, 1, -1,
258                               HOGDescriptor::L2Hys, 0.2, gamma_corr, cv::HOGDescriptor::DEFAULT_NLEVELS);
259     gpu_hog.setSVMDetector(detector);
260     cpu_hog.setSVMDetector(detector);
261
262     while (running)
263     {
264         VideoCapture vc;
265         Mat frame;
266
267         if (args.src_is_video)
268         {
269             vc.open(args.src.c_str());
270             if (!vc.isOpened())
271                 throw runtime_error(string("can't open video file: " + args.src));
272             vc >> frame;
273         }
274         else if (args.src_is_camera)
275         {
276             vc.open(args.camera_id);
277             if (!vc.isOpened())
278             {
279                 stringstream msg;
280                 msg << "can't open camera: " << args.camera_id;
281                 throw runtime_error(msg.str());
282             }
283             vc >> frame;
284         }
285         else
286         {
287             frame = imread(args.src);
288             if (frame.empty())
289                 throw runtime_error(string("can't open image file: " + args.src));
290         }
291
292         Mat img_aux, img, img_to_show;
293         cuda::GpuMat gpu_img;
294
295         // Iterate over all frames
296         while (running && !frame.empty())
297         {
298             workBegin();
299
300             // Change format of the image
301             if (make_gray) cvtColor(frame, img_aux, COLOR_BGR2GRAY);
302             else if (use_gpu) cvtColor(frame, img_aux, COLOR_BGR2BGRA);
303             else frame.copyTo(img_aux);
304
305             // Resize image
306             if (args.resize_src) resize(img_aux, img, Size(args.width, args.height));
307             else img = img_aux;
308             img_to_show = img;
309
310             gpu_hog.nlevels = nlevels;
311             cpu_hog.nlevels = nlevels;
312
313             vector<Rect> found;
314
315             // Perform HOG classification
316             hogWorkBegin();
317             if (use_gpu)
318             {
319                 gpu_img.upload(img);
320                 gpu_hog.detectMultiScale(gpu_img, found, hit_threshold, win_stride,
321                                          Size(0, 0), scale, gr_threshold);
322             }
323             else cpu_hog.detectMultiScale(img, found, hit_threshold, win_stride,
324                                           Size(0, 0), scale, gr_threshold);
325             hogWorkEnd();
326
327             // Draw positive classified windows
328             for (size_t i = 0; i < found.size(); i++)
329             {
330                 Rect r = found[i];
331                 rectangle(img_to_show, r.tl(), r.br(), Scalar(0, 255, 0), 3);
332             }
333
334             if (use_gpu)
335                 putText(img_to_show, "Mode: GPU", Point(5, 25), FONT_HERSHEY_SIMPLEX, 1., Scalar(255, 100, 0), 2);
336             else
337                 putText(img_to_show, "Mode: CPU", Point(5, 25), FONT_HERSHEY_SIMPLEX, 1., Scalar(255, 100, 0), 2);
338             putText(img_to_show, "FPS (HOG only): " + hogWorkFps(), Point(5, 65), FONT_HERSHEY_SIMPLEX, 1., Scalar(255, 100, 0), 2);
339             putText(img_to_show, "FPS (total): " + workFps(), Point(5, 105), FONT_HERSHEY_SIMPLEX, 1., Scalar(255, 100, 0), 2);
340             imshow("opencv_gpu_hog", img_to_show);
341
342             if (args.src_is_video || args.src_is_camera) vc >> frame;
343
344             workEnd();
345
346             if (args.write_video)
347             {
348                 if (!video_writer.isOpened())
349                 {
350                     video_writer.open(args.dst_video, VideoWriter::fourcc('x','v','i','d'), args.dst_video_fps,
351                                       img_to_show.size(), true);
352                     if (!video_writer.isOpened())
353                         throw std::runtime_error("can't create video writer");
354                 }
355
356                 if (make_gray) cvtColor(img_to_show, img, COLOR_GRAY2BGR);
357                 else cvtColor(img_to_show, img, COLOR_BGRA2BGR);
358
359                 video_writer << img;
360             }
361
362             handleKey((char)waitKey(3));
363         }
364     }
365 }
366
367
368 void App::handleKey(char key)
369 {
370     switch (key)
371     {
372     case 27:
373         running = false;
374         break;
375     case 'm':
376     case 'M':
377         use_gpu = !use_gpu;
378         cout << "Switched to " << (use_gpu ? "CUDA" : "CPU") << " mode\n";
379         break;
380     case 'g':
381     case 'G':
382         make_gray = !make_gray;
383         cout << "Convert image to gray: " << (make_gray ? "YES" : "NO") << endl;
384         break;
385     case '1':
386         scale *= 1.05;
387         cout << "Scale: " << scale << endl;
388         break;
389     case 'q':
390     case 'Q':
391         scale /= 1.05;
392         cout << "Scale: " << scale << endl;
393         break;
394     case '2':
395         nlevels++;
396         cout << "Levels number: " << nlevels << endl;
397         break;
398     case 'w':
399     case 'W':
400         nlevels = max(nlevels - 1, 1);
401         cout << "Levels number: " << nlevels << endl;
402         break;
403     case '3':
404         gr_threshold++;
405         cout << "Group threshold: " << gr_threshold << endl;
406         break;
407     case 'e':
408     case 'E':
409         gr_threshold = max(0, gr_threshold - 1);
410         cout << "Group threshold: " << gr_threshold << endl;
411         break;
412     case '4':
413         hit_threshold+=0.25;
414         cout << "Hit threshold: " << hit_threshold << endl;
415         break;
416     case 'r':
417     case 'R':
418         hit_threshold = max(0.0, hit_threshold - 0.25);
419         cout << "Hit threshold: " << hit_threshold << endl;
420         break;
421     case 'c':
422     case 'C':
423         gamma_corr = !gamma_corr;
424         cout << "Gamma correction: " << gamma_corr << endl;
425         break;
426     }
427 }
428
429
430 inline void App::hogWorkBegin() { hog_work_begin = getTickCount(); }
431
432 inline void App::hogWorkEnd()
433 {
434     int64 delta = getTickCount() - hog_work_begin;
435     double freq = getTickFrequency();
436     hog_work_fps = freq / delta;
437 }
438
439 inline string App::hogWorkFps() const
440 {
441     stringstream ss;
442     ss << hog_work_fps;
443     return ss.str();
444 }
445
446
447 inline void App::workBegin() { work_begin = getTickCount(); }
448
449 inline void App::workEnd()
450 {
451     int64 delta = getTickCount() - work_begin;
452     double freq = getTickFrequency();
453     work_fps = freq / delta;
454 }
455
456 inline string App::workFps() const
457 {
458     stringstream ss;
459     ss << work_fps;
460     return ss.str();
461 }