Merge pull request #1663 from vpisarev:ocl_experiments3
[profile/ivi/opencv.git] / modules / ocl / src / mssegmentation.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2010-2012, Institute Of Software Chinese Academy Of Science, all rights reserved.
14 // Copyright (C) 2010-2012, Advanced Micro Devices, Inc., all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // @Authors
18 //
19 // Redistribution and use in source and binary forms, with or without modification,
20 // are permitted provided that the following conditions are met:
21 //
22 //   * Redistribution's of source code must retain the above copyright notice,
23 //     this list of conditions and the following disclaimer.
24 //
25 //   * Redistribution's in binary form must reproduce the above copyright notice,
26 //     this list of conditions and the following disclaimer in the documentation
27 //     and/or other oclMaterials provided with the distribution.
28 //
29 //   * The name of the copyright holders may not be used to endorse or promote products
30 //     derived from this software without specific prior written permission.
31 //
32 // This software is provided by the copyright holders and contributors "as is" and
33 // any express or implied warranties, including, but not limited to, the implied
34 // warranties of merchantability and fitness for a particular purpose are disclaimed.
35 // In no event shall the Intel Corporation or contributors be liable for any direct,
36 // indirect, incidental, special, exemplary, or consequential damages
37 // (including, but not limited to, procurement of substitute goods or services;
38 // loss of use, data, or profits; or business interruption) however caused
39 // and on any theory of liability, whether in contract, strict liability,
40 // or tort (including negligence or otherwise) arising in any way out of
41 // the use of this software, even if advised of the possibility of such damage.
42 //
43 //M*/
44
45 #include "precomp.hpp"
46 #include "opencl_kernels.hpp"
47
48 using namespace cv;
49 using namespace cv::ocl;
50
51 // Auxiliray stuff
52 namespace
53 {
54
55     //
56     // Declarations
57     //
58
59     class DjSets
60     {
61     public:
62         DjSets(int n);
63         int find(int elem);
64         int merge(int set1, int set2);
65
66         std::vector<int> parent;
67         std::vector<int> rank;
68         std::vector<int> size;
69     private:
70         DjSets(const DjSets &) {}
71         DjSets operator =(const DjSets &);
72     };
73
74     template <typename T>
75     struct GraphEdge
76     {
77         GraphEdge() {}
78         GraphEdge(int to, int next, const T &val) : to(to), next(next), val(val) {}
79         int to;
80         int next;
81         T val;
82     };
83
84
85     template <typename T>
86     class Graph
87     {
88     public:
89         typedef GraphEdge<T> Edge;
90
91         Graph(int numv, int nume_max);
92
93         void addEdge(int from, int to, const T &val = T());
94
95         std::vector<int> start;
96         std::vector<Edge> edges;
97
98         int numv;
99         int nume_max;
100         int nume;
101     private:
102         Graph(const Graph &) {}
103         Graph operator =(const Graph &) {}
104     };
105
106
107     struct SegmLinkVal
108     {
109         SegmLinkVal() {}
110         SegmLinkVal(int dr, int dsp) : dr(dr), dsp(dsp) {}
111         bool operator <(const SegmLinkVal &other) const
112         {
113             return dr + dsp < other.dr + other.dsp;
114         }
115         int dr;
116         int dsp;
117     };
118
119
120     struct SegmLink
121     {
122         SegmLink() {}
123         SegmLink(int from, int to, const SegmLinkVal &val)
124             : from(from), to(to), val(val) {}
125         bool operator <(const SegmLink &other) const
126         {
127             return val < other.val;
128         }
129         int from;
130         int to;
131         SegmLinkVal val;
132     };
133
134     //
135     // Implementation
136     //
137
138     DjSets DjSets::operator = (const DjSets &/*obj*/)
139     {
140         //cout << "Invalid DjSets constructor\n";
141         CV_Error(-1, "Invalid DjSets constructor\n");
142         return *this;
143     }
144
145     DjSets::DjSets(int n) : parent(n), rank(n, 0), size(n, 1)
146     {
147         for (int i = 0; i < n; ++i)
148             parent[i] = i;
149     }
150
151
152     inline int DjSets::find(int elem)
153     {
154         int set = elem;
155         while (set != parent[set])
156             set = parent[set];
157         while (elem != parent[elem])
158         {
159             int next = parent[elem];
160             parent[elem] = set;
161             elem = next;
162         }
163         return set;
164     }
165
166
167     inline int DjSets::merge(int set1, int set2)
168     {
169         if (rank[set1] < rank[set2])
170         {
171             parent[set1] = set2;
172             size[set2] += size[set1];
173             return set2;
174         }
175         if (rank[set2] < rank[set1])
176         {
177             parent[set2] = set1;
178             size[set1] += size[set2];
179             return set1;
180         }
181         parent[set1] = set2;
182         rank[set2]++;
183         size[set2] += size[set1];
184         return set2;
185     }
186
187
188     template <typename T>
189     Graph<T>::Graph(int numv, int nume_max) : start(numv, -1), edges(nume_max)
190     {
191         this->numv = numv;
192         this->nume_max = nume_max;
193         nume = 0;
194     }
195
196
197     template <typename T>
198     inline void Graph<T>::addEdge(int from, int to, const T &val)
199     {
200         edges[nume] = Edge(to, start[from], val);
201         start[from] = nume;
202         nume++;
203     }
204
205
206     inline int pix(int y, int x, int ncols)
207     {
208         return y * ncols + x;
209     }
210
211
212     inline int sqr(int x)
213     {
214         return x * x;
215     }
216
217
218     inline int dist2(const cv::Vec4b &lhs, const cv::Vec4b &rhs)
219     {
220         return sqr(lhs[0] - rhs[0]) + sqr(lhs[1] - rhs[1]) + sqr(lhs[2] - rhs[2]);
221     }
222
223
224     inline int dist2(const cv::Vec2s &lhs, const cv::Vec2s &rhs)
225     {
226         return sqr(lhs[0] - rhs[0]) + sqr(lhs[1] - rhs[1]);
227     }
228
229 } // anonymous namespace
230
231 namespace cv
232 {
233     namespace ocl
234     {
235
236         void meanShiftSegmentation(const oclMat &src, Mat &dst, int sp, int sr, int minsize, TermCriteria criteria)
237         {
238             CV_Assert(src.type() == CV_8UC4);
239             const int nrows = src.rows;
240             const int ncols = src.cols;
241             const int hr = sr;
242             const int hsp = sp;
243
244             // Perform mean shift procedure and obtain region and spatial maps
245             oclMat h_rmap, h_spmap;
246             meanShiftProc(src, h_rmap, h_spmap, sp, sr, criteria);
247             Mat rmap = h_rmap;
248             Mat spmap = h_spmap;
249
250             Graph<SegmLinkVal> g(nrows * ncols, 4 * (nrows - 1) * (ncols - 1)
251                                  + (nrows - 1) + (ncols - 1));
252
253             // Make region adjacent graph from image
254             Vec4b r1;
255             Vec4b r2[4];
256             Vec2s sp1;
257             Vec2s sp2[4];
258             int dr[4];
259             int dsp[4];
260             for (int y = 0; y < nrows - 1; ++y)
261             {
262                 Vec4b *ry = rmap.ptr<Vec4b>(y);
263                 Vec4b *ryp = rmap.ptr<Vec4b>(y + 1);
264                 Vec2s *spy = spmap.ptr<Vec2s>(y);
265                 Vec2s *spyp = spmap.ptr<Vec2s>(y + 1);
266                 for (int x = 0; x < ncols - 1; ++x)
267                 {
268                     r1 = ry[x];
269                     sp1 = spy[x];
270
271                     r2[0] = ry[x + 1];
272                     r2[1] = ryp[x];
273                     r2[2] = ryp[x + 1];
274                     r2[3] = ryp[x];
275
276                     sp2[0] = spy[x + 1];
277                     sp2[1] = spyp[x];
278                     sp2[2] = spyp[x + 1];
279                     sp2[3] = spyp[x];
280
281                     dr[0] = dist2(r1, r2[0]);
282                     dr[1] = dist2(r1, r2[1]);
283                     dr[2] = dist2(r1, r2[2]);
284                     dsp[0] = dist2(sp1, sp2[0]);
285                     dsp[1] = dist2(sp1, sp2[1]);
286                     dsp[2] = dist2(sp1, sp2[2]);
287
288                     r1 = ry[x + 1];
289                     sp1 = spy[x + 1];
290
291                     dr[3] = dist2(r1, r2[3]);
292                     dsp[3] = dist2(sp1, sp2[3]);
293
294                     g.addEdge(pix(y, x, ncols), pix(y, x + 1, ncols), SegmLinkVal(dr[0], dsp[0]));
295                     g.addEdge(pix(y, x, ncols), pix(y + 1, x, ncols), SegmLinkVal(dr[1], dsp[1]));
296                     g.addEdge(pix(y, x, ncols), pix(y + 1, x + 1, ncols), SegmLinkVal(dr[2], dsp[2]));
297                     g.addEdge(pix(y, x + 1, ncols), pix(y + 1, x, ncols), SegmLinkVal(dr[3], dsp[3]));
298                 }
299             }
300             for (int y = 0; y < nrows - 1; ++y)
301             {
302                 r1 = rmap.at<Vec4b>(y, ncols - 1);
303                 r2[0] = rmap.at<Vec4b>(y + 1, ncols - 1);
304                 sp1 = spmap.at<Vec2s>(y, ncols - 1);
305                 sp2[0] = spmap.at<Vec2s>(y + 1, ncols - 1);
306                 dr[0] = dist2(r1, r2[0]);
307                 dsp[0] = dist2(sp1, sp2[0]);
308                 g.addEdge(pix(y, ncols - 1, ncols), pix(y + 1, ncols - 1, ncols), SegmLinkVal(dr[0], dsp[0]));
309             }
310             for (int x = 0; x < ncols - 1; ++x)
311             {
312                 r1 = rmap.at<Vec4b>(nrows - 1, x);
313                 r2[0] = rmap.at<Vec4b>(nrows - 1, x + 1);
314                 sp1 = spmap.at<Vec2s>(nrows - 1, x);
315                 sp2[0] = spmap.at<Vec2s>(nrows - 1, x + 1);
316                 dr[0] = dist2(r1, r2[0]);
317                 dsp[0] = dist2(sp1, sp2[0]);
318                 g.addEdge(pix(nrows - 1, x, ncols), pix(nrows - 1, x + 1, ncols), SegmLinkVal(dr[0], dsp[0]));
319             }
320
321             DjSets comps(g.numv);
322
323             // Find adjacent components
324             for (int v = 0; v < g.numv; ++v)
325             {
326                 for (int e_it = g.start[v]; e_it != -1; e_it = g.edges[e_it].next)
327                 {
328                     int c1 = comps.find(v);
329                     int c2 = comps.find(g.edges[e_it].to);
330                     if (c1 != c2 && g.edges[e_it].val.dr < hr && g.edges[e_it].val.dsp < hsp)
331                         comps.merge(c1, c2);
332                 }
333             }
334
335             std::vector<SegmLink> edges;
336             edges.reserve(g.numv);
337
338             // Prepare edges connecting differnet components
339             for (int v = 0; v < g.numv; ++v)
340             {
341                 int c1 = comps.find(v);
342                 for (int e_it = g.start[v]; e_it != -1; e_it = g.edges[e_it].next)
343                 {
344                     int c2 = comps.find(g.edges[e_it].to);
345                     if (c1 != c2)
346                         edges.push_back(SegmLink(c1, c2, g.edges[e_it].val));
347                 }
348             }
349
350             // Sort all graph's edges connecting differnet components (in asceding order)
351             sort(edges.begin(), edges.end());
352
353             // Exclude small components (starting from the nearest couple)
354             for (size_t i = 0; i < edges.size(); ++i)
355             {
356                 int c1 = comps.find(edges[i].from);
357                 int c2 = comps.find(edges[i].to);
358                 if (c1 != c2 && (comps.size[c1] < minsize || comps.size[c2] < minsize))
359                     comps.merge(c1, c2);
360             }
361
362             // Compute sum of the pixel's colors which are in the same segment
363             Mat h_src = src;
364             std::vector<Vec4i> sumcols(nrows * ncols, Vec4i(0, 0, 0, 0));
365             for (int y = 0; y < nrows; ++y)
366             {
367                 Vec4b *h_srcy = h_src.ptr<Vec4b>(y);
368                 for (int x = 0; x < ncols; ++x)
369                 {
370                     int parent = comps.find(pix(y, x, ncols));
371                     Vec4b col = h_srcy[x];
372                     Vec4i &sumcol = sumcols[parent];
373                     sumcol[0] += col[0];
374                     sumcol[1] += col[1];
375                     sumcol[2] += col[2];
376                 }
377             }
378
379             // Create final image, color of each segment is the average color of its pixels
380             dst.create(src.size(), src.type());
381
382             for (int y = 0; y < nrows; ++y)
383             {
384                 Vec4b *dsty = dst.ptr<Vec4b>(y);
385                 for (int x = 0; x < ncols; ++x)
386                 {
387                     int parent = comps.find(pix(y, x, ncols));
388                     const Vec4i &sumcol = sumcols[parent];
389                     Vec4b &dstcol = dsty[x];
390                     dstcol[0] = static_cast<uchar>(sumcol[0] / comps.size[parent]);
391                     dstcol[1] = static_cast<uchar>(sumcol[1] / comps.size[parent]);
392                     dstcol[2] = static_cast<uchar>(sumcol[2] / comps.size[parent]);
393                 }
394             }
395         }
396
397     }
398 }