From d991c220903270253518e9624d0ec5f9d70f64b1 Mon Sep 17 00:00:00 2001 From: Liubov Batanina Date: Tue, 19 May 2020 15:29:50 +0300 Subject: [PATCH] Merge pull request #16575 from l-bat:flownet2 Support FlowNet2 model * Support DataAugmentation layer * Fix warnings * Fix comments * Support Correlation layer * TEST * Support Correlation layer * Supported Accum and FlowWarp layers * Supported ChannelNorm layer * Supported Resample with inputs.size() > 1 * Fixed comments * Refactoring * Added tests * Add resample test * Added asserts in resize layer * Updated DataAugmentation layer * Update convolution layer * Refactoring * Fix data augmentation layer * Fix caffe importer * Fix resize * Switch to Mat ptr * Remove useless resize type * Used ResizeLayer in Accum * Split ChannelNormLayer * Delete duplicate assert * Add sample * Fix sample * Added colormap --- modules/dnn/include/opencv2/dnn/all_layers.hpp | 24 +++ modules/dnn/src/caffe/caffe_importer.cpp | 29 ++++ modules/dnn/src/init.cpp | 4 + modules/dnn/src/layers/accum_layer.cpp | 141 +++++++++++++++++ modules/dnn/src/layers/correlation_layer.cpp | 207 +++++++++++++++++++++++++ modules/dnn/src/layers/flow_warp_layer.cpp | 117 ++++++++++++++ modules/dnn/src/layers/resize_layer.cpp | 11 +- modules/dnn/src/layers/scale_layer.cpp | 113 ++++++++++++++ modules/dnn/test/test_layers.cpp | 111 +++++++++++-- samples/dnn/optical_flow.py | 85 ++++++++++ 10 files changed, 829 insertions(+), 13 deletions(-) create mode 100644 modules/dnn/src/layers/accum_layer.cpp create mode 100644 modules/dnn/src/layers/correlation_layer.cpp create mode 100644 modules/dnn/src/layers/flow_warp_layer.cpp create mode 100644 samples/dnn/optical_flow.py diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 249f82d..e1df918 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -556,6 +556,30 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN static Ptr create(const LayerParams& params); }; + class CV_EXPORTS DataAugmentationLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + + class CV_EXPORTS CorrelationLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + + class CV_EXPORTS AccumLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + + class CV_EXPORTS FlowWarpLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + class CV_EXPORTS PriorBoxLayer : public Layer { public: diff --git a/modules/dnn/src/caffe/caffe_importer.cpp b/modules/dnn/src/caffe/caffe_importer.cpp index 16860a9..8673be0 100644 --- a/modules/dnn/src/caffe/caffe_importer.cpp +++ b/modules/dnn/src/caffe/caffe_importer.cpp @@ -465,6 +465,35 @@ public: net.mutable_layer(li)->mutable_bottom()->RemoveLast(); type = "Eltwise"; } + else if (type == "Resample") + { + CV_Assert(layer.bottom_size() == 1 || layer.bottom_size() == 2); + type = "Resize"; + String interp = layerParams.get("type").toLowerCase(); + layerParams.set("interpolation", interp == "linear" ? "bilinear" : interp); + + if (layerParams.has("factor")) + { + float factor = layerParams.get("factor"); + CV_Assert(layer.bottom_size() != 2 || factor == 1.0); + layerParams.set("zoom_factor", factor); + + if ((interp == "linear" && factor != 1.0) || + (interp == "nearest" && factor < 1.0)) + CV_Error(Error::StsNotImplemented, "Unsupported Resample mode"); + } + } + else if ("Convolution" == type) + { + CV_Assert(layer.bottom_size() == layer.top_size()); + for (int i = 0; i < layer.bottom_size(); i++) + { + int conv_id = dstNet.addLayer(layer.top(i), type, layerParams); + addInput(layer.bottom(i), conv_id, 0, dstNet); + addedBlobs.push_back(BlobNote(layer.top(i), conv_id, 0)); + } + continue; + } else if ("ConvolutionDepthwise" == type) { type = "Convolution"; diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index df3a716..be4e115 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -132,6 +132,10 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(Padding, PaddingLayer); CV_DNN_REGISTER_LAYER_CLASS(Proposal, ProposalLayer); CV_DNN_REGISTER_LAYER_CLASS(Scale, ScaleLayer); + CV_DNN_REGISTER_LAYER_CLASS(DataAugmentation, DataAugmentationLayer); + CV_DNN_REGISTER_LAYER_CLASS(Correlation, CorrelationLayer); + CV_DNN_REGISTER_LAYER_CLASS(Accum, AccumLayer); + CV_DNN_REGISTER_LAYER_CLASS(FlowWarp, FlowWarpLayer); CV_DNN_REGISTER_LAYER_CLASS(LSTM, LSTMLayer); } diff --git a/modules/dnn/src/layers/accum_layer.cpp b/modules/dnn/src/layers/accum_layer.cpp new file mode 100644 index 0000000..72bbf04 --- /dev/null +++ b/modules/dnn/src/layers/accum_layer.cpp @@ -0,0 +1,141 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Copyright (C) 2020, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. + +#include "../precomp.hpp" +#include "layers_common.hpp" + + +namespace cv { namespace dnn { + +class AccumLayerImpl CV_FINAL : public AccumLayer +{ +public: + AccumLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + top_height = params.get("top_height", 0); + top_width = params.get("top_width", 0); + divisor = params.get("size_divisible_by", 0); + have_reference = params.get("have_reference", "false") == "true"; + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + std::vector outShape; + int batch = inputs[0][0]; + outShape.push_back(batch); + + if (have_reference) + { + CV_Assert(inputs.size() >= 2); + int totalchannels = 0; + for (int i = 0; i < inputs.size() - 1; i++) { + CV_Assert(inputs[i][0] == batch); + totalchannels += inputs[i][1]; + } + outShape.push_back(totalchannels); + + int height = inputs.back()[2]; + int width = inputs.back()[3]; + + outShape.push_back(height); + outShape.push_back(width); + } + else + { + int maxwidth = -1; + int maxheight = -1; + int totalchannels = 0; + + // Find largest blob size and count total channels + for (int i = 0; i < inputs.size(); ++i) + { + totalchannels += inputs[i][1]; + maxheight = std::max(maxheight, inputs[i][2]); + maxwidth = std::max(maxwidth, inputs[i][3]); + CV_Assert(inputs[i][0] == batch); + } + outShape.push_back(totalchannels); + + int out_h = divisor ? static_cast(ceil(maxheight / divisor) * divisor) : top_height; + int out_w = divisor ? static_cast(ceil(maxwidth / divisor) * divisor) : top_width; + + // Layer can specify custom top size which is larger than default + if (out_h <= maxheight || out_w <= maxwidth) + { + out_h = maxheight; + out_w = maxwidth; + } + + outShape.push_back(out_h); + outShape.push_back(out_w); + } + + outputs.assign(1, outShape); + return false; + } + + virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE + { + LayerParams resizeParams; + resizeParams.set("interpolation", "bilinear"); + resizeParams.set("align_corners", true); + resize = ResizeLayer::create(resizeParams); + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + const int out_h = outputs[0].size[2]; + const int out_w = outputs[0].size[3]; + float* out_data = outputs[0].ptr(); + std::vector sizes(&outputs[0].size[0], &outputs[0].size[0] + outputs[0].size.dims()); + for (int i = 0; i < inputs.size() - have_reference; i++) + { + sizes[1] = inputs[i].size[1]; + Mat outSlice(sizes, CV_32F, out_data); + + if (out_h == inputs[i].size[2] && out_w == inputs[i].size[3]) + { + inputs[i].copyTo(outSlice); + } + else + { + std::vector inp_slices, out_slices; + inp_slices.push_back(inputs[i]); + out_slices.push_back(outSlice); + + resize->finalize(inp_slices, out_slices); + resize->forward(inp_slices, out_slices, internals_arr); + } + out_data += outSlice.total(1); + } + } + +private: + int top_height; + int top_width; + int divisor; + bool have_reference; + Ptr resize; +}; + +Ptr AccumLayer::create(const LayerParams& params) +{ + return Ptr(new AccumLayerImpl(params)); +} + +}} // namespace cv::dnn diff --git a/modules/dnn/src/layers/correlation_layer.cpp b/modules/dnn/src/layers/correlation_layer.cpp new file mode 100644 index 0000000..cfb3b8e --- /dev/null +++ b/modules/dnn/src/layers/correlation_layer.cpp @@ -0,0 +1,207 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Copyright (C) 2020, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. + +#include "../precomp.hpp" +#include "layers_common.hpp" + + +namespace cv { namespace dnn { + +class CorrelationLayerImpl CV_FINAL : public CorrelationLayer +{ +public: + CorrelationLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + pad = params.get("pad", 0); + CV_Assert_N(params.has("kernel_size"), params.has("max_displacement")); + max_displacement = params.get("max_displacement"); + kernel = params.get("kernel_size"); + if (kernel % 2 == 0) + CV_Error(Error::StsNotImplemented, "Odd kernel size required."); + + stride_1 = params.get("stride_1", 1); + stride_2 = params.get("stride_2", 1); + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + CV_Assert_N(inputs.size() == 2, inputs[0].size() == 4, inputs[1].size() == 4); + + int padded_height = inputs[0][2] + 2 * pad; + int padded_width = inputs[0][3] + 2 * pad; + + int kernel_radius = (kernel - 1) / 2; + int border_size = max_displacement + kernel_radius; + + int neighborhood_grid_radius = max_displacement / stride_2; + int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + + std::vector outShape; + + int num = inputs[0][0]; + outShape.push_back(num); + + int out_c = neighborhood_grid_width * neighborhood_grid_width; + outShape.push_back(out_c); + + int out_h = ceil(static_cast(padded_height - border_size * 2) / stride_1); + int out_w = ceil(static_cast(padded_width - border_size * 2) / stride_1); + CV_Assert_N(out_h >= 1, out_w >= 1); + + outShape.push_back(out_h); + outShape.push_back(out_w); + outputs.assign(1, outShape); + return false; + } + + virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays) CV_OVERRIDE + { + std::vector inputs; + inputs_arr.getMatVector(inputs); + + int padded_height = inputs[0].size[2] + 2 * pad; + int padded_width = inputs[0].size[3] + 2 * pad; + + int size[] = {inputs[0].size[0], padded_height, padded_width, inputs[0].size[1]}; + rbot0 = Mat(4, &size[0], CV_32F, float(0)); + rbot1 = Mat(4, &size[0], CV_32F, float(0)); + } + + void blobRearrangeKernel2(const Mat& input, Mat& output) + { + const int num = input.size[0]; + const int channels = input.size[1]; + const int height = input.size[2]; + const int width = input.size[3]; + const int area = height * width; + const int pad_area = (width + 2 * pad) * (height + 2 * pad); + + const float* in = input.ptr(); + float* out = output.ptr(); + for (int n = 0; n < num; n++) + { + for (int ch = 0; ch < channels; ch++) + { + for (int xy = 0; xy < area; xy++) + { + float value = in[(n * channels + ch) * area + xy]; + int xpad = (xy % width + pad); + int ypad = (xy / width + pad); + int xypad = ypad * (width + 2 * pad) + xpad; + out[(n * pad_area + xypad) * channels + ch] = value; + } + } + } + } + + void correlationKernelSubtraction(const Mat& input0, const Mat& input1, Mat& output, int item) + { + const int inp_h = input0.size[1]; + const int inp_w = input0.size[2]; + const int inp_c = input0.size[3]; + + const int out_c = output.size[1]; + const int out_h = output.size[2]; + const int out_w = output.size[3]; + + int topcount = output.total(1); + int neighborhood_grid_radius = max_displacement / stride_2; + int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; + + const float* inp0_data = input0.ptr(); + const float* inp1_data = input1.ptr(); + float* out_data = output.ptr(); + int sumelems = kernel * kernel * inp_c; + std::vector patch_data(sumelems, 0); + for (int y = 0; y < out_h; y++) + { + for (int x = 0; x < out_w; x++) + { + int x1 = x * stride_1 + max_displacement; + int y1 = y * stride_1 + max_displacement; + + for (int j = 0; j < kernel; j++) + { + for (int i = 0; i < kernel; i++) + { + int ji_off = ((j * kernel) + i) * inp_c; + for (int ch = 0; ch < inp_c; ch++) + { + int idx1 = ((item * inp_h + y1 + j) * inp_w + x1 + i) * inp_c + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = inp0_data[idx1]; + } + } + } + + for (int out_ch = 0; out_ch < out_c; out_ch++) + { + float sum = 0; + int s2o = (out_ch % neighborhood_grid_width - neighborhood_grid_radius) * stride_2; + int s2p = (out_ch / neighborhood_grid_width - neighborhood_grid_radius) * stride_2; + + int x2 = x1 + s2o; + int y2 = y1 + s2p; + for (int j = 0; j < kernel; j++) + { + for (int i = 0; i < kernel; i++) + { + int ji_off = ((j * kernel) + i) * inp_c; + for (int ch = 0; ch < inp_c; ch++) + { + int idxPatchData = ji_off + ch; + int idx2 = ((item * inp_h + y2 + j) * inp_w + x2 + i) * inp_c + ch; + sum += patch_data[idxPatchData] * inp1_data[idx2]; + } + } + } + int index = ((out_ch * out_h + y) * out_w) + x; + out_data[index + item * topcount] = static_cast(sum) / sumelems; + } + } + } + } + + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + std::vector inputs, outputs, internals; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + internals_arr.getMatVector(internals); + + blobRearrangeKernel2(inputs[0], rbot0); + blobRearrangeKernel2(inputs[1], rbot1); + for (int i = 0; i < inputs[0].size[0]; i++) + { + correlationKernelSubtraction(rbot0, rbot1, outputs[0], i); + } + } + +private: + int pad; + int kernel; + int max_displacement; + int stride_1; + int stride_2; + Mat rbot0; + Mat rbot1; +}; + +Ptr CorrelationLayer::create(const LayerParams& params) +{ + return Ptr(new CorrelationLayerImpl(params)); +} + +}} // namespace cv::dnn diff --git a/modules/dnn/src/layers/flow_warp_layer.cpp b/modules/dnn/src/layers/flow_warp_layer.cpp new file mode 100644 index 0000000..5d0f9f4 --- /dev/null +++ b/modules/dnn/src/layers/flow_warp_layer.cpp @@ -0,0 +1,117 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Copyright (C) 2020, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. + +#include "../precomp.hpp" +#include "layers_common.hpp" + + +namespace cv { namespace dnn { + +class FlowWarpLayerImpl CV_FINAL : public FlowWarpLayer +{ +public: + FlowWarpLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + String fill_string = params.get("FillParameter", "ZERO").toLowerCase(); + if (fill_string != "zero") + CV_Error(Error::StsNotImplemented, "Only zero filling supported."); + fill_value = 0; + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + CV_Assert(inputs.size() == 2); + CV_Assert_N(inputs[0][0] == inputs[1][0], inputs[1][1] == 2, + inputs[0][2] == inputs[1][2], inputs[0][3] == inputs[1][3]); + + outputs.assign(1, inputs[0]); + return false; + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + const int out_n = outputs[0].size[0]; + const int out_c = outputs[0].size[1]; + const int out_h = outputs[0].size[2]; + const int out_w = outputs[0].size[3]; + + const int area = out_w * out_h; + const int total = area * out_c; + + const float* image_data = inputs[0].ptr(); + const float* flow_data = inputs[1].ptr(); + float* out_data = outputs[0].ptr(); + + for (int n = 0; n < out_n; n++) + { + int off = total * n; + for (int x = 0; x < out_w; x++) + { + for (int y = 0; y < out_h; y++) + { + int idx = 2 * area * n + y * out_w + x; + float fx = flow_data[idx]; + float fy = flow_data[idx + area]; + + float x2 = x + fx; + float y2 = y + fy; + + if (x2 >= 0 && y2 >= 0 && x2 < out_w && y2 < out_h) + { + int ix2_L = x2; + float alpha = x2 - ix2_L; + + int iy2_T = y2; + float beta = y2 - iy2_T; + + int ix2_R = std::min(ix2_L + 1, out_w - 1); + int iy2_B = std::min(iy2_T + 1, out_h - 1); + + for (int c = 0; c < out_c; c++) + { + float TL = image_data[off + c * area + iy2_T * out_w + ix2_L]; + float TR = image_data[off + c * area + iy2_T * out_w + ix2_R]; + float BL = image_data[off + c * area + iy2_B * out_w + ix2_L]; + float BR = image_data[off + c * area + iy2_B * out_w + ix2_R]; + + out_data[off + c * area + y * out_w + x] = (1 - alpha) * (1 - beta) * TL + + (1 - alpha) * beta * BL + + alpha * (1 - beta) * TR + + alpha * beta * BR; + } + } + else + { + for (int c = 0; c < out_c; c++) + out_data[off + c * area + y * out_w + x] = fill_value; + } + } + } + } + } + +private: + float fill_value; +}; + +Ptr FlowWarpLayer::create(const LayerParams& params) +{ + return Ptr(new FlowWarpLayerImpl(params)); +} + +}} // namespace cv::dnn diff --git a/modules/dnn/src/layers/resize_layer.cpp b/modules/dnn/src/layers/resize_layer.cpp index 09e68ee..3679c9e 100644 --- a/modules/dnn/src/layers/resize_layer.cpp +++ b/modules/dnn/src/layers/resize_layer.cpp @@ -45,10 +45,15 @@ public: std::vector &outputs, std::vector &internals) const CV_OVERRIDE { - CV_Assert_N(inputs.size() == 1, inputs[0].size() == 4); + CV_Assert_N(inputs.size() == 1 || inputs.size() == 2, inputs[0].size() == 4); outputs.resize(1, inputs[0]); - outputs[0][2] = zoomFactorHeight > 0 ? (outputs[0][2] * zoomFactorHeight) : outHeight; - outputs[0][3] = zoomFactorWidth > 0 ? (outputs[0][3] * zoomFactorWidth) : outWidth; + if (inputs.size() == 1) { + outputs[0][2] = zoomFactorHeight > 0 ? (outputs[0][2] * zoomFactorHeight) : outHeight; + outputs[0][3] = zoomFactorWidth > 0 ? (outputs[0][3] * zoomFactorWidth) : outWidth; + } else { + outputs[0][2] = inputs[1][2]; + outputs[0][3] = inputs[1][3]; + } // We can work in-place (do nothing) if input shape == output shape. return (outputs[0][2] == inputs[0][2]) && (outputs[0][3] == inputs[0][3]); } diff --git a/modules/dnn/src/layers/scale_layer.cpp b/modules/dnn/src/layers/scale_layer.cpp index a53618f..27aefd9 100644 --- a/modules/dnn/src/layers/scale_layer.cpp +++ b/modules/dnn/src/layers/scale_layer.cpp @@ -307,5 +307,118 @@ Ptr ShiftLayer::create(const LayerParams& params) return Ptr(new ScaleLayerImpl(scaleParams)); } +class DataAugmentationLayerImpl CV_FINAL : public DataAugmentationLayer +{ +public: + DataAugmentationLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + recompute_mean = params.get("recompute_mean", 1); + CV_CheckGT(recompute_mean, 0, ""); + mean_per_pixel = params.get("mean_per_pixel", false); + } + + bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + CV_Assert_N(inputs.size() == 1, blobs.size() == 3); + CV_Assert_N(blobs[0].total() == 1, blobs[1].total() == total(inputs[0], 1), + blobs[2].total() == inputs[0][1]); + + outputs.assign(1, inputs[0]); + return true; + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + CV_Assert_N(outputs.size() == 1, blobs.size() == 3, inputs.size() == 1); + int num_iter = 0; + + float* inpData = inputs[0].ptr(); + float* outData = outputs[0].ptr(); + + Mat data_mean_cpu = blobs[1].clone(); + Mat data_mean_per_channel_cpu = blobs[2].clone(); + + const int numWeights = data_mean_cpu.total(); + CV_Assert(numWeights != 0); + + ++num_iter; + if (num_iter <= recompute_mean) + { + data_mean_cpu *= (num_iter - 1); + const int batch = inputs[0].size[0]; + float alpha = 1.0 / batch; + + for (int i = 0; i < batch; ++i) + { + Mat inpSlice(1, numWeights, CV_32F, inpData); + inpSlice = alpha * inpSlice; + + add(data_mean_cpu.reshape(1, 1), inpSlice, data_mean_cpu.reshape(1, 1)); + inpData += numWeights; + } + data_mean_cpu *= (1.0 / num_iter); + + int newsize[] = {blobs[1].size[1], (int)blobs[1].total(2)}; + reduce(data_mean_cpu.reshape(1, 2, &newsize[0]), data_mean_per_channel_cpu, 1, REDUCE_SUM, CV_32F); + + int area = blobs[1].total(2); + data_mean_per_channel_cpu *= (1.0 / area); + } + + MatShape inpShape = shape(inputs[0]); + + inpData = inputs[0].ptr(); + if (mean_per_pixel) + { + int numSlices = inputs[0].size[0]; + for (int i = 0; i < numSlices; ++i) + { + Mat inpSlice(1, numWeights, CV_32F, inpData); + Mat outSlice(1, numWeights, CV_32F, outData); + + add(inpSlice, (-1) * data_mean_cpu, outSlice); + inpData += numWeights; + outData += numWeights; + } + } + else + { + int numSlices = inpShape[1]; + int count = numWeights / numSlices; + + for (int i = 0; i < numSlices; ++i) + { + Mat inpSlice(1, count, CV_32F, inpData); + Mat outSlice(1, count, CV_32F, outData); + float coeff = data_mean_per_channel_cpu.reshape(1, 1).at(0, i); + outSlice = inpSlice - coeff; + + inpData += count; + outData += count; + } + } + } + +private: + int recompute_mean; + bool mean_per_pixel; +}; + +Ptr DataAugmentationLayer::create(const LayerParams& params) +{ + return Ptr(new DataAugmentationLayerImpl(params)); +} + } // namespace dnn } // namespace cv diff --git a/modules/dnn/test/test_layers.cpp b/modules/dnn/test/test_layers.cpp index c31b9f3..88f44d3 100644 --- a/modules/dnn/test/test_layers.cpp +++ b/modules/dnn/test/test_layers.cpp @@ -97,29 +97,68 @@ class Test_Caffe_layers : public DNNTestLayer { public: void testLayerUsingCaffeModels(const String& basename, bool useCaffeModel = false, - bool useCommonInputBlob = true, double l1 = 0.0, - double lInf = 0.0) + bool useCommonInputBlob = true, double l1 = 0.0, double lInf = 0.0, + int numInps = 1, int numOuts = 1) { + CV_Assert_N(numInps >= 1, numInps <= 10, numOuts >= 1, numOuts <= 10); String prototxt = _tf(basename + ".prototxt"); String caffemodel = _tf(basename + ".caffemodel"); - String inpfile = (useCommonInputBlob) ? _tf("blob.npy") : _tf(basename + ".input.npy"); - String outfile = _tf(basename + ".npy"); + std::vector inps, refs, outs; - Mat inp = blobFromNPY(inpfile); - Mat ref = blobFromNPY(outfile); - checkBackend(&inp, &ref); + if (numInps > 1) + { + for (int i = 0; i < numInps; i++) + { + String inpfile = _tf(basename + ".input_" + (i + '0') + ".npy"); + inps.push_back(blobFromNPY(inpfile)); + } + } + else + { + String inpfile = (useCommonInputBlob) ? _tf("blob.npy") : _tf(basename + ".input.npy"); + inps.push_back(blobFromNPY(inpfile)); + } + + if (numOuts > 1) + { + for (int i = 0; i < numOuts; i++) + { + String outfile = _tf(basename + "_" + (i + '0') + ".npy"); + refs.push_back(blobFromNPY(outfile)); + } + } + else + { + String outfile = _tf(basename + ".npy"); + refs.push_back(blobFromNPY(outfile)); + } Net net = readNetFromCaffe(prototxt, (useCaffeModel) ? caffemodel : String()); ASSERT_FALSE(net.empty()); + checkBackend(&inps[0], &refs[0]); net.setPreferableBackend(backend); net.setPreferableTarget(target); - net.setInput(inp, "input"); - Mat out = net.forward("output"); + String inp_name = "input"; + if (numInps > 1) + { + for (int i = 0; i < numInps; i++) + { + net.setInput(inps[i], inp_name + "_" + (i + '0')); + } + } + else + { + net.setInput(inps.back(), inp_name); + } - normAssert(ref, out, "", l1 ? l1 : default_l1, lInf ? lInf : default_lInf); + net.forward(outs); + for (int i = 0; i < refs.size(); i++) + { + normAssert(refs[i], outs[i], "", l1 ? l1 : default_l1, lInf ? lInf : default_lInf); + } } }; @@ -568,6 +607,58 @@ TEST_F(Layer_RNN_Test, get_set_test) EXPECT_EQ(shape(outputs[1]), shape(nT, nS, nH)); } +TEST_P(Test_Caffe_layers, Accum) +{ + if (backend == DNN_BACKEND_OPENCV && target != DNN_TARGET_CPU) + applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL, CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + + testLayerUsingCaffeModels("accum", false, false, 0.0, 0.0, 2); + testLayerUsingCaffeModels("accum_ref", false, false, 0.0, 0.0, 2); +} + +TEST_P(Test_Caffe_layers, FlowWarp) +{ + if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16) + applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + + testLayerUsingCaffeModels("flow_warp", false, false, 0.0, 0.0, 2); +} + +TEST_P(Test_Caffe_layers, ChannelNorm) +{ + if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16) + applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + testLayerUsingCaffeModels("channel_norm", false, false); +} + +TEST_P(Test_Caffe_layers, DataAugmentation) +{ + if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16) + applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + testLayerUsingCaffeModels("data_augmentation", true, false); +} + +TEST_P(Test_Caffe_layers, Resample) +{ + if (backend != DNN_BACKEND_OPENCV) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH); + testLayerUsingCaffeModels("nearest_2inps", false, false, 0.0, 0.0, 2); + testLayerUsingCaffeModels("nearest", false, false); +} + +TEST_P(Test_Caffe_layers, Correlation) +{ + if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER, + CV_TEST_TAG_DNN_SKIP_OPENCL, CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + testLayerUsingCaffeModels("correlation", false, false, 0.0, 0.0, 2); +} + +TEST_P(Test_Caffe_layers, Convolution2Inputs) +{ + testLayerUsingCaffeModels("conv_2_inps", true, false, 0.0, 0.0, 2); +} + TEST_P(Test_Caffe_layers, ROIPooling_Accuracy) { Net net = readNetFromCaffe(_tf("net_roi_pooling.prototxt")); diff --git a/samples/dnn/optical_flow.py b/samples/dnn/optical_flow.py new file mode 100644 index 0000000..5d0d831 --- /dev/null +++ b/samples/dnn/optical_flow.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +''' +This sample using FlowNet v2 model to calculate optical flow. +Original paper: https://arxiv.org/abs/1612.01925. +Original repo: https://github.com/lmb-freiburg/flownet2. + +Download the converted .caffemodel model from https://drive.google.com/open?id=16qvE9VNmU39NttpZwZs81Ga8VYQJDaWZ +and .prototxt from https://drive.google.com/open?id=19bo6SWU2p8ZKvjXqMKiCPdK8mghwDy9b. +Otherwise download original model from https://lmb.informatik.uni-freiburg.de/resources/binaries/flownet2/flownet2-models.tar.gz, +convert .h5 model to .caffemodel and modify original .prototxt using .prototxt from link above. +''' + +import argparse +import os.path +import numpy as np +import cv2 as cv + + +class OpticalFlow(object): + def __init__(self, proto, model, height, width): + self.net = cv.dnn.readNet(proto, model) + self.net.setPreferableBackend(cv.dnn.DNN_BACKEND_OPENCV) + self.height = height + self.width = width + + def compute_flow(self, first_img, second_img): + inp0 = cv.dnn.blobFromImage(first_img, size=(self.width, self.height)) + inp1 = cv.dnn.blobFromImage(second_img, size=(self.width, self.height)) + self.net.setInput(inp0, "img0") + self.net.setInput(inp1, "img1") + flow = self.net.forward() + output = self.motion_to_color(flow) + return output + + def motion_to_color(self, flow): + arr = np.arange(0, 255, dtype=np.uint8) + colormap = cv.applyColorMap(arr, cv.COLORMAP_HSV) + colormap = colormap.squeeze(1) + + flow = flow.squeeze(0) + fx, fy = flow[0, ...], flow[1, ...] + rad = np.sqrt(fx**2 + fy**2) + maxrad = rad.max() if rad.max() != 0 else 1 + + ncols = arr.size + rad = rad[..., np.newaxis] / maxrad + a = np.arctan2(-fy / maxrad, -fx / maxrad) / np.pi + fk = (a + 1) / 2.0 * (ncols - 1) + k0 = fk.astype(np.int) + k1 = (k0 + 1) % ncols + f = fk[..., np.newaxis] - k0[..., np.newaxis] + + col0 = colormap[k0] / 255.0 + col1 = colormap[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + col = np.where(rad <= 1, 1 - rad * (1 - col), col * 0.75) + output = (255.0 * col).astype(np.uint8) + return output + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Use this script to calculate optical flow using FlowNetv2', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('-input', '-i', required=True, help='Path to input video file. Skip this argument to capture frames from a camera.') + parser.add_argument('--height', default=320, help='Input height') + parser.add_argument('--width', default=448, help='Input width') + parser.add_argument('--proto', '-p', default='FlowNet2_deploy.prototxt', help='Path to prototxt.') + parser.add_argument('--model', '-m', default='FlowNet2_weights.caffemodel', help='Path to caffemodel.') + args, _ = parser.parse_known_args() + + if not os.path.isfile(args.model) or not os.path.isfile(args.proto): + raise OSError("Prototxt or caffemodel not exist") + + winName = 'Calculation optical flow in OpenCV' + cv.namedWindow(winName, cv.WINDOW_NORMAL) + cap = cv.VideoCapture(args.input if args.input else 0) + hasFrame, first_frame = cap.read() + opt_flow = OpticalFlow(args.proto, args.model, args.height, args.width) + while cv.waitKey(1) < 0: + hasFrame, second_frame = cap.read() + if not hasFrame: + break + flow = opt_flow.compute_flow(first_frame, second_frame) + first_frame = second_frame + cv.imshow(winName, flow) -- 2.7.4