Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
[platform/upstream/opencv.git] / modules / dnn / src / layers / convolution_layer.cpp
index 3b298e6..09bdd93 100644 (file)
 
 #include "../precomp.hpp"
 #include "layers_common.hpp"
+#include "../op_cuda.hpp"
 #include "../op_halide.hpp"
 #include "../op_inf_engine.hpp"
+#include "../op_vkcom.hpp"
 #include "opencv2/core/hal/hal.hpp"
 #include "opencv2/core/hal/intrin.hpp"
 #include <iostream>
+#include <numeric>
 
 #ifdef HAVE_OPENCL
 #include "opencl_kernels_dnn.hpp"
 using namespace cv::dnn::ocl4dnn;
 #endif
 
+#ifdef HAVE_CUDA
+#include "../cuda4dnn/primitives/convolution.hpp"
+#include "../cuda4dnn/primitives/transpose_convolution.hpp"
+using namespace cv::dnn::cuda4dnn;
+#endif
+
 namespace cv
 {
 namespace dnn
@@ -61,12 +70,12 @@ namespace dnn
 class BaseConvolutionLayerImpl : public ConvolutionLayer
 {
 public:
-    bool newWeightAndBias;
+    bool fusedWeights, fusedBias;
     std::vector<double> weightsMultipliers;
     BaseConvolutionLayerImpl(const LayerParams &params)
     {
         setParamsFrom(params);
-        getConvolutionKernelParams(params, kernel_size, pads_begin, pads_end, strides, dilations, padMode);
+        getConvolutionKernelParams(params, kernel_size, pads_begin, pads_end, strides, dilations, padMode, adjust_pads);
 
         numOutput = params.get<int>("num_output");
         int ngroups = params.get<int>("group", 1);
@@ -82,15 +91,16 @@ public:
             pad = Size(pads_begin[1], pads_begin[0]);
             dilation = Size(dilations[1], dilations[0]);
 
-            adjust_pads.push_back(params.get<int>("adj_h", 0));
-            adjust_pads.push_back(params.get<int>("adj_w", 0));
-
             adjustPad.height = adjust_pads[0];
             adjustPad.width = adjust_pads[1];
-            CV_Assert(adjustPad.width < stride.width &&
-                      adjustPad.height < stride.height);
         }
-        newWeightAndBias = false;
+
+        for (int i = 0; i < adjust_pads.size(); i++) {
+            CV_Assert(adjust_pads[i] < strides[i]);
+        }
+
+        fusedWeights = false;
+        fusedBias = false;
     }
 
     virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
@@ -125,7 +135,7 @@ public:
             inpShape.push_back(inputs[0].size[i]);
             outShape.push_back(outputs[0].size[i]);
         }
-        getConvPoolPaddings(inpShape, outShape, kernel_size, strides, padMode, dilations, pads_begin, pads_end);
+        getConvPoolPaddings(inpShape, kernel_size, strides, padMode, pads_begin, pads_end);
         if (pads_begin.size() == 2) {
             for (int i = 0; i < pads_begin.size(); i++) {
                 if (pads_begin[i] != pads_end[i])
@@ -133,6 +143,8 @@ public:
             }
             pad = Size(pads_begin[1], pads_begin[0]);
         }
+        fusedWeights = false;
+        fusedBias = false;
     }
 
     bool hasBias() const
@@ -155,6 +167,8 @@ public:
         if (!w.empty() || !b.empty())
         {
             fuseWeights(w, b);
+            fusedWeights = fusedWeights || !w.empty();
+            fusedBias = fusedBias || (hasBias() && !w.empty()) || !b.empty();
             return true;
         }
         return false;
@@ -215,7 +229,6 @@ public:
     std::vector<float> biasvec;
     std::vector<float> reluslope;
     Ptr<ActivationLayer> activ;
-    bool fusedBias;
 
 #ifdef HAVE_OPENCL
     Ptr<OCL4DNNConvSpatial<float> > convolutionOp;
@@ -226,7 +239,6 @@ public:
 #endif
     ConvolutionLayerImpl(const LayerParams &params) : BaseConvolutionLayerImpl(params)
     {
-        fusedBias = false;
 #ifdef HAVE_OPENCL
         newActiv = false;
         activType = OCL4DNN_CONV_FUSED_ACTIV_NONE;
@@ -236,25 +248,46 @@ public:
 
     MatShape computeColRowShape(const MatShape &inpShape, const MatShape &outShape) const CV_OVERRIDE
     {
-        Size out(outShape[3], outShape[2]);
+        int dims = inpShape.size();
+        int inpD = dims == 5 ? inpShape[2] : 1;
+        int inpH = inpShape[dims - 2];
+        int inpW = inpShape.back();
         int inpGroupCn = blobs[0].size[1];
-        int ksize = inpGroupCn * kernel.height * kernel.width;
-        return shape(out.area(), ksize);
+        int ksize = inpGroupCn * std::accumulate(kernel_size.begin(), kernel_size.end(),
+                                                 1, std::multiplies<size_t>());
+        return shape(inpD * inpH * inpW, ksize);
     }
 
     virtual bool supportBackend(int backendId) CV_OVERRIDE
     {
+        if (backendId == DNN_BACKEND_CUDA)
+        {
+            /* only convolution 2d and 3d supported */
+            if(kernel_size.size() == 2 || kernel_size.size() == 3)
+                return true;
+
+            return false;
+        }
+
 #ifdef HAVE_INF_ENGINE
         if (backendId == DNN_BACKEND_INFERENCE_ENGINE)
         {
             if (kernel_size.size() == 3)
                 return preferableTarget == DNN_TARGET_CPU;
-            return INF_ENGINE_VER_MAJOR_GE(INF_ENGINE_RELEASE_2018R4) ||
-                   (preferableTarget != DNN_TARGET_MYRIAD || dilation.width == dilation.height);
+            return (preferableTarget != DNN_TARGET_MYRIAD || dilation.width == dilation.height);
         }
         else
 #endif
-            return (kernel_size.size() == 2) && (backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE);
+        {
+            if (kernel_size.size() == 3)
+                return (preferableTarget == DNN_TARGET_CPU && backendId == DNN_BACKEND_OPENCV);
+            else if (kernel_size.size() == 2)
+                return backendId == DNN_BACKEND_OPENCV ||
+                       backendId == DNN_BACKEND_HALIDE ||
+                       (backendId == DNN_BACKEND_VKCOM && haveVulkan());
+            else
+                return false;
+        }
     }
 
     bool getMemoryShapes(const std::vector<MatShape> &inputs,
@@ -406,12 +439,74 @@ public:
             for (int i = 0; i < outCn; ++i)
                 biasvec[i] += b.at<float>(i);
         }
-
-        newWeightAndBias = !w.empty() || !b.empty();
-        fusedBias = hasBias() || !b.empty();
         biasvec[outCn] = biasvec[outCn+1] = biasvec[outCn-1];
     }
 
+    virtual Ptr<BackendNode> initVkCom(const std::vector<Ptr<BackendWrapper> > &inputs) CV_OVERRIDE
+    {
+#ifdef HAVE_VULKAN
+        int out_channel = blobs[0].size[0];
+        bool has_bias = hasBias() || fusedBias;
+        int filter_size[2] = {kernel.height, kernel.width};
+        int pad_size[2] = {pad.height, pad.width};
+        int stride_size[2] = {stride.height, stride.width};
+        int dilation_size[2] = {dilation.height, dilation.width};
+        int activation = 0;
+        vkcom::Tensor input_tensor = VkComTensor(inputs[0]);
+        int in_channel = input_tensor.dimSize(1);
+        int group = in_channel / blobs[0].size[1];
+
+        // TODO: support group > 1
+        if (group != 1)
+            return Ptr<BackendNode>();
+
+        int padding_mode;
+        if (padMode.empty())
+        {
+            padding_mode = vkcom::kPaddingModeCaffe;
+        }
+        else if (padMode == "VALID")
+        {
+            padding_mode = vkcom::kPaddingModeValid;
+        }
+        else if (padMode == "SAME")
+        {
+            padding_mode = vkcom::kPaddingModeSame;
+        }
+        else
+            CV_Error(Error::StsError, "Unsupported padding mode " + padMode);
+
+        std::shared_ptr<vkcom::OpBase> op(new vkcom::OpConv(out_channel, has_bias,
+                    filter_size, pad_size,
+                    stride_size, dilation_size,
+                    activation, group,
+                    padding_mode));
+
+        std::vector<Ptr<BackendWrapper> > blobsWrapper;
+
+        if (fusedWeights)
+        {
+            Mat wm;
+            weightsMat.copyTo(wm); // to handle the case of isContinuous() == false
+            wm = wm.reshape(1, blobs[0].dims, blobs[0].size);
+            blobsWrapper.push_back(Ptr<BackendWrapper>(new VkComBackendWrapper(wm)));
+        }
+        else
+        {
+            blobsWrapper.push_back(Ptr<BackendWrapper>(new VkComBackendWrapper(blobs[0])));
+        }
+
+        if (has_bias)
+        {
+            Mat biasesMat({out_channel}, CV_32F, &biasvec[0]);
+            blobsWrapper.push_back(Ptr<BackendWrapper>(new VkComBackendWrapper(biasesMat)));
+        }
+
+        return Ptr<BackendNode>(new VkComBackendNode(inputs, op, blobsWrapper));
+#endif  // HAVE_VULKAN
+        return Ptr<BackendNode>();
+    }
+
     virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs) CV_OVERRIDE
     {
 #ifdef HAVE_HALIDE
@@ -460,38 +555,38 @@ public:
         return Ptr<BackendNode>();
     }
 
+#ifdef HAVE_INF_ENGINE
     virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> > &inputs) CV_OVERRIDE
     {
-#ifdef HAVE_INF_ENGINE
         InferenceEngine::DataPtr input = infEngineDataNode(inputs[0]);
-        CV_Assert(input->dims.size() == 4 || input->dims.size() == 5);
-
-        const int inpCn = input->dims[input->dims.size() - 2];  // NOTE: input->dims are reversed (WHIO or WHDIO)
+        std::vector<size_t> dims = input->getDims();
+        CV_Assert(dims.size() == 4 || dims.size() == 5);
+        const int inpCn = dims[1];
         const int outCn = blobs[0].size[0];
         const int inpGroupCn = blobs[0].size[1];
         const int group = inpCn / inpGroupCn;
-
-        InferenceEngine::Layout layout = (input->dims.size() == 4) ? InferenceEngine::Layout::OIHW :
-                                                                     InferenceEngine::Layout::NCDHW;
+        InferenceEngine::Layout layout = (dims.size() == 4) ? InferenceEngine::Layout::OIHW :
+                                                              InferenceEngine::Layout::NCDHW;
 
         auto ieWeights = wrapToInfEngineBlob(blobs[0], layout);
-        if (newWeightAndBias)
+        if (fusedWeights)
         {
             if (weightsMat.isContinuous())
             {
-                Mat fusedWeights = weightsMat.reshape(1, blobs[0].dims, blobs[0].size);
-                ieWeights = wrapToInfEngineBlob(fusedWeights, layout);
+                Mat cvWeights = weightsMat.reshape(1, blobs[0].dims, blobs[0].size);
+                ieWeights = wrapToInfEngineBlob(cvWeights, layout);
             }
             else
             {
-                ieWeights = InferenceEngine::make_shared_blob<float>(
-                                    InferenceEngine::Precision::FP32, layout,
-                                    ieWeights->dims());
+                ieWeights = InferenceEngine::make_shared_blob<float>({
+                                InferenceEngine::Precision::FP32,
+                                ieWeights->getTensorDesc().getDims(), layout
+                            });
                 ieWeights->allocate();
 
                 Mat newWeights = infEngineBlobToMat(ieWeights).reshape(1, outCn);
-                Mat fusedWeights = weightsMat.colRange(0, newWeights.cols);
-                fusedWeights.copyTo(newWeights);
+                Mat cvWeights = weightsMat.colRange(0, newWeights.cols);
+                cvWeights.copyTo(newWeights);
             }
         }
         InferenceEngine::Blob::Ptr ieBiases;
@@ -501,7 +596,6 @@ public:
             ieBiases = wrapToInfEngineBlob(biasesMat, {(size_t)outCn}, InferenceEngine::Layout::C);
         }
 
-#if INF_ENGINE_VER_MAJOR_GE(INF_ENGINE_RELEASE_2018R5)
         InferenceEngine::Builder::ConvolutionLayer ieLayer(name);
 
         ieLayer.setKernel(kernel_size);
@@ -521,51 +615,8 @@ public:
             l.getParameters()["auto_pad"] = padMode == "VALID" ? std::string("valid") : std::string("same_upper");
 
         return Ptr<BackendNode>(new InfEngineBackendNode(l));
-#else
-        InferenceEngine::LayerParams lp;
-        lp.name = name;
-        lp.type = "Convolution";
-        lp.precision = InferenceEngine::Precision::FP32;
-        std::shared_ptr<InferenceEngine::ConvolutionLayer> ieLayer(new InferenceEngine::ConvolutionLayer(lp));
-
-#if INF_ENGINE_VER_MAJOR_GT(INF_ENGINE_RELEASE_2018R3)
-        ieLayer->_kernel.insert(InferenceEngine::X_AXIS, kernel.width);
-        ieLayer->_kernel.insert(InferenceEngine::Y_AXIS, kernel.height);
-        ieLayer->_stride.insert(InferenceEngine::X_AXIS, stride.width);
-        ieLayer->_stride.insert(InferenceEngine::Y_AXIS, stride.height);
-        ieLayer->_padding.insert(InferenceEngine::X_AXIS, pad.width);
-        ieLayer->_padding.insert(InferenceEngine::Y_AXIS, pad.height);
-        ieLayer->_pads_end.insert(InferenceEngine::X_AXIS, pad.width);
-        ieLayer->_pads_end.insert(InferenceEngine::Y_AXIS, pad.height);
-        ieLayer->_dilation.insert(InferenceEngine::X_AXIS, dilation.width);
-        ieLayer->_dilation.insert(InferenceEngine::Y_AXIS, dilation.height);
-        ieLayer->params["output"] = format("%d", outCn);
-        ieLayer->params["kernel"] = format("%d,%d,%d,%d", outCn, inpGroupCn, kernel.height, kernel.width);
-        ieLayer->params["pads_begin"] = format("%d,%d", pad.height, pad.width);
-        ieLayer->params["pads_end"] = format("%d,%d", pad.height, pad.width);
-        ieLayer->params["strides"] = format("%d,%d", stride.height, stride.width);
-        ieLayer->params["dilations"] = format("%d,%d", dilation.height, dilation.width);
-#else
-        ieLayer->_kernel_x = kernel.width;
-        ieLayer->_kernel_y = kernel.height;
-        ieLayer->_stride_x = stride.width;
-        ieLayer->_stride_y = stride.height;
-        ieLayer->_padding_x = pad.width;
-        ieLayer->_padding_y = pad.height;
-        ieLayer->_dilation_x = dilation.width;
-        ieLayer->_dilation_y = dilation.height;
-#endif
-        ieLayer->_out_depth = outCn;
-        ieLayer->_group = group;
-
-        ieLayer->_weights = ieWeights;
-        if (ieBiases)
-            ieLayer->_biases = ieBiases;
-        return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
-#endif
-#endif  // HAVE_INF_ENGINE
-        return Ptr<BackendNode>();
     }
+#endif  // HAVE_INF_ENGINE
 
     class ParallelConv : public cv::ParallelLoopBody
     {
@@ -575,8 +626,8 @@ public:
         const Mat* input_;
         const Mat* weights_;
         Mat* output_;
-        int outShape[4];
-        Size kernel_, pad_, stride_, dilation_;
+        int outShape[4]; // used only for conv2d
+        std::vector<size_t> kernel_size, pads_begin, pads_end, strides, dilations;
         int ngroups_, nstripes_;
         std::vector<int> ofstab_;
         const std::vector<float>* biasvec_;
@@ -595,14 +646,18 @@ public:
         static void run( const Mat& input, Mat& output, const Mat& weights,
                          const std::vector<float>& biasvec,
                          const std::vector<float>& reluslope,
-                         Size kernel, Size pad, Size stride, Size dilation,
+                         const std::vector<size_t>& kernel_size, const std::vector<size_t>& strides,
+                         const std::vector<size_t>& pads_begin, const std::vector<size_t>& pads_end,
+                         const std::vector<size_t>& dilations,
                          const ActivationLayer* activ, int ngroups, int nstripes )
         {
+            size_t karea = std::accumulate(kernel_size.begin(), kernel_size.end(),
+                                           1, std::multiplies<size_t>());
             CV_Assert_N(
-                       input.dims == 4 && output.dims == 4,
+                       (input.dims == 4 || input.dims == 5) && (input.dims == output.dims),
                        input.size[0] == output.size[0],
                        weights.rows == output.size[1],
-                       weights.cols == (input.size[1]/ngroups)*kernel.width*kernel.height,
+                       weights.cols == (input.size[1]/ngroups)*karea,
                        input.type() == output.type(),
                        input.type() == weights.type(),
                        input.type() == CV_32FC1,
@@ -616,26 +671,58 @@ public:
             p.output_ = &output;
             for( int i = 0; i < 4; i++ ) p.outShape[i] = output.size[i];
             p.outShape[1] /= ngroups;
-            p.kernel_ = kernel; p.pad_ = pad; p.stride_ = stride; p.dilation_ = dilation;
+
+            p.kernel_size = kernel_size; p.strides = strides; p.dilations = dilations;
+            p.pads_begin = pads_begin; p.pads_end = pads_end;
+
             p.ngroups_ = ngroups;
             p.nstripes_ = nstripes;
 
-            int inpCnAll = input.size[1], width = input.size[3], height = input.size[2];
+            int inpCnAll = input.size[1];
+            int depth = (input.dims == 5) ? input.size[2] : 1;
+            int width = input.size[input.dims - 1];
+            int height = input.size[input.dims - 2];
             int inpCn = inpCnAll / ngroups;
-            p.is1x1_ = kernel == Size(1,1) && pad == Size(0, 0);
-            p.useAVX = checkHardwareSupport(CPU_AVX);
-            p.useAVX2 = checkHardwareSupport(CPU_AVX2);
-            p.useAVX512 = CV_CPU_HAS_SUPPORT_AVX512_SKX;
+
+            bool isConv2D = kernel_size.size() == 2;
+
+            p.is1x1_ = isConv2D && kernel_size[0] == 1 && kernel_size[1] == 1 &&
+                       pads_begin[0] == 0  && pads_begin[1] == 0;
+
+            p.useAVX    = checkHardwareSupport(CPU_AVX)  && isConv2D;
+            p.useAVX2   = checkHardwareSupport(CPU_AVX2) && isConv2D;
+            p.useAVX512 = CV_CPU_HAS_SUPPORT_AVX512_SKX  && isConv2D;
 
             int ncn = std::min(inpCn, (int)BLK_SIZE_CN);
-            p.ofstab_.resize(kernel.width*kernel.height*ncn);
+
+            int kernel_d = !isConv2D? kernel_size[0] : 1;
+            int kernel_h = kernel_size[kernel_size.size() - 2];
+            int kernel_w = kernel_size.back();
+
+            int dil_d = !isConv2D? dilations[0] : 1;
+            int dil_h = dilations[dilations.size() - 2];
+            int dil_w = dilations.back();
+
+            p.ofstab_.resize(karea * ncn);
             int* ofstab = &p.ofstab_[0];
 
-            for( int k = 0; k < ncn; k++ )
-                for( int k_r = 0; k_r < kernel.height; k_r++ )
-                    for( int k_c = 0; k_c < kernel.width; k_c++ )
-                        ofstab[(k*kernel.height + k_r)*kernel.width + k_c] =
-                        (k*height + k_r*dilation.height)*width + k_c*dilation.width;
+            if (isConv2D)
+            {
+                for( int k = 0; k < ncn; k++ )
+                    for( int k_r = 0; k_r < kernel_h; k_r++ )
+                        for( int k_c = 0; k_c < kernel_w; k_c++ )
+                            ofstab[(k*kernel_h + k_r)*kernel_w + k_c] =
+                                   (k*height + k_r*dil_h)*width + k_c*dil_w;
+            }
+            else
+            {
+                for( int k = 0; k < ncn; k++ )
+                    for (int k_d = 0; k_d < kernel_d; k_d++)
+                        for( int k_r = 0; k_r < kernel_h; k_r++ )
+                            for( int k_c = 0; k_c < kernel_w; k_c++ )
+                                ofstab[(k*kernel_d*kernel_h + k_d*kernel_h + k_r)*kernel_w + k_c] =
+                                       (k*depth*height + k_d*dil_d*height + k_r*dil_h)*width + k_c*dil_w;
+            }
 
             p.biasvec_ = &biasvec;
             p.reluslope_ = &reluslope;
@@ -648,17 +735,39 @@ public:
         {
             const int valign = ConvolutionLayerImpl::VEC_ALIGN;
             int ngroups = ngroups_, batchSize = input_->size[0]*ngroups;
-            int outW = output_->size[3], outH = output_->size[2], outCn = output_->size[1]/ngroups;
-            int width = input_->size[3], height = input_->size[2], inpCn = input_->size[1]/ngroups;
+            bool isConv2D = input_->dims == 4;
+
+            int outW = output_->size[output_->dims - 1];
+            int outH = output_->size[output_->dims - 2];
+            int outCn = output_->size[1]/ngroups;
+
+            int depth = !isConv2D? input_->size[2] : 1;
+            int height = input_->size[input_->dims - 2];
+            int width = input_->size[input_->dims - 1];
+            int inpCn = input_->size[1]/ngroups;
+
             const int nstripes = nstripes_;
-            int kernel_w = kernel_.width, kernel_h = kernel_.height;
-            int pad_w = pad_.width, pad_h = pad_.height;
-            int stride_w = stride_.width, stride_h = stride_.height;
-            int dilation_w = dilation_.width, dilation_h = dilation_.height;
-            int karea = kernel_w*kernel_h;
-            int i, j, k;
-            size_t inpPlaneSize = width*height;
-            size_t outPlaneSize = outW*outH;
+
+            int kernel_d = !isConv2D? kernel_size[0] : 1;
+            int kernel_h = kernel_size[kernel_size.size() - 2];
+            int kernel_w = kernel_size.back();
+            int karea = kernel_w*kernel_h*kernel_d;
+
+            int pad_d = !isConv2D? pads_begin[0] : 0;
+            int pad_t = pads_begin[pads_begin.size() - 2];
+            int pad_l = pads_begin.back();
+
+            int stride_d = !isConv2D? strides[0] : 0;
+            int stride_h = strides[strides.size() - 2];
+            int stride_w = strides.back();
+
+            int dilation_d = !isConv2D? dilations[0] : 1;
+            int dilation_h = dilations[dilations.size() - 2];
+            int dilation_w = dilations.back();
+
+            int i, j, k, d;
+            size_t inpPlaneSize = input_->total(2);
+            size_t outPlaneSize = output_->total(2);
             bool is1x1 = is1x1_;
 
             int stripesPerSample;
@@ -727,72 +836,125 @@ public:
                     for( int ofs0 = stripeStart; ofs0 < stripeEnd; ofs0 += BLK_SIZE )
                     {
                         int ofs, ofs1 = std::min(ofs0 + BLK_SIZE, stripeEnd);
-                        int out_i = ofs0 / outW;
-                        int out_j = ofs0 - out_i * outW;
+
+                        int out_d = ofs0 / (outH * outW);
+                        int out_i = (ofs0 - out_d * outH * outW) / outW;
+                        int out_j = ofs0 % outW;
 
                         // do im2row for a part of input tensor
                         float* rowbuf = rowbuf0;
-                        for( ofs = ofs0; ofs < ofs1; out_j = 0, ++out_i )
+
+                        if (isConv2D)
                         {
-                            int delta = std::min(ofs1 - ofs, outW - out_j);
-                            int out_j1 = out_j + delta;
-                            int in_i = out_i * stride_h - pad_h;
-                            int in_j = out_j * stride_w - pad_w;
-                            const float* imgptr = data_inp0 + (cn0*height + in_i)*width + in_j;
-                            ofs += delta;
-
-                            // do im2row for a part of input tensor
-                            if( is1x1 )
+                            for( ofs = ofs0; ofs < ofs1; out_j = 0, ++out_i )
                             {
-                                for( ; out_j < out_j1; out_j++, rowbuf += vsz_a, imgptr += stride_w )
+                                int delta = std::min(ofs1 - ofs, outW - out_j);
+                                int out_j1 = out_j + delta;
+
+                                int in_i = out_i * stride_h - pad_t;
+                                int in_j = out_j * stride_w - pad_l;
+                                const float* imgptr = data_inp0 + (cn0*height + in_i)*width + in_j;
+                                ofs += delta;
+
+                                // do im2row for a part of input tensor
+                                if( is1x1 )
                                 {
-                                    for( k = 0; k < vsz; k++ )
-                                        rowbuf[k] = imgptr[k*inpPlaneSize];
+                                    for( ; out_j < out_j1; out_j++, rowbuf += vsz_a, imgptr += stride_w )
+                                    {
+                                        for( k = 0; k < vsz; k++ )
+                                            rowbuf[k] = imgptr[k*inpPlaneSize];
+                                    }
+                                }
+                                else
+                                {
+                                    bool ok_i = 0 <= in_i && in_i < height - (kernel_h-1)*dilation_h;
+                                    int i0 = std::max(0, (-in_i + dilation_h-1)/dilation_h);
+                                    int i1 = std::min(kernel_h, (height - in_i + dilation_h-1)/dilation_h);
+
+                                    for( ; out_j < out_j1; out_j++, rowbuf += vsz_a, imgptr += stride_w, in_j += stride_w )
+                                    {
+                                        // this condition should be true for most of the tensor elements, i.e.
+                                        // most of the time the kernel aperture is inside the tensor X-Y plane.
+                                        if( ok_i && out_j + 2 <= out_j1 && 0 <= in_j && in_j + stride_w*2 <= width - (kernel_w-1)*dilation_w )
+                                        {
+                                            for( k = 0; k < vsz; k++ )
+                                            {
+                                                int k1 = ofstab[k];
+                                                float v0 = imgptr[k1];
+                                                float v1 = imgptr[k1 + stride_w];
+                                                rowbuf[k] = v0;
+                                                rowbuf[k+vsz_a] = v1;
+                                            }
+                                            out_j++;
+                                            rowbuf += vsz_a;
+                                            imgptr += stride_w;
+                                            in_j += stride_w;
+                                        }
+                                        else
+                                        {
+                                            int j0 = std::max(0, (-in_j + dilation_w-1)/dilation_w);
+                                            int j1 = std::min(kernel_w, (width - in_j + dilation_w-1)/dilation_w);
+
+                                            // here some non-continuous sub-row of the row will not be
+                                            // filled from the tensor; we need to make sure that the uncovered
+                                            // elements are explicitly set to 0's. the easiest way is to
+                                            // set all the elements to 0's before the loop.
+                                            memset(rowbuf, 0, vsz*sizeof(rowbuf[0]));
+                                            for( k = 0; k < ncn; k++ )
+                                            {
+                                                for( i = i0; i < i1; i++ )
+                                                {
+                                                    for( j = j0; j < j1; j++ )
+                                                    {
+                                                        int imgofs = k*(width*height) + i*(dilation_h*width) + j*dilation_w;
+                                                        rowbuf[(k*kernel_h + i)*kernel_w + j] = imgptr[imgofs];
+                                                    }
+                                                }
+                                            }
+                                        }
+                                    }
                                 }
                             }
-                            else
+                        }
+                        else
+                        {
+                            for( ofs = ofs0; ofs < ofs1; out_d += (out_i + 1) / outH, out_i = (out_i + 1) % outH, out_j = 0 )
                             {
-                                bool ok_i = 0 <= in_i && in_i < height - (kernel_h-1)*dilation_h;
+                                int delta = std::min(ofs1 - ofs, outW - out_j);
+                                int out_j1 = out_j + delta;
+
+                                int in_d = out_d * stride_d - pad_d;
+                                int in_i = out_i * stride_h - pad_t;
+                                int in_j = out_j * stride_w - pad_l;
+                                const float* imgptr = data_inp0 + (cn0*depth*height + in_d*height + in_i)*width + in_j;
+                                ofs += delta;
+
+                                int d0 = std::max(0, (-in_d + dilation_d - 1) / dilation_d);
+                                int d1 = std::min(kernel_d, (depth - in_d + dilation_d - 1) / dilation_d);
+
                                 int i0 = std::max(0, (-in_i + dilation_h-1)/dilation_h);
                                 int i1 = std::min(kernel_h, (height - in_i + dilation_h-1)/dilation_h);
 
                                 for( ; out_j < out_j1; out_j++, rowbuf += vsz_a, imgptr += stride_w, in_j += stride_w )
                                 {
-                                    // this condition should be true for most of the tensor elements, i.e.
-                                    // most of the time the kernel aperture is inside the tensor X-Y plane.
-                                    if( ok_i && out_j + 2 <= out_j1 && 0 <= in_j && in_j + stride_w*2 <= width - (kernel_w-1)*dilation_w )
+                                    int j0 = std::max(0, (-in_j + dilation_w-1)/dilation_w);
+                                    int j1 = std::min(kernel_w, (width - in_j + dilation_w-1)/dilation_w);
+
+                                    // here some non-continuous sub-row of the row will not be
+                                    // filled from the tensor; we need to make sure that the uncovered
+                                    // elements are explicitly set to 0's. the easiest way is to
+                                    // set all the elements to 0's before the loop.
+                                    memset(rowbuf, 0, vsz*sizeof(rowbuf[0]));
+                                    for( k = 0; k < ncn; k++ )
                                     {
-                                        for( k = 0; k < vsz; k++ )
-                                        {
-                                            int k1 = ofstab[k];
-                                            float v0 = imgptr[k1];
-                                            float v1 = imgptr[k1 + stride_w];
-                                            rowbuf[k] = v0;
-                                            rowbuf[k+vsz_a] = v1;
-                                        }
-                                        out_j++;
-                                        rowbuf += vsz_a;
-                                        imgptr += stride_w;
-                                        in_j += stride_w;
-                                    }
-                                    else
-                                    {
-                                        int j0 = std::max(0, (-in_j + dilation_w-1)/dilation_w);
-                                        int j1 = std::min(kernel_w, (width - in_j + dilation_w-1)/dilation_w);
-
-                                        // here some non-continuous sub-row of the row will not be
-                                        // filled from the tensor; we need to make sure that the uncovered
-                                        // elements are explicitly set to 0's. the easiest way is to
-                                        // set all the elements to 0's before the loop.
-                                        memset(rowbuf, 0, vsz*sizeof(rowbuf[0]));
-                                        for( k = 0; k < ncn; k++ )
+                                        for ( d = d0; d < d1; d++)
                                         {
                                             for( i = i0; i < i1; i++ )
                                             {
                                                 for( j = j0; j < j1; j++ )
                                                 {
-                                                    int imgofs = k*(width*height) + i*(dilation_h*width) + j*dilation_w;
-                                                    rowbuf[(k*kernel_h + i)*kernel_w + j] = imgptr[imgofs];
+                                                    int imgofs = k*(depth*width*height) + d*dilation_d*width*height + i*(dilation_h*width) + j*dilation_w;
+                                                    rowbuf[(k*kernel_d*kernel_h + d*kernel_h + i)*kernel_w + j] = imgptr[imgofs];
                                                 }
                                             }
                                         }
@@ -1015,17 +1177,18 @@ public:
             }
         }
 
-        if ( newWeightAndBias )
+        if (fusedWeights)
         {
             weightsMat.copyTo(umat_blobs[0]);
-            if ( fusedBias )
-            {
-                if ( umat_blobs.size() < 2 )
-                    umat_blobs.resize(2);
-                umat_blobs[1] = UMat(biasvec, true);
-            }
-            convolutionOp->setBias(fusedBias || hasBias());
-            newWeightAndBias = false;
+            fusedWeights = false;
+        }
+        if (fusedBias)
+        {
+            if ( umat_blobs.size() < 2 )
+                umat_blobs.resize(2);
+            umat_blobs[1] = UMat(biasvec, true);
+            convolutionOp->setBias(true);
+            fusedBias = false;
         }
 
         if ( newActiv )
@@ -1070,7 +1233,7 @@ public:
         return convolutionOp->Forward(inpMat,
                                       inputs.size() == 2 ? inputs[1] : UMat(),
                                       umat_blobs[0],
-                                      (hasBias() || fusedBias) ? umat_blobs[1] : UMat(),
+                                      umat_blobs.size() > 1 ? umat_blobs[1] : UMat(),
                                       outMat,
                                       batch_size);
     }
@@ -1101,10 +1264,6 @@ public:
         CV_Assert_N(inputs.size() == (size_t)1, inputs[0].size[1] % blobs[0].size[1] == 0,
                     outputs.size() == 1, inputs[0].data != outputs[0].data);
 
-        if (inputs[0].dims == 5) {
-            CV_Error(Error::StsNotImplemented, "Convolution3D layer is not supported on OCV backend");
-        }
-
         int ngroups = inputs[0].size[1]/blobs[0].size[1];
         CV_Assert(outputs[0].size[1] % ngroups == 0);
         int outCn = blobs[0].size[0];
@@ -1133,18 +1292,79 @@ public:
         int nstripes = std::max(getNumThreads(), 1);
 
         ParallelConv::run(inputs[0], outputs[0], weightsMat, biasvec, reluslope,
-                          kernel, pad, stride, dilation, activ.get(), ngroups, nstripes);
+                          kernel_size, strides, pads_begin, pads_end, dilations, activ.get(), ngroups, nstripes);
     }
 
+#ifdef HAVE_CUDA
+    Ptr<BackendNode> initCUDA(
+        void *context_,
+        const std::vector<Ptr<BackendWrapper>>& inputs,
+        const std::vector<Ptr<BackendWrapper>>& outputs
+    ) override
+    {
+        auto context = reinterpret_cast<csl::CSLContext*>(context_);
+
+        CV_Assert(inputs.size() == 1);
+        auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
+        auto input_shape = input_wrapper->getShape();
+
+        CV_Assert(outputs.size() == 1);
+        auto output_wrapper = outputs[0].dynamicCast<CUDABackendWrapper>();
+        auto output_shape = output_wrapper->getShape();
+
+        const auto output_feature_maps = blobs[0].size[0];
+        const auto input_feature_maps = input_shape[1];
+        const auto input_feature_maps_per_group = blobs[0].size[1];
+        const auto groups = input_feature_maps / input_feature_maps_per_group;
+
+        ConvolutionConfiguration config;
+        config.kernel_size.assign(std::begin(kernel_size), std::end(kernel_size));
+        config.dilations.assign(std::begin(dilations), std::end(dilations));
+        config.strides.assign(std::begin(strides), std::end(strides));
+
+        if (padMode.empty())
+        {
+            config.padMode = ConvolutionConfiguration::PaddingMode::MANUAL;
+            config.pads_begin.assign(std::begin(pads_begin), std::end(pads_begin));
+            config.pads_end.assign(std::begin(pads_end), std::end(pads_end));
+        }
+        else if (padMode == "VALID")
+        {
+            config.padMode = ConvolutionConfiguration::PaddingMode::VALID;
+        }
+        else if (padMode == "SAME")
+        {
+            config.padMode = ConvolutionConfiguration::PaddingMode::SAME;
+        }
+        else
+        {
+            CV_Error(Error::StsNotImplemented, padMode + " padding mode not supported by ConvolutionLayer");
+        }
+
+        config.input_shape.assign(std::begin(input_shape), std::end(input_shape));
+        config.output_shape.assign(std::begin(output_shape), std::end(output_shape));
+        config.groups = groups;
+
+        Mat filtersMat = fusedWeights ? weightsMat : blobs[0];
+        Mat biasMat = (hasBias() || fusedBias) ? Mat(output_feature_maps, 1, CV_32F, biasvec.data()) : Mat();
+        if (countNonZero(biasMat) == 0)
+            biasMat = Mat();
+
+        return make_cuda_node<cuda4dnn::ConvolutionOp>(
+            preferableTarget, std::move(context->stream), std::move(context->cudnn_handle), config, filtersMat, biasMat);
+    }
+#endif
+
     virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
                            const std::vector<MatShape> &outputs) const CV_OVERRIDE
     {
         CV_Assert(inputs.size() == outputs.size());
 
         int64 flops = 0;
+        int karea = std::accumulate(kernel_size.begin(), kernel_size.end(), 1, std::multiplies<size_t>());
         for (int i = 0; i < inputs.size(); i++)
         {
-            flops += total(outputs[i])*(CV_BIG_INT(2)*kernel.area()*inputs[i][1] + 1);
+            flops += total(outputs[i])*(CV_BIG_INT(2)*karea*inputs[i][1] + 1);
         }
 
         return flops;
@@ -1162,40 +1382,81 @@ public:
 
     MatShape computeColRowShape(const MatShape &inpShape, const MatShape &outShape) const CV_OVERRIDE
     {
+        int dims = inpShape.size();
         int inpCn = inpShape[1];
-        int inpH = inpShape[2];
-        int inpW = inpShape[3];
+        int inpD = dims == 5 ? inpShape[2] : 1;
+        int inpH = inpShape[dims - 2];
+        int inpW = inpShape.back();
         int outCn = outShape[1];
         int ngroups = inpCn / blobs[0].size[0];
         int outGroupCn = outCn / ngroups;
-        int ksize = outGroupCn * kernel.height * kernel.width;
-        return shape(ksize, inpH * inpW);
+        int ksize = outGroupCn * std::accumulate(kernel_size.begin(), kernel_size.end(),
+                                                 1, std::multiplies<size_t>());
+        return shape(ksize, inpD * inpH * inpW);
     }
 
     virtual bool supportBackend(int backendId) CV_OVERRIDE
     {
+        if (backendId == DNN_BACKEND_CUDA)
+        {
+            /* only deconvolution 2d and 3d supported */
+            if (kernel_size.size() == 2 || kernel_size.size() == 3)
+                return true;
+
+            return false;
+        }
+
 #ifdef HAVE_INF_ENGINE
+        const int outGroupCn = blobs[0].size[1];  // Weights are in IOHW or IODHW layout
+        const int group = numOutput / outGroupCn;
+
         if (backendId == DNN_BACKEND_INFERENCE_ENGINE)
         {
-            if (kernel_size.size() == 3)
-                CV_Error(Error::StsNotImplemented, "Unsupported deconvolution3D layer");
-
-            if (INF_ENGINE_RELEASE >= 2018050000 && (adjustPad.height || adjustPad.width))
+            if (kernel_size.size() == 3 && preferableTarget != DNN_TARGET_CPU) {
                 return false;
+            }
+
+            if (std::accumulate(adjust_pads.begin(), adjust_pads.end(), 0, std::plus<size_t>()) > 0)
+            {
+                if (padMode.empty())
+                {
+                    if (preferableTarget != DNN_TARGET_CPU && group != 1)
+                    {
+                        for (int i = 0; i < adjust_pads.size(); i++) {
+                            if (adjust_pads[i] && pads_begin[i])
+                                return false;
+                        }
+                    }
+                    for (int i = 0; i < adjust_pads.size(); i++) {
+                        if (pads_end[i] < adjust_pads[i])
+                            return false;
+                    }
+                    return true;
+                }
+                else if (padMode == "SAME")
+                {
+                    for (int i = 0; i < adjust_pads.size(); i++) {
+                        if (kernel_size[i] < pads_begin[i] + 1 + adjust_pads[i])
+                            return false;
+                    }
+                    return true;
+                }
+                else if (padMode == "VALID")
+                    return false;
+            }
 
-            const int outGroupCn = blobs[0].size[1];  // Weights are in IOHW layout
-            const int group = numOutput / outGroupCn;
             if (group != 1)
             {
                 return preferableTarget == DNN_TARGET_CPU;
             }
             if (preferableTarget == DNN_TARGET_OPENCL || preferableTarget == DNN_TARGET_OPENCL_FP16)
-                return dilation.width == 1 && dilation.height == 1;
+                return std::accumulate(dilations.begin(), dilations.end(), 1, std::multiplies<size_t>()) == 1;
             return true;
         }
         else
 #endif  // HAVE_INF_ENGINE
-            return kernel_size.size() == 2 && (backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE);
+            return backendId == DNN_BACKEND_CUDA ||
+            (kernel_size.size() == 2 && (backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE));
     }
 
     bool getMemoryShapes(const std::vector<MatShape> &inputs,
@@ -1257,7 +1518,7 @@ public:
             inpShape.push_back(inputs[0].size[i]);
             outShape.push_back(outputs[0].size[i]);
         }
-        getConvPoolPaddings(outShape, inpShape, kernel_size, strides, padMode, dilations, pads_begin, pads_end);
+        getConvPoolPaddings(outShape, kernel_size, strides, padMode, pads_begin, pads_end);
         if (pads_begin.size() == 2) {
             for (int i = 0; i < pads_begin.size(); i++) {
                 if (pads_begin[i] != pads_end[i])
@@ -1302,8 +1563,6 @@ public:
         {
             cv::add(biasesMat, b.reshape(1, numOutput), biasesMat);
         }
-
-        newWeightAndBias = !w.empty() || !b.empty();
     }
 
     class MatMulInvoker : public ParallelLoopBody
@@ -1571,14 +1830,15 @@ public:
 
         if (umat_weights.empty())
         {
-            if (newWeightAndBias)
-            {
+            if (fusedWeights)
                 weightsMat.copyTo(umat_weights);
+            else
+                transpose(blobs[0].reshape(1, inpCn), umat_weights);
+
+            if (fusedBias)
                 biasesMat.copyTo(umat_biases);
-            }
             else
             {
-                transpose(blobs[0].reshape(1, inpCn), umat_weights);
                 if (hasBias())
                     blobs[1].reshape(1, outCn).copyTo(umat_biases);
                 else
@@ -1722,6 +1982,67 @@ public:
         }
     }
 
+#ifdef HAVE_CUDA
+    Ptr<BackendNode> initCUDA(
+        void *context_,
+        const std::vector<Ptr<BackendWrapper>>& inputs,
+        const std::vector<Ptr<BackendWrapper>>& outputs
+    ) override
+    {
+        auto context = reinterpret_cast<csl::CSLContext*>(context_);
+
+        CV_Assert(inputs.size() == 1);
+        auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
+        auto input_shape = input_wrapper->getShape();
+
+        CV_Assert(outputs.size() == 1);
+        auto output_wrapper = outputs[0].dynamicCast<CUDABackendWrapper>();
+        auto output_shape = output_wrapper->getShape();
+
+        const auto output_feature_maps = numOutput;
+        const auto output_feature_maps_per_group = blobs[0].size[1];
+        const auto groups = output_feature_maps / output_feature_maps_per_group;
+
+        TransposeConvolutionConfiguration config;
+        config.kernel_size.assign(std::begin(kernel_size), std::end(kernel_size));
+        config.dilations.assign(std::begin(dilations), std::end(dilations));
+        config.strides.assign(std::begin(strides), std::end(strides));
+
+        if (padMode.empty())
+        {
+            config.padMode = TransposeConvolutionConfiguration::PaddingMode::MANUAL;
+            config.pads_begin.assign(std::begin(pads_begin), std::end(pads_begin));
+            config.pads_end.assign(std::begin(pads_end), std::end(pads_end));
+        }
+        else if (padMode == "VALID")
+        {
+            config.padMode = TransposeConvolutionConfiguration::PaddingMode::VALID;
+        }
+        else if (padMode == "SAME")
+        {
+            config.padMode = TransposeConvolutionConfiguration::PaddingMode::SAME;
+        }
+        else
+        {
+            CV_Error(Error::StsNotImplemented, padMode + " padding mode not supported by DeconvolutionLayer");
+        }
+
+        config.input_shape.assign(std::begin(input_shape), std::end(input_shape));
+        config.output_shape.assign(std::begin(output_shape), std::end(output_shape));
+        config.groups = groups;
+
+        CV_Assert(blobs.size() >= 1);
+        Mat filtersMat = fusedWeights ? weightsMat.t() : blobs[0];
+
+        Mat biasMat = (hasBias() || fusedBias) ? biasesMat : Mat();
+        if (countNonZero(biasMat) == 0)
+            biasMat = Mat();
+
+        return make_cuda_node<cuda4dnn::TransposeConvolutionOp>(
+            preferableTarget, std::move(context->stream), std::move(context->cudnn_handle), config, filtersMat, biasMat);
+    }
+#endif
+
     virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs) CV_OVERRIDE
     {
 #ifdef HAVE_HALIDE
@@ -1775,11 +2096,27 @@ public:
         return Ptr<BackendNode>();
     }
 
+#ifdef HAVE_INF_ENGINE
     virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> > &) CV_OVERRIDE
     {
-#ifdef HAVE_INF_ENGINE
-#if INF_ENGINE_VER_MAJOR_GE(INF_ENGINE_RELEASE_2018R5)
-        const int outGroupCn = blobs[0].size[1];  // Weights are in IOHW layout
+        InferenceEngine::Layout layout = blobs[0].dims == 5? InferenceEngine::Layout::NCDHW :
+                                                             InferenceEngine::Layout::OIHW;
+
+        auto ieWeights = wrapToInfEngineBlob(blobs[0], layout);
+        if (fusedWeights)
+        {
+            ieWeights = InferenceEngine::make_shared_blob<float>({
+                            InferenceEngine::Precision::FP32,
+                            ieWeights->getTensorDesc().getDims(), layout
+                        });
+            ieWeights->allocate();
+
+            int inpCn = blobs[0].size[0];
+            Mat newWeights = infEngineBlobToMat(ieWeights).reshape(1, inpCn);
+            transpose(weightsMat, newWeights);
+        }
+
+        const int outGroupCn = blobs[0].size[1];  // Weights are in IOHW or OIDHW layout
         const int group = numOutput / outGroupCn;
 
         InferenceEngine::Builder::DeconvolutionLayer ieLayer(name);
@@ -1788,59 +2125,33 @@ public:
         ieLayer.setStrides(strides);
         ieLayer.setDilation(dilations);
         ieLayer.setPaddingsBegin(pads_begin);
-        ieLayer.setPaddingsEnd(pads_end);
+
+        if (padMode.empty())
+        {
+            std::vector<size_t> paddings_end;
+            for (int i = 0; i < pads_end.size(); i++) {
+                paddings_end.push_back(pads_end[i] - adjust_pads[i]);
+            }
+            ieLayer.setPaddingsEnd(paddings_end);
+        }
+        else if (padMode == "SAME")
+        {
+            std::vector<size_t> paddings_end;
+            for (int i = 0; i < pads_begin.size(); i++) {
+                paddings_end.push_back(kernel_size[i] - pads_begin[i] - 1 - adjust_pads[i]);
+            }
+            ieLayer.setPaddingsEnd(paddings_end);
+        }
         ieLayer.setGroup((size_t)group);
         ieLayer.setOutDepth((size_t)numOutput);
 
         InferenceEngine::Builder::Layer l = ieLayer;
-        addConstantData("weights", wrapToInfEngineBlob(blobs[0], InferenceEngine::Layout::OIHW), l);
+        addConstantData("weights", ieWeights, l);
         if (hasBias())
-            addConstantData("biases", wrapToInfEngineBlob(blobs[1], {(size_t)numOutput}, InferenceEngine::Layout::C), l);
+            addConstantData("biases", wrapToInfEngineBlob(biasesMat, {(size_t)numOutput}, InferenceEngine::Layout::C), l);
         return Ptr<BackendNode>(new InfEngineBackendNode(l));
-#else
-        const int outGroupCn = blobs[0].size[1];  // Weights are in IOHW layout
-        const int group = numOutput / outGroupCn;
-
-        InferenceEngine::LayerParams lp;
-        lp.name = name;
-        lp.type = "Deconvolution";
-        lp.precision = InferenceEngine::Precision::FP32;
-        std::shared_ptr<InferenceEngine::DeconvolutionLayer> ieLayer(new InferenceEngine::DeconvolutionLayer(lp));
-
-#if INF_ENGINE_VER_MAJOR_GT(INF_ENGINE_RELEASE_2018R3)
-        ieLayer->_kernel.insert(InferenceEngine::X_AXIS, kernel.width);
-        ieLayer->_kernel.insert(InferenceEngine::Y_AXIS, kernel.height);
-        ieLayer->_stride.insert(InferenceEngine::X_AXIS, stride.width);
-        ieLayer->_stride.insert(InferenceEngine::Y_AXIS, stride.height);
-        ieLayer->_padding.insert(InferenceEngine::X_AXIS, pad.width);
-        ieLayer->_padding.insert(InferenceEngine::Y_AXIS, pad.height);
-        ieLayer->_pads_end.insert(InferenceEngine::X_AXIS, pad.width);
-        ieLayer->_pads_end.insert(InferenceEngine::Y_AXIS, pad.height);
-        ieLayer->_dilation.insert(InferenceEngine::X_AXIS, dilation.width);
-        ieLayer->_dilation.insert(InferenceEngine::Y_AXIS, dilation.height);
-#else
-        ieLayer->_kernel_x = kernel.width;
-        ieLayer->_kernel_y = kernel.height;
-        ieLayer->_stride_x = stride.width;
-        ieLayer->_stride_y = stride.height;
-        ieLayer->_padding_x = pad.width;
-        ieLayer->_padding_y = pad.height;
-        ieLayer->_dilation_x = dilation.width;
-        ieLayer->_dilation_y = dilation.height;
-#endif
-        ieLayer->_out_depth = numOutput;
-        ieLayer->_group = group;
-
-        ieLayer->_weights = wrapToInfEngineBlob(blobs[0], InferenceEngine::Layout::OIHW);
-        if (hasBias())
-        {
-            ieLayer->_biases = wrapToInfEngineBlob(blobs[1], {(size_t)numOutput}, InferenceEngine::Layout::C);
-        }
-        return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
-#endif
-#endif  // HAVE_INF_ENGINE
-        return Ptr<BackendNode>();
     }
+#endif  // HAVE_INF_ENGINE
 
     virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
                            const std::vector<MatShape> &outputs) const CV_OVERRIDE
@@ -1849,10 +2160,12 @@ public:
 
         float flops = 0;
         int outChannels = blobs[0].size[0];
+        size_t karea = std::accumulate(kernel_size.begin(), kernel_size.end(),
+                                       1, std::multiplies<size_t>());
 
         for (int i = 0; i < inputs.size(); i++)
         {
-            flops += CV_BIG_INT(2)*outChannels*kernel.area()*total(inputs[i]);
+            flops += CV_BIG_INT(2)*outChannels*karea*total(inputs[i]);
         }
 
         return flops;