Merge pull request #2887 from ilya-lavrenov:ipp_morph_fix
[platform/upstream/opencv.git] / modules / contrib / src / hybridtracker.cpp
1 //*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                                License Agreement
11 //                       For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
14 // Copyright (C) 2008-2011, Willow Garage Inc., all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of Intel Corporation may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42 #include "precomp.hpp"
43 #include "opencv2/contrib/hybridtracker.hpp"
44
45 using namespace cv;
46
47 CvHybridTrackerParams::CvHybridTrackerParams(float _ft_tracker_weight, float _ms_tracker_weight,
48             CvFeatureTrackerParams _ft_params,
49             CvMeanShiftTrackerParams _ms_params,
50             CvMotionModel)
51 {
52     ft_tracker_weight = _ft_tracker_weight;
53     ms_tracker_weight = _ms_tracker_weight;
54     ft_params = _ft_params;
55     ms_params = _ms_params;
56 }
57
58 CvMeanShiftTrackerParams::CvMeanShiftTrackerParams(int _tracking_type, CvTermCriteria _term_crit)
59 {
60     tracking_type = _tracking_type;
61     term_crit = _term_crit;
62 }
63
64 CvHybridTracker::CvHybridTracker() {
65
66 }
67
68 CvHybridTracker::CvHybridTracker(HybridTrackerParams _params) :
69     params(_params) {
70     params.ft_params.feature_type = CvFeatureTrackerParams::SIFT;
71     mstracker = new CvMeanShiftTracker(params.ms_params);
72     fttracker = new CvFeatureTracker(params.ft_params);
73 }
74
75 CvHybridTracker::~CvHybridTracker() {
76     if (mstracker != NULL)
77         delete mstracker;
78     if (fttracker != NULL)
79         delete fttracker;
80 }
81
82 inline float CvHybridTracker::getL2Norm(Point2f p1, Point2f p2) {
83     float distance = (p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y) * (p1.y
84             - p2.y);
85     return std::sqrt(distance);
86 }
87
88 Mat CvHybridTracker::getDistanceProjection(Mat image, Point2f center) {
89     Mat hist(image.size(), CV_64F);
90
91     double lu = getL2Norm(Point(0, 0), center);
92     double ru = getL2Norm(Point(0, image.size().width), center);
93     double rd = getL2Norm(Point(image.size().height, image.size().width),
94             center);
95     double ld = getL2Norm(Point(image.size().height, 0), center);
96
97     double max = (lu < ru) ? lu : ru;
98     max = (max < rd) ? max : rd;
99     max = (max < ld) ? max : ld;
100
101     for (int i = 0; i < hist.rows; i++)
102         for (int j = 0; j < hist.cols; j++)
103             hist.at<double> (i, j) = 1.0 - (getL2Norm(Point(i, j), center)
104                     / max);
105
106     return hist;
107 }
108
109 Mat CvHybridTracker::getGaussianProjection(Mat image, int ksize, double sigma,
110         Point2f center) {
111     Mat kernel = getGaussianKernel(ksize, sigma, CV_64F);
112     double max = kernel.at<double> (ksize / 2);
113
114     Mat hist(image.size(), CV_64F);
115     for (int i = 0; i < hist.rows; i++)
116         for (int j = 0; j < hist.cols; j++) {
117             int pos = cvRound(getL2Norm(Point(i, j), center));
118             if (pos < ksize / 2.0)
119                 hist.at<double> (i, j) = 1.0 - (kernel.at<double> (pos) / max);
120         }
121
122     return hist;
123 }
124
125 void CvHybridTracker::newTracker(Mat image, Rect selection) {
126     prev_proj = Mat::zeros(image.size(), CV_64FC1);
127     prev_center = Point2f(selection.x + selection.width / 2.0f, selection.y
128             + selection.height / 2.0f);
129     prev_window = selection;
130
131     mstracker->newTrackingWindow(image, selection);
132     fttracker->newTrackingWindow(image, selection);
133
134     samples = cvCreateMat(2, 1, CV_32FC1);
135     labels = cvCreateMat(2, 1, CV_32SC1);
136
137     ittr = 0;
138 }
139
140 void CvHybridTracker::updateTracker(Mat image) {
141     ittr++;
142
143     //copy over clean images: TODO
144     mstracker->updateTrackingWindow(image);
145     fttracker->updateTrackingWindowWithFlow(image);
146
147     if (params.motion_model == CvMotionModel::EM)
148         updateTrackerWithEM(image);
149     else
150         updateTrackerWithLowPassFilter(image);
151
152     // Regression to find new weights
153     Point2f ms_center = mstracker->getTrackingEllipse().center;
154     Point2f ft_center = fttracker->getTrackingCenter();
155
156 #ifdef DEBUG_HYTRACKER
157     circle(image, ms_center, 3, Scalar(0, 0, 255), -1, 8);
158     circle(image, ft_center, 3, Scalar(255, 0, 0), -1, 8);
159     putText(image, "ms", Point(ms_center.x+2, ms_center.y), FONT_HERSHEY_PLAIN, 0.75, Scalar(255, 255, 255));
160     putText(image, "ft", Point(ft_center.x+2, ft_center.y), FONT_HERSHEY_PLAIN, 0.75, Scalar(255, 255, 255));
161 #endif
162
163     double ms_len = getL2Norm(ms_center, curr_center);
164     double ft_len = getL2Norm(ft_center, curr_center);
165     double total_len = ms_len + ft_len;
166
167     params.ms_tracker_weight *= (ittr - 1);
168     params.ms_tracker_weight += (float)((ms_len / total_len));
169     params.ms_tracker_weight /= ittr;
170     params.ft_tracker_weight *= (ittr - 1);
171     params.ft_tracker_weight += (float)((ft_len / total_len));
172     params.ft_tracker_weight /= ittr;
173
174     circle(image, prev_center, 3, Scalar(0, 0, 0), -1, 8);
175     circle(image, curr_center, 3, Scalar(255, 255, 255), -1, 8);
176
177     prev_center = curr_center;
178     prev_window.x = (int)(curr_center.x-prev_window.width/2.0);
179     prev_window.y = (int)(curr_center.y-prev_window.height/2.0);
180
181     mstracker->setTrackingWindow(prev_window);
182     fttracker->setTrackingWindow(prev_window);
183 }
184
185 void CvHybridTracker::updateTrackerWithEM(Mat image) {
186     Mat ms_backproj = mstracker->getHistogramProjection(CV_64F);
187     Mat ms_distproj = getDistanceProjection(image, mstracker->getTrackingCenter());
188     Mat ms_proj = ms_backproj.mul(ms_distproj);
189
190     float dist_err = getL2Norm(mstracker->getTrackingCenter(), fttracker->getTrackingCenter());
191     Mat ft_gaussproj = getGaussianProjection(image, cvRound(dist_err), -1, fttracker->getTrackingCenter());
192     Mat ft_distproj = getDistanceProjection(image, fttracker->getTrackingCenter());
193     Mat ft_proj = ft_gaussproj.mul(ft_distproj);
194
195     Mat proj = params.ms_tracker_weight * ms_proj + params.ft_tracker_weight * ft_proj + prev_proj;
196
197     int sample_count = countNonZero(proj);
198     cvReleaseMat(&samples);
199     cvReleaseMat(&labels);
200     samples = cvCreateMat(sample_count, 2, CV_32FC1);
201     labels = cvCreateMat(sample_count, 1, CV_32SC1);
202
203     int count = 0;
204     for (int i = 0; i < proj.rows; i++)
205         for (int j = 0; j < proj.cols; j++)
206             if (proj.at<double> (i, j) > 0) {
207                 samples->data.fl[count * 2] = (float)i;
208                 samples->data.fl[count * 2 + 1] = (float)j;
209                 count++;
210             }
211
212     cv::Mat lbls;
213
214     EM em_model(1, EM::COV_MAT_SPHERICAL, TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.001));
215     em_model.train(cvarrToMat(samples), noArray(), lbls);
216     if(labels)
217         lbls.copyTo(cvarrToMat(labels));
218
219     Mat em_means = em_model.get<Mat>("means");
220     curr_center.x = (float)em_means.at<float>(0, 0);
221     curr_center.y = (float)em_means.at<float>(0, 1);
222 }
223
224 void CvHybridTracker::updateTrackerWithLowPassFilter(Mat) {
225     RotatedRect ms_track = mstracker->getTrackingEllipse();
226     Point2f ft_center = fttracker->getTrackingCenter();
227
228     float a = params.low_pass_gain;
229     curr_center.x = (1 - a) * prev_center.x + a * (params.ms_tracker_weight * ms_track.center.x + params.ft_tracker_weight * ft_center.x);
230     curr_center.y = (1 - a) * prev_center.y + a * (params.ms_tracker_weight * ms_track.center.y + params.ft_tracker_weight * ft_center.y);
231 }
232
233 Rect CvHybridTracker::getTrackingWindow() {
234     return prev_window;
235 }