1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 #ifndef OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP
6 #define OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP
11 /* The performance of many kernels are highly dependent on the tensor rank. Instead of having
12 * one kernel which can work with the maximally ranked tensors, we make one kernel for each supported
13 * tensor rank. This is to ensure that the requirements of the maximally ranked tensors do not take a
14 * toll on the performance of the operation for low ranked tensors. Hence, many kernels take the tensor
15 * rank as a template parameter.
17 * The kernel is a template and we have different instantiations for each rank. This causes the following pattern
18 * to arise frequently:
27 * The rank is a runtime variable. To facilitate creation of such structures, we use GENERATE_KERNEL_DISPATCHER.
28 * This macro creates a function which selects the correct kernel instantiation at runtime.
32 * // function which setups the kernel and launches it
33 * template <class T, std::size_t Rank>
34 * void launch_some_kernel(...);
36 * // creates the dispatcher named "some_dispatcher" which invokves the correct instantiation of "launch_some_kernel"
37 * GENERATE_KERNEL_DISPATCHER(some_dispatcher, launch_some_kernel);
39 * // internal API function
43 * auto rank = input.rank();
44 * some_dispatcher<T, MIN_RANK, MAX_RANK>(rank, ...);
49 * name name of the dispatcher function that is generated
50 * func template function that requires runtime selection
52 * T first template parameter to `func`
54 * end ending rank (inclusive)
56 * Executes func<T, selector> based on runtime `selector` argument given `selector` lies
57 * within the range [start, end]. If outside the range, no instantiation of `func` is executed.
59 #define GENERATE_KERNEL_DISPATCHER(name,func); \
60 template <class T, std::size_t start, std::size_t end, class... Args> static \
61 typename std::enable_if<start == end, void> \
62 ::type name(int selector, Args&& ...args) { \
63 if(selector == start) \
64 func<T, start>(std::forward<Args>(args)...); \
67 template <class T, std::size_t start, std::size_t end, class... Args> static \
68 typename std::enable_if<start != end, void> \
69 ::type name(int selector, Args&& ...args) { \
70 if(selector == start) \
71 func<T, start>(std::forward<Args>(args)...); \
73 name<T, start + 1, end, Args...>(selector, std::forward<Args>(args)...); \
76 #endif /* OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP */