Merge pull request #16575 from l-bat:flownet2
authorLiubov Batanina <piccione-mail@yandex.ru>
Tue, 19 May 2020 12:29:50 +0000 (15:29 +0300)
committerGitHub <noreply@github.com>
Tue, 19 May 2020 12:29:50 +0000 (12:29 +0000)
Support FlowNet2 model

* Support DataAugmentation layer

* Fix warnings

* Fix comments

* Support Correlation layer

* TEST

* Support Correlation layer

* Supported Accum and FlowWarp layers

* Supported ChannelNorm layer

* Supported Resample with inputs.size() > 1

* Fixed comments

* Refactoring

* Added tests

* Add resample test

* Added asserts in resize layer

* Updated DataAugmentation layer

* Update convolution layer

* Refactoring

* Fix data augmentation layer

* Fix caffe importer

* Fix resize

* Switch to Mat ptr

* Remove useless resize type

* Used ResizeLayer in Accum

* Split ChannelNormLayer

* Delete duplicate assert

* Add sample

* Fix sample

* Added colormap

modules/dnn/include/opencv2/dnn/all_layers.hpp
modules/dnn/src/caffe/caffe_importer.cpp
modules/dnn/src/init.cpp
modules/dnn/src/layers/accum_layer.cpp [new file with mode: 0644]
modules/dnn/src/layers/correlation_layer.cpp [new file with mode: 0644]
modules/dnn/src/layers/flow_warp_layer.cpp [new file with mode: 0644]
modules/dnn/src/layers/resize_layer.cpp
modules/dnn/src/layers/scale_layer.cpp
modules/dnn/test/test_layers.cpp
samples/dnn/optical_flow.py [new file with mode: 0644]

index 249f82d..e1df918 100644 (file)
@@ -556,6 +556,30 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
         static Ptr<Layer> create(const LayerParams& params);
     };
 
+    class CV_EXPORTS DataAugmentationLayer : public Layer
+    {
+    public:
+        static Ptr<DataAugmentationLayer> create(const LayerParams& params);
+    };
+
+    class CV_EXPORTS CorrelationLayer : public Layer
+    {
+    public:
+        static Ptr<CorrelationLayer> create(const LayerParams& params);
+    };
+
+    class CV_EXPORTS AccumLayer : public Layer
+    {
+    public:
+        static Ptr<AccumLayer> create(const LayerParams& params);
+    };
+
+    class CV_EXPORTS FlowWarpLayer : public Layer
+    {
+    public:
+        static Ptr<FlowWarpLayer> create(const LayerParams& params);
+    };
+
     class CV_EXPORTS PriorBoxLayer : public Layer
     {
     public:
index 16860a9..8673be0 100644 (file)
@@ -465,6 +465,35 @@ public:
                 net.mutable_layer(li)->mutable_bottom()->RemoveLast();
                 type = "Eltwise";
             }
+            else if (type == "Resample")
+            {
+                CV_Assert(layer.bottom_size() == 1 || layer.bottom_size() == 2);
+                type = "Resize";
+                String interp = layerParams.get<String>("type").toLowerCase();
+                layerParams.set("interpolation", interp == "linear" ? "bilinear" : interp);
+
+                if (layerParams.has("factor"))
+                {
+                    float factor = layerParams.get<float>("factor");
+                    CV_Assert(layer.bottom_size() != 2 || factor == 1.0);
+                    layerParams.set("zoom_factor", factor);
+
+                    if ((interp == "linear" && factor != 1.0) ||
+                        (interp == "nearest" && factor < 1.0))
+                        CV_Error(Error::StsNotImplemented, "Unsupported Resample mode");
+                }
+            }
+            else if ("Convolution" == type)
+            {
+                CV_Assert(layer.bottom_size() == layer.top_size());
+                for (int i = 0; i < layer.bottom_size(); i++)
+                {
+                    int conv_id = dstNet.addLayer(layer.top(i), type, layerParams);
+                    addInput(layer.bottom(i), conv_id, 0, dstNet);
+                    addedBlobs.push_back(BlobNote(layer.top(i), conv_id, 0));
+                }
+                continue;
+            }
             else if ("ConvolutionDepthwise" == type)
             {
                 type = "Convolution";
index df3a716..be4e115 100644 (file)
@@ -132,6 +132,10 @@ void initializeLayerFactory()
     CV_DNN_REGISTER_LAYER_CLASS(Padding,        PaddingLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Proposal,       ProposalLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Scale,          ScaleLayer);
+    CV_DNN_REGISTER_LAYER_CLASS(DataAugmentation, DataAugmentationLayer);
+    CV_DNN_REGISTER_LAYER_CLASS(Correlation,    CorrelationLayer);
+    CV_DNN_REGISTER_LAYER_CLASS(Accum,          AccumLayer);
+    CV_DNN_REGISTER_LAYER_CLASS(FlowWarp,       FlowWarpLayer);
 
     CV_DNN_REGISTER_LAYER_CLASS(LSTM,           LSTMLayer);
 }
diff --git a/modules/dnn/src/layers/accum_layer.cpp b/modules/dnn/src/layers/accum_layer.cpp
new file mode 100644 (file)
index 0000000..72bbf04
--- /dev/null
@@ -0,0 +1,141 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+// Copyright (C) 2020, Intel Corporation, all rights reserved.
+// Third party copyrights are property of their respective owners.
+
+#include "../precomp.hpp"
+#include "layers_common.hpp"
+
+
+namespace cv { namespace dnn {
+
+class AccumLayerImpl CV_FINAL : public AccumLayer
+{
+public:
+    AccumLayerImpl(const LayerParams& params)
+    {
+        setParamsFrom(params);
+        top_height = params.get<int>("top_height", 0);
+        top_width = params.get<int>("top_width", 0);
+        divisor = params.get<int>("size_divisible_by", 0);
+        have_reference = params.get<String>("have_reference", "false") == "true";
+    }
+
+    virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                                 const int requiredOutputs,
+                                 std::vector<MatShape> &outputs,
+                                 std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        std::vector<int> outShape;
+        int batch = inputs[0][0];
+        outShape.push_back(batch);
+
+        if (have_reference)
+        {
+            CV_Assert(inputs.size() >= 2);
+            int totalchannels = 0;
+            for (int i = 0; i < inputs.size() - 1; i++) {
+                CV_Assert(inputs[i][0] == batch);
+                totalchannels += inputs[i][1];
+            }
+            outShape.push_back(totalchannels);
+
+            int height = inputs.back()[2];
+            int width = inputs.back()[3];
+
+            outShape.push_back(height);
+            outShape.push_back(width);
+        }
+        else
+        {
+            int maxwidth = -1;
+            int maxheight = -1;
+            int totalchannels = 0;
+
+            // Find largest blob size and count total channels
+            for (int i = 0; i < inputs.size(); ++i)
+            {
+                totalchannels += inputs[i][1];
+                maxheight = std::max(maxheight, inputs[i][2]);
+                maxwidth = std::max(maxwidth, inputs[i][3]);
+                CV_Assert(inputs[i][0] == batch);
+            }
+            outShape.push_back(totalchannels);
+
+            int out_h = divisor ? static_cast<int>(ceil(maxheight / divisor) * divisor) : top_height;
+            int out_w = divisor ? static_cast<int>(ceil(maxwidth / divisor) * divisor) : top_width;
+
+            // Layer can specify custom top size which is larger than default
+            if (out_h <= maxheight || out_w <= maxwidth)
+            {
+                out_h = maxheight;
+                out_w = maxwidth;
+            }
+
+            outShape.push_back(out_h);
+            outShape.push_back(out_w);
+        }
+
+        outputs.assign(1, outShape);
+        return false;
+    }
+
+    virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE
+    {
+        LayerParams resizeParams;
+        resizeParams.set("interpolation", "bilinear");
+        resizeParams.set("align_corners", true);
+        resize = ResizeLayer::create(resizeParams);
+    }
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        std::vector<Mat> inputs, outputs;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+
+        const int out_h = outputs[0].size[2];
+        const int out_w = outputs[0].size[3];
+        float* out_data = outputs[0].ptr<float>();
+        std::vector<int> sizes(&outputs[0].size[0], &outputs[0].size[0] + outputs[0].size.dims());
+        for (int i = 0; i < inputs.size() - have_reference; i++)
+        {
+            sizes[1] = inputs[i].size[1];
+            Mat outSlice(sizes, CV_32F, out_data);
+
+            if (out_h == inputs[i].size[2] && out_w == inputs[i].size[3])
+            {
+                inputs[i].copyTo(outSlice);
+            }
+            else
+            {
+                std::vector<Mat> inp_slices, out_slices;
+                inp_slices.push_back(inputs[i]);
+                out_slices.push_back(outSlice);
+
+                resize->finalize(inp_slices, out_slices);
+                resize->forward(inp_slices, out_slices, internals_arr);
+            }
+            out_data += outSlice.total(1);
+        }
+    }
+
+private:
+    int top_height;
+    int top_width;
+    int divisor;
+    bool have_reference;
+    Ptr<ResizeLayer> resize;
+};
+
+Ptr<AccumLayer> AccumLayer::create(const LayerParams& params)
+{
+    return Ptr<AccumLayer>(new AccumLayerImpl(params));
+}
+
+}}  // namespace cv::dnn
diff --git a/modules/dnn/src/layers/correlation_layer.cpp b/modules/dnn/src/layers/correlation_layer.cpp
new file mode 100644 (file)
index 0000000..cfb3b8e
--- /dev/null
@@ -0,0 +1,207 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+// Copyright (C) 2020, Intel Corporation, all rights reserved.
+// Third party copyrights are property of their respective owners.
+
+#include "../precomp.hpp"
+#include "layers_common.hpp"
+
+
+namespace cv { namespace dnn {
+
+class CorrelationLayerImpl CV_FINAL : public CorrelationLayer
+{
+public:
+    CorrelationLayerImpl(const LayerParams& params)
+    {
+        setParamsFrom(params);
+        pad = params.get<int>("pad", 0);
+        CV_Assert_N(params.has("kernel_size"), params.has("max_displacement"));
+        max_displacement = params.get<int>("max_displacement");
+        kernel = params.get<int>("kernel_size");
+        if (kernel % 2 == 0)
+            CV_Error(Error::StsNotImplemented, "Odd kernel size required.");
+
+        stride_1 = params.get<int>("stride_1", 1);
+        stride_2 = params.get<int>("stride_2", 1);
+    }
+
+    virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                                 const int requiredOutputs,
+                                 std::vector<MatShape> &outputs,
+                                 std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        CV_Assert_N(inputs.size() == 2, inputs[0].size() == 4, inputs[1].size() == 4);
+
+        int padded_height = inputs[0][2] + 2 * pad;
+        int padded_width  = inputs[0][3] + 2 * pad;
+
+        int kernel_radius = (kernel - 1) / 2;
+        int border_size = max_displacement + kernel_radius;
+
+        int neighborhood_grid_radius = max_displacement / stride_2;
+        int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
+
+        std::vector<int> outShape;
+
+        int num = inputs[0][0];
+        outShape.push_back(num);
+
+        int out_c = neighborhood_grid_width * neighborhood_grid_width;
+        outShape.push_back(out_c);
+
+        int out_h = ceil(static_cast<float>(padded_height - border_size * 2) / stride_1);
+        int out_w = ceil(static_cast<float>(padded_width - border_size * 2)  / stride_1);
+        CV_Assert_N(out_h >= 1, out_w >= 1);
+
+        outShape.push_back(out_h);
+        outShape.push_back(out_w);
+        outputs.assign(1, outShape);
+        return false;
+    }
+
+    virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE
+    {
+        std::vector<Mat> inputs;
+        inputs_arr.getMatVector(inputs);
+
+        int padded_height = inputs[0].size[2] + 2 * pad;
+        int padded_width  = inputs[0].size[3] + 2 * pad;
+
+        int size[] = {inputs[0].size[0], padded_height, padded_width, inputs[0].size[1]};
+        rbot0 = Mat(4, &size[0], CV_32F, float(0));
+        rbot1 = Mat(4, &size[0], CV_32F, float(0));
+    }
+
+    void blobRearrangeKernel2(const Mat& input, Mat& output)
+    {
+        const int num      = input.size[0];
+        const int channels = input.size[1];
+        const int height   = input.size[2];
+        const int width    = input.size[3];
+        const int area     = height * width;
+        const int pad_area = (width + 2 * pad) * (height + 2 * pad);
+
+        const float* in = input.ptr<float>();
+        float* out = output.ptr<float>();
+        for (int n = 0; n < num; n++)
+        {
+            for (int ch = 0; ch < channels; ch++)
+            {
+                for (int xy = 0; xy < area; xy++)
+                {
+                    float value = in[(n * channels + ch) * area + xy];
+                    int xpad  = (xy % width + pad);
+                    int ypad  = (xy / width + pad);
+                    int xypad = ypad * (width + 2 * pad) + xpad;
+                    out[(n * pad_area + xypad) * channels + ch] = value;
+                }
+            }
+        }
+    }
+
+    void correlationKernelSubtraction(const Mat& input0, const Mat& input1, Mat& output, int item)
+    {
+        const int inp_h = input0.size[1];
+        const int inp_w = input0.size[2];
+        const int inp_c = input0.size[3];
+
+        const int out_c = output.size[1];
+        const int out_h = output.size[2];
+        const int out_w = output.size[3];
+
+        int topcount = output.total(1);
+        int neighborhood_grid_radius = max_displacement / stride_2;
+        int neighborhood_grid_width  = neighborhood_grid_radius * 2 + 1;
+
+        const float* inp0_data = input0.ptr<float>();
+        const float* inp1_data = input1.ptr<float>();
+        float* out_data  = output.ptr<float>();
+        int sumelems = kernel * kernel * inp_c;
+        std::vector<float> patch_data(sumelems, 0);
+        for (int y = 0; y < out_h; y++)
+        {
+            for (int x = 0; x < out_w; x++)
+            {
+                int x1 = x * stride_1 + max_displacement;
+                int y1 = y * stride_1 + max_displacement;
+
+                for (int j = 0; j < kernel; j++)
+                {
+                    for (int i = 0; i < kernel; i++)
+                    {
+                        int ji_off = ((j * kernel) + i) * inp_c;
+                        for (int ch = 0; ch < inp_c; ch++)
+                        {
+                            int idx1 = ((item * inp_h + y1 + j) * inp_w + x1 + i) * inp_c + ch;
+                            int idxPatchData = ji_off + ch;
+                            patch_data[idxPatchData] = inp0_data[idx1];
+                        }
+                    }
+                }
+
+                for (int out_ch = 0; out_ch < out_c; out_ch++)
+                {
+                    float sum = 0;
+                    int s2o = (out_ch % neighborhood_grid_width - neighborhood_grid_radius) * stride_2;
+                    int s2p = (out_ch / neighborhood_grid_width - neighborhood_grid_radius) * stride_2;
+
+                    int x2 = x1 + s2o;
+                    int y2 = y1 + s2p;
+                    for (int j = 0; j < kernel; j++)
+                    {
+                        for (int i = 0; i < kernel; i++)
+                        {
+                            int ji_off = ((j * kernel) + i) * inp_c;
+                            for (int ch = 0; ch < inp_c; ch++)
+                            {
+                                int idxPatchData = ji_off + ch;
+                                int idx2 = ((item * inp_h + y2 + j) * inp_w + x2 + i) * inp_c + ch;
+                                sum += patch_data[idxPatchData] * inp1_data[idx2];
+                            }
+                        }
+                    }
+                    int index = ((out_ch * out_h + y) * out_w) + x;
+                    out_data[index + item * topcount] = static_cast<float>(sum) / sumelems;
+                }
+            }
+        }
+    }
+
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        std::vector<Mat> inputs, outputs, internals;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+        internals_arr.getMatVector(internals);
+
+        blobRearrangeKernel2(inputs[0], rbot0);
+        blobRearrangeKernel2(inputs[1], rbot1);
+        for (int i = 0; i < inputs[0].size[0]; i++)
+        {
+            correlationKernelSubtraction(rbot0, rbot1, outputs[0], i);
+        }
+    }
+
+private:
+    int pad;
+    int kernel;
+    int max_displacement;
+    int stride_1;
+    int stride_2;
+    Mat rbot0;
+    Mat rbot1;
+};
+
+Ptr<CorrelationLayer> CorrelationLayer::create(const LayerParams& params)
+{
+    return Ptr<CorrelationLayer>(new CorrelationLayerImpl(params));
+}
+
+}}  // namespace cv::dnn
diff --git a/modules/dnn/src/layers/flow_warp_layer.cpp b/modules/dnn/src/layers/flow_warp_layer.cpp
new file mode 100644 (file)
index 0000000..5d0f9f4
--- /dev/null
@@ -0,0 +1,117 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+// Copyright (C) 2020, Intel Corporation, all rights reserved.
+// Third party copyrights are property of their respective owners.
+
+#include "../precomp.hpp"
+#include "layers_common.hpp"
+
+
+namespace cv { namespace dnn {
+
+class FlowWarpLayerImpl CV_FINAL : public FlowWarpLayer
+{
+public:
+    FlowWarpLayerImpl(const LayerParams& params)
+    {
+        setParamsFrom(params);
+        String fill_string = params.get<String>("FillParameter", "ZERO").toLowerCase();
+        if (fill_string != "zero")
+            CV_Error(Error::StsNotImplemented, "Only zero filling supported.");
+        fill_value = 0;
+    }
+
+    virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                                 const int requiredOutputs,
+                                 std::vector<MatShape> &outputs,
+                                 std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        CV_Assert(inputs.size() == 2);
+        CV_Assert_N(inputs[0][0] == inputs[1][0], inputs[1][1] == 2,
+                    inputs[0][2] == inputs[1][2], inputs[0][3] == inputs[1][3]);
+
+        outputs.assign(1, inputs[0]);
+        return false;
+    }
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        std::vector<Mat> inputs, outputs;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+
+        const int out_n = outputs[0].size[0];
+        const int out_c = outputs[0].size[1];
+        const int out_h = outputs[0].size[2];
+        const int out_w = outputs[0].size[3];
+
+        const int area = out_w * out_h;
+        const int total = area * out_c;
+
+        const float* image_data = inputs[0].ptr<float>();
+        const float* flow_data  = inputs[1].ptr<float>();
+        float* out_data = outputs[0].ptr<float>();
+
+        for (int n = 0; n < out_n; n++)
+        {
+            int off = total * n;
+            for (int x = 0; x < out_w; x++)
+            {
+                for (int y = 0; y < out_h; y++)
+                {
+                    int idx = 2 * area * n + y * out_w + x;
+                    float fx = flow_data[idx];
+                    float fy = flow_data[idx + area];
+
+                    float x2 = x + fx;
+                    float y2 = y + fy;
+
+                    if (x2 >= 0 && y2 >= 0 && x2 < out_w && y2 < out_h)
+                    {
+                        int ix2_L = x2;
+                        float alpha = x2 - ix2_L;
+
+                        int iy2_T = y2;
+                        float beta = y2 - iy2_T;
+
+                        int ix2_R = std::min(ix2_L + 1, out_w - 1);
+                        int iy2_B = std::min(iy2_T + 1, out_h - 1);
+
+                        for (int c = 0; c < out_c; c++)
+                        {
+                            float TL = image_data[off + c * area + iy2_T * out_w + ix2_L];
+                            float TR = image_data[off + c * area + iy2_T * out_w + ix2_R];
+                            float BL = image_data[off + c * area + iy2_B * out_w + ix2_L];
+                            float BR = image_data[off + c * area + iy2_B * out_w + ix2_R];
+
+                            out_data[off + c * area + y * out_w + x] = (1 - alpha) * (1 - beta) * TL +
+                                                                       (1 - alpha) * beta       * BL +
+                                                                        alpha      * (1 - beta) * TR +
+                                                                        alpha      * beta       * BR;
+                        }
+                    }
+                    else
+                    {
+                        for (int c = 0; c < out_c; c++)
+                            out_data[off + c * area + y * out_w + x] = fill_value;
+                    }
+                }
+            }
+        }
+    }
+
+private:
+    float fill_value;
+};
+
+Ptr<FlowWarpLayer> FlowWarpLayer::create(const LayerParams& params)
+{
+    return Ptr<FlowWarpLayer>(new FlowWarpLayerImpl(params));
+}
+
+}}  // namespace cv::dnn
index 09e68ee..3679c9e 100644 (file)
@@ -45,10 +45,15 @@ public:
                          std::vector<MatShape> &outputs,
                          std::vector<MatShape> &internals) const CV_OVERRIDE
     {
-        CV_Assert_N(inputs.size() == 1, inputs[0].size() == 4);
+        CV_Assert_N(inputs.size() == 1 || inputs.size() == 2, inputs[0].size() == 4);
         outputs.resize(1, inputs[0]);
-        outputs[0][2] = zoomFactorHeight > 0 ? (outputs[0][2] * zoomFactorHeight) : outHeight;
-        outputs[0][3] = zoomFactorWidth > 0 ? (outputs[0][3] * zoomFactorWidth) : outWidth;
+        if (inputs.size() == 1) {
+            outputs[0][2] = zoomFactorHeight > 0 ? (outputs[0][2] * zoomFactorHeight) : outHeight;
+            outputs[0][3] = zoomFactorWidth > 0 ? (outputs[0][3] * zoomFactorWidth) : outWidth;
+        } else {
+            outputs[0][2] = inputs[1][2];
+            outputs[0][3] = inputs[1][3];
+        }
         // We can work in-place (do nothing) if input shape == output shape.
         return (outputs[0][2] == inputs[0][2]) && (outputs[0][3] == inputs[0][3]);
     }
index a53618f..27aefd9 100644 (file)
@@ -307,5 +307,118 @@ Ptr<Layer> ShiftLayer::create(const LayerParams& params)
     return Ptr<ScaleLayer>(new ScaleLayerImpl(scaleParams));
 }
 
+class DataAugmentationLayerImpl CV_FINAL : public DataAugmentationLayer
+{
+public:
+    DataAugmentationLayerImpl(const LayerParams& params)
+    {
+        setParamsFrom(params);
+        recompute_mean = params.get<int>("recompute_mean", 1);
+        CV_CheckGT(recompute_mean, 0, "");
+        mean_per_pixel = params.get<bool>("mean_per_pixel", false);
+    }
+
+    bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                         const int requiredOutputs,
+                         std::vector<MatShape> &outputs,
+                         std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        CV_Assert_N(inputs.size() == 1, blobs.size() == 3);
+        CV_Assert_N(blobs[0].total() == 1, blobs[1].total() == total(inputs[0], 1),
+                    blobs[2].total() == inputs[0][1]);
+
+        outputs.assign(1, inputs[0]);
+        return true;
+    }
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        std::vector<Mat> inputs, outputs;
+        inputs_arr.getMatVector(inputs);
+        outputs_arr.getMatVector(outputs);
+
+        CV_Assert_N(outputs.size() == 1, blobs.size() == 3, inputs.size() == 1);
+        int num_iter = 0;
+
+        float* inpData = inputs[0].ptr<float>();
+        float* outData = outputs[0].ptr<float>();
+
+        Mat data_mean_cpu = blobs[1].clone();
+        Mat data_mean_per_channel_cpu = blobs[2].clone();
+
+        const int numWeights = data_mean_cpu.total();
+        CV_Assert(numWeights != 0);
+
+        ++num_iter;
+        if (num_iter <= recompute_mean)
+        {
+            data_mean_cpu *= (num_iter - 1);
+            const int batch = inputs[0].size[0];
+            float alpha = 1.0 / batch;
+
+            for (int i = 0; i < batch; ++i)
+            {
+                Mat inpSlice(1, numWeights, CV_32F, inpData);
+                inpSlice = alpha * inpSlice;
+
+                add(data_mean_cpu.reshape(1, 1), inpSlice, data_mean_cpu.reshape(1, 1));
+                inpData += numWeights;
+            }
+            data_mean_cpu *= (1.0 / num_iter);
+
+            int newsize[] = {blobs[1].size[1], (int)blobs[1].total(2)};
+            reduce(data_mean_cpu.reshape(1, 2, &newsize[0]), data_mean_per_channel_cpu, 1, REDUCE_SUM, CV_32F);
+
+            int area = blobs[1].total(2);
+            data_mean_per_channel_cpu *= (1.0 / area);
+        }
+
+        MatShape inpShape = shape(inputs[0]);
+
+        inpData = inputs[0].ptr<float>();
+        if (mean_per_pixel)
+        {
+            int numSlices = inputs[0].size[0];
+            for (int i = 0; i < numSlices; ++i)
+            {
+                Mat inpSlice(1, numWeights, CV_32F, inpData);
+                Mat outSlice(1, numWeights, CV_32F, outData);
+
+                add(inpSlice, (-1) * data_mean_cpu, outSlice);
+                inpData += numWeights;
+                outData += numWeights;
+            }
+        }
+        else
+        {
+            int numSlices = inpShape[1];
+            int count = numWeights / numSlices;
+
+            for (int i = 0; i < numSlices; ++i)
+            {
+                Mat inpSlice(1, count, CV_32F, inpData);
+                Mat outSlice(1, count, CV_32F, outData);
+                float coeff = data_mean_per_channel_cpu.reshape(1, 1).at<float>(0, i);
+                outSlice = inpSlice - coeff;
+
+                inpData += count;
+                outData += count;
+            }
+        }
+    }
+
+private:
+    int recompute_mean;
+    bool mean_per_pixel;
+};
+
+Ptr<DataAugmentationLayer> DataAugmentationLayer::create(const LayerParams& params)
+{
+    return Ptr<DataAugmentationLayer>(new DataAugmentationLayerImpl(params));
+}
+
 }  // namespace dnn
 }  // namespace cv
index c31b9f3..88f44d3 100644 (file)
@@ -97,29 +97,68 @@ class Test_Caffe_layers : public DNNTestLayer
 {
 public:
     void testLayerUsingCaffeModels(const String& basename, bool useCaffeModel = false,
-                                   bool useCommonInputBlob = true, double l1 = 0.0,
-                                   double lInf = 0.0)
+                                   bool useCommonInputBlob = true, double l1 = 0.0, double lInf = 0.0,
+                                   int numInps = 1, int numOuts = 1)
     {
+        CV_Assert_N(numInps >= 1, numInps <= 10, numOuts >= 1, numOuts <= 10);
         String prototxt = _tf(basename + ".prototxt");
         String caffemodel = _tf(basename + ".caffemodel");
 
-        String inpfile = (useCommonInputBlob) ? _tf("blob.npy") : _tf(basename + ".input.npy");
-        String outfile = _tf(basename + ".npy");
+        std::vector<Mat> inps, refs, outs;
 
-        Mat inp = blobFromNPY(inpfile);
-        Mat ref = blobFromNPY(outfile);
-        checkBackend(&inp, &ref);
+        if (numInps > 1)
+        {
+            for (int i = 0; i < numInps; i++)
+            {
+                String inpfile = _tf(basename + ".input_" + (i + '0') + ".npy");
+                inps.push_back(blobFromNPY(inpfile));
+            }
+        }
+        else
+        {
+            String inpfile = (useCommonInputBlob) ? _tf("blob.npy") : _tf(basename + ".input.npy");
+            inps.push_back(blobFromNPY(inpfile));
+        }
+
+        if (numOuts > 1)
+        {
+            for (int i = 0; i < numOuts; i++)
+            {
+                String outfile = _tf(basename + "_" + (i + '0') + ".npy");
+                refs.push_back(blobFromNPY(outfile));
+            }
+        }
+        else
+        {
+            String outfile = _tf(basename + ".npy");
+            refs.push_back(blobFromNPY(outfile));
+        }
 
         Net net = readNetFromCaffe(prototxt, (useCaffeModel) ? caffemodel : String());
         ASSERT_FALSE(net.empty());
+        checkBackend(&inps[0], &refs[0]);
 
         net.setPreferableBackend(backend);
         net.setPreferableTarget(target);
 
-        net.setInput(inp, "input");
-        Mat out = net.forward("output");
+        String inp_name = "input";
+        if (numInps > 1)
+        {
+            for (int i = 0; i < numInps; i++)
+            {
+                net.setInput(inps[i], inp_name + "_" + (i + '0'));
+            }
+        }
+        else
+        {
+            net.setInput(inps.back(), inp_name);
+        }
 
-        normAssert(ref, out, "", l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
+        net.forward(outs);
+        for (int i = 0; i < refs.size(); i++)
+        {
+            normAssert(refs[i], outs[i], "", l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
+        }
     }
 };
 
@@ -568,6 +607,58 @@ TEST_F(Layer_RNN_Test, get_set_test)
     EXPECT_EQ(shape(outputs[1]), shape(nT, nS, nH));
 }
 
+TEST_P(Test_Caffe_layers, Accum)
+{
+    if (backend == DNN_BACKEND_OPENCV && target != DNN_TARGET_CPU)
+        applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL, CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
+
+    testLayerUsingCaffeModels("accum", false, false, 0.0, 0.0, 2);
+    testLayerUsingCaffeModels("accum_ref", false, false, 0.0, 0.0, 2);
+}
+
+TEST_P(Test_Caffe_layers, FlowWarp)
+{
+    if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
+        applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
+
+    testLayerUsingCaffeModels("flow_warp", false, false, 0.0, 0.0, 2);
+}
+
+TEST_P(Test_Caffe_layers, ChannelNorm)
+{
+    if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
+        applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
+    testLayerUsingCaffeModels("channel_norm", false, false);
+}
+
+TEST_P(Test_Caffe_layers, DataAugmentation)
+{
+    if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
+        applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
+    testLayerUsingCaffeModels("data_augmentation", true, false);
+}
+
+TEST_P(Test_Caffe_layers, Resample)
+{
+    if (backend != DNN_BACKEND_OPENCV)
+        applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
+    testLayerUsingCaffeModels("nearest_2inps", false, false, 0.0, 0.0, 2);
+    testLayerUsingCaffeModels("nearest", false, false);
+}
+
+TEST_P(Test_Caffe_layers, Correlation)
+{
+    if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
+        applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER,
+                     CV_TEST_TAG_DNN_SKIP_OPENCL, CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
+    testLayerUsingCaffeModels("correlation", false, false, 0.0, 0.0, 2);
+}
+
+TEST_P(Test_Caffe_layers, Convolution2Inputs)
+{
+    testLayerUsingCaffeModels("conv_2_inps", true, false, 0.0, 0.0, 2);
+}
+
 TEST_P(Test_Caffe_layers, ROIPooling_Accuracy)
 {
     Net net = readNetFromCaffe(_tf("net_roi_pooling.prototxt"));
diff --git a/samples/dnn/optical_flow.py b/samples/dnn/optical_flow.py
new file mode 100644 (file)
index 0000000..5d0d831
--- /dev/null
@@ -0,0 +1,85 @@
+#!/usr/bin/env python
+'''
+This sample using FlowNet v2 model to calculate optical flow.
+Original paper: https://arxiv.org/abs/1612.01925.
+Original repo:  https://github.com/lmb-freiburg/flownet2.
+
+Download the converted .caffemodel model from https://drive.google.com/open?id=16qvE9VNmU39NttpZwZs81Ga8VYQJDaWZ
+and .prototxt from https://drive.google.com/open?id=19bo6SWU2p8ZKvjXqMKiCPdK8mghwDy9b.
+Otherwise download original model from https://lmb.informatik.uni-freiburg.de/resources/binaries/flownet2/flownet2-models.tar.gz,
+convert .h5 model to .caffemodel and modify original .prototxt using .prototxt from link above.
+'''
+
+import argparse
+import os.path
+import numpy as np
+import cv2 as cv
+
+
+class OpticalFlow(object):
+    def __init__(self, proto, model, height, width):
+        self.net = cv.dnn.readNet(proto, model)
+        self.net.setPreferableBackend(cv.dnn.DNN_BACKEND_OPENCV)
+        self.height = height
+        self.width = width
+
+    def compute_flow(self, first_img, second_img):
+        inp0 = cv.dnn.blobFromImage(first_img, size=(self.width, self.height))
+        inp1 = cv.dnn.blobFromImage(second_img, size=(self.width, self.height))
+        self.net.setInput(inp0, "img0")
+        self.net.setInput(inp1, "img1")
+        flow = self.net.forward()
+        output = self.motion_to_color(flow)
+        return output
+
+    def motion_to_color(self, flow):
+        arr = np.arange(0, 255, dtype=np.uint8)
+        colormap = cv.applyColorMap(arr, cv.COLORMAP_HSV)
+        colormap = colormap.squeeze(1)
+
+        flow = flow.squeeze(0)
+        fx, fy = flow[0, ...], flow[1, ...]
+        rad = np.sqrt(fx**2 + fy**2)
+        maxrad = rad.max() if rad.max() != 0 else 1
+
+        ncols = arr.size
+        rad = rad[..., np.newaxis] / maxrad
+        a = np.arctan2(-fy / maxrad, -fx / maxrad) / np.pi
+        fk = (a + 1) / 2.0 * (ncols - 1)
+        k0 = fk.astype(np.int)
+        k1 = (k0 + 1) % ncols
+        f = fk[..., np.newaxis] - k0[..., np.newaxis]
+
+        col0 = colormap[k0] / 255.0
+        col1 = colormap[k1] / 255.0
+        col = (1 - f) * col0 + f * col1
+        col = np.where(rad <= 1, 1 - rad * (1 - col), col * 0.75)
+        output = (255.0 * col).astype(np.uint8)
+        return output
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Use this script to calculate optical flow using FlowNetv2',
+                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('-input', '-i', required=True, help='Path to input video file. Skip this argument to capture frames from a camera.')
+    parser.add_argument('--height', default=320, help='Input height')
+    parser.add_argument('--width',  default=448, help='Input width')
+    parser.add_argument('--proto', '-p', default='FlowNet2_deploy.prototxt', help='Path to prototxt.')
+    parser.add_argument('--model', '-m', default='FlowNet2_weights.caffemodel', help='Path to caffemodel.')
+    args, _ = parser.parse_known_args()
+
+    if not os.path.isfile(args.model) or not os.path.isfile(args.proto):
+        raise OSError("Prototxt or caffemodel not exist")
+
+    winName = 'Calculation optical flow in OpenCV'
+    cv.namedWindow(winName, cv.WINDOW_NORMAL)
+    cap = cv.VideoCapture(args.input if args.input else 0)
+    hasFrame, first_frame = cap.read()
+    opt_flow = OpticalFlow(args.proto, args.model, args.height, args.width)
+    while cv.waitKey(1) < 0:
+        hasFrame, second_frame = cap.read()
+        if not hasFrame:
+            break
+        flow = opt_flow.compute_flow(first_frame, second_frame)
+        first_frame = second_frame
+        cv.imshow(winName, flow)