Merge pull request #3589 from JBosch:master
[platform/upstream/opencv.git] / samples / gpu / hog.cpp
1 #include <iostream>
2 #include <fstream>
3 #include <string>
4 #include <sstream>
5 #include <iomanip>
6 #include <stdexcept>
7 #include <opencv2/core/utility.hpp>
8 #include "opencv2/cudaobjdetect.hpp"
9 #include "opencv2/highgui.hpp"
10 #include "opencv2/objdetect.hpp"
11 #include "opencv2/imgproc.hpp"
12
13 using namespace std;
14 using namespace cv;
15
16 bool help_showed = false;
17
18 class Args
19 {
20 public:
21     Args();
22     static Args read(int argc, char** argv);
23
24     string src;
25     bool src_is_video;
26     bool src_is_camera;
27     int camera_id;
28
29     bool write_video;
30     string dst_video;
31     double dst_video_fps;
32
33     bool make_gray;
34
35     bool resize_src;
36     int width, height;
37
38     double scale;
39     int nlevels;
40     int gr_threshold;
41
42     double hit_threshold;
43     bool hit_threshold_auto;
44
45     int win_width;
46     int win_stride_width, win_stride_height;
47
48     bool gamma_corr;
49 };
50
51
52 class App
53 {
54 public:
55     App(const Args& s);
56     void run();
57
58     void handleKey(char key);
59
60     void hogWorkBegin();
61     void hogWorkEnd();
62     string hogWorkFps() const;
63
64     void workBegin();
65     void workEnd();
66     string workFps() const;
67
68     string message() const;
69
70 private:
71     App operator=(App&);
72
73     Args args;
74     bool running;
75
76     bool use_gpu;
77     bool make_gray;
78     double scale;
79     int gr_threshold;
80     int nlevels;
81     double hit_threshold;
82     bool gamma_corr;
83
84     int64 hog_work_begin;
85     double hog_work_fps;
86
87     int64 work_begin;
88     double work_fps;
89 };
90
91 static void printHelp()
92 {
93     cout << "Histogram of Oriented Gradients descriptor and detector sample.\n"
94          << "\nUsage: hog_gpu\n"
95          << "  (<image>|--video <vide>|--camera <camera_id>) # frames source\n"
96          << "  [--make_gray <true/false>] # convert image to gray one or not\n"
97          << "  [--resize_src <true/false>] # do resize of the source image or not\n"
98          << "  [--width <int>] # resized image width\n"
99          << "  [--height <int>] # resized image height\n"
100          << "  [--hit_threshold <double>] # classifying plane distance threshold (0.0 usually)\n"
101          << "  [--scale <double>] # HOG window scale factor\n"
102          << "  [--nlevels <int>] # max number of HOG window scales\n"
103          << "  [--win_width <int>] # width of the window (48 or 64)\n"
104          << "  [--win_stride_width <int>] # distance by OX axis between neighbour wins\n"
105          << "  [--win_stride_height <int>] # distance by OY axis between neighbour wins\n"
106          << "  [--gr_threshold <int>] # merging similar rects constant\n"
107          << "  [--gamma_correct <int>] # do gamma correction or not\n"
108          << "  [--write_video <bool>] # write video or not\n"
109          << "  [--dst_video <path>] # output video path\n"
110          << "  [--dst_video_fps <double>] # output video fps\n";
111     help_showed = true;
112 }
113
114 int main(int argc, char** argv)
115 {
116     try
117     {
118         Args args;
119         if (argc < 2)
120         {
121             printHelp();
122             args.camera_id = 0;
123             args.src_is_camera = true;
124         }
125         else
126         {
127             args = Args::read(argc, argv);
128             if (help_showed)
129                 return -1;
130         }
131         App app(args);
132         app.run();
133     }
134     catch (const Exception& e) { return cout << "error: "  << e.what() << endl, 1; }
135     catch (const exception& e) { return cout << "error: "  << e.what() << endl, 1; }
136     catch(...) { return cout << "unknown exception" << endl, 1; }
137     return 0;
138 }
139
140
141 Args::Args()
142 {
143     src_is_video = false;
144     src_is_camera = false;
145     camera_id = 0;
146
147     write_video = false;
148     dst_video_fps = 24.;
149
150     make_gray = false;
151
152     resize_src = false;
153     width = 640;
154     height = 480;
155
156     scale = 1.05;
157     nlevels = 13;
158     gr_threshold = 8;
159     hit_threshold = 1.4;
160     hit_threshold_auto = true;
161
162     win_width = 48;
163     win_stride_width = 8;
164     win_stride_height = 8;
165
166     gamma_corr = true;
167 }
168
169
170 Args Args::read(int argc, char** argv)
171 {
172     Args args;
173     for (int i = 1; i < argc; i++)
174     {
175         if (string(argv[i]) == "--make_gray") args.make_gray = (string(argv[++i]) == "true");
176         else if (string(argv[i]) == "--resize_src") args.resize_src = (string(argv[++i]) == "true");
177         else if (string(argv[i]) == "--width") args.width = atoi(argv[++i]);
178         else if (string(argv[i]) == "--height") args.height = atoi(argv[++i]);
179         else if (string(argv[i]) == "--hit_threshold")
180         {
181             args.hit_threshold = atof(argv[++i]);
182             args.hit_threshold_auto = false;
183         }
184         else if (string(argv[i]) == "--scale") args.scale = atof(argv[++i]);
185         else if (string(argv[i]) == "--nlevels") args.nlevels = atoi(argv[++i]);
186         else if (string(argv[i]) == "--win_width") args.win_width = atoi(argv[++i]);
187         else if (string(argv[i]) == "--win_stride_width") args.win_stride_width = atoi(argv[++i]);
188         else if (string(argv[i]) == "--win_stride_height") args.win_stride_height = atoi(argv[++i]);
189         else if (string(argv[i]) == "--gr_threshold") args.gr_threshold = atoi(argv[++i]);
190         else if (string(argv[i]) == "--gamma_correct") args.gamma_corr = (string(argv[++i]) == "true");
191         else if (string(argv[i]) == "--write_video") args.write_video = (string(argv[++i]) == "true");
192         else if (string(argv[i]) == "--dst_video") args.dst_video = argv[++i];
193         else if (string(argv[i]) == "--dst_video_fps") args.dst_video_fps = atof(argv[++i]);
194         else if (string(argv[i]) == "--help") printHelp();
195         else if (string(argv[i]) == "--video") { args.src = argv[++i]; args.src_is_video = true; }
196         else if (string(argv[i]) == "--camera") { args.camera_id = atoi(argv[++i]); args.src_is_camera = true; }
197         else if (args.src.empty()) args.src = argv[i];
198         else throw runtime_error((string("unknown key: ") + argv[i]));
199     }
200     return args;
201 }
202
203
204 App::App(const Args& s)
205 {
206     cv::cuda::printShortCudaDeviceInfo(cv::cuda::getDevice());
207
208     args = s;
209     cout << "\nControls:\n"
210          << "\tESC - exit\n"
211          << "\tm - change mode GPU <-> CPU\n"
212          << "\tg - convert image to gray or not\n"
213          << "\t1/q - increase/decrease HOG scale\n"
214          << "\t2/w - increase/decrease levels count\n"
215          << "\t3/e - increase/decrease HOG group threshold\n"
216          << "\t4/r - increase/decrease hit threshold\n"
217          << endl;
218
219     use_gpu = true;
220     make_gray = args.make_gray;
221     scale = args.scale;
222     gr_threshold = args.gr_threshold;
223     nlevels = args.nlevels;
224
225     if (args.hit_threshold_auto)
226         args.hit_threshold = args.win_width == 48 ? 1.4 : 0.;
227     hit_threshold = args.hit_threshold;
228
229     gamma_corr = args.gamma_corr;
230
231     if (args.win_width != 64 && args.win_width != 48)
232         args.win_width = 64;
233
234     cout << "Scale: " << scale << endl;
235     if (args.resize_src)
236         cout << "Resized source: (" << args.width << ", " << args.height << ")\n";
237     cout << "Group threshold: " << gr_threshold << endl;
238     cout << "Levels number: " << nlevels << endl;
239     cout << "Win width: " << args.win_width << endl;
240     cout << "Win stride: (" << args.win_stride_width << ", " << args.win_stride_height << ")\n";
241     cout << "Hit threshold: " << hit_threshold << endl;
242     cout << "Gamma correction: " << gamma_corr << endl;
243     cout << endl;
244 }
245
246
247 void App::run()
248 {
249     running = true;
250     cv::VideoWriter video_writer;
251
252     Size win_size(args.win_width, args.win_width * 2); //(64, 128) or (48, 96)
253     Size win_stride(args.win_stride_width, args.win_stride_height);
254
255     cv::Ptr<cv::cuda::HOG> gpu_hog = cv::cuda::HOG::create(win_size);
256     cv::HOGDescriptor cpu_hog(win_size, Size(16, 16), Size(8, 8), Size(8, 8), 9);
257
258     // Create HOG descriptors and detectors here
259     Mat detector = gpu_hog->getDefaultPeopleDetector();
260
261     gpu_hog->setSVMDetector(detector);
262     cpu_hog.setSVMDetector(detector);
263
264     while (running)
265     {
266         VideoCapture vc;
267         Mat frame;
268
269         if (args.src_is_video)
270         {
271             vc.open(args.src.c_str());
272             if (!vc.isOpened())
273                 throw runtime_error(string("can't open video file: " + args.src));
274             vc >> frame;
275         }
276         else if (args.src_is_camera)
277         {
278             vc.open(args.camera_id);
279             if (!vc.isOpened())
280             {
281                 stringstream msg;
282                 msg << "can't open camera: " << args.camera_id;
283                 throw runtime_error(msg.str());
284             }
285             vc >> frame;
286         }
287         else
288         {
289             frame = imread(args.src);
290             if (frame.empty())
291                 throw runtime_error(string("can't open image file: " + args.src));
292         }
293
294         Mat img_aux, img, img_to_show;
295         cuda::GpuMat gpu_img;
296
297         // Iterate over all frames
298         while (running && !frame.empty())
299         {
300             workBegin();
301
302             // Change format of the image
303             if (make_gray) cvtColor(frame, img_aux, COLOR_BGR2GRAY);
304             else if (use_gpu) cvtColor(frame, img_aux, COLOR_BGR2BGRA);
305             else frame.copyTo(img_aux);
306
307             // Resize image
308             if (args.resize_src) resize(img_aux, img, Size(args.width, args.height));
309             else img = img_aux;
310             img_to_show = img;
311
312             vector<Rect> found;
313
314             // Perform HOG classification
315             hogWorkBegin();
316             if (use_gpu)
317             {
318                 gpu_img.upload(img);
319                 gpu_hog->setNumLevels(nlevels);
320                 gpu_hog->setHitThreshold(hit_threshold);
321                 gpu_hog->setWinStride(win_stride);
322                 gpu_hog->setScaleFactor(scale);
323                 gpu_hog->setGroupThreshold(gr_threshold);
324                 gpu_hog->detectMultiScale(gpu_img, found);
325             }
326             else
327             {
328                 cpu_hog.nlevels = nlevels;
329                 cpu_hog.detectMultiScale(img, found, hit_threshold, win_stride,
330                                           Size(0, 0), scale, gr_threshold);
331             }
332             hogWorkEnd();
333
334             // Draw positive classified windows
335             for (size_t i = 0; i < found.size(); i++)
336             {
337                 Rect r = found[i];
338                 rectangle(img_to_show, r.tl(), r.br(), Scalar(0, 255, 0), 3);
339             }
340
341             if (use_gpu)
342                 putText(img_to_show, "Mode: GPU", Point(5, 25), FONT_HERSHEY_SIMPLEX, 1., Scalar(255, 100, 0), 2);
343             else
344                 putText(img_to_show, "Mode: CPU", Point(5, 25), FONT_HERSHEY_SIMPLEX, 1., Scalar(255, 100, 0), 2);
345             putText(img_to_show, "FPS (HOG only): " + hogWorkFps(), Point(5, 65), FONT_HERSHEY_SIMPLEX, 1., Scalar(255, 100, 0), 2);
346             putText(img_to_show, "FPS (total): " + workFps(), Point(5, 105), FONT_HERSHEY_SIMPLEX, 1., Scalar(255, 100, 0), 2);
347             imshow("opencv_gpu_hog", img_to_show);
348
349             if (args.src_is_video || args.src_is_camera) vc >> frame;
350
351             workEnd();
352
353             if (args.write_video)
354             {
355                 if (!video_writer.isOpened())
356                 {
357                     video_writer.open(args.dst_video, VideoWriter::fourcc('x','v','i','d'), args.dst_video_fps,
358                                       img_to_show.size(), true);
359                     if (!video_writer.isOpened())
360                         throw std::runtime_error("can't create video writer");
361                 }
362
363                 if (make_gray) cvtColor(img_to_show, img, COLOR_GRAY2BGR);
364                 else cvtColor(img_to_show, img, COLOR_BGRA2BGR);
365
366                 video_writer << img;
367             }
368
369             handleKey((char)waitKey(3));
370         }
371     }
372 }
373
374
375 void App::handleKey(char key)
376 {
377     switch (key)
378     {
379     case 27:
380         running = false;
381         break;
382     case 'm':
383     case 'M':
384         use_gpu = !use_gpu;
385         cout << "Switched to " << (use_gpu ? "CUDA" : "CPU") << " mode\n";
386         break;
387     case 'g':
388     case 'G':
389         make_gray = !make_gray;
390         cout << "Convert image to gray: " << (make_gray ? "YES" : "NO") << endl;
391         break;
392     case '1':
393         scale *= 1.05;
394         cout << "Scale: " << scale << endl;
395         break;
396     case 'q':
397     case 'Q':
398         scale /= 1.05;
399         cout << "Scale: " << scale << endl;
400         break;
401     case '2':
402         nlevels++;
403         cout << "Levels number: " << nlevels << endl;
404         break;
405     case 'w':
406     case 'W':
407         nlevels = max(nlevels - 1, 1);
408         cout << "Levels number: " << nlevels << endl;
409         break;
410     case '3':
411         gr_threshold++;
412         cout << "Group threshold: " << gr_threshold << endl;
413         break;
414     case 'e':
415     case 'E':
416         gr_threshold = max(0, gr_threshold - 1);
417         cout << "Group threshold: " << gr_threshold << endl;
418         break;
419     case '4':
420         hit_threshold+=0.25;
421         cout << "Hit threshold: " << hit_threshold << endl;
422         break;
423     case 'r':
424     case 'R':
425         hit_threshold = max(0.0, hit_threshold - 0.25);
426         cout << "Hit threshold: " << hit_threshold << endl;
427         break;
428     case 'c':
429     case 'C':
430         gamma_corr = !gamma_corr;
431         cout << "Gamma correction: " << gamma_corr << endl;
432         break;
433     }
434 }
435
436
437 inline void App::hogWorkBegin() { hog_work_begin = getTickCount(); }
438
439 inline void App::hogWorkEnd()
440 {
441     int64 delta = getTickCount() - hog_work_begin;
442     double freq = getTickFrequency();
443     hog_work_fps = freq / delta;
444 }
445
446 inline string App::hogWorkFps() const
447 {
448     stringstream ss;
449     ss << hog_work_fps;
450     return ss.str();
451 }
452
453
454 inline void App::workBegin() { work_begin = getTickCount(); }
455
456 inline void App::workEnd()
457 {
458     int64 delta = getTickCount() - work_begin;
459     double freq = getTickFrequency();
460     work_fps = freq / delta;
461 }
462
463 inline string App::workFps() const
464 {
465     stringstream ss;
466     ss << work_fps;
467     return ss.str();
468 }