From ecc34dc5219bf70cf9ede89cf7bac8f895938da1 Mon Sep 17 00:00:00 2001 From: AlexeyAB Date: Sun, 24 Sep 2017 23:34:08 +0300 Subject: [PATCH] Added DNN Darknet Yolo v2 for object detection --- modules/dnn/include/opencv2/dnn/all_layers.hpp | 12 + modules/dnn/include/opencv2/dnn/dnn.hpp | 8 + modules/dnn/src/darknet/darknet_importer.cpp | 195 ++++++++ modules/dnn/src/darknet/darknet_io.cpp | 624 +++++++++++++++++++++++++ modules/dnn/src/darknet/darknet_io.hpp | 116 +++++ modules/dnn/src/init.cpp | 2 + modules/dnn/src/layers/region_layer.cpp | 331 +++++++++++++ modules/dnn/src/layers/reorg_layer.cpp | 140 ++++++ modules/dnn/test/test_darknet_importer.cpp | 186 ++++++++ modules/dnn/test/test_layers.cpp | 34 +- samples/dnn/yolo_object_detection.cpp | 117 +++++ 11 files changed, 1764 insertions(+), 1 deletion(-) create mode 100644 modules/dnn/src/darknet/darknet_importer.cpp create mode 100644 modules/dnn/src/darknet/darknet_io.cpp create mode 100644 modules/dnn/src/darknet/darknet_io.hpp create mode 100644 modules/dnn/src/layers/region_layer.cpp create mode 100644 modules/dnn/src/layers/reorg_layer.cpp create mode 100644 modules/dnn/test/test_darknet_importer.cpp create mode 100644 samples/dnn/yolo_object_detection.cpp diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index cf47c70..4c08fb6 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -473,6 +473,18 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN static Ptr create(const LayerParams& params); }; + class CV_EXPORTS ReorgLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + + class CV_EXPORTS RegionLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + class CV_EXPORTS DetectionOutputLayer : public Layer { public: diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index bd79669..6c19a1d 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -611,6 +611,14 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN virtual ~Importer(); }; + /** @brief Reads a network model stored in Darknet model files. + * @param cfgFile path to the .cfg file with text description of the network architecture. + * @param darknetModel path to the .weights file with learned network. + * @returns Network object that ready to do forward, throw an exception in failure cases. + * @details This is shortcut consisting from DarknetImporter and Net::populateNet calls. + */ + CV_EXPORTS_W Net readNetFromDarknet(const String &cfgFile, const String &darknetModel = String()); + /** * @deprecated Use @ref readNetFromCaffe instead. * @brief Creates the importer of Caffe framework network. diff --git a/modules/dnn/src/darknet/darknet_importer.cpp b/modules/dnn/src/darknet/darknet_importer.cpp new file mode 100644 index 0000000..18a7bc9 --- /dev/null +++ b/modules/dnn/src/darknet/darknet_importer.cpp @@ -0,0 +1,195 @@ +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// (3-clause BSD License) +// +// Copyright (C) 2017, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the names of the copyright holders nor the names of the contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall copyright holders or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + +#include "../precomp.hpp" + +#include +#include +#include +#include + +#include "darknet_io.hpp" + + +namespace cv { +namespace dnn { +CV__DNN_EXPERIMENTAL_NS_BEGIN + +namespace +{ + +class DarknetImporter : public Importer +{ + darknet::NetParameter net; + +public: + + DarknetImporter() {} + + DarknetImporter(const char *cfgFile, const char *darknetModel) + { + CV_TRACE_FUNCTION(); + + ReadNetParamsFromCfgFileOrDie(cfgFile, &net); + + if (darknetModel && darknetModel[0]) + ReadNetParamsFromBinaryFileOrDie(darknetModel, &net); + } + + struct BlobNote + { + BlobNote(const std::string &_name, int _layerId, int _outNum) : + name(_name), layerId(_layerId), outNum(_outNum) {} + + std::string name; + int layerId, outNum; + }; + + std::vector addedBlobs; + std::map layerCounter; + + void populateNet(Net dstNet) + { + CV_TRACE_FUNCTION(); + + int layersSize = net.layer_size(); + layerCounter.clear(); + addedBlobs.clear(); + addedBlobs.reserve(layersSize + 1); + + //setup input layer names + { + std::vector netInputs(net.input_size()); + for (int inNum = 0; inNum < net.input_size(); inNum++) + { + addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum)); + netInputs[inNum] = net.input(inNum); + } + dstNet.setInputsNames(netInputs); + } + + for (int li = 0; li < layersSize; li++) + { + const darknet::LayerParameter &layer = net.layer(li); + String name = layer.name(); + String type = layer.type(); + LayerParams layerParams = layer.getLayerParams(); + + int repetitions = layerCounter[name]++; + if (repetitions) + name += cv::format("_%d", repetitions); + + int id = dstNet.addLayer(name, type, layerParams); + + // iterate many bottoms layers (for example for: route -1, -4) + for (int inNum = 0; inNum < layer.bottom_size(); inNum++) + addInput(layer.bottom(inNum), id, inNum, dstNet, layer.name()); + + for (int outNum = 0; outNum < layer.top_size(); outNum++) + addOutput(layer, id, outNum); + } + + addedBlobs.clear(); + } + + void addOutput(const darknet::LayerParameter &layer, int layerId, int outNum) + { + const std::string &name = layer.top(outNum); + + bool haveDups = false; + for (int idx = (int)addedBlobs.size() - 1; idx >= 0; idx--) + { + if (addedBlobs[idx].name == name) + { + haveDups = true; + break; + } + } + + if (haveDups) + { + bool isInplace = layer.bottom_size() > outNum && layer.bottom(outNum) == name; + if (!isInplace) + CV_Error(Error::StsBadArg, "Duplicate blobs produced by multiple sources"); + } + + addedBlobs.push_back(BlobNote(name, layerId, outNum)); + } + + void addInput(const std::string &name, int layerId, int inNum, Net &dstNet, std::string nn) + { + int idx; + for (idx = (int)addedBlobs.size() - 1; idx >= 0; idx--) + { + if (addedBlobs[idx].name == name) + break; + } + + if (idx < 0) + { + CV_Error(Error::StsObjectNotFound, "Can't find output blob \"" + name + "\""); + return; + } + + dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum); + } + + ~DarknetImporter() + { + + } + +}; + +} + +Net readNetFromDarknet(const String &cfgFile, const String &darknetModel /*= String()*/) +{ + DarknetImporter darknetImporter(cfgFile.c_str(), darknetModel.c_str()); + Net net; + darknetImporter.populateNet(net); + return net; +} + +CV__DNN_EXPERIMENTAL_NS_END +}} // namespace diff --git a/modules/dnn/src/darknet/darknet_io.cpp b/modules/dnn/src/darknet/darknet_io.cpp new file mode 100644 index 0000000..8f705ed --- /dev/null +++ b/modules/dnn/src/darknet/darknet_io.cpp @@ -0,0 +1,624 @@ +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// (3-clause BSD License) +// +// Copyright (C) 2017, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the names of the copyright holders nor the names of the contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall copyright holders or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + +/*M/////////////////////////////////////////////////////////////////////////////////////// +//MIT License +// +//Copyright (c) 2017 Joseph Redmon +// +//Permission is hereby granted, free of charge, to any person obtaining a copy +//of this software and associated documentation files (the "Software"), to deal +//in the Software without restriction, including without limitation the rights +//to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +//copies of the Software, and to permit persons to whom the Software is +//furnished to do so, subject to the following conditions: +// +//The above copyright notice and this permission notice shall be included in all +//copies or substantial portions of the Software. +// +//THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +//IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +//FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +//AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +//LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +//OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +//SOFTWARE. +// +//M*/ + +#include + +#include +#include +#include + +#include "darknet_io.hpp" + +namespace cv { + namespace dnn { + namespace darknet { + + template + T getParam(const std::map ¶ms, const std::string param_name, T init_val) + { + std::map::const_iterator it = params.find(param_name); + if (it != params.end()) { + std::stringstream ss(it->second); + ss >> init_val; + } + return init_val; + } + + class setLayersParams { + + NetParameter *net; + int layer_id; + std::string last_layer; + std::vector fused_layer_names; + + public: + setLayersParams(NetParameter *_net, std::string _first_layer = "data") : + net(_net), layer_id(0), last_layer(_first_layer) + {} + + void setLayerBlobs(int i, std::vector blobs) + { + cv::dnn::experimental_dnn_v1::LayerParams ¶ms = net->layers[i].layerParams; + params.blobs = blobs; + } + + cv::dnn::experimental_dnn_v1::LayerParams getParamConvolution(int kernel, int pad, + int stride, int filters_num) + { + cv::dnn::experimental_dnn_v1::LayerParams params; + params.name = "Convolution-name"; + params.type = "Convolution"; + + params.set("kernel_size", kernel); + params.set("pad", pad); + params.set("stride", stride); + + params.set("bias_term", false); // true only if(BatchNorm == false) + params.set("num_output", filters_num); + + return params; + } + + + void setConvolution(int kernel, int pad, int stride, + int filters_num, int channels_num, int use_batch_normalize, int use_relu) + { + cv::dnn::experimental_dnn_v1::LayerParams conv_param = + getParamConvolution(kernel, pad, stride, filters_num); + + darknet::LayerParameter lp; + std::string layer_name = cv::format("conv_%d", layer_id); + + // use BIAS in any case + if (!use_batch_normalize) { + conv_param.set("bias_term", true); + } + + lp.layer_name = layer_name; + lp.layer_type = conv_param.type; + lp.layerParams = conv_param; + lp.bottom_indexes.push_back(last_layer); + last_layer = layer_name; + net->layers.push_back(lp); + + if (use_batch_normalize) + { + cv::dnn::experimental_dnn_v1::LayerParams bn_param; + + bn_param.name = "BatchNorm-name"; + bn_param.type = "BatchNorm"; + bn_param.set("has_weight", true); + bn_param.set("has_bias", true); + bn_param.set("eps", 1E-6); // .000001f in Darknet Yolo + + darknet::LayerParameter lp; + std::string layer_name = cv::format("bn_%d", layer_id); + lp.layer_name = layer_name; + lp.layer_type = bn_param.type; + lp.layerParams = bn_param; + lp.bottom_indexes.push_back(last_layer); + last_layer = layer_name; + net->layers.push_back(lp); + } + + if (use_relu) + { + cv::dnn::experimental_dnn_v1::LayerParams activation_param; + activation_param.set("negative_slope", 0.1f); + activation_param.name = "ReLU-name"; + activation_param.type = "ReLU"; + + darknet::LayerParameter lp; + std::string layer_name = cv::format("relu_%d", layer_id); + lp.layer_name = layer_name; + lp.layer_type = activation_param.type; + lp.layerParams = activation_param; + lp.bottom_indexes.push_back(last_layer); + last_layer = layer_name; + net->layers.push_back(lp); + } + + layer_id++; + fused_layer_names.push_back(last_layer); + } + + void setMaxpool(size_t kernel, size_t pad, size_t stride) + { + cv::dnn::experimental_dnn_v1::LayerParams maxpool_param; + maxpool_param.set("pool", "max"); + maxpool_param.set("kernel_size", kernel); + maxpool_param.set("pad", pad); + maxpool_param.set("stride", stride); + maxpool_param.set("pad_mode", "SAME"); + maxpool_param.name = "Pooling-name"; + maxpool_param.type = "Pooling"; + darknet::LayerParameter lp; + + std::string layer_name = cv::format("pool_%d", layer_id); + lp.layer_name = layer_name; + lp.layer_type = maxpool_param.type; + lp.layerParams = maxpool_param; + lp.bottom_indexes.push_back(last_layer); + last_layer = layer_name; + net->layers.push_back(lp); + layer_id++; + fused_layer_names.push_back(last_layer); + } + + void setConcat(int number_of_inputs, int *input_indexes) + { + cv::dnn::experimental_dnn_v1::LayerParams concat_param; + concat_param.name = "Concat-name"; + concat_param.type = "Concat"; + concat_param.set("axis", 1); // channels are in axis = 1 + + darknet::LayerParameter lp; + + std::string layer_name = cv::format("concat_%d", layer_id); + lp.layer_name = layer_name; + lp.layer_type = concat_param.type; + lp.layerParams = concat_param; + for (int i = 0; i < number_of_inputs; ++i) + lp.bottom_indexes.push_back(fused_layer_names.at(input_indexes[i])); + + last_layer = layer_name; + net->layers.push_back(lp); + + layer_id++; + fused_layer_names.push_back(last_layer); + } + + void setIdentity(int bottom_index) + { + cv::dnn::experimental_dnn_v1::LayerParams identity_param; + identity_param.name = "Identity-name"; + identity_param.type = "Identity"; + + darknet::LayerParameter lp; + + std::string layer_name = cv::format("identity_%d", layer_id); + lp.layer_name = layer_name; + lp.layer_type = identity_param.type; + lp.layerParams = identity_param; + lp.bottom_indexes.push_back(fused_layer_names.at(bottom_index)); + + last_layer = layer_name; + net->layers.push_back(lp); + + layer_id++; + fused_layer_names.push_back(last_layer); + } + + void setReorg(int stride) + { + cv::dnn::experimental_dnn_v1::LayerParams reorg_params; + reorg_params.name = "Reorg-name"; + reorg_params.type = "Reorg"; + reorg_params.set("reorg_stride", stride); + + darknet::LayerParameter lp; + std::string layer_name = cv::format("reorg_%d", layer_id); + lp.layer_name = layer_name; + lp.layer_type = reorg_params.type; + lp.layerParams = reorg_params; + lp.bottom_indexes.push_back(last_layer); + last_layer = layer_name; + + net->layers.push_back(lp); + + layer_id++; + fused_layer_names.push_back(last_layer); + } + + void setPermute() + { + cv::dnn::experimental_dnn_v1::LayerParams permute_params; + permute_params.name = "Permute-name"; + permute_params.type = "Permute"; + int permute[] = { 0, 2, 3, 1 }; + cv::dnn::DictValue paramOrder = cv::dnn::DictValue::arrayInt(permute, 4); + + permute_params.set("order", paramOrder); + + darknet::LayerParameter lp; + std::string layer_name = cv::format("premute_%d", layer_id); + lp.layer_name = layer_name; + lp.layer_type = permute_params.type; + lp.layerParams = permute_params; + lp.bottom_indexes.push_back(last_layer); + last_layer = layer_name; + net->layers.push_back(lp); + + layer_id++; + fused_layer_names.push_back(last_layer); + } + + void setRegion(float thresh, int coords, int classes, int anchors, int classfix, int softmax, int softmax_tree, float *biasData) + { + cv::dnn::experimental_dnn_v1::LayerParams region_param; + region_param.name = "Region-name"; + region_param.type = "Region"; + + region_param.set("thresh", thresh); + region_param.set("coords", coords); + region_param.set("classes", classes); + region_param.set("anchors", anchors); + region_param.set("classfix", classfix); + region_param.set("softmax_tree", softmax_tree); + region_param.set("softmax", softmax); + + cv::Mat biasData_mat = cv::Mat(1, anchors * 2, CV_32F, biasData).clone(); + region_param.blobs.push_back(biasData_mat); + + darknet::LayerParameter lp; + std::string layer_name = "detection_out"; + lp.layer_name = layer_name; + lp.layer_type = region_param.type; + lp.layerParams = region_param; + lp.bottom_indexes.push_back(last_layer); + last_layer = layer_name; + net->layers.push_back(lp); + + layer_id++; + fused_layer_names.push_back(last_layer); + } + }; + + std::string escapeString(const std::string &src) + { + std::string dst; + for (size_t i = 0; i < src.size(); ++i) + if (src[i] > ' ' && src[i] <= 'z') + dst += src[i]; + return dst; + } + + template + std::vector getNumbers(const std::string &src) + { + std::vector dst; + std::stringstream ss(src); + + for (std::string str; std::getline(ss, str, ',');) { + std::stringstream line(str); + T val; + line >> val; + dst.push_back(val); + } + return dst; + } + + bool ReadDarknetFromCfgFile(const char *cfgFile, NetParameter *net) + { + std::ifstream ifile; + ifile.open(cfgFile); + if (ifile.is_open()) + { + bool read_net = false; + int layers_counter = -1; + for (std::string line; std::getline(ifile, line);) { + line = escapeString(line); + if (line.empty()) continue; + switch (line[0]) { + case '\0': break; + case '#': break; + case ';': break; + case '[': + if (line == "[net]") { + read_net = true; + } + else { + // read section + read_net = false; + ++layers_counter; + const size_t layer_type_size = line.find("]") - 1; + CV_Assert(layer_type_size < line.size()); + std::string layer_type = line.substr(1, layer_type_size); + net->layers_cfg[layers_counter]["type"] = layer_type; + } + break; + default: + // read entry + const size_t separator_index = line.find('='); + CV_Assert(separator_index < line.size()); + if (separator_index != std::string::npos) { + std::string name = line.substr(0, separator_index); + std::string value = line.substr(separator_index + 1, line.size() - (separator_index + 1)); + name = escapeString(name); + value = escapeString(value); + if (name.empty() || value.empty()) continue; + if (read_net) + net->net_cfg[name] = value; + else + net->layers_cfg[layers_counter][name] = value; + } + } + } + + std::string anchors = net->layers_cfg[net->layers_cfg.size() - 1]["anchors"]; + std::vector vec = getNumbers(anchors); + std::map &net_params = net->net_cfg; + net->width = getParam(net_params, "width", 416); + net->height = getParam(net_params, "height", 416); + net->channels = getParam(net_params, "channels", 3); + CV_Assert(net->width > 0 && net->height > 0 && net->channels > 0); + } + else + return false; + + int current_channels = net->channels; + net->out_channels_vec.resize(net->layers_cfg.size()); + + int layers_counter = -1; + + setLayersParams setParams(net); + + typedef std::map >::iterator it_type; + for (it_type i = net->layers_cfg.begin(); i != net->layers_cfg.end(); ++i) { + ++layers_counter; + std::map &layer_params = i->second; + std::string layer_type = layer_params["type"]; + + if (layer_type == "convolutional") + { + int kernel_size = getParam(layer_params, "size", -1); + int pad = getParam(layer_params, "pad", 0); + int stride = getParam(layer_params, "stride", 1); + int filters = getParam(layer_params, "filters", -1); + std::string activation = getParam(layer_params, "activation", "linear"); + bool batch_normalize = getParam(layer_params, "batch_normalize", 0) == 1; + if(activation != "linear" && activation != "leaky") + CV_Error(cv::Error::StsParseError, "Unsupported activation: " + activation); + int flipped = getParam(layer_params, "flipped", 0); + if (flipped == 1) + CV_Error(cv::Error::StsNotImplemented, "Transpose the convolutional weights is not implemented"); + + // correct the strange value of pad=1 for kernel_size=1 in the Darknet cfg-file + if (kernel_size < 3) pad = 0; + + CV_Assert(kernel_size > 0 && filters > 0); + CV_Assert(current_channels > 0); + + setParams.setConvolution(kernel_size, pad, stride, filters, current_channels, + batch_normalize, activation == "leaky"); + + current_channels = filters; + } + else if (layer_type == "maxpool") + { + int kernel_size = getParam(layer_params, "size", 2); + int stride = getParam(layer_params, "stride", 2); + int pad = getParam(layer_params, "pad", 0); + setParams.setMaxpool(kernel_size, pad, stride); + } + else if (layer_type == "route") + { + std::string bottom_layers = getParam(layer_params, "layers", ""); + CV_Assert(!bottom_layers.empty()); + std::vector layers_vec = getNumbers(bottom_layers); + + current_channels = 0; + for (size_t k = 0; k < layers_vec.size(); ++k) { + layers_vec[k] += layers_counter; + current_channels += net->out_channels_vec[layers_vec[k]]; + } + + if (layers_vec.size() == 1) + setParams.setIdentity(layers_vec.at(0)); + else + setParams.setConcat(layers_vec.size(), layers_vec.data()); + } + else if (layer_type == "reorg") + { + int stride = getParam(layer_params, "stride", 2); + current_channels = current_channels * (stride*stride); + + setParams.setReorg(stride); + } + else if (layer_type == "region") + { + float thresh = 0.001; // in the original Darknet is equal to the detection threshold set by the user + int coords = getParam(layer_params, "coords", 4); + int classes = getParam(layer_params, "classes", -1); + int num_of_anchors = getParam(layer_params, "num", -1); + int classfix = getParam(layer_params, "classfix", 0); + bool softmax = (getParam(layer_params, "softmax", 0) == 1); + bool softmax_tree = (getParam(layer_params, "tree", "").size() > 0); + + std::string anchors_values = getParam(layer_params, "anchors", std::string()); + CV_Assert(!anchors_values.empty()); + std::vector anchors_vec = getNumbers(anchors_values); + + CV_Assert(classes > 0 && num_of_anchors > 0 && (num_of_anchors * 2) == anchors_vec.size()); + + setParams.setPermute(); + setParams.setRegion(thresh, coords, classes, num_of_anchors, classfix, softmax, softmax_tree, anchors_vec.data()); + } + else { + CV_Error(cv::Error::StsParseError, "Unknown layer type: " + layer_type); + } + net->out_channels_vec[layers_counter] = current_channels; + } + + return true; + } + + + bool ReadDarknetFromWeightsFile(const char *darknetModel, NetParameter *net) + { + std::ifstream ifile; + ifile.open(darknetModel, std::ios::binary); + CV_Assert(ifile.is_open()); + + int32_t major_ver, minor_ver, revision; + ifile.read(reinterpret_cast(&major_ver), sizeof(int32_t)); + ifile.read(reinterpret_cast(&minor_ver), sizeof(int32_t)); + ifile.read(reinterpret_cast(&revision), sizeof(int32_t)); + + uint64_t seen; + if ((major_ver * 10 + minor_ver) >= 2) { + ifile.read(reinterpret_cast(&seen), sizeof(uint64_t)); + } + else { + int32_t iseen = 0; + ifile.read(reinterpret_cast(&iseen), sizeof(int32_t)); + seen = iseen; + } + bool transpose = (major_ver > 1000) || (minor_ver > 1000); + if(transpose) + CV_Error(cv::Error::StsNotImplemented, "Transpose the weights (except for convolutional) is not implemented"); + + int current_channels = net->channels; + int cv_layers_counter = -1; + int darknet_layers_counter = -1; + + setLayersParams setParams(net); + + typedef std::map >::iterator it_type; + for (it_type i = net->layers_cfg.begin(); i != net->layers_cfg.end(); ++i) { + ++darknet_layers_counter; + ++cv_layers_counter; + std::map &layer_params = i->second; + std::string layer_type = layer_params["type"]; + + if (layer_type == "convolutional") + { + int kernel_size = getParam(layer_params, "size", -1); + int filters = getParam(layer_params, "filters", -1); + std::string activation = getParam(layer_params, "activation", "linear"); + bool use_batch_normalize = getParam(layer_params, "batch_normalize", 0) == 1; + + CV_Assert(kernel_size > 0 && filters > 0); + CV_Assert(current_channels > 0); + + size_t const weights_size = filters * current_channels * kernel_size * kernel_size; + int sizes_weights[] = { filters, current_channels, kernel_size, kernel_size }; + cv::Mat weightsBlob; + weightsBlob.create(4, sizes_weights, CV_32F); + CV_Assert(weightsBlob.isContinuous()); + + cv::Mat meanData_mat(1, filters, CV_32F); // mean + cv::Mat stdData_mat(1, filters, CV_32F); // variance + cv::Mat weightsData_mat(1, filters, CV_32F);// scale + cv::Mat biasData_mat(1, filters, CV_32F); // bias + + ifile.read(reinterpret_cast(biasData_mat.ptr()), sizeof(float)*filters); + if (use_batch_normalize) { + ifile.read(reinterpret_cast(weightsData_mat.ptr()), sizeof(float)*filters); + ifile.read(reinterpret_cast(meanData_mat.ptr()), sizeof(float)*filters); + ifile.read(reinterpret_cast(stdData_mat.ptr()), sizeof(float)*filters); + } + ifile.read(reinterpret_cast(weightsBlob.ptr()), sizeof(float)*weights_size); + + // set convolutional weights + std::vector conv_blobs; + conv_blobs.push_back(weightsBlob); + if (!use_batch_normalize) { + // use BIAS in any case + conv_blobs.push_back(biasData_mat); + } + setParams.setLayerBlobs(cv_layers_counter, conv_blobs); + + // set batch normalize (mean, variance, scale, bias) + if (use_batch_normalize) { + ++cv_layers_counter; + std::vector bn_blobs; + bn_blobs.push_back(meanData_mat); + bn_blobs.push_back(stdData_mat); + bn_blobs.push_back(weightsData_mat); + bn_blobs.push_back(biasData_mat); + setParams.setLayerBlobs(cv_layers_counter, bn_blobs); + } + + if(activation == "leaky") + ++cv_layers_counter; + } + current_channels = net->out_channels_vec[darknet_layers_counter]; + } + return true; + } + + } + + + void ReadNetParamsFromCfgFileOrDie(const char *cfgFile, darknet::NetParameter *net) + { + if (!darknet::ReadDarknetFromCfgFile(cfgFile, net)) { + CV_Error(cv::Error::StsParseError, "Failed to parse NetParameter file: " + std::string(cfgFile)); + } + } + + void ReadNetParamsFromBinaryFileOrDie(const char *darknetModel, darknet::NetParameter *net) + { + if (!darknet::ReadDarknetFromWeightsFile(darknetModel, net)) { + CV_Error(cv::Error::StsParseError, "Failed to parse NetParameter file: " + std::string(darknetModel)); + } + } + + } +} diff --git a/modules/dnn/src/darknet/darknet_io.hpp b/modules/dnn/src/darknet/darknet_io.hpp new file mode 100644 index 0000000..f1f19c9 --- /dev/null +++ b/modules/dnn/src/darknet/darknet_io.hpp @@ -0,0 +1,116 @@ +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// (3-clause BSD License) +// +// Copyright (C) 2017, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the names of the copyright holders nor the names of the contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall copyright holders or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + +/*M/////////////////////////////////////////////////////////////////////////////////////// +//MIT License +// +//Copyright (c) 2017 Joseph Redmon +// +//Permission is hereby granted, free of charge, to any person obtaining a copy +//of this software and associated documentation files (the "Software"), to deal +//in the Software without restriction, including without limitation the rights +//to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +//copies of the Software, and to permit persons to whom the Software is +//furnished to do so, subject to the following conditions: +// +//The above copyright notice and this permission notice shall be included in all +//copies or substantial portions of the Software. +// +//THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +//IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +//FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +//AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +//LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +//OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +//SOFTWARE. +// +//M*/ + +#ifndef __OPENCV_DNN_DARKNET_IO_HPP__ +#define __OPENCV_DNN_DARKNET_IO_HPP__ + +#include + +namespace cv { + namespace dnn { + namespace darknet { + + class LayerParameter { + std::string layer_name, layer_type; + std::vector bottom_indexes; + cv::dnn::experimental_dnn_v1::LayerParams layerParams; + public: + friend class setLayersParams; + cv::dnn::experimental_dnn_v1::LayerParams getLayerParams() const { return layerParams; } + std::string name() const { return layer_name; } + std::string type() const { return layer_type; } + int bottom_size() const { return bottom_indexes.size(); } + std::string bottom(const int index) const { return bottom_indexes.at(index); } + int top_size() const { return 1; } + std::string top(const int index) const { return layer_name; } + }; + + class NetParameter { + public: + int width, height, channels; + std::vector layers; + std::vector out_channels_vec; + + std::map > layers_cfg; + std::map net_cfg; + + int layer_size() const { return layers.size(); } + + int input_size() const { return 1; } + std::string input(const int index) const { return "data"; } + LayerParameter layer(const int index) const { return layers.at(index); } + }; + } + + // Read parameters from a file into a NetParameter message. + void ReadNetParamsFromCfgFileOrDie(const char *cfgFile, darknet::NetParameter *net); + void ReadNetParamsFromBinaryFileOrDie(const char *darknetModel, darknet::NetParameter *net); + + } +} +#endif diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index fe1036c..06f4502 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -111,6 +111,8 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer); CV_DNN_REGISTER_LAYER_CLASS(Permute, PermuteLayer); CV_DNN_REGISTER_LAYER_CLASS(PriorBox, PriorBoxLayer); + CV_DNN_REGISTER_LAYER_CLASS(Reorg, ReorgLayer); + CV_DNN_REGISTER_LAYER_CLASS(Region, RegionLayer); CV_DNN_REGISTER_LAYER_CLASS(DetectionOutput, DetectionOutputLayer); CV_DNN_REGISTER_LAYER_CLASS(NormalizeBBox, NormalizeBBoxLayer); CV_DNN_REGISTER_LAYER_CLASS(Normalize, NormalizeBBoxLayer); diff --git a/modules/dnn/src/layers/region_layer.cpp b/modules/dnn/src/layers/region_layer.cpp new file mode 100644 index 0000000..1e0f6b0 --- /dev/null +++ b/modules/dnn/src/layers/region_layer.cpp @@ -0,0 +1,331 @@ +/*M /////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// +// Copyright (C) 2013, OpenCV Foundation, all rights reserved. +// Copyright (C) 2017, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistribution's of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistribution's in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * The name of the copyright holders may not be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall the Intel Corporation or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + +#include "../precomp.hpp" +#include +#include +#include + +namespace cv +{ +namespace dnn +{ + +class RegionLayerImpl : public RegionLayer +{ +public: + int coords, classes, anchors, classfix; + float thresh, nmsThreshold; + bool useSoftmaxTree, useSoftmax; + + RegionLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + CV_Assert(blobs.size() == 1); + + thresh = params.get("thresh", 0.2); + coords = params.get("coords", 4); + classes = params.get("classes", 0); + anchors = params.get("anchors", 5); + classfix = params.get("classfix", 0); + useSoftmaxTree = params.get("softmax_tree", false); + useSoftmax = params.get("softmax", false); + nmsThreshold = params.get("nms_threshold", 0.4); + + CV_Assert(nmsThreshold >= 0.); + CV_Assert(coords == 4); + CV_Assert(classes >= 1); + CV_Assert(anchors >= 1); + CV_Assert(useSoftmaxTree || useSoftmax); + } + + bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const + { + CV_Assert(inputs.size() > 0); + CV_Assert(inputs[0][3] == (1 + coords + classes)*anchors); + outputs = std::vector(inputs.size(), shape(inputs[0][1] * inputs[0][2] * anchors, inputs[0][3] / anchors)); + return false; + } + + virtual bool supportBackend(int backendId) + { + return backendId == DNN_BACKEND_DEFAULT; + } + + float logistic_activate(float x) { return 1.F / (1.F + exp(-x)); } + + void softmax_activate(const float* input, const int n, const float temp, float* output) + { + int i; + float sum = 0; + float largest = -FLT_MAX; + for (i = 0; i < n; ++i) { + if (input[i] > largest) largest = input[i]; + } + for (i = 0; i < n; ++i) { + float e = exp((input[i] - largest) / temp); + sum += e; + output[i] = e; + } + for (i = 0; i < n; ++i) { + output[i] /= sum; + } + } + + void forward(std::vector &inputs, std::vector &outputs, std::vector &internals) + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + CV_Assert(inputs.size() >= 1); + int const cell_size = classes + coords + 1; + + const float* biasData = blobs[0].ptr(); + + for (size_t ii = 0; ii < outputs.size(); ii++) + { + Mat &inpBlob = *inputs[ii]; + Mat &outBlob = outputs[ii]; + + int rows = inpBlob.size[1]; + int cols = inpBlob.size[2]; + + const float *srcData = inpBlob.ptr(); + float *dstData = outBlob.ptr(); + + // logistic activation for t0, for each grid cell (X x Y x Anchor-index) + for (int i = 0; i < rows*cols*anchors; ++i) { + int index = cell_size*i; + float x = srcData[index + 4]; + dstData[index + 4] = logistic_activate(x); // logistic activation + } + + if (useSoftmaxTree) { // Yolo 9000 + CV_Error(cv::Error::StsNotImplemented, "Yolo9000 is not implemented"); + } + else if (useSoftmax) { // Yolo v2 + // softmax activation for Probability, for each grid cell (X x Y x Anchor-index) + for (int i = 0; i < rows*cols*anchors; ++i) { + int index = cell_size*i; + softmax_activate(srcData + index + 5, classes, 1, dstData + index + 5); + } + + for (int x = 0; x < cols; ++x) + for(int y = 0; y < rows; ++y) + for (int a = 0; a < anchors; ++a) { + int index = (y*cols + x)*anchors + a; // index for each grid-cell & anchor + int p_index = index * cell_size + 4; + float scale = dstData[p_index]; + if (classfix == -1 && scale < .5) scale = 0; // if(t0 < 0.5) t0 = 0; + int box_index = index * cell_size; + + dstData[box_index + 0] = (x + logistic_activate(srcData[box_index + 0])) / cols; + dstData[box_index + 1] = (y + logistic_activate(srcData[box_index + 1])) / rows; + dstData[box_index + 2] = exp(srcData[box_index + 2]) * biasData[2 * a] / cols; + dstData[box_index + 3] = exp(srcData[box_index + 3]) * biasData[2 * a + 1] / rows; + + int class_index = index * cell_size + 5; + + if (useSoftmaxTree) { + CV_Error(cv::Error::StsNotImplemented, "Yolo9000 is not implemented"); + } + else { + for (int j = 0; j < classes; ++j) { + float prob = scale*dstData[class_index + j]; // prob = IoU(box, object) = t0 * class-probability + dstData[class_index + j] = (prob > thresh) ? prob : 0; // if (IoU < threshold) IoU = 0; + } + } + } + + } + + if (nmsThreshold > 0) { + do_nms_sort(dstData, rows*cols*anchors, nmsThreshold); + //do_nms(dstData, rows*cols*anchors, nmsThreshold); + } + + } + } + + + struct box { + float x, y, w, h; + float *probs; + }; + + float overlap(float x1, float w1, float x2, float w2) + { + float l1 = x1 - w1 / 2; + float l2 = x2 - w2 / 2; + float left = l1 > l2 ? l1 : l2; + float r1 = x1 + w1 / 2; + float r2 = x2 + w2 / 2; + float right = r1 < r2 ? r1 : r2; + return right - left; + } + + float box_intersection(box a, box b) + { + float w = overlap(a.x, a.w, b.x, b.w); + float h = overlap(a.y, a.h, b.y, b.h); + if (w < 0 || h < 0) return 0; + float area = w*h; + return area; + } + + float box_union(box a, box b) + { + float i = box_intersection(a, b); + float u = a.w*a.h + b.w*b.h - i; + return u; + } + + float box_iou(box a, box b) + { + return box_intersection(a, b) / box_union(a, b); + } + + struct sortable_bbox { + int index; + float *probs; + }; + + struct nms_comparator { + int k; + nms_comparator(int _k) : k(_k) {} + bool operator ()(sortable_bbox v1, sortable_bbox v2) { + return v2.probs[k] < v1.probs[k]; + } + }; + + void do_nms_sort(float *detections, int total, float nms_thresh) + { + std::vector boxes(total); + for (int i = 0; i < total; ++i) { + box &b = boxes[i]; + int box_index = i * (classes + coords + 1); + b.x = detections[box_index + 0]; + b.y = detections[box_index + 1]; + b.w = detections[box_index + 2]; + b.h = detections[box_index + 3]; + int class_index = i * (classes + 5) + 5; + b.probs = (detections + class_index); + } + + std::vector s(total); + + for (int i = 0; i < total; ++i) { + s[i].index = i; + int class_index = i * (classes + 5) + 5; + s[i].probs = (detections + class_index); + } + + for (int k = 0; k < classes; ++k) { + std::stable_sort(s.begin(), s.end(), nms_comparator(k)); + for (int i = 0; i < total; ++i) { + if (boxes[s[i].index].probs[k] == 0) continue; + box a = boxes[s[i].index]; + for (int j = i + 1; j < total; ++j) { + box b = boxes[s[j].index]; + if (box_iou(a, b) > nms_thresh) { + boxes[s[j].index].probs[k] = 0; + } + } + } + } + } + + void do_nms(float *detections, int total, float nms_thresh) + { + std::vector boxes(total); + for (int i = 0; i < total; ++i) { + box &b = boxes[i]; + int box_index = i * (classes + coords + 1); + b.x = detections[box_index + 0]; + b.y = detections[box_index + 1]; + b.w = detections[box_index + 2]; + b.h = detections[box_index + 3]; + int class_index = i * (classes + 5) + 5; + b.probs = (detections + class_index); + } + + for (int i = 0; i < total; ++i) { + bool any = false; + for (int k = 0; k < classes; ++k) any = any || (boxes[i].probs[k] > 0); + if (!any) { + continue; + } + for (int j = i + 1; j < total; ++j) { + if (box_iou(boxes[i], boxes[j]) > nms_thresh) { + for (int k = 0; k < classes; ++k) { + if (boxes[i].probs[k] < boxes[j].probs[k]) boxes[i].probs[k] = 0; + else boxes[j].probs[k] = 0; + } + } + } + } + } + + virtual int64 getFLOPS(const std::vector &inputs, + const std::vector &outputs) const + { + (void)outputs; // suppress unused variable warning + + int64 flops = 0; + for(int i = 0; i < inputs.size(); i++) + { + flops += 60*total(inputs[i]); + } + return flops; + } +}; + +Ptr RegionLayer::create(const LayerParams& params) +{ + return Ptr(new RegionLayerImpl(params)); +} + +} // namespace dnn +} // namespace cv diff --git a/modules/dnn/src/layers/reorg_layer.cpp b/modules/dnn/src/layers/reorg_layer.cpp new file mode 100644 index 0000000..720d25e --- /dev/null +++ b/modules/dnn/src/layers/reorg_layer.cpp @@ -0,0 +1,140 @@ +/*M /////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// +// Copyright (C) 2013, OpenCV Foundation, all rights reserved. +// Copyright (C) 2017, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistribution's of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistribution's in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * The name of the copyright holders may not be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall the Intel Corporation or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + +#include "../precomp.hpp" +#include +#include +#include + +namespace cv +{ +namespace dnn +{ + +class ReorgLayerImpl : public ReorgLayer +{ + int reorgStride; +public: + + ReorgLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + + reorgStride = params.get("reorg_stride", 2); + CV_Assert(reorgStride > 0); + } + + bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const + { + CV_Assert(inputs.size() > 0); + outputs = std::vector(inputs.size(), shape( + inputs[0][0], + inputs[0][1] * reorgStride * reorgStride, + inputs[0][2] / reorgStride, + inputs[0][3] / reorgStride)); + + CV_Assert(outputs[0][0] > 0 && outputs[0][1] > 0 && outputs[0][2] > 0 && outputs[0][3] > 0); + CV_Assert(total(outputs[0]) == total(inputs[0])); + + return false; + } + + virtual bool supportBackend(int backendId) + { + return backendId == DNN_BACKEND_DEFAULT; + } + void forward(std::vector &inputs, std::vector &outputs, std::vector &internals) + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + for (size_t i = 0; i < inputs.size(); i++) + { + Mat srcBlob = *inputs[i]; + MatShape inputShape = shape(srcBlob), outShape = shape(outputs[i]); + float *dstData = outputs[0].ptr(); + const float *srcData = srcBlob.ptr(); + + int channels = inputShape[1], height = inputShape[2], width = inputShape[3]; + + int out_c = channels / (reorgStride*reorgStride); + + for (int k = 0; k < channels; ++k) { + for (int j = 0; j < height; ++j) { + for (int i = 0; i < width; ++i) { + int out_index = i + width*(j + height*k); + int c2 = k % out_c; + int offset = k / out_c; + int w2 = i*reorgStride + offset % reorgStride; + int h2 = j*reorgStride + offset / reorgStride; + int in_index = w2 + width*reorgStride*(h2 + height*reorgStride*c2); + dstData[out_index] = srcData[in_index]; + } + } + } + } + } + + virtual int64 getFLOPS(const std::vector &inputs, + const std::vector &outputs) const + { + (void)outputs; // suppress unused variable warning + + int64 flops = 0; + for(int i = 0; i < inputs.size(); i++) + { + flops += 21*total(inputs[i]); + } + return flops; + } +}; + +Ptr ReorgLayer::create(const LayerParams& params) +{ + return Ptr(new ReorgLayerImpl(params)); +} + +} // namespace dnn +} // namespace cv diff --git a/modules/dnn/test/test_darknet_importer.cpp b/modules/dnn/test/test_darknet_importer.cpp new file mode 100644 index 0000000..8f57b17 --- /dev/null +++ b/modules/dnn/test/test_darknet_importer.cpp @@ -0,0 +1,186 @@ +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// (3-clause BSD License) +// +// Copyright (C) 2017, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the names of the copyright holders nor the names of the contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall copyright holders or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + +#include "test_precomp.hpp" +#include +#include + +namespace cvtest +{ + +using namespace cv; +using namespace cv::dnn; + +template +static std::string _tf(TString filename) +{ + return (getOpenCVExtraDir() + "/dnn/") + filename; +} + +TEST(Test_Darknet, read_tiny_yolo_voc) +{ + Net net = readNetFromDarknet(_tf("tiny-yolo-voc.cfg")); + ASSERT_FALSE(net.empty()); +} + +TEST(Test_Darknet, read_yolo_voc) +{ + Net net = readNetFromDarknet(_tf("yolo-voc.cfg")); + ASSERT_FALSE(net.empty()); +} + +TEST(Reproducibility_TinyYoloVoc, Accuracy) +{ + Net net; + { + const string cfg = findDataFile("dnn/tiny-yolo-voc.cfg", false); + const string model = findDataFile("dnn/tiny-yolo-voc.weights", false); + net = readNetFromDarknet(cfg, model); + ASSERT_FALSE(net.empty()); + } + + // dog416.png is dog.jpg that resized to 416x416 in the lossless PNG format + Mat sample = imread(_tf("dog416.png")); + ASSERT_TRUE(!sample.empty()); + + Size inputSize(416, 416); + + if (sample.size() != inputSize) + resize(sample, sample, inputSize); + + net.setInput(blobFromImage(sample, 1 / 255.F), "data"); + Mat out = net.forward("detection_out"); + + Mat detection; + const float confidenceThreshold = 0.24; + + for (int i = 0; i < out.rows; i++) { + const int probability_index = 5; + const int probability_size = out.cols - probability_index; + float *prob_array_ptr = &out.at(i, probability_index); + size_t objectClass = std::max_element(prob_array_ptr, prob_array_ptr + probability_size) - prob_array_ptr; + float confidence = out.at(i, (int)objectClass + probability_index); + + if (confidence > confidenceThreshold) + detection.push_back(out.row(i)); + } + + // obtained by: ./darknet detector test ./cfg/voc.data ./cfg/tiny-yolo-voc.cfg ./tiny-yolo-voc.weights -thresh 0.24 ./dog416.png + // There are 2 objects (6-car, 11-dog) with 25 values for each: + // { relative_center_x, relative_center_y, relative_width, relative_height, unused_t0, probability_for_each_class[20] } + float ref_array[] = { + 0.736762F, 0.239551F, 0.315440F, 0.160779F, 0.761977F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, + 0.000000F, 0.000000F, 0.761967F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, + 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, + + 0.287486F, 0.653731F, 0.315579F, 0.534527F, 0.782737F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, + 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.780595F, + 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F + }; + + const int number_of_objects = 2; + Mat ref(number_of_objects, sizeof(ref_array) / (number_of_objects * sizeof(float)), CV_32FC1, &ref_array); + + normAssert(ref, detection); +} + +TEST(Reproducibility_YoloVoc, Accuracy) +{ + Net net; + { + const string cfg = findDataFile("dnn/yolo-voc.cfg", false); + const string model = findDataFile("dnn/yolo-voc.weights", false); + net = readNetFromDarknet(cfg, model); + ASSERT_FALSE(net.empty()); + } + + // dog416.png is dog.jpg that resized to 416x416 in the lossless PNG format + Mat sample = imread(_tf("dog416.png")); + ASSERT_TRUE(!sample.empty()); + + Size inputSize(416, 416); + + if (sample.size() != inputSize) + resize(sample, sample, inputSize); + + net.setInput(blobFromImage(sample, 1 / 255.F), "data"); + Mat out = net.forward("detection_out"); + + Mat detection; + const float confidenceThreshold = 0.24; + + for (int i = 0; i < out.rows; i++) { + const int probability_index = 5; + const int probability_size = out.cols - probability_index; + float *prob_array_ptr = &out.at(i, probability_index); + size_t objectClass = std::max_element(prob_array_ptr, prob_array_ptr + probability_size) - prob_array_ptr; + float confidence = out.at(i, (int)objectClass + probability_index); + + if (confidence > confidenceThreshold) + detection.push_back(out.row(i)); + } + + // obtained by: ./darknet detector test ./cfg/voc.data ./cfg/yolo-voc.cfg ./yolo-voc.weights -thresh 0.24 ./dog416.png + // There are 3 objects (6-car, 1-bicycle, 11-dog) with 25 values for each: + // { relative_center_x, relative_center_y, relative_width, relative_height, unused_t0, probability_for_each_class[20] } + float ref_array[] = { + 0.740161F, 0.214100F, 0.325575F, 0.173418F, 0.750769F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, + 0.000000F, 0.000000F, 0.750469F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, + 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, + + 0.501618F, 0.504757F, 0.461713F, 0.481310F, 0.783550F, 0.000000F, 0.780879F, 0.000000F, 0.000000F, + 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, + 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, + + 0.279968F, 0.638651F, 0.282737F, 0.600284F, 0.901864F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, + 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.901615F, + 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F, 0.000000F + }; + + const int number_of_objects = 3; + Mat ref(number_of_objects, sizeof(ref_array) / (number_of_objects * sizeof(float)), CV_32FC1, &ref_array); + + normAssert(ref, detection); +} + +} diff --git a/modules/dnn/test/test_layers.cpp b/modules/dnn/test/test_layers.cpp index e6807b3..9c6e61f 100644 --- a/modules/dnn/test/test_layers.cpp +++ b/modules/dnn/test/test_layers.cpp @@ -10,7 +10,7 @@ // License Agreement // For Open Source Computer Vision Library // -// Copyright (C) 2013, OpenCV Foundation, all rights reserved. +// Copyright (C) 2017, Intel Corporation, all rights reserved. // Third party copyrights are property of their respective owners. // // Redistribution and use in source and binary forms, with or without modification, @@ -420,4 +420,36 @@ TEST_F(Layer_RNN_Test, get_set_test) EXPECT_EQ(shape(outputs[1]), shape(nT, nS, nH)); } +void testLayerUsingDarknetModels(String basename, bool useDarknetModel = false, bool useCommonInputBlob = true) +{ + String cfg = _tf(basename + ".cfg"); + String weights = _tf(basename + ".weights"); + + String inpfile = (useCommonInputBlob) ? _tf("blob.npy") : _tf(basename + ".input.npy"); + String outfile = _tf(basename + ".npy"); + + cv::setNumThreads(cv::getNumberOfCPUs()); + + Net net = readNetFromDarknet(cfg, (useDarknetModel) ? weights : String()); + ASSERT_FALSE(net.empty()); + + Mat inp = blobFromNPY(inpfile); + Mat ref = blobFromNPY(outfile); + + net.setInput(inp, "data"); + Mat out = net.forward(); + + normAssert(ref, out); +} + +TEST(Layer_Test_Region, Accuracy) +{ + testLayerUsingDarknetModels("region", false, false); +} + +TEST(Layer_Test_Reorg, Accuracy) +{ + testLayerUsingDarknetModels("reorg", false, false); +} + } diff --git a/samples/dnn/yolo_object_detection.cpp b/samples/dnn/yolo_object_detection.cpp new file mode 100644 index 0000000..0731ad2 --- /dev/null +++ b/samples/dnn/yolo_object_detection.cpp @@ -0,0 +1,117 @@ +#include +#include +#include +#include +using namespace cv; +using namespace cv::dnn; + +#include +#include +#include +#include +using namespace std; + +const size_t network_width = 416; +const size_t network_height = 416; + +const char* about = "This sample uses You only look once (YOLO)-Detector " + "(https://arxiv.org/abs/1612.08242)" + "to detect objects on image\n"; // TODO: link + +const char* params + = "{ help | false | print usage }" + "{ cfg | | model configuration }" + "{ model | | model weights }" + "{ image | | image for detection }" + "{ min_confidence | 0.24 | min confidence }"; + +int main(int argc, char** argv) +{ + cv::CommandLineParser parser(argc, argv, params); + + if (parser.get("help")) + { + std::cout << about << std::endl; + parser.printMessage(); + return 0; + } + + String modelConfiguration = parser.get("cfg"); + String modelBinary = parser.get("model"); + + //! [Initialize network] + dnn::Net net = readNetFromDarknet(modelConfiguration, modelBinary); + //! [Initialize network] + + if (net.empty()) + { + cerr << "Can't load network by using the following files: " << endl; + cerr << "cfg-file: " << modelConfiguration << endl; + cerr << "weights-file: " << modelBinary << endl; + cerr << "Models can be downloaded here:" << endl; + cerr << "https://pjreddie.com/darknet/yolo/" << endl; + exit(-1); + } + + cv::Mat frame = cv::imread(parser.get("image")); + + //! [Resizing without keeping aspect ratio] + cv::Mat resized; + cv::resize(frame, resized, cv::Size(network_width, network_height)); + //! [Resizing without keeping aspect ratio] + + //! [Prepare blob] + Mat inputBlob = blobFromImage(resized, 1 / 255.F); //Convert Mat to batch of images + //! [Prepare blob] + + //! [Set input blob] + net.setInput(inputBlob, "data"); //set the network input + //! [Set input blob] + + //! [Make forward pass] + cv::Mat detectionMat = net.forward("detection_out"); //compute output + //! [Make forward pass] + + + float confidenceThreshold = parser.get("min_confidence"); + for (int i = 0; i < detectionMat.rows; i++) + { + const int probability_index = 5; + const int probability_size = detectionMat.cols - probability_index; + float *prob_array_ptr = &detectionMat.at(i, probability_index); + + size_t objectClass = std::max_element(prob_array_ptr, prob_array_ptr + probability_size) - prob_array_ptr; + float confidence = detectionMat.at(i, (int)objectClass + probability_index); + + if (confidence > confidenceThreshold) + { + float x = detectionMat.at(i, 0); + float y = detectionMat.at(i, 1); + float width = detectionMat.at(i, 2); + float height = detectionMat.at(i, 3); + float xLeftBottom = (x - width / 2) * frame.cols; + float yLeftBottom = (y - height / 2) * frame.rows; + float xRightTop = (x + width / 2) * frame.cols; + float yRightTop = (y + height / 2) * frame.rows; + + std::cout << "Class: " << objectClass << std::endl; + std::cout << "Confidence: " << confidence << std::endl; + + std::cout << " " << xLeftBottom + << " " << yLeftBottom + << " " << xRightTop + << " " << yRightTop << std::endl; + + Rect object((int)xLeftBottom, (int)yLeftBottom, + (int)(xRightTop - xLeftBottom), + (int)(yRightTop - yLeftBottom)); + + rectangle(frame, object, Scalar(0, 255, 0)); + } + } + + imshow("detections", frame); + waitKey(); + + return 0; +} // main \ No newline at end of file -- 2.7.4