added dual tvl1 optical flow gpu implementation
[profile/ivi/opencv.git] / modules / gpu / src / cuda / row_filter.h
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
14 // Copyright (C) 2009, Willow Garage Inc., all rights reserved.
15 // Copyright (C) 1993-2011, NVIDIA Corporation, all rights reserved.
16 // Third party copyrights are property of their respective owners.
17 //
18 // Redistribution and use in source and binary forms, with or without modification,
19 // are permitted provided that the following conditions are met:
20 //
21 //   * Redistribution's of source code must retain the above copyright notice,
22 //     this list of conditions and the following disclaimer.
23 //
24 //   * Redistribution's in binary form must reproduce the above copyright notice,
25 //     this list of conditions and the following disclaimer in the documentation
26 //     and/or other materials provided with the distribution.
27 //
28 //   * The name of the copyright holders may not be used to endorse or promote products
29 //     derived from this software without specific prior written permission.
30 //
31 // This software is provided by the copyright holders and contributors "as is" and
32 // any express or implied warranties, including, but not limited to, the implied
33 // warranties of merchantability and fitness for a particular purpose are disclaimed.
34 // In no event shall the Intel Corporation or contributors be liable for any direct,
35 // indirect, incidental, special, exemplary, or consequential damages
36 // (including, but not limited to, procurement of substitute goods or services;
37 // loss of use, data, or profits; or business interruption) however caused
38 // and on any theory of liability, whether in contract, strict liability,
39 // or tort (including negligence or otherwise) arising in any way out of
40 // the use of this software, even if advised of the possibility of such damage.
41 //
42 //M*/
43
44 #include "opencv2/gpu/device/common.hpp"
45 #include "opencv2/gpu/device/saturate_cast.hpp"
46 #include "opencv2/gpu/device/vec_math.hpp"
47 #include "opencv2/gpu/device/border_interpolate.hpp"
48
49 using namespace cv::gpu;
50 using namespace cv::gpu::device;
51
52 namespace row_filter
53 {
54     #define MAX_KERNEL_SIZE 32
55
56     __constant__ float c_kernel[MAX_KERNEL_SIZE];
57
58     template <int KSIZE, typename T, typename D, typename B>
59     __global__ void linearRowFilter(const PtrStepSz<T> src, PtrStep<D> dst, const int anchor, const B brd)
60     {
61         #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 200)
62             const int BLOCK_DIM_X = 32;
63             const int BLOCK_DIM_Y = 8;
64             const int PATCH_PER_BLOCK = 4;
65             const int HALO_SIZE = 1;
66         #else
67             const int BLOCK_DIM_X = 32;
68             const int BLOCK_DIM_Y = 4;
69             const int PATCH_PER_BLOCK = 4;
70             const int HALO_SIZE = 1;
71         #endif
72
73         typedef typename TypeVec<float, VecTraits<T>::cn>::vec_type sum_t;
74
75         __shared__ sum_t smem[BLOCK_DIM_Y][(PATCH_PER_BLOCK + 2 * HALO_SIZE) * BLOCK_DIM_X];
76
77         const int y = blockIdx.y * BLOCK_DIM_Y + threadIdx.y;
78
79         if (y >= src.rows)
80             return;
81
82         const T* src_row = src.ptr(y);
83
84         const int xStart = blockIdx.x * (PATCH_PER_BLOCK * BLOCK_DIM_X) + threadIdx.x;
85
86         if (blockIdx.x > 0)
87         {
88             //Load left halo
89             #pragma unroll
90             for (int j = 0; j < HALO_SIZE; ++j)
91                 smem[threadIdx.y][threadIdx.x + j * BLOCK_DIM_X] = saturate_cast<sum_t>(src_row[xStart - (HALO_SIZE - j) * BLOCK_DIM_X]);
92         }
93         else
94         {
95             //Load left halo
96             #pragma unroll
97             for (int j = 0; j < HALO_SIZE; ++j)
98                 smem[threadIdx.y][threadIdx.x + j * BLOCK_DIM_X] = saturate_cast<sum_t>(brd.at_low(xStart - (HALO_SIZE - j) * BLOCK_DIM_X, src_row));
99         }
100
101         if (blockIdx.x + 2 < gridDim.x)
102         {
103             //Load main data
104             #pragma unroll
105             for (int j = 0; j < PATCH_PER_BLOCK; ++j)
106                 smem[threadIdx.y][threadIdx.x + HALO_SIZE * BLOCK_DIM_X + j * BLOCK_DIM_X] = saturate_cast<sum_t>(src_row[xStart + j * BLOCK_DIM_X]);
107
108             //Load right halo
109             #pragma unroll
110             for (int j = 0; j < HALO_SIZE; ++j)
111                 smem[threadIdx.y][threadIdx.x + (PATCH_PER_BLOCK + HALO_SIZE) * BLOCK_DIM_X + j * BLOCK_DIM_X] = saturate_cast<sum_t>(src_row[xStart + (PATCH_PER_BLOCK + j) * BLOCK_DIM_X]);
112         }
113         else
114         {
115             //Load main data
116             #pragma unroll
117             for (int j = 0; j < PATCH_PER_BLOCK; ++j)
118                 smem[threadIdx.y][threadIdx.x + HALO_SIZE * BLOCK_DIM_X + j * BLOCK_DIM_X] = saturate_cast<sum_t>(brd.at_high(xStart + j * BLOCK_DIM_X, src_row));
119
120             //Load right halo
121             #pragma unroll
122             for (int j = 0; j < HALO_SIZE; ++j)
123                 smem[threadIdx.y][threadIdx.x + (PATCH_PER_BLOCK + HALO_SIZE) * BLOCK_DIM_X + j * BLOCK_DIM_X] = saturate_cast<sum_t>(brd.at_high(xStart + (PATCH_PER_BLOCK + j) * BLOCK_DIM_X, src_row));
124         }
125
126         __syncthreads();
127
128         #pragma unroll
129         for (int j = 0; j < PATCH_PER_BLOCK; ++j)
130         {
131             const int x = xStart + j * BLOCK_DIM_X;
132
133             if (x < src.cols)
134             {
135                 sum_t sum = VecTraits<sum_t>::all(0);
136
137                 #pragma unroll
138                 for (int k = 0; k < KSIZE; ++k)
139                     sum = sum + smem[threadIdx.y][threadIdx.x + HALO_SIZE * BLOCK_DIM_X + j * BLOCK_DIM_X - anchor + k] * c_kernel[k];
140
141                 dst(y, x) = saturate_cast<D>(sum);
142             }
143         }
144     }
145
146     template <int KSIZE, typename T, typename D, template<typename> class B>
147     void caller(PtrStepSz<T> src, PtrStepSz<D> dst, int anchor, int cc, cudaStream_t stream)
148     {
149         int BLOCK_DIM_X;
150         int BLOCK_DIM_Y;
151         int PATCH_PER_BLOCK;
152
153         if (cc >= 20)
154         {
155             BLOCK_DIM_X = 32;
156             BLOCK_DIM_Y = 8;
157             PATCH_PER_BLOCK = 4;
158         }
159         else
160         {
161             BLOCK_DIM_X = 32;
162             BLOCK_DIM_Y = 4;
163             PATCH_PER_BLOCK = 4;
164         }
165
166         const dim3 block(BLOCK_DIM_X, BLOCK_DIM_Y);
167         const dim3 grid(divUp(src.cols, BLOCK_DIM_X * PATCH_PER_BLOCK), divUp(src.rows, BLOCK_DIM_Y));
168
169         B<T> brd(src.cols);
170
171         linearRowFilter<KSIZE, T, D><<<grid, block, 0, stream>>>(src, dst, anchor, brd);
172         cudaSafeCall( cudaGetLastError() );
173
174         if (stream == 0)
175             cudaSafeCall( cudaDeviceSynchronize() );
176     }
177 }
178
179 namespace filter
180 {
181     template <typename T, typename D>
182     void linearRow(PtrStepSzb src, PtrStepSzb dst, const float* kernel, int ksize, int anchor, int brd_type, int cc, cudaStream_t stream)
183     {
184         typedef void (*caller_t)(PtrStepSz<T> src, PtrStepSz<D> dst, int anchor, int cc, cudaStream_t stream);
185
186         static const caller_t callers[5][33] =
187         {
188             {
189                 0,
190                 row_filter::caller< 1, T, D, BrdRowReflect101>,
191                 row_filter::caller< 2, T, D, BrdRowReflect101>,
192                 row_filter::caller< 3, T, D, BrdRowReflect101>,
193                 row_filter::caller< 4, T, D, BrdRowReflect101>,
194                 row_filter::caller< 5, T, D, BrdRowReflect101>,
195                 row_filter::caller< 6, T, D, BrdRowReflect101>,
196                 row_filter::caller< 7, T, D, BrdRowReflect101>,
197                 row_filter::caller< 8, T, D, BrdRowReflect101>,
198                 row_filter::caller< 9, T, D, BrdRowReflect101>,
199                 row_filter::caller<10, T, D, BrdRowReflect101>,
200                 row_filter::caller<11, T, D, BrdRowReflect101>,
201                 row_filter::caller<12, T, D, BrdRowReflect101>,
202                 row_filter::caller<13, T, D, BrdRowReflect101>,
203                 row_filter::caller<14, T, D, BrdRowReflect101>,
204                 row_filter::caller<15, T, D, BrdRowReflect101>,
205                 row_filter::caller<16, T, D, BrdRowReflect101>,
206                 row_filter::caller<17, T, D, BrdRowReflect101>,
207                 row_filter::caller<18, T, D, BrdRowReflect101>,
208                 row_filter::caller<19, T, D, BrdRowReflect101>,
209                 row_filter::caller<20, T, D, BrdRowReflect101>,
210                 row_filter::caller<21, T, D, BrdRowReflect101>,
211                 row_filter::caller<22, T, D, BrdRowReflect101>,
212                 row_filter::caller<23, T, D, BrdRowReflect101>,
213                 row_filter::caller<24, T, D, BrdRowReflect101>,
214                 row_filter::caller<25, T, D, BrdRowReflect101>,
215                 row_filter::caller<26, T, D, BrdRowReflect101>,
216                 row_filter::caller<27, T, D, BrdRowReflect101>,
217                 row_filter::caller<28, T, D, BrdRowReflect101>,
218                 row_filter::caller<29, T, D, BrdRowReflect101>,
219                 row_filter::caller<30, T, D, BrdRowReflect101>,
220                 row_filter::caller<31, T, D, BrdRowReflect101>,
221                 row_filter::caller<32, T, D, BrdRowReflect101>
222             },
223             {
224                 0,
225                 row_filter::caller< 1, T, D, BrdRowReplicate>,
226                 row_filter::caller< 2, T, D, BrdRowReplicate>,
227                 row_filter::caller< 3, T, D, BrdRowReplicate>,
228                 row_filter::caller< 4, T, D, BrdRowReplicate>,
229                 row_filter::caller< 5, T, D, BrdRowReplicate>,
230                 row_filter::caller< 6, T, D, BrdRowReplicate>,
231                 row_filter::caller< 7, T, D, BrdRowReplicate>,
232                 row_filter::caller< 8, T, D, BrdRowReplicate>,
233                 row_filter::caller< 9, T, D, BrdRowReplicate>,
234                 row_filter::caller<10, T, D, BrdRowReplicate>,
235                 row_filter::caller<11, T, D, BrdRowReplicate>,
236                 row_filter::caller<12, T, D, BrdRowReplicate>,
237                 row_filter::caller<13, T, D, BrdRowReplicate>,
238                 row_filter::caller<14, T, D, BrdRowReplicate>,
239                 row_filter::caller<15, T, D, BrdRowReplicate>,
240                 row_filter::caller<16, T, D, BrdRowReplicate>,
241                 row_filter::caller<17, T, D, BrdRowReplicate>,
242                 row_filter::caller<18, T, D, BrdRowReplicate>,
243                 row_filter::caller<19, T, D, BrdRowReplicate>,
244                 row_filter::caller<20, T, D, BrdRowReplicate>,
245                 row_filter::caller<21, T, D, BrdRowReplicate>,
246                 row_filter::caller<22, T, D, BrdRowReplicate>,
247                 row_filter::caller<23, T, D, BrdRowReplicate>,
248                 row_filter::caller<24, T, D, BrdRowReplicate>,
249                 row_filter::caller<25, T, D, BrdRowReplicate>,
250                 row_filter::caller<26, T, D, BrdRowReplicate>,
251                 row_filter::caller<27, T, D, BrdRowReplicate>,
252                 row_filter::caller<28, T, D, BrdRowReplicate>,
253                 row_filter::caller<29, T, D, BrdRowReplicate>,
254                 row_filter::caller<30, T, D, BrdRowReplicate>,
255                 row_filter::caller<31, T, D, BrdRowReplicate>,
256                 row_filter::caller<32, T, D, BrdRowReplicate>
257             },
258             {
259                 0,
260                 row_filter::caller< 1, T, D, BrdRowConstant>,
261                 row_filter::caller< 2, T, D, BrdRowConstant>,
262                 row_filter::caller< 3, T, D, BrdRowConstant>,
263                 row_filter::caller< 4, T, D, BrdRowConstant>,
264                 row_filter::caller< 5, T, D, BrdRowConstant>,
265                 row_filter::caller< 6, T, D, BrdRowConstant>,
266                 row_filter::caller< 7, T, D, BrdRowConstant>,
267                 row_filter::caller< 8, T, D, BrdRowConstant>,
268                 row_filter::caller< 9, T, D, BrdRowConstant>,
269                 row_filter::caller<10, T, D, BrdRowConstant>,
270                 row_filter::caller<11, T, D, BrdRowConstant>,
271                 row_filter::caller<12, T, D, BrdRowConstant>,
272                 row_filter::caller<13, T, D, BrdRowConstant>,
273                 row_filter::caller<14, T, D, BrdRowConstant>,
274                 row_filter::caller<15, T, D, BrdRowConstant>,
275                 row_filter::caller<16, T, D, BrdRowConstant>,
276                 row_filter::caller<17, T, D, BrdRowConstant>,
277                 row_filter::caller<18, T, D, BrdRowConstant>,
278                 row_filter::caller<19, T, D, BrdRowConstant>,
279                 row_filter::caller<20, T, D, BrdRowConstant>,
280                 row_filter::caller<21, T, D, BrdRowConstant>,
281                 row_filter::caller<22, T, D, BrdRowConstant>,
282                 row_filter::caller<23, T, D, BrdRowConstant>,
283                 row_filter::caller<24, T, D, BrdRowConstant>,
284                 row_filter::caller<25, T, D, BrdRowConstant>,
285                 row_filter::caller<26, T, D, BrdRowConstant>,
286                 row_filter::caller<27, T, D, BrdRowConstant>,
287                 row_filter::caller<28, T, D, BrdRowConstant>,
288                 row_filter::caller<29, T, D, BrdRowConstant>,
289                 row_filter::caller<30, T, D, BrdRowConstant>,
290                 row_filter::caller<31, T, D, BrdRowConstant>,
291                 row_filter::caller<32, T, D, BrdRowConstant>
292             },
293             {
294                 0,
295                 row_filter::caller< 1, T, D, BrdRowReflect>,
296                 row_filter::caller< 2, T, D, BrdRowReflect>,
297                 row_filter::caller< 3, T, D, BrdRowReflect>,
298                 row_filter::caller< 4, T, D, BrdRowReflect>,
299                 row_filter::caller< 5, T, D, BrdRowReflect>,
300                 row_filter::caller< 6, T, D, BrdRowReflect>,
301                 row_filter::caller< 7, T, D, BrdRowReflect>,
302                 row_filter::caller< 8, T, D, BrdRowReflect>,
303                 row_filter::caller< 9, T, D, BrdRowReflect>,
304                 row_filter::caller<10, T, D, BrdRowReflect>,
305                 row_filter::caller<11, T, D, BrdRowReflect>,
306                 row_filter::caller<12, T, D, BrdRowReflect>,
307                 row_filter::caller<13, T, D, BrdRowReflect>,
308                 row_filter::caller<14, T, D, BrdRowReflect>,
309                 row_filter::caller<15, T, D, BrdRowReflect>,
310                 row_filter::caller<16, T, D, BrdRowReflect>,
311                 row_filter::caller<17, T, D, BrdRowReflect>,
312                 row_filter::caller<18, T, D, BrdRowReflect>,
313                 row_filter::caller<19, T, D, BrdRowReflect>,
314                 row_filter::caller<20, T, D, BrdRowReflect>,
315                 row_filter::caller<21, T, D, BrdRowReflect>,
316                 row_filter::caller<22, T, D, BrdRowReflect>,
317                 row_filter::caller<23, T, D, BrdRowReflect>,
318                 row_filter::caller<24, T, D, BrdRowReflect>,
319                 row_filter::caller<25, T, D, BrdRowReflect>,
320                 row_filter::caller<26, T, D, BrdRowReflect>,
321                 row_filter::caller<27, T, D, BrdRowReflect>,
322                 row_filter::caller<28, T, D, BrdRowReflect>,
323                 row_filter::caller<29, T, D, BrdRowReflect>,
324                 row_filter::caller<30, T, D, BrdRowReflect>,
325                 row_filter::caller<31, T, D, BrdRowReflect>,
326                 row_filter::caller<32, T, D, BrdRowReflect>
327             },
328             {
329                 0,
330                 row_filter::caller< 1, T, D, BrdRowWrap>,
331                 row_filter::caller< 2, T, D, BrdRowWrap>,
332                 row_filter::caller< 3, T, D, BrdRowWrap>,
333                 row_filter::caller< 4, T, D, BrdRowWrap>,
334                 row_filter::caller< 5, T, D, BrdRowWrap>,
335                 row_filter::caller< 6, T, D, BrdRowWrap>,
336                 row_filter::caller< 7, T, D, BrdRowWrap>,
337                 row_filter::caller< 8, T, D, BrdRowWrap>,
338                 row_filter::caller< 9, T, D, BrdRowWrap>,
339                 row_filter::caller<10, T, D, BrdRowWrap>,
340                 row_filter::caller<11, T, D, BrdRowWrap>,
341                 row_filter::caller<12, T, D, BrdRowWrap>,
342                 row_filter::caller<13, T, D, BrdRowWrap>,
343                 row_filter::caller<14, T, D, BrdRowWrap>,
344                 row_filter::caller<15, T, D, BrdRowWrap>,
345                 row_filter::caller<16, T, D, BrdRowWrap>,
346                 row_filter::caller<17, T, D, BrdRowWrap>,
347                 row_filter::caller<18, T, D, BrdRowWrap>,
348                 row_filter::caller<19, T, D, BrdRowWrap>,
349                 row_filter::caller<20, T, D, BrdRowWrap>,
350                 row_filter::caller<21, T, D, BrdRowWrap>,
351                 row_filter::caller<22, T, D, BrdRowWrap>,
352                 row_filter::caller<23, T, D, BrdRowWrap>,
353                 row_filter::caller<24, T, D, BrdRowWrap>,
354                 row_filter::caller<25, T, D, BrdRowWrap>,
355                 row_filter::caller<26, T, D, BrdRowWrap>,
356                 row_filter::caller<27, T, D, BrdRowWrap>,
357                 row_filter::caller<28, T, D, BrdRowWrap>,
358                 row_filter::caller<29, T, D, BrdRowWrap>,
359                 row_filter::caller<30, T, D, BrdRowWrap>,
360                 row_filter::caller<31, T, D, BrdRowWrap>,
361                 row_filter::caller<32, T, D, BrdRowWrap>
362             }
363         };
364
365         if (stream == 0)
366             cudaSafeCall( cudaMemcpyToSymbol(row_filter::c_kernel, kernel, ksize * sizeof(float), 0, cudaMemcpyDeviceToDevice) );
367         else
368             cudaSafeCall( cudaMemcpyToSymbolAsync(row_filter::c_kernel, kernel, ksize * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream) );
369
370         callers[brd_type][ksize]((PtrStepSz<T>)src, (PtrStepSz<D>)dst, anchor, cc, stream);
371     }
372 }