add half pixel centers and align corners param
authorYashasSamaga <yashas_2010@yahoo.com>
Sun, 27 Dec 2020 09:35:39 +0000 (15:05 +0530)
committerYashasSamaga <yashas_2010@yahoo.com>
Sun, 27 Dec 2020 09:35:39 +0000 (15:05 +0530)
modules/dnn/src/cuda/math.hpp
modules/dnn/src/cuda/resize.cu
modules/dnn/src/cuda4dnn/kernels/resize.hpp
modules/dnn/src/cuda4dnn/primitives/resize.hpp
modules/dnn/src/layers/resize_layer.cpp

index 1a9b221896fcb486dc8e245fbd07241be5ae8fac..273f3fe98e0ca49bc6db7e8bfb37753fffa72fcd 100644 (file)
@@ -108,6 +108,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
 
     template <class T> __device__ T clamp(T value, T lower, T upper) { return min(max(value, lower), upper); }
 
+    template <class T> __device__ long lround(T value);
+    template <> inline __device__ long lround(double value) { return ::lround(value); }
+    template <> inline __device__ long lround(float value) { return lroundf(value); }
+
     template <class T> __device__ T round(T value);
     template <> inline __device__ double round(double value) { return ::round(value); }
     template <> inline __device__ float round(float value) { return roundf(value); }
index 045b4f0a873975770969f2da45b310471b386d87..b780dab9f9948aba7f0708ca1162e7f58b5affca 100644 (file)
@@ -26,7 +26,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
         template <class T, std::size_t CHANNELS_PER_ITER>
         __global__ void resize_nn(
             Span<T> output, size_type out_height, size_type out_width,
-            View<T> input, size_type in_height, size_type in_width)
+            View<T> input, size_type in_height, size_type in_width,
+            float o2i_fy, float o2i_fx, bool round, bool half_pixel_centers)
         {
             auto in_image_size = in_height * in_width;
             auto out_image_size = out_height * out_width;
@@ -60,12 +61,16 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
                 const index_type y = (iter % out_image_size) / out_width;
                 const index_type x = iter % out_width;
 
-                /* o2i = output to input */
-                auto o2i_fy = static_cast<float>(in_height) / out_height;
-                auto o2i_fx = static_cast<float>(in_width) / out_width;
+                auto in_yf = half_pixel_centers ? (y + 0.5f) * o2i_fy : y * o2i_fy;
+                auto in_xf = half_pixel_centers ? (x + 0.5f) * o2i_fx : x * o2i_fx;
+
+                using device::lround;
+                index_type in_y = round ? lround(in_yf) : static_cast<index_type>(in_yf);
+                index_type in_x = round ? lround(in_xf) : static_cast<index_type>(in_xf);
 
-                auto in_y = static_cast<index_type>(y * o2i_fy);
-                auto in_x = static_cast<index_type>(x * o2i_fx);
+                using device::min;
+                in_y = min(in_y, in_height - 1);
+                in_x = min(in_x, in_width - 1);
 
                 index_type in_idx = c_start * in_image_size + in_y * in_width + in_x;
                 index_type out_idx = c_start * out_image_size + y * out_width + x;
@@ -83,7 +88,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
         __global__ void resize_bilinear(
             Span<T> output, size_type out_height, size_type out_width,
             View<T> input, size_type in_height, size_type in_width,
-            float o2i_fy, float o2i_fx)
+            float o2i_fy, float o2i_fx, bool half_pixel_centers)
         {
             auto in_image_size = in_height * in_width;
             auto out_image_size = out_height * out_width;
@@ -119,8 +124,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
                 const index_type y = (iter % out_image_size) / out_width;
                 const index_type x = iter % out_width;
 
-                auto in_x = x * o2i_fx;
-                auto in_y = y * o2i_fy;
+                using device::max;
+                auto in_x = half_pixel_centers ? max<float>((x + 0.5f) * o2i_fx - 0.5f, 0.0f) : x * o2i_fx;
+                auto in_y = half_pixel_centers ? max<float>((y + 0.5f) * o2i_fy - 0.5f, 0.0f) : y * o2i_fy;
 
                 auto in_x0 = static_cast<index_type>(in_x);
                 auto in_y0 = static_cast<index_type>(in_y);
@@ -157,15 +163,16 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
     template <class T, std::size_t CHANNELS_PER_ITER> static
     void launch_multichannel_resize_nn(const Stream& stream,
         Span<T> output, size_type out_height, size_type out_width,
-        View<T> input, size_type in_height, size_type in_width)
+        View<T> input, size_type in_height, size_type in_width,
+        float scale_y, float scale_x, bool round, bool half_pixel_centers)
     {
         auto kernel = raw::resize_nn<T, CHANNELS_PER_ITER>;
         auto policy = make_policy(kernel, output.size() / CHANNELS_PER_ITER, 0, stream);
-        launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width);
+        launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width,  scale_y, scale_x, round, half_pixel_centers);
     }
 
     template <class T>
-    void resize_nn(const Stream& stream, TensorSpan<T> output, TensorView<T> input) {
+    void resize_nn(const Stream& stream, TensorSpan<T> output, TensorView<T> input, float scale_y, float scale_x, bool round, bool half_pixel_centers) {
         auto out_height = output.get_axis_size(-2);
         auto out_width = output.get_axis_size(-1);
 
@@ -176,38 +183,38 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
         auto num_iters = num_effective_channels * out_height * out_width;
 
         if (num_effective_channels % 32 == 0 && num_iters > 655360) {
-            launch_multichannel_resize_nn<T, 32>(stream, output, out_height, out_width, input, in_height, in_width);
+            launch_multichannel_resize_nn<T, 32>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, round, half_pixel_centers);
         } else if (num_effective_channels % 16 == 0 && num_iters > 327680) {
-            launch_multichannel_resize_nn<T, 16>(stream, output, out_height, out_width, input, in_height, in_width);
+            launch_multichannel_resize_nn<T, 16>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, round, half_pixel_centers);
         } else if (num_effective_channels % 8 == 0 && num_iters > 163840) {
-            launch_multichannel_resize_nn<T, 8>(stream, output, out_height, out_width, input, in_height, in_width);
+            launch_multichannel_resize_nn<T, 8>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, round, half_pixel_centers);
         } else if (num_effective_channels % 4 == 0 && num_iters > 81920) {
-            launch_multichannel_resize_nn<T, 4>(stream, output, out_height, out_width, input, in_height, in_width);
+            launch_multichannel_resize_nn<T, 4>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, round, half_pixel_centers);
         } else if (num_effective_channels % 2 == 0) {
-            launch_multichannel_resize_nn<T, 2>(stream, output, out_height, out_width, input, in_height, in_width);
+            launch_multichannel_resize_nn<T, 2>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, round, half_pixel_centers);
         } else {
-            launch_multichannel_resize_nn<T, 1>(stream, output, out_height, out_width, input, in_height, in_width);
+            launch_multichannel_resize_nn<T, 1>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, round, half_pixel_centers);
         }
     }
 
 #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void resize_nn<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>);
+    template void resize_nn<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, float, float, bool, bool);
 #endif
-    template void resize_nn<float>(const Stream&, TensorSpan<float>, TensorView<float>);
+    template void resize_nn<float>(const Stream&, TensorSpan<float>, TensorView<float>, float, float, bool,bool);
 
     template <class T, std::size_t CHANNELS_PER_ITER> static
     void launch_multichannel_resize_bilinear(const Stream& stream,
         Span<T> output, size_type out_height, size_type out_width,
         View<T> input, size_type in_height, size_type in_width,
-        float scale_y, float scale_x)
+        float scale_y, float scale_x, bool half_pixel_centers)
     {
         auto kernel = raw::resize_bilinear<T, CHANNELS_PER_ITER>;
         auto policy = make_policy(kernel, output.size() / CHANNELS_PER_ITER, 0, stream);
-        launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x);
+        launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, half_pixel_centers);
     }
 
     template <class T>
-    void resize_bilinear(const Stream& stream, TensorSpan<T> output, TensorView<T> input, float scale_y, float scale_x) {
+    void resize_bilinear(const Stream& stream, TensorSpan<T> output, TensorView<T> input, float scale_y, float scale_x, bool half_pixel_centers) {
         auto out_height = output.get_axis_size(-2);
         auto out_width = output.get_axis_size(-1);
 
@@ -218,21 +225,21 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
         auto num_iters = num_effective_channels * out_height * out_width;
 
         if (num_effective_channels % 16 == 0 && num_iters > 163840) {
-            launch_multichannel_resize_bilinear<T, 16>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x);
+            launch_multichannel_resize_bilinear<T, 16>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, half_pixel_centers);
         } else if (num_effective_channels % 8 == 0 && num_iters > 81920) {
-            launch_multichannel_resize_bilinear<T, 8>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x);
+            launch_multichannel_resize_bilinear<T, 8>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, half_pixel_centers);
         } else if (num_effective_channels % 4 == 0 && num_iters > 40960) {
-            launch_multichannel_resize_bilinear<T, 4>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x);
+            launch_multichannel_resize_bilinear<T, 4>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, half_pixel_centers);
         } else if (num_effective_channels % 2 == 0) {
-            launch_multichannel_resize_bilinear<T, 2>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x);
+            launch_multichannel_resize_bilinear<T, 2>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, half_pixel_centers);
         } else {
-            launch_multichannel_resize_bilinear<T, 1>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x);
+            launch_multichannel_resize_bilinear<T, 1>(stream, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x, half_pixel_centers);
         }
     }
 
 #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
-    template void resize_bilinear<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, float, float);
+    template void resize_bilinear<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, float, float, bool);
 #endif
-    template void resize_bilinear<float>(const Stream&, TensorSpan<float>, TensorView<float>, float, float);
+    template void resize_bilinear<float>(const Stream&, TensorSpan<float>, TensorView<float>, float, float, bool);
 
 }}}} /* namespace cv::dnn::cuda4dnn::kernels */
index 31aee3d371e065f7f1eaf0c8e350036cd472c1a1..4a3768a70a87712f445fcbf28331d12ba96641f2 100644 (file)
 namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
 
     template <class T>
-    void resize_nn(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input);
+    void resize_nn(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, float scale_y, float scale_x, bool round, bool half_pixel_centers);
 
     template <class T>
-    void resize_bilinear(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, float scale_y, float scale_x);
+    void resize_bilinear(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, float scale_y, float scale_x, bool half_pixel_centers);
 
 }}}} /* namespace cv::dnn::cuda4dnn::kernels */
 
index 0ac7b94e1966603420998e7f1006114e841bea8c..1465aa8867596661933bdeceaf4901f1634df333 100644 (file)
@@ -20,14 +20,23 @@ namespace cv { namespace dnn { namespace cuda4dnn {
         BILINEAR
     };
 
+    struct ResizeConfiguration {
+        InterpolationType type;
+        bool align_corners;
+        bool half_pixel_centers;
+    };
+
     template <class T>
     class ResizeOp final : public CUDABackendNode {
     public:
         using wrapper_type = GetCUDABackendWrapperType<T>;
 
-        ResizeOp(csl::Stream stream_, InterpolationType type_, float scaleHeight_, float scaleWidth_)
-            : stream(std::move(stream_)), type{ type_ }, scaleHeight{ scaleHeight_ }, scaleWidth{ scaleWidth_ }
+        ResizeOp(csl::Stream stream_, const ResizeConfiguration& config)
+            : stream(std::move(stream_))
         {
+            type = config.type;
+            align_corners = config.align_corners;
+            half_pixel_centers = config.half_pixel_centers;
         }
 
         void forward(
@@ -44,16 +53,27 @@ namespace cv { namespace dnn { namespace cuda4dnn {
             auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
             auto output = output_wrapper->getSpan();
 
+            const auto compute_scale = [this](std::size_t input_size, std::size_t output_size) {
+                return (align_corners && output_size > 1) ?
+                            static_cast<float>(input_size - 1) / (output_size - 1) :
+                            static_cast<float>(input_size) / output_size;
+            };
+
+            auto out_height = output.get_axis_size(-2), out_width = output.get_axis_size(-1);
+            auto in_height = input.get_axis_size(-2), in_width = input.get_axis_size(-1);
+            float scale_height = compute_scale(in_height, out_height),
+                  scale_width = compute_scale(in_width, out_width);
+
             if (type == InterpolationType::NEAREST_NEIGHBOUR)
-                kernels::resize_nn<T>(stream, output, input);
+                kernels::resize_nn<T>(stream, output, input, scale_height, scale_width, align_corners, half_pixel_centers);
             else if (type == InterpolationType::BILINEAR)
-                kernels::resize_bilinear<T>(stream, output, input, scaleHeight, scaleWidth);
+                kernels::resize_bilinear<T>(stream, output, input, scale_height, scale_width, half_pixel_centers);
         }
 
     private:
         csl::Stream stream;
         InterpolationType type;
-        float scaleHeight, scaleWidth; /* for bilinear interpolation */
+        bool align_corners, half_pixel_centers;
     };
 
 }}} /* namespace cv::dnn::cuda4dnn */
index 6c4ecd9379683c822a16039c91fdf77c0d20b75a..ac5d246c75e587d1fd12224ad169641a12c2b883 100644 (file)
@@ -72,7 +72,7 @@ public:
     virtual bool supportBackend(int backendId) CV_OVERRIDE
     {
         if (backendId == DNN_BACKEND_CUDA)
-            return interpolation == "nearest" || interpolation == "bilinear";
+            return interpolation == "nearest" || interpolation == "bilinear" || interpolation == "opencv_linear";
 
 #ifdef HAVE_INF_ENGINE
         if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
@@ -299,15 +299,28 @@ public:
     {
         auto context = reinterpret_cast<csl::CSLContext*>(context_);
 
-        cuda4dnn::InterpolationType itype;
+        cuda4dnn::ResizeConfiguration config;
         if (interpolation == "nearest")
-            itype = InterpolationType::NEAREST_NEIGHBOUR;
+        {
+            config.type = InterpolationType::NEAREST_NEIGHBOUR;
+            config.align_corners = alignCorners;
+            config.half_pixel_centers = halfPixelCenters;
+        }
         else if (interpolation == "bilinear")
-            itype = InterpolationType::BILINEAR;
+        {
+            config.type = InterpolationType::BILINEAR;
+            config.align_corners = alignCorners;
+            config.half_pixel_centers = halfPixelCenters;
+        }
+        else if (interpolation == "opencv_linear")
+        {
+            config.type = InterpolationType::BILINEAR;
+            config.align_corners = false;
+            config.half_pixel_centers = true;
+        }
         else
             CV_Error(Error::StsNotImplemented, "Requested interpolation mode is not available in resize layer.");
-
-        return make_cuda_node<cuda4dnn::ResizeOp>(preferableTarget, std::move(context->stream), itype, scaleHeight, scaleWidth);
+        return make_cuda_node<cuda4dnn::ResizeOp>(preferableTarget, std::move(context->stream), config);
     }
 #endif