Merge pull request #16658 from YashasSamaga:cuda4dnn-refactor-activations
authorYashas Samaga B L <yashas_2010@yahoo.com>
Sat, 29 Feb 2020 08:46:14 +0000 (14:16 +0530)
committerGitHub <noreply@github.com>
Sat, 29 Feb 2020 08:46:14 +0000 (11:46 +0300)
cuda4dnn(activations, eltwise, scale_shift): refactor to reduce code duplication

* refactor activations

* refactor eltwise kernels

* move all functors to functors.hpp

* remove bias1 and scale1 kernels

modules/dnn/src/cuda/activations.cu
modules/dnn/src/cuda/bias_activation.cu
modules/dnn/src/cuda/eltwise_ops.cu
modules/dnn/src/cuda/execution.hpp
modules/dnn/src/cuda/functors.hpp [new file with mode: 0644]
modules/dnn/src/cuda/scale_shift.cu
modules/dnn/src/cuda4dnn/kernels/bias_activation.hpp
modules/dnn/src/cuda4dnn/kernels/scale_shift.hpp
modules/dnn/src/cuda4dnn/primitives/convolution.hpp
modules/dnn/src/cuda4dnn/primitives/normalize_bbox.hpp

index 143361c..221516d 100644 (file)
@@ -5,7 +5,7 @@
 #include <cuda_runtime.h>
 #include <cuda_fp16.h>
 
-#include "math.hpp"
+#include "functors.hpp"
 #include "types.hpp"
 #include "vector_traits.hpp"
 #include "grid_stride_range.hpp"
@@ -25,519 +25,178 @@ using namespace cv::dnn::cuda4dnn::csl::device;
 
 namespace cv { namespace dnn { namespace cuda4dnn  { namespace kernels {
 
-    namespace raw {
-        template <class T, std::size_t N>
-        __global__ void abs_vec(Span<T> output, View<T> input) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vector_type::size(); j++) {
-                    using device::abs;
-                    vec.data[j] = abs(vec.data[j]);
-                }
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void tanh_vec(Span<T> output, View<T> input) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vector_type::size(); j++) {
-                    using device::tanh;
-                    vec.data[j] = tanh(vec.data[j]);
-                }
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void swish_vec(Span<T> output, View<T> input) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vector_type::size(); j++) {
-                    using device::sigmoid;
-                    vec.data[j] = vec.data[j] * sigmoid(vec.data[j]);
-                }
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void mish_vec(Span<T> output, View<T> input) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vector_type::size(); j++) {
-                    using device::tanh;
-                    using device::log1pexp;
-                    vec.data[j] = vec.data[j] * tanh(log1pexp(vec.data[j]));
-                }
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void sigmoid_vec(Span<T> output, View<T> input) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vector_type::size(); j++) {
-                    using device::sigmoid;
-                    vec.data[j] = sigmoid(vec.data[j]);
-                }
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void bnll_vec(Span<T> output, View<T> input) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vector_type::size(); j++) {
-                    using device::log1pexp;
-                    vec.data[j] = vec.data[j] > T(0) ? vec.data[j] + log1pexp(-vec.data[j]) : log1pexp(vec.data[j]);
-                }
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void elu_vec(Span<T> output, View<T> input) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vector_type::size(); j++) {
-                    using device::expm1;
-                    vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : expm1(vec.data[j]);
-                }
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void relu_vec(Span<T> output, View<T> input, T slope) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for(int j = 0; j < vector_type::size(); j++)
-                    vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : slope * vec.data[j];
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void clipped_relu_vec(Span<T> output, View<T> input, T floor, T ceiling) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                using device::clamp;
-
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vector_type::size(); j++)
-                    vec.data[j] = clamp(vec.data[j], floor, ceiling);
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void axiswise_relu_vec(Span<T> output, View<T> input, size_type inner_size, View<T> slope) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            inner_size /= vector_type::size();
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                const index_type c = (i / inner_size) % static_cast<size_type>(slope.size());
-
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vector_type::size(); j++)
-                    vec.data[j] = vec.data[j] > T(0) ? vec.data[j] : vec.data[j] * slope[c];
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void power_vec(Span<T> output, View<T> input, T exp, T scale, T shift) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                using device::pow;
-
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vector_type::size(); j++)
-                    vec.data[j] = pow(shift + scale * vec.data[j], exp);
-                v_store(output_vPtr[i], vec);
-            }
-        }
-    }
-
-    template <class T, std::size_t N>
-    void launch_vectorized_abs(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
-
-        auto kernel = raw::abs_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input);
-    }
-
-    template <class T>
-    void abs(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(input.size() == output.size());
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_vectorized_abs<T, 4>(stream, output, input);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_vectorized_abs<T, 2>(stream, output, input);
-        } else {
-            launch_vectorized_abs<T, 1>(stream, output, input);
-        }
-    }
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
-#endif
-    template void abs<float>(const Stream& stream, Span<float> output, View<float> input);
-
-    template <class T, std::size_t N>
-    void launch_vectorized_tanh(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
-
-        auto kernel = raw::tanh_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input);
-    }
-
-    template <class T>
-    void tanh(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(input.size() == output.size());
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_vectorized_tanh<T, 4>(stream, output, input);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_vectorized_tanh<T, 2>(stream, output, input);
-        } else {
-            launch_vectorized_tanh<T, 1>(stream, output, input);
-        }
-    }
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void tanh<__half>(const Stream&, Span<__half>, View<__half>);
-#endif
-    template void tanh<float>(const Stream&, Span<float>, View<float>);
-
-    template <class T, std::size_t N>
-    void launch_vectorized_swish(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
-
-        auto kernel = raw::swish_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input);
-    }
-
-    template <class T>
-    void swish(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(input.size() == output.size());
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_vectorized_swish<T, 4>(stream, output, input);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_vectorized_swish<T, 2>(stream, output, input);
-        } else {
-            launch_vectorized_swish<T, 1>(stream, output, input);
-        }
-    }
+namespace raw {
+    template <class T, class Functor, std::size_t N, class ...FunctorArgs>
+    __global__ void generic_op_vec(Span<T> output, View<T> input, FunctorArgs ...functorArgs) {
+        using vector_type = get_vector_type_t<T, N>;
 
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void swish<__half>(const Stream&, Span<__half>, View<__half>);
-#endif
-    template void swish<float>(const Stream&, Span<float>, View<float>);
+        auto output_vPtr = vector_type::get_pointer(output.data());
+        auto input_vPtr = vector_type::get_pointer(input.data());
 
-    template <class T, std::size_t N>
-    void launch_vectorized_mish(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
+        Functor functor(functorArgs...);
 
-        auto kernel = raw::mish_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input);
-    }
-
-    template <class T>
-    void mish(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(input.size() == output.size());
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_vectorized_mish<T, 4>(stream, output, input);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_vectorized_mish<T, 2>(stream, output, input);
-        } else {
-            launch_vectorized_mish<T, 1>(stream, output, input);
+        for (auto i : grid_stride_range(output.size() / vector_type::size())) {
+            vector_type vec;
+            v_load(vec, input_vPtr[i]);
+            for (int j = 0; j < vector_type::size(); j++)
+                vec.data[j] = functor(vec.data[j]);
+            v_store(output_vPtr[i], vec);
         }
     }
 
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void mish<__half>(const Stream&, Span<__half>, View<__half>);
-#endif
-    template void mish<float>(const Stream&, Span<float>, View<float>);
-
     template <class T, std::size_t N>
-    void launch_vectorized_sigmoid(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
+    __global__ void axiswise_relu_vec(Span<T> output, View<T> input, size_type inner_size, View<T> slope) {
+        using vector_type = get_vector_type_t<T, N>;
 
-        auto kernel = raw::sigmoid_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input);
-    }
+        auto output_vPtr = vector_type::get_pointer(output.data());
+        auto input_vPtr = vector_type::get_pointer(input.data());
 
-    template <class T>
-    void sigmoid(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(input.size() == output.size());
+        inner_size /= vector_type::size();
+        for (auto i : grid_stride_range(output.size() / vector_type::size())) {
+            const index_type c = (i / inner_size) % static_cast<size_type>(slope.size());
 
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_vectorized_sigmoid<T, 4>(stream, output, input);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_vectorized_sigmoid<T, 2>(stream, output, input);
-        } else {
-            launch_vectorized_sigmoid<T, 1>(stream, output, input);
+            vector_type vec;
+            v_load(vec, input_vPtr[i]);
+            for (int j = 0; j < vector_type::size(); j++)
+                vec.data[j] = vec.data[j] > T(0) ? vec.data[j] : vec.data[j] * slope[c];
+            v_store(output_vPtr[i], vec);
         }
     }
 
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>);
-#endif
-    template void sigmoid<float>(const Stream&, Span<float>, View<float>);
+} /* namespace raw */
 
-    template <class T, std::size_t N>
-    void launch_vectorized_bnll(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
+template <class T, template <class> class Activation, std::size_t N, class ...ActivationArgs> static
+void launch_vectorized_generic_op(const Stream& stream, Span<T> output, View<T> input, ActivationArgs ...activationArgs) {
+    CV_Assert(is_fully_aligned<T>(output, N));
+    CV_Assert(is_fully_aligned<T>(input, N));
 
-        auto kernel = raw::bnll_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input);
-    }
+    auto kernel = raw::generic_op_vec<T, Activation<T>, N, ActivationArgs...>;
+    auto policy = make_policy(kernel, output.size() / N, 0, stream);
+    launch_kernel(kernel, policy, output, input, activationArgs...);
+}
 
-    template <class T>
-    void bnll(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(input.size() == output.size());
+template <class T, template <class> class Activation, class ...ActivationArgs> static
+void generic_op(const Stream& stream, Span<T> output, View<T> input, ActivationArgs ...activationArgs) {
+    CV_Assert(input.size() == output.size());
 
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_vectorized_bnll<T, 4>(stream, output, input);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_vectorized_bnll<T, 2>(stream, output, input);
-        } else {
-            launch_vectorized_bnll<T, 1>(stream, output, input);
-        }
+    if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
+        launch_vectorized_generic_op<T, Activation, 4>(stream, output, input, activationArgs...);
+    } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
+        launch_vectorized_generic_op<T, Activation, 2>(stream, output, input, activationArgs...);
+    } else {
+        launch_vectorized_generic_op<T, Activation, 1>(stream, output, input, activationArgs...);
     }
+}
 
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
-#endif
-    template void bnll<float>(const Stream&, Span<float>, View<float>);
-
-    template <class T, std::size_t N>
-    void launch_vectorized_elu(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
+template <class T>
+void abs(const Stream& stream, Span<T> output, View<T> input) {
+    generic_op<T, abs_functor>(stream, output, input);
+}
 
-        auto kernel = raw::elu_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input);
-    }
+template <class T>
+void tanh(const Stream& stream, Span<T> output, View<T> input) {
+    generic_op<T, tanh_functor>(stream, output, input);
+}
 
-    template <class T>
-    void elu(const Stream& stream, Span<T> output, View<T> input) {
-        CV_Assert(input.size() == output.size());
+template <class T>
+void swish(const Stream& stream, Span<T> output, View<T> input) {
+    generic_op<T, swish_functor>(stream, output, input);
+}
 
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_vectorized_elu<T, 4>(stream, output, input);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_vectorized_elu<T, 2>(stream, output, input);
-        } else {
-            launch_vectorized_elu<T, 1>(stream, output, input);
-        }
-    }
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void elu<__half>(const Stream&, Span<__half>, View<__half>);
-#endif
-    template void elu<float>(const Stream&, Span<float>, View<float>);
+template <class T>
+void mish(const Stream& stream, Span<T> output, View<T> input) {
+    generic_op<T, mish_functor>(stream, output, input);
+}
 
-    template <class T, std::size_t N>
-    void launch_vectorized_relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
+template <class T>
+void sigmoid(const Stream& stream, Span<T> output, View<T> input) {
+    generic_op<T, sigmoid_functor>(stream, output, input);
+}
 
-        auto kernel = raw::relu_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input, slope);
-    }
+template <class T>
+void bnll(const Stream& stream, Span<T> output, View<T> input) {
+    generic_op<T, bnll_functor>(stream, output, input);
+}
 
-    template <class T>
-    void relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
-        CV_Assert(input.size() == output.size());
+template <class T>
+void elu(const Stream& stream, Span<T> output, View<T> input) {
+    generic_op<T, elu_functor>(stream, output, input);
+}
 
-        if(is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_vectorized_relu<T, 4>(stream, output, input, slope);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_vectorized_relu<T, 2>(stream, output, input, slope);
-        } else {
-            launch_vectorized_relu<T, 1>(stream, output, input, slope);
-        }
-    }
+template <class T>
+void relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
+    generic_op<T, relu_functor>(stream, output, input, slope);
+}
 
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half);
-#endif
-    template void relu<float>(const Stream&, Span<float>, View<float>, float);
+template <class T>
+void clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
+    CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
+    generic_op<T, clipped_relu_functor>(stream, output, input, floor, ceiling);
+}
 
-    template <class T, std::size_t N>
-    void launch_vectorized_clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
+template <class T>
+void power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale, T shift) {
+    CV_Assert(input.size() == output.size());
 
-        auto kernel = raw::clipped_relu_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input, floor, ceiling);
+    if (static_cast<float>(exp) == 1.0f) {
+        scale1_with_bias1(stream, output, input, scale, shift);
+        return;
     }
 
-    template <class T>
-    void clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
-        CV_Assert(input.size() == output.size());
-        CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
-
-        if(is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_vectorized_clipped_relu<T, 4>(stream, output, input, floor, ceiling);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_vectorized_clipped_relu<T, 2>(stream, output, input, floor, ceiling);
-        } else {
-            launch_vectorized_clipped_relu<T, 1>(stream, output, input, floor, ceiling);
-        }
-    }
+    generic_op<T, power_functor>(stream, output, input, exp, scale, shift);
+}
 
 #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
+template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
+template void tanh<__half>(const Stream&, Span<__half>, View<__half>);
+template void swish<__half>(const Stream&, Span<__half>, View<__half>);
+template void mish<__half>(const Stream&, Span<__half>, View<__half>);
+template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>);
+template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
+template void elu<__half>(const Stream&, Span<__half>, View<__half>);
+template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half);
+template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
+template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
 #endif
-    template void clipped_relu<float>(const Stream&, Span<float>, View<float>, float, float);
 
-    template <class T, std::size_t N>
-    void launch_vectorized_axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
-        CV_Assert(inner_size % N == 0);
-
-        auto kernel = raw::axiswise_relu_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input, inner_size, slope);
-    }
-
-    template <class T>
-    void axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {
-        CV_Assert(input.size() == output.size());
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && inner_size % 4 == 0) {
-            launch_vectorized_axiswise_relu<T, 4>(stream, output, input, inner_size, slope);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && inner_size % 2 == 0) {
-            launch_vectorized_axiswise_relu<T, 2>(stream, output, input, inner_size, slope);
-        } else {
-            launch_vectorized_axiswise_relu<T, 1>(stream, output, input, inner_size, slope);
-        }
-    }
+template void abs<float>(const Stream& stream, Span<float> output, View<float> input);
+template void tanh<float>(const Stream&, Span<float>, View<float>);
+template void swish<float>(const Stream&, Span<float>, View<float>);
+template void mish<float>(const Stream&, Span<float>, View<float>);
+template void sigmoid<float>(const Stream&, Span<float>, View<float>);
+template void bnll<float>(const Stream&, Span<float>, View<float>);
+template void elu<float>(const Stream&, Span<float>, View<float>);
+template void relu<float>(const Stream&, Span<float>, View<float>, float);
+template void clipped_relu<float>(const Stream&, Span<float>, View<float>, float, float);
+template void power<float>(const Stream&, Span<float>, View<float>, float, float, float);
+
+template <class T, std::size_t N> static
+void launch_vectorized_axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {
+    CV_Assert(is_fully_aligned<T>(output, N));
+    CV_Assert(is_fully_aligned<T>(input, N));
+    CV_Assert(inner_size % N == 0);
+
+    auto kernel = raw::axiswise_relu_vec<T, N>;
+    auto policy = make_policy(kernel, output.size() / N, 0, stream);
+    launch_kernel(kernel, policy, output, input, inner_size, slope);
+}
+
+template <class T>
+void axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {
+    CV_Assert(input.size() == output.size());
+
+    if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && inner_size % 4 == 0) {
+        launch_vectorized_axiswise_relu<T, 4>(stream, output, input, inner_size, slope);
+    } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && inner_size % 2 == 0) {
+        launch_vectorized_axiswise_relu<T, 2>(stream, output, input, inner_size, slope);
+    } else {
+        launch_vectorized_axiswise_relu<T, 1>(stream, output, input, inner_size, slope);
+    }
+}
 
 #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
     template void axiswise_relu<__half>(const Stream&, Span<__half>, View<__half>, std::size_t, View<__half>);
 #endif
     template void axiswise_relu<float>(const Stream&, Span<float>, View<float>, std::size_t, View<float>);
 
-    template <class T, std::size_t N>
-    void launch_vectorized_power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale, T shift) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
-
-        auto kernel = raw::power_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input, exp, scale, shift);
-    }
-
-    template <class T>
-    void power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale, T shift) {
-        CV_Assert(input.size() == output.size());
-
-        if (static_cast<float>(exp) == 1.0f) {
-            scale1_with_bias1(stream, output, input, scale, shift);
-            return;
-        }
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && output.size()) {
-            launch_vectorized_power<T, 4>(stream, output, input, exp, scale, shift);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && output.size()) {
-            launch_vectorized_power<T, 2>(stream, output, input, exp, scale, shift);
-        } else {
-            launch_vectorized_power<T, 1>(stream, output, input, exp, scale, shift);
-        }
-    }
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
-#endif
-    template void power<float>(const Stream&, Span<float>, View<float>, float, float, float);
-
 }}}} /* namespace cv::dnn::cuda4dnn::kernels */
index 6a5229c..0acc2ff 100644 (file)
@@ -5,8 +5,8 @@
 #include <cuda_runtime.h>
 #include <cuda_fp16.h>
 
+#include "functors.hpp"
 #include "types.hpp"
-#include "math.hpp"
 #include "vector_traits.hpp"
 #include "grid_stride_range.hpp"
 #include "execution.hpp"
@@ -20,32 +20,13 @@ using namespace cv::dnn::cuda4dnn::csl::device;
 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
 
 namespace raw {
-
-    template <class T, std::size_t N>
-    __global__ void biasN_relu_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias, T slope) {
+    template <class T, class Functor, std::size_t N, class ...FunctorArgs>
+    __global__ void biasN_generic_op_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias, FunctorArgs ...functorArgs) {
         using vector_type = get_vector_type_t<T, N>;
 
         auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
 
-        inner_size /= vector_type::size();
-        for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
-            const index_type bias_idx = (i / inner_size) % static_cast<size_type>(bias.size());
-
-            vector_type vec;
-            v_load(vec, inplace_output_vPtr[i]);
-            for(int j = 0; j < vec.size(); j++) {
-                vec.data[j] += bias[bias_idx];
-                vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : slope * vec.data[j];
-            }
-            v_store(inplace_output_vPtr[i], vec);
-        }
-    }
-
-    template <class T, std::size_t N>
-    __global__ void biasN_clipped_relu_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias, T floor, T ceil) {
-        using vector_type = get_vector_type_t<T, N>;
-
-        auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
+        Functor functor(functorArgs...);
 
         inner_size /= vector_type::size();
         for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
@@ -53,298 +34,89 @@ namespace raw {
 
             vector_type vec;
             v_load(vec, inplace_output_vPtr[i]);
-            for(int j = 0; j < vec.size(); j++) {
-                using device::clamp;
-                vec.data[j] = clamp(vec.data[j] + bias[bias_idx], floor, ceil);
-            }
+            for(int j = 0; j < vec.size(); j++)
+                vec.data[j] = functor(vec.data[j] + bias[bias_idx]);
             v_store(inplace_output_vPtr[i], vec);
         }
     }
 
-    template <class T, std::size_t N>
-    __global__ void biasN_power_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias, T power) {
-        using vector_type = get_vector_type_t<T, N>;
-
-        auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
-
-        inner_size /= vector_type::size();
-        for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
-            const index_type bias_idx = (i / inner_size) % static_cast<size_type>(bias.size());
+} /* namespace raw */
 
-            vector_type vec;
-            v_load(vec, inplace_output_vPtr[i]);
-            for(int j = 0; j < vec.size(); j++) {
-                using device::pow;
-                vec.data[j] = pow(vec.data[j] + bias[bias_idx], power);
-            }
-            v_store(inplace_output_vPtr[i], vec);
-        }
-    }
-
-    template <class T, std::size_t N>
-    __global__ void biasN_tanh_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias) {
-        using vector_type = get_vector_type_t<T, N>;
-
-        auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
-
-        inner_size /= vector_type::size();
-        for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
-            const index_type bias_idx = (i / inner_size) % static_cast<size_type>(bias.size());
-
-            vector_type vec;
-            v_load(vec, inplace_output_vPtr[i]);
-            for(int j = 0; j < vec.size(); j++) {
-                using device::tanh;
-                vec.data[j] = tanh(vec.data[j] + bias[bias_idx]);
-            }
-            v_store(inplace_output_vPtr[i], vec);
-        }
-    }
-
-    template <class T, std::size_t N>
-    __global__ void biasN_sigmoid_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias) {
-        using vector_type = get_vector_type_t<T, N>;
-
-        auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
-
-        inner_size /= vector_type::size();
-        for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
-            const index_type bias_idx = (i / inner_size) % static_cast<size_type>(bias.size());
-
-            vector_type vec;
-            v_load(vec, inplace_output_vPtr[i]);
-            for(int j = 0; j < vec.size(); j++) {
-                using device::sigmoid;
-                vec.data[j] = sigmoid(vec.data[j] + bias[bias_idx]);
-            }
-            v_store(inplace_output_vPtr[i], vec);
-        }
-    }
-
-    template <class T, std::size_t N>
-    __global__ void biasN_swish_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias) {
-        using vector_type = get_vector_type_t<T, N>;
-
-        auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
-
-        inner_size /= vector_type::size();
-        for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
-            const index_type bias_idx = (i / inner_size) % static_cast<size_type>(bias.size());
-
-            vector_type vec;
-            v_load(vec, inplace_output_vPtr[i]);
-            for(int j = 0; j < vec.size(); j++) {
-                using device::sigmoid;
-                vec.data[j] += bias[bias_idx];
-                vec.data[j] = vec.data[j] * sigmoid(vec.data[j]);
-            }
-            v_store(inplace_output_vPtr[i], vec);
-        }
-    }
-
-    template <class T, std::size_t N>
-    __global__ void biasN_mish_inplace_vec(Span<T> inplace_output, size_type inner_size, View<T> bias) {
-        using vector_type = get_vector_type_t<T, N>;
-
-        auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
-
-        inner_size /= vector_type::size();
-        for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
-            const index_type bias_idx = (i / inner_size) % static_cast<size_type>(bias.size());
-
-            vector_type vec;
-            v_load(vec, inplace_output_vPtr[i]);
-            for(int j = 0; j < vec.size(); j++) {
-                using device::tanh;
-                using device::log1pexp;
-                vec.data[j] += bias[bias_idx];
-                vec.data[j] = vec.data[j] * tanh(log1pexp(vec.data[j]));
-            }
-            v_store(inplace_output_vPtr[i], vec);
-        }
-    }
-}
-
-template <class T, std::size_t N> static
-void launch_biasN_relu_inplace_vec_kernel(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T slope) {
+template <class T, template <class> class Activation, std::size_t N, class ...ActivationArgs> static
+void launch_vectorized_biasN_generic_op_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, ActivationArgs ...activationArgs) {
+    CV_Assert(inplace_output.size() % inner_size == 0);
+    CV_Assert(inplace_output.size() % bias.size() == 0);
     CV_Assert(is_fully_aligned<T>(inplace_output, N));
     CV_Assert(inner_size % N == 0);
 
-    auto kernel = raw::biasN_relu_inplace_vec<T, N>;
+    auto kernel = raw::biasN_generic_op_inplace_vec<T, Activation<T>, N, ActivationArgs...>;
     auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
-    launch_kernel(kernel, policy, inplace_output, inner_size, bias, slope);
+    launch_kernel(kernel, policy, inplace_output, inner_size, bias, activationArgs...);
 }
 
-template <class T>
-void biasN_relu_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T slope) {
+template <class T, template <class> class Activation, class ...ActivationArgs> static
+void biasN_generic_op_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, ActivationArgs ...activationArgs) {
     if (is_fully_aligned<T>(inplace_output, 4) && inner_size % 4 == 0) {
-        launch_biasN_relu_inplace_vec_kernel<T, 4>(stream, inplace_output, inner_size, bias, slope);
+        launch_vectorized_biasN_generic_op_inplace<T, Activation, 4>(stream, inplace_output, inner_size, bias, activationArgs...);
     } else if (is_fully_aligned<T>(inplace_output, 2) && inner_size % 2 == 0) {
-        launch_biasN_relu_inplace_vec_kernel<T, 2>(stream, inplace_output, inner_size, bias, slope);
+        launch_vectorized_biasN_generic_op_inplace<T, Activation, 2>(stream, inplace_output, inner_size, bias, activationArgs...);
     } else {
-        launch_biasN_relu_inplace_vec_kernel<T, 1>(stream, inplace_output, inner_size, bias, slope);
+        launch_vectorized_biasN_generic_op_inplace<T, Activation, 1>(stream, inplace_output, inner_size, bias, activationArgs...);
     }
 }
 
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-template void biasN_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half);
-#endif
-template void biasN_relu_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, float);
-
-template <class T, std::size_t N> static
-void launch_biasN_clipped_relu_inplace_vec_kernel(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T floor, T ceil) {
-    CV_Assert(is_fully_aligned<T>(inplace_output, N));
-    CV_Assert(inner_size % N == 0);
-
-    auto kernel = raw::biasN_clipped_relu_inplace_vec<T, N>;
-    auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
-    launch_kernel(kernel, policy, inplace_output, inner_size, bias, floor, ceil);
+template <class T>
+void biasN_relu_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T slope) {
+    biasN_generic_op_inplace<T, relu_functor>(stream, inplace_output, inner_size, bias, slope);
 }
 
 template <class T>
 void biasN_clipped_relu_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T floor, T ceil) {
-    if (is_fully_aligned<T>(inplace_output, 4) && inner_size % 4 == 0) {
-        launch_biasN_clipped_relu_inplace_vec_kernel<T, 4>(stream, inplace_output, inner_size, bias, floor, ceil);
-    } else if (is_fully_aligned<T>(inplace_output, 2) && inner_size % 2 == 0) {
-        launch_biasN_clipped_relu_inplace_vec_kernel<T, 2>(stream, inplace_output, inner_size, bias, floor, ceil);
-    } else {
-        launch_biasN_clipped_relu_inplace_vec_kernel<T, 1>(stream, inplace_output, inner_size, bias, floor, ceil);
-    }
-}
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-template void biasN_clipped_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half, __half);
-#endif
-template void biasN_clipped_relu_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, float, float);
-
-template <class T, std::size_t N> static
-void launch_biasN_power_inplace_vec_kernel(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T power) {
-    CV_Assert(is_fully_aligned<T>(inplace_output, N));
-    CV_Assert(inner_size % N == 0);
-
-    auto kernel = raw::biasN_power_inplace_vec<T, N>;
-    auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
-    launch_kernel(kernel, policy, inplace_output, inner_size, bias, power);
+    CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceil));
+    biasN_generic_op_inplace<T, clipped_relu_functor>(stream, inplace_output, inner_size, bias, floor, ceil);
 }
 
 template <class T>
-void biasN_power_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T power) {
-    if (is_fully_aligned<T>(inplace_output, 4) && inner_size % 4 == 0) {
-        launch_biasN_power_inplace_vec_kernel<T, 4>(stream, inplace_output, inner_size, bias, power);
-    } else if (is_fully_aligned<T>(inplace_output, 2) && inner_size % 2 == 0) {
-        launch_biasN_power_inplace_vec_kernel<T, 2>(stream, inplace_output, inner_size, bias, power);
-    } else {
-        launch_biasN_power_inplace_vec_kernel<T, 1>(stream, inplace_output, inner_size, bias, power);
-    }
-}
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-template void biasN_power_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half);
-#endif
-template void biasN_power_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, float);
-
-template <class T, std::size_t N> static
-void launch_biasN_tanh_inplace_vec_kernel(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
-    CV_Assert(is_fully_aligned<T>(inplace_output, N));
-    CV_Assert(inner_size % N == 0);
-
-    auto kernel = raw::biasN_tanh_inplace_vec<T, N>;
-    auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
-    launch_kernel(kernel, policy, inplace_output, inner_size, bias);
+void biasN_power_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias, T power, T scale, T shift) {
+    biasN_generic_op_inplace<T, power_functor>(stream, inplace_output, inner_size, bias, power, scale, shift);
 }
 
 template <class T>
 void biasN_tanh_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
-    if (is_fully_aligned<T>(inplace_output, 4) && inner_size % 4 == 0) {
-        launch_biasN_tanh_inplace_vec_kernel<T, 4>(stream, inplace_output, inner_size, bias);
-    } else if (is_fully_aligned<T>(inplace_output, 2) && inner_size % 2 == 0) {
-        launch_biasN_tanh_inplace_vec_kernel<T, 2>(stream, inplace_output, inner_size, bias);
-    } else {
-        launch_biasN_tanh_inplace_vec_kernel<T, 1>(stream, inplace_output, inner_size, bias);
-    }
-}
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-template void biasN_tanh_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
-#endif
-template void biasN_tanh_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
-
-template <class T, std::size_t N> static
-void launch_biasN_sigmoid_inplace_vec_kernel(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
-    CV_Assert(is_fully_aligned<T>(inplace_output, N));
-    CV_Assert(inner_size % N == 0);
-
-    auto kernel = raw::biasN_sigmoid_inplace_vec<T, N>;
-    auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
-    launch_kernel(kernel, policy, inplace_output, inner_size, bias);
+    biasN_generic_op_inplace<T, tanh_functor>(stream, inplace_output, inner_size, bias);
 }
 
 template <class T>
 void biasN_sigmoid_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
-    if (is_fully_aligned<T>(inplace_output, 4) && inner_size % 4 == 0) {
-        launch_biasN_sigmoid_inplace_vec_kernel<T, 4>(stream, inplace_output, inner_size, bias);
-    } else if (is_fully_aligned<T>(inplace_output, 2) && inner_size % 2 == 0) {
-        launch_biasN_sigmoid_inplace_vec_kernel<T, 2>(stream, inplace_output, inner_size, bias);
-    } else {
-        launch_biasN_sigmoid_inplace_vec_kernel<T, 1>(stream, inplace_output, inner_size, bias);
-    }
-}
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-template void biasN_sigmoid_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
-#endif
-template void biasN_sigmoid_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
-
-template <class T, std::size_t N> static
-void launch_biasN_swish_inplace_vec_kernel(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
-    CV_Assert(is_fully_aligned<T>(inplace_output, N));
-    CV_Assert(inner_size % N == 0);
-
-    auto kernel = raw::biasN_swish_inplace_vec<T, N>;
-    auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
-    launch_kernel(kernel, policy, inplace_output, inner_size, bias);
+    biasN_generic_op_inplace<T, sigmoid_functor>(stream, inplace_output, inner_size, bias);
 }
 
 template <class T>
 void biasN_swish_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
-    if (is_fully_aligned<T>(inplace_output, 4) && inner_size % 4 == 0) {
-        launch_biasN_swish_inplace_vec_kernel<T, 4>(stream, inplace_output, inner_size, bias);
-    } else if (is_fully_aligned<T>(inplace_output, 2) && inner_size % 2 == 0) {
-        launch_biasN_swish_inplace_vec_kernel<T, 2>(stream, inplace_output, inner_size, bias);
-    } else {
-        launch_biasN_swish_inplace_vec_kernel<T, 1>(stream, inplace_output, inner_size, bias);
-    }
-}
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-template void biasN_swish_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
-#endif
-template void biasN_swish_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
-
-template <class T, std::size_t N> static
-void launch_biasN_mish_inplace_vec_kernel(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
-    CV_Assert(is_fully_aligned<T>(inplace_output, N));
-    CV_Assert(inner_size % N == 0);
-
-    auto kernel = raw::biasN_mish_inplace_vec<T, N>;
-    auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
-    launch_kernel(kernel, policy, inplace_output, inner_size, bias);
+    biasN_generic_op_inplace<T, swish_functor>(stream, inplace_output, inner_size, bias);
 }
 
 template <class T>
 void biasN_mish_inplace(const Stream& stream, Span<T> inplace_output, std::size_t inner_size, View<T> bias) {
-    if (is_fully_aligned<T>(inplace_output, 4) && inner_size % 4 == 0) {
-        launch_biasN_mish_inplace_vec_kernel<T, 4>(stream, inplace_output, inner_size, bias);
-    } else if (is_fully_aligned<T>(inplace_output, 2) && inner_size % 2 == 0) {
-        launch_biasN_mish_inplace_vec_kernel<T, 2>(stream, inplace_output, inner_size, bias);
-    } else {
-        launch_biasN_mish_inplace_vec_kernel<T, 1>(stream, inplace_output, inner_size, bias);
-    }
+    biasN_generic_op_inplace<T, mish_functor>(stream, inplace_output, inner_size, bias);
 }
 
 #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
+template void biasN_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half);
+template void biasN_clipped_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half, __half);
+template void biasN_power_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half, __half, __half);
+template void biasN_tanh_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
+template void biasN_sigmoid_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
+template void biasN_swish_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
 template void biasN_mish_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>);
 #endif
+
+template void biasN_relu_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, float);
+template void biasN_clipped_relu_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, float, float);
+template void biasN_power_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>, float, float, float);
+template void biasN_tanh_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
+template void biasN_sigmoid_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
+template void biasN_swish_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
 template void biasN_mish_inplace<float>(const Stream&, Span<float>, std::size_t, View<float>);
 
 }}}} /* namespace cv::dnn::cuda4dnn::kernels */
index 521bb43..a7d06e6 100644 (file)
@@ -5,7 +5,7 @@
 #include <cuda_runtime.h>
 #include <cuda_fp16.h>
 
-#include "math.hpp"
+#include "functors.hpp"
 #include "grid_stride_range.hpp"
 #include "execution.hpp"
 #include "vector_traits.hpp"
@@ -20,263 +20,91 @@ using namespace cv::dnn::cuda4dnn::csl::device;
 
 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
 
-    namespace raw {
-        template <class T, std::size_t N>
-        __global__ void eltwise_max_2_vec(Span<T> output, View<T> x, View<T> y) {
-            using vector_type = get_vector_type_t<T, N>;
+namespace raw {
+    template <class T, class Functor, std::size_t N, class ...FunctorArgs>
+    __global__ void eltwise_op_vec(Span<T> output, View<T> x, View<T> y, FunctorArgs ...functorArgs) {
+        using vector_type = get_vector_type_t<T, N>;
 
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto x_vPtr = vector_type::get_pointer(x.data());
-            auto y_vPtr = vector_type::get_pointer(y.data());
+        auto output_vPtr = vector_type::get_pointer(output.data());
+        auto x_vPtr = vector_type::get_pointer(x.data());
+        auto y_vPtr = vector_type::get_pointer(y.data());
 
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec_x, vec_y;
-                v_load(vec_x, x_vPtr[i]);
-                v_load(vec_y, y_vPtr[i]);
+        Functor functor(functorArgs...);
 
-                for (int j = 0; j < vector_type::size(); j++) {
-                    using device::max;
-                    vec_x.data[j] = max(vec_x.data[j], vec_y.data[j]);
-                }
-
-                v_store(output_vPtr[i], vec_x);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void eltwise_sum_2_vec(Span<T> output, View<T> x, View<T> y) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto x_vPtr = vector_type::get_pointer(x.data());
-            auto y_vPtr = vector_type::get_pointer(y.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec_x, vec_y;
-                v_load(vec_x, x_vPtr[i]);
-                v_load(vec_y, y_vPtr[i]);
-
-                for (int j = 0; j < vector_type::size(); j++)
-                    vec_x.data[j] = vec_x.data[j] + vec_y.data[j];
-
-                v_store(output_vPtr[i], vec_x);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void eltwise_sum_coeff_2_vec(Span<T> output, T coeff_x, View<T> x, T coeff_y, View<T> y) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto x_vPtr = vector_type::get_pointer(x.data());
-            auto y_vPtr = vector_type::get_pointer(y.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec_x, vec_y;
-                v_load(vec_x, x_vPtr[i]);
-                v_load(vec_y, y_vPtr[i]);
-
-                for (int j = 0; j < vector_type::size(); j++)
-                    vec_x.data[j] = coeff_x * vec_x.data[j] + coeff_y * vec_y.data[j];
-
-                v_store(output_vPtr[i], vec_x);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void eltwise_prod_2_vec(Span<T> output, View<T> x, View<T> y) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto x_vPtr = vector_type::get_pointer(x.data());
-            auto y_vPtr = vector_type::get_pointer(y.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec_x, vec_y;
-                v_load(vec_x, x_vPtr[i]);
-                v_load(vec_y, y_vPtr[i]);
-
-                for (int j = 0; j < vector_type::size(); j++)
-                    vec_x.data[j] = vec_x.data[j] * vec_y.data[j];
-
-                v_store(output_vPtr[i], vec_x);
-            }
-        }
-
-        template <class T, std::size_t N>
-        __global__ void eltwise_div_2_vec(Span<T> output, View<T> x, View<T> y) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto x_vPtr = vector_type::get_pointer(x.data());
-            auto y_vPtr = vector_type::get_pointer(y.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec_x, vec_y;
-                v_load(vec_x, x_vPtr[i]);
-                v_load(vec_y, y_vPtr[i]);
-
-                for (int j = 0; j < vector_type::size(); j++)
-                    vec_x.data[j] = vec_x.data[j] / vec_y.data[j];
-
-                v_store(output_vPtr[i], vec_x);
-            }
-        }
-    }
-
-    template <class T, std::size_t N>
-    void launch_vectorized_eltwise_max_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(x, N));
-        CV_Assert(is_fully_aligned<T>(y, N));
-
-        auto kernel = raw::eltwise_max_2_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, x, y);
-    }
-
-    template <class T>
-    void eltwise_max_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
-        CV_Assert(x.size() == y.size());
-        CV_Assert(x.size() == output.size());
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
-            launch_vectorized_eltwise_max_2<T, 4>(stream, output, x, y);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
-            launch_vectorized_eltwise_max_2<T, 2>(stream, output, x, y);
-        } else {
-            launch_vectorized_eltwise_max_2<T, 1>(stream, output, x, y);
-        }
-    }
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void eltwise_max_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y);
-#endif
-    template void eltwise_max_2(const Stream& stream, Span<float> output, View<float> x, View<float> y);
-
-    template <class T, std::size_t N>
-    void launch_vectorized_eltwise_sum_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(x, N));
-        CV_Assert(is_fully_aligned<T>(y, N));
-
-        auto kernel = raw::eltwise_sum_2_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, x, y);
-    }
-
-    template <class T>
-    void eltwise_sum_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
-        CV_Assert(x.size() == y.size());
-        CV_Assert(x.size() == output.size());
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
-            launch_vectorized_eltwise_sum_2<T, 4>(stream, output, x, y);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
-            launch_vectorized_eltwise_sum_2<T, 2>(stream, output, x, y);
-        } else {
-            launch_vectorized_eltwise_sum_2<T, 1>(stream, output, x, y);
+        for (auto i : grid_stride_range(output.size() / vector_type::size())) {
+            vector_type vec_x, vec_y;
+            v_load(vec_x, x_vPtr[i]);
+            v_load(vec_y, y_vPtr[i]);
+            for (int j = 0; j < vector_type::size(); j++)
+                vec_x.data[j] = functor(vec_x.data[j], vec_y.data[j]);
+            v_store(output_vPtr[i], vec_x);
         }
     }
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void eltwise_sum_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y);
-#endif
-    template void eltwise_sum_2(const Stream& stream, Span<float> output, View<float> x, View<float> y);
-
-    template <class T, std::size_t N>
-    void launch_vectorized_eltwise_sum_coeff_2(const Stream& stream, Span<T> output, T coeff_x, View<T> x, T coeff_y, View<T> y) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(x, N));
-        CV_Assert(is_fully_aligned<T>(y, N));
-
-        auto kernel = raw::eltwise_sum_coeff_2_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, coeff_x, x, coeff_y, y);
-    }
-
-    template <class T>
-    void eltwise_sum_coeff_2(const Stream& stream, Span<T> output, T coeff_x, View<T> x, T coeff_y, View<T> y) {
-        CV_Assert(x.size() == y.size());
-        CV_Assert(x.size() == output.size());
-
-        if (static_cast<float>(coeff_x) == 1.0f && static_cast<float>(coeff_y) == 1.0f) {
-            eltwise_sum_2(stream, output, x, y);
-            return;
-        }
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
-            launch_vectorized_eltwise_sum_coeff_2<T, 4>(stream, output, coeff_x, x, coeff_y, y);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
-            launch_vectorized_eltwise_sum_coeff_2<T, 2>(stream, output, coeff_x, x, coeff_y, y);
-        } else {
-            launch_vectorized_eltwise_sum_coeff_2<T, 1>(stream, output, coeff_x, x, coeff_y, y);
-        }
-    }
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void eltwise_sum_coeff_2(const Stream&, Span<__half>, __half, View<__half>, __half, View<__half>);
-#endif
-    template void eltwise_sum_coeff_2(const Stream&, Span<float>, float, View<float>, float, View<float>);
-
-    template <class T, std::size_t N>
-    void launch_vectorized_eltwise_prod_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(x, N));
-        CV_Assert(is_fully_aligned<T>(y, N));
-
-        auto kernel = raw::eltwise_prod_2_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, x, y);
-    }
-
-    template <class T>
-    void eltwise_prod_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
-        CV_Assert(x.size() == y.size());
-        CV_Assert(x.size() == output.size());
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
-            launch_vectorized_eltwise_prod_2<T, 4>(stream, output, x, y);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
-            launch_vectorized_eltwise_prod_2<T, 2>(stream, output, x, y);
-        } else {
-            launch_vectorized_eltwise_prod_2<T, 1>(stream, output, x, y);
-        }
+}
+
+template <class T, template <class> class EltwiseOp, std::size_t N, class ...EltwiseOpArgs> static
+void launch_vectorized_eltwise_op(const Stream& stream, Span<T> output, View<T> x, View<T> y, EltwiseOpArgs ...eltwiseOpArgs) {
+    CV_Assert(x.size() == y.size());
+    CV_Assert(x.size() == output.size());
+    CV_Assert(is_fully_aligned<T>(output, N));
+    CV_Assert(is_fully_aligned<T>(x, N));
+    CV_Assert(is_fully_aligned<T>(y, N));
+
+    auto kernel = raw::eltwise_op_vec<T, EltwiseOp<T>, N, EltwiseOpArgs...>;
+    auto policy = make_policy(kernel, output.size() / N, 0, stream);
+    launch_kernel(kernel, policy, output, x, y, eltwiseOpArgs...);
+}
+
+template <class T, template <class> class EltwiseOp, class ...EltwiseOpArgs> static
+void eltwise_op(const Stream& stream, Span<T> output, View<T> x, View<T> y, EltwiseOpArgs ...eltwiseOpArgs) {
+    CV_Assert(x.size() == y.size());
+    CV_Assert(x.size() == output.size());
+
+    if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
+        launch_vectorized_eltwise_op<T, EltwiseOp, 4>(stream, output, x, y, eltwiseOpArgs...);
+    } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
+        launch_vectorized_eltwise_op<T, EltwiseOp, 2>(stream, output, x, y, eltwiseOpArgs...);
+    } else {
+        launch_vectorized_eltwise_op<T, EltwiseOp, 1>(stream, output, x, y, eltwiseOpArgs...);
     }
+}
 
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void eltwise_prod_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y);
-#endif
-    template void eltwise_prod_2(const Stream& stream, Span<float> output, View<float> x, View<float> y);
+template <class T>
+void eltwise_max_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
+    eltwise_op<T, max_functor>(stream, output, x, y);
+}
 
-    template <class T, std::size_t N>
-    void launch_vectorized_eltwise_div_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(x, N));
-        CV_Assert(is_fully_aligned<T>(y, N));
+template <class T>
+void eltwise_sum_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
+    eltwise_op<T, sum_functor>(stream, output, x, y);
+}
 
-        auto kernel = raw::eltwise_div_2_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, x, y);
-    }
+template <class T>
+void eltwise_sum_coeff_2(const Stream& stream, Span<T> output, T coeff_x, View<T> x, T coeff_y, View<T> y) {
+    eltwise_op<T, scaled_sum_functor>(stream, output, x, y, coeff_x, coeff_y);
+}
 
-    template <class T>
-    void eltwise_div_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
-        CV_Assert(x.size() == y.size());
-        CV_Assert(x.size() == output.size());
+template <class T>
+void eltwise_prod_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
+    eltwise_op<T, product_functor>(stream, output, x, y);
+}
 
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
-            launch_vectorized_eltwise_div_2<T, 4>(stream, output, x, y);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
-            launch_vectorized_eltwise_div_2<T, 2>(stream, output, x, y);
-        } else {
-            launch_vectorized_eltwise_div_2<T, 1>(stream, output, x, y);
-        }
-    }
+template <class T>
+void eltwise_div_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
+    eltwise_op<T, div_functor>(stream, output, x, y);
+}
 
 #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
     template void eltwise_div_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y);
+    template void eltwise_prod_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y);
+    template void eltwise_sum_coeff_2(const Stream&, Span<__half>, __half, View<__half>, __half, View<__half>);
+    template void eltwise_sum_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y);
+    template void eltwise_max_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y);
 #endif
     template void eltwise_div_2(const Stream& stream, Span<float> output, View<float> x, View<float> y);
+    template void eltwise_prod_2(const Stream& stream, Span<float> output, View<float> x, View<float> y);
+    template void eltwise_sum_coeff_2(const Stream&, Span<float>, float, View<float>, float, View<float>);
+    template void eltwise_sum_2(const Stream& stream, Span<float> output, View<float> x, View<float> y);
+    template void eltwise_max_2(const Stream& stream, Span<float> output, View<float> x, View<float> y);
 
 }}}} /* namespace cv::dnn::cuda4dnn::kernels */
index 57d1e30..27b86ef 100644 (file)
@@ -63,17 +63,17 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
     template <class Kernel, typename ...Args> inline
     void launch_kernel(Kernel kernel, Args ...args) {
         auto policy = make_policy(kernel);
-        kernel <<<policy.grid, policy.block>>> (std::forward<Args>(args)...);
+        kernel <<<policy.grid, policy.block>>> (args...);
     }
 
     template <class Kernel, typename ...Args> inline
     void launch_kernel(Kernel kernel, dim3 grid, dim3 block, Args ...args) {
-        kernel <<<grid, block>>> (std::forward<Args>(args)...);
+        kernel <<<grid, block>>> (args...);
     }
 
     template <class Kernel, typename ...Args> inline
     void launch_kernel(Kernel kernel, execution_policy policy, Args ...args) {
-        kernel <<<policy.grid, policy.block, policy.sharedMem, policy.stream>>> (std::forward<Args>(args)...);
+        kernel <<<policy.grid, policy.block, policy.sharedMem, policy.stream>>> (args...);
     }
 
 }}}} /* namespace cv::dnn::cuda4dnn::csl */
diff --git a/modules/dnn/src/cuda/functors.hpp b/modules/dnn/src/cuda/functors.hpp
new file mode 100644 (file)
index 0000000..c35a854
--- /dev/null
@@ -0,0 +1,139 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+#ifndef OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP
+#define OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP
+
+#include <cuda_runtime.h>
+
+#include "math.hpp"
+
+namespace cv { namespace dnn { namespace cuda4dnn  { namespace kernels {
+
+template <class T>
+struct abs_functor {
+    __device__ T operator()(T value) {
+        using csl::device::abs;
+        return abs(value);
+    }
+};
+
+template <class T>
+struct tanh_functor {
+    __device__ T operator()(T value) {
+        using csl::device::tanh;
+        return tanh(value);
+    }
+};
+
+template <class T>
+struct swish_functor {
+    __device__ T operator()(T value) {
+        using csl::device::sigmoid;
+        return value * sigmoid(value);
+    }
+};
+
+template <class T>
+struct mish_functor {
+    __device__ T operator()(T value) {
+        using csl::device::tanh;
+        using csl::device::log1pexp;
+        return value * tanh(log1pexp(value));
+    }
+};
+
+template <class T>
+struct sigmoid_functor {
+    __device__ T operator()(T value) {
+        using csl::device::sigmoid;
+        return sigmoid(value);
+    }
+};
+
+template <class T>
+struct bnll_functor {
+    __device__ T operator()(T value) {
+        using csl::device::log1pexp;
+        return value > T(0) ? value + log1pexp(-value) : log1pexp(value);
+    }
+};
+
+template <class T>
+struct elu_functor {
+    __device__ T operator()(T value) {
+        using csl::device::expm1;
+        return value >= T(0) ? value : expm1(value);
+    }
+};
+
+template <class T>
+struct relu_functor {
+    __device__ relu_functor(T slope_) : slope{slope_} { }
+    __device__ T operator()(T value) {
+        using csl::device::log1pexp;
+        return value >= T(0) ? value : slope * value;
+    }
+
+    T slope;
+};
+
+template <class T>
+struct clipped_relu_functor {
+    __device__ clipped_relu_functor(T floor_, T ceiling_) : floor{floor_}, ceiling{ceiling_} { }
+    __device__ T operator()(T value) {
+        using csl::device::clamp;
+        return clamp(value, floor, ceiling);
+    }
+
+    T floor, ceiling;
+};
+
+template <class T>
+struct power_functor {
+    __device__ power_functor(T exp_, T scale_, T shift_) : exp{exp_}, scale{scale_}, shift{shift_} { }
+    __device__ T operator()(T value) {
+        using csl::device::pow;
+        return pow(shift + scale * value, exp);
+    }
+
+    T exp, scale, shift;
+};
+
+template <class T>
+struct max_functor {
+    __device__ T operator()(T x, T y) {
+        using csl::device::max;
+        return max(x, y);
+    }
+};
+
+template <class T>
+struct sum_functor {
+    __device__ T operator()(T x, T y) { return x + y; }
+};
+
+template <class T>
+struct scaled_sum_functor {
+    __device__ scaled_sum_functor(T scale_x_, T scale_y_)
+        : scale_x{scale_x_}, scale_y{scale_y_} { }
+
+    __device__ T operator()(T x, T y) { return scale_x * x + scale_y * y; }
+
+    T scale_x, scale_y;
+};
+
+template <class T>
+struct product_functor {
+    __device__ T operator()(T x, T y) { return x * y; }
+};
+
+template <class T>
+struct div_functor {
+    __device__ T operator()(T x, T y) { return x / y; }
+};
+
+}}}} /* namespace cv::dnn::cuda4dnn::kernels */
+
+#endif /* OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP */
\ No newline at end of file
index 31fa471..36bdb7a 100644 (file)
@@ -25,22 +25,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
 
     namespace raw {
         template <class T, std::size_t N>
-        __global__ void bias1_vec(Span<T> output, View<T> input, T beta) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vec.size(); j++)
-                    vec.data[j] = vec.data[j] + beta;
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
         __global__ void biasN_vec(Span<T> output, View<T> input, size_type inner_size, View<T> bias) {
             using vector_type = get_vector_type_t<T, N>;
 
@@ -60,22 +44,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
         }
 
         template <class T, std::size_t N>
-        __global__ void scale1_vec(Span<T> output, View<T> input, T alpha) {
-            using vector_type = get_vector_type_t<T, N>;
-
-            auto output_vPtr = vector_type::get_pointer(output.data());
-            auto input_vPtr = vector_type::get_pointer(input.data());
-
-            for (auto i : grid_stride_range(output.size() / vector_type::size())) {
-                vector_type vec;
-                v_load(vec, input_vPtr[i]);
-                for (int j = 0; j < vec.size(); j++)
-                    vec.data[j] = vec.data[j] * alpha;
-                v_store(output_vPtr[i], vec);
-            }
-        }
-
-        template <class T, std::size_t N>
         __global__ void scaleN_vec(Span<T> output, View<T> input, size_type inner_size, View<T> weights)
         {
             using vector_type = get_vector_type_t<T, N>;
@@ -134,34 +102,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
     }
 
     template <class T, std::size_t N> static
-    void launch_bias1_vec_kernel(const Stream& stream, Span<T> output, View<T> input, T beta) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
-
-        auto kernel = raw::bias1_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input, beta);
-    }
-
-    template <class T>
-    void bias1(const Stream& stream, TensorSpan<T> output, TensorView<T> input, T beta) {
-        CV_Assert(is_shape_same(input, output));
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_bias1_vec_kernel<T, 4>(stream, output, input, beta);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_bias1_vec_kernel<T, 2>(stream, output, input, beta);
-        } else {
-            launch_bias1_vec_kernel<T, 1>(stream, output, input, beta);
-        }
-    }
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void bias1<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, __half);
-#endif
-    template void bias1<float>(const Stream&, TensorSpan<float>, TensorView<float>, float);
-
-    template <class T, std::size_t N> static
     void launch_biasN_vec_kernel(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> bias){
         CV_Assert(is_fully_aligned<T>(output, N));
         CV_Assert(is_fully_aligned<T>(input, N));
@@ -196,34 +136,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
     template void biasN<float>(const Stream&, TensorSpan<float>, TensorView<float>, std::size_t, TensorView<float>);
 
     template <class T, std::size_t N> static
-    void launch_scale1_vec_kernel(const Stream& stream, Span<T> output, View<T> input, T alpha) {
-        CV_Assert(is_fully_aligned<T>(output, N));
-        CV_Assert(is_fully_aligned<T>(input, N));
-
-        auto kernel = raw::scale1_vec<T, N>;
-        auto policy = make_policy(kernel, output.size() / N, 0, stream);
-        launch_kernel(kernel, policy, output, input, alpha);
-    }
-
-    template <class T>
-    void scale1(const Stream& stream, TensorSpan<T> output, TensorView<T> input, T alpha) {
-        CV_Assert(is_shape_same(input, output));
-
-        if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
-            launch_scale1_vec_kernel<T, 4>(stream, output, input, alpha);
-        } else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
-            launch_scale1_vec_kernel<T, 2>(stream, output, input, alpha);
-        } else {
-            launch_scale1_vec_kernel<T, 1>(stream, output, input, alpha);
-        }
-    }
-
-#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void scale1<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, __half);
-#endif
-    template void scale1<float>(const Stream&, TensorSpan<float>, TensorView<float>, float);
-
-    template <class T, std::size_t N> static
     void launch_scaleN_vec_kernel(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> weights) {
         CV_Assert(is_fully_aligned<T>(output, N));
         CV_Assert(is_fully_aligned<T>(input, N));
index 93660a8..500f9bb 100644 (file)
@@ -19,7 +19,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
     void biasN_clipped_relu_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, T floor, T ceiling);
 
     template <class T>
-    void biasN_power_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, T exp);
+    void biasN_power_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias, T exp, T scale, T shift);
 
     template <class T>
     void biasN_tanh_inplace(const csl::Stream& stream, csl::Span<T> inplace_output, std::size_t inner_size, csl::View<T> bias);
index 32fa1d8..7b7da3b 100644 (file)
 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
 
     template <class T>
-    void bias1(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, T alpha);
-
-    template <class T>
     void biasN(const csl::Stream& stream,
         csl::TensorSpan<T> output,
         csl::TensorView<T> input, std::size_t inner_size,
         csl::TensorView<T> bias);
 
     template <class T>
-    void scale1(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, T alpha);
-
-    template <class T>
     void scaleN(const csl::Stream& stream,
         csl::TensorSpan<T> output,
         csl::TensorView<T> input, std::size_t inner_size,
index 0a0050b..b003952 100644 (file)
@@ -286,7 +286,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
                             kernels::biasN_clipped_relu_inplace<T>(stream, output, inner_size, biasTensor, crelu_floor, crelu_ceil);
                             break;
                         case ConvolutionConfiguration::ActivationType::POWER:
-                            kernels::biasN_power_inplace<T>(stream, output, inner_size, biasTensor, power_exp);
+                            kernels::biasN_power_inplace<T>(stream, output, inner_size, biasTensor, power_exp, T(1.0), T(0.0));
                             break;
                         case ConvolutionConfiguration::ActivationType::TANH:
                             kernels::biasN_tanh_inplace<T>(stream, output, inner_size, biasTensor);
index ecef608..f067ddd 100644 (file)
@@ -113,7 +113,7 @@ namespace cv { namespace dnn { namespace cuda4dnn {
              */
             if (weight != 1.0)
             {
-                kernels::scale1<T>(stream, output, input, weight);
+                kernels::scale1_with_bias1<T>(stream, output, input, weight, 1.0);
             }
             else if (!weightsTensor.empty())
             {