Unite deep learning object detection samples
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 2 Mar 2018 09:04:39 +0000 (12:04 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Sat, 3 Mar 2018 11:47:13 +0000 (14:47 +0300)
14 files changed:
doc/tutorials/dnn/dnn_yolo/dnn_yolo.markdown
modules/dnn/include/opencv2/dnn/dnn.hpp
samples/dnn/README.md [new file with mode: 0644]
samples/dnn/faster_rcnn.cpp [deleted file]
samples/dnn/mobilenet_ssd_python.py [deleted file]
samples/dnn/object_detection.cpp [new file with mode: 0644]
samples/dnn/object_detection.py [new file with mode: 0644]
samples/dnn/object_detection_classes_coco.txt [new file with mode: 0644]
samples/dnn/object_detection_classes_pascal_voc.txt [new file with mode: 0644]
samples/dnn/resnet_ssd_face.cpp [deleted file]
samples/dnn/resnet_ssd_face_python.py [deleted file]
samples/dnn/ssd_mobilenet_object_detection.cpp [deleted file]
samples/dnn/ssd_object_detection.cpp [deleted file]
samples/dnn/yolo_object_detection.cpp [deleted file]

index 885c864..a5eb0a9 100644 (file)
@@ -18,40 +18,26 @@ VIDEO DEMO:
 Source Code
 -----------
 
-The latest version of sample source code can be downloaded [here](https://github.com/opencv/opencv/blob/master/samples/dnn/yolo_object_detection.cpp).
+Use a universal sample for object detection models written
+[in C++](https://github.com/opencv/opencv/blob/master/samples/dnn/object_detection.cpp) and
+[in Python](https://github.com/opencv/opencv/blob/master/samples/dnn/object_detection.py) languages
 
-@include dnn/yolo_object_detection.cpp
-
-How to compile in command line with pkg-config
-----------------------------------------------
-
-@code{.bash}
-
-# g++ `pkg-config --cflags opencv` `pkg-config --libs opencv` yolo_object_detection.cpp -o yolo_object_detection
-
-@endcode
+Usage examples
+--------------
 
 Execute in webcam:
 
 @code{.bash}
 
-$ yolo_object_detection -camera_device=0  -cfg=[PATH-TO-DARKNET]/cfg/yolo.cfg -model=[PATH-TO-DARKNET]/yolo.weights   -class_names=[PATH-TO-DARKNET]/data/coco.names
-
-@endcode
-
-Execute with image:
-
-@code{.bash}
-
-$ yolo_object_detection -source=[PATH-IMAGE]  -cfg=[PATH-TO-DARKNET]/cfg/yolo.cfg -model=[PATH-TO-DARKNET]/yolo.weights   -class_names=[PATH-TO-DARKNET]/data/coco.names
+$ example_dnn_object_detection --config=[PATH-TO-DARKNET]/cfg/yolo.cfg --model=[PATH-TO-DARKNET]/yolo.weights --classes=object_detection_classes_pascal_voc.txt --width=416 --height=416 --scale=0.00392
 
 @endcode
 
-Execute in video file:
+Execute with image or video file:
 
 @code{.bash}
 
-$ yolo_object_detection -source=[PATH-TO-VIDEO] -cfg=[PATH-TO-DARKNET]/cfg/yolo.cfg -model=[PATH-TO-DARKNET]/yolo.weights   -class_names=[PATH-TO-DARKNET]/data/coco.names
+$ example_dnn_object_detection --config=[PATH-TO-DARKNET]/cfg/yolo.cfg --model=[PATH-TO-DARKNET]/yolo.weights --classes=object_detection_classes_pascal_voc.txt --width=416 --height=416 --scale=0.00392 --input[PATH-TO-IMAGE-OR-VIDEO-FILE]
 
 @endcode
 
index 54f33a6..5495276 100644 (file)
@@ -222,7 +222,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
         /** @brief Returns index of output blob in output array.
          *  @see inputNameToIndex()
          */
-        virtual int outputNameToIndex(String outputName);
+        CV_WRAP virtual int outputNameToIndex(String outputName);
 
         /**
          * @brief Ask layer if it support specific backend for doing computations.
diff --git a/samples/dnn/README.md b/samples/dnn/README.md
new file mode 100644 (file)
index 0000000..3a7523b
--- /dev/null
@@ -0,0 +1,20 @@
+# OpenCV deep learning module samples
+
+## Model Zoo
+
+### Object detection
+
+|    Model | Scale |   Size WxH|   Mean subtraction | Channels order |
+|---------------|-------|-----------|--------------------|-------|
+| [MobileNet-SSD, Caffe](https://github.com/chuanqi305/MobileNet-SSD/) | `0.00784 (2/255)` | `300x300` | `127.5 127.5 127.5` | BGR |
+| [OpenCV face detector](https://github.com/opencv/opencv/tree/master/samples/dnn/face_detector) | `1.0` | `300x300` | `104 177 123` | BGR |
+| [SSDs from TensorFlow](https://github.com/tensorflow/models/tree/master/research/object_detection/) | `0.00784 (2/255)` | `300x300` | `127.5 127.5 127.5` | RGB |
+| [YOLO](https://pjreddie.com/darknet/yolo/) | `0.00392 (1/255)` | `416x416` | `0 0 0` | RGB |
+| [VGG16-SSD](https://github.com/weiliu89/caffe/tree/ssd) | `1.0` | `300x300` | `104 117 123` | BGR |
+| [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801, 115.9465, 122.7717` | BGR |
+| [R-FCN](https://github.com/YuwenXiong/py-R-FCN) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR |
+
+## References
+* [Models downloading script](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/download_models.py)
+* [Configuration files adopted for OpenCV](https://github.com/opencv/opencv_extra/tree/master/testdata/dnn)
+* [How to import models from TensorFlow Object Detection API](https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API)
diff --git a/samples/dnn/faster_rcnn.cpp b/samples/dnn/faster_rcnn.cpp
deleted file mode 100644 (file)
index 5d92021..0000000
+++ /dev/null
@@ -1,93 +0,0 @@
-#include <opencv2/dnn.hpp>
-#include <opencv2/dnn/all_layers.hpp>
-#include <opencv2/imgproc.hpp>
-#include <opencv2/highgui.hpp>
-
-using namespace cv;
-using namespace dnn;
-
-const char* keys =
-    "{ help  h |     | print help message  }"
-    "{ proto p |     | path to .prototxt   }"
-    "{ model m |     | path to .caffemodel }"
-    "{ image i |     | path to input image }"
-    "{ conf  c | 0.8 | minimal confidence  }";
-
-const char* classNames[] = {
-    "__background__",
-    "aeroplane", "bicycle", "bird", "boat",
-    "bottle", "bus", "car", "cat", "chair",
-    "cow", "diningtable", "dog", "horse",
-    "motorbike", "person", "pottedplant",
-    "sheep", "sofa", "train", "tvmonitor"
-};
-
-static const int kInpWidth = 800;
-static const int kInpHeight = 600;
-
-int main(int argc, char** argv)
-{
-    // Parse command line arguments.
-    CommandLineParser parser(argc, argv, keys);
-    parser.about("This sample is used to run Faster-RCNN and R-FCN object detection "
-                 "models with OpenCV. You can get required models from "
-                 "https://github.com/rbgirshick/py-faster-rcnn (Faster-RCNN) and from "
-                 "https://github.com/YuwenXiong/py-R-FCN (R-FCN). Corresponding .prototxt "
-                 "files may be found at https://github.com/opencv/opencv_extra/tree/master/testdata/dnn.");
-    if (argc == 1 || parser.has("help"))
-    {
-        parser.printMessage();
-        return 0;
-    }
-
-    String protoPath = parser.get<String>("proto");
-    String modelPath = parser.get<String>("model");
-    String imagePath = parser.get<String>("image");
-    float confThreshold = parser.get<float>("conf");
-    CV_Assert(!protoPath.empty(), !modelPath.empty(), !imagePath.empty());
-
-    // Load a model.
-    Net net = readNetFromCaffe(protoPath, modelPath);
-
-    Mat img = imread(imagePath);
-    resize(img, img, Size(kInpWidth, kInpHeight));
-    Mat blob = blobFromImage(img, 1.0, Size(), Scalar(102.9801, 115.9465, 122.7717), false, false);
-    Mat imInfo = (Mat_<float>(1, 3) << img.rows, img.cols, 1.6f);
-
-    net.setInput(blob, "data");
-    net.setInput(imInfo, "im_info");
-
-    // Draw detections.
-    Mat detections = net.forward();
-    const float* data = (float*)detections.data;
-    for (size_t i = 0; i < detections.total(); i += 7)
-    {
-        // An every detection is a vector [id, classId, confidence, left, top, right, bottom]
-        float confidence = data[i + 2];
-        if (confidence > confThreshold)
-        {
-            int classId = (int)data[i + 1];
-            int left = max(0, min((int)data[i + 3], img.cols - 1));
-            int top = max(0, min((int)data[i + 4], img.rows - 1));
-            int right = max(0, min((int)data[i + 5], img.cols - 1));
-            int bottom = max(0, min((int)data[i + 6], img.rows - 1));
-
-            // Draw a bounding box.
-            rectangle(img, Point(left, top), Point(right, bottom), Scalar(0, 255, 0));
-
-            // Put a label with a class name and confidence.
-            String label = cv::format("%s, %.3f", classNames[classId], confidence);
-            int baseLine;
-            Size labelSize = cv::getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
-
-            top = max(top, labelSize.height);
-            rectangle(img, Point(left, top - labelSize.height),
-                      Point(left + labelSize.width, top + baseLine),
-                      Scalar(255, 255, 255), FILLED);
-            putText(img, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 0));
-        }
-    }
-    imshow("frame", img);
-    waitKey();
-    return 0;
-}
diff --git a/samples/dnn/mobilenet_ssd_python.py b/samples/dnn/mobilenet_ssd_python.py
deleted file mode 100644 (file)
index 839f879..0000000
+++ /dev/null
@@ -1,132 +0,0 @@
-# This script is used to demonstrate MobileNet-SSD network using OpenCV deep learning module.
-#
-# It works with model taken from https://github.com/chuanqi305/MobileNet-SSD/ that
-# was trained in Caffe-SSD framework, https://github.com/weiliu89/caffe/tree/ssd.
-# Model detects objects from 20 classes.
-#
-# Also TensorFlow model from TensorFlow object detection model zoo may be used to
-# detect objects from 90 classes:
-# http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz
-# Text graph definition must be taken from opencv_extra:
-# https://github.com/opencv/opencv_extra/tree/master/testdata/dnn/ssd_mobilenet_v1_coco.pbtxt
-import numpy as np
-import argparse
-
-try:
-    import cv2 as cv
-except ImportError:
-    raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
-                      'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')
-
-inWidth = 300
-inHeight = 300
-WHRatio = inWidth / float(inHeight)
-inScaleFactor = 0.007843
-meanVal = 127.5
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(
-        description='Script to run MobileNet-SSD object detection network '
-                    'trained either in Caffe or TensorFlow frameworks.')
-    parser.add_argument("--video", help="path to video file. If empty, camera's stream will be used")
-    parser.add_argument("--prototxt", default="MobileNetSSD_deploy.prototxt",
-                                      help='Path to text network file: '
-                                           'MobileNetSSD_deploy.prototxt for Caffe model or '
-                                           'ssd_mobilenet_v1_coco.pbtxt from opencv_extra for TensorFlow model')
-    parser.add_argument("--weights", default="MobileNetSSD_deploy.caffemodel",
-                                     help='Path to weights: '
-                                          'MobileNetSSD_deploy.caffemodel for Caffe model or '
-                                          'frozen_inference_graph.pb from TensorFlow.')
-    parser.add_argument("--num_classes", default=20, type=int,
-                        help="Number of classes. It's 20 for Caffe model from "
-                             "https://github.com/chuanqi305/MobileNet-SSD/ and 90 for "
-                             "TensorFlow model from http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz")
-    parser.add_argument("--thr", default=0.2, type=float, help="confidence threshold to filter out weak detections")
-    args = parser.parse_args()
-
-    if args.num_classes == 20:
-        net = cv.dnn.readNetFromCaffe(args.prototxt, args.weights)
-        swapRB = False
-        classNames = { 0: 'background',
-            1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat',
-            5: 'bottle', 6: 'bus', 7: 'car', 8: 'cat', 9: 'chair',
-            10: 'cow', 11: 'diningtable', 12: 'dog', 13: 'horse',
-            14: 'motorbike', 15: 'person', 16: 'pottedplant',
-            17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor' }
-    else:
-        assert(args.num_classes == 90)
-        net = cv.dnn.readNetFromTensorflow(args.weights, args.prototxt)
-        swapRB = True
-        classNames = { 0: 'background',
-            1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus',
-            7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant',
-            13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat',
-            18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear',
-            24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag',
-            32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard',
-            37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove',
-            41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle',
-            46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon',
-            51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',
-            56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut',
-            61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed',
-            67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse',
-            75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven',
-            80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock',
-            86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush' }
-
-    if args.video:
-        cap = cv.VideoCapture(args.video)
-    else:
-        cap = cv.VideoCapture(0)
-
-    while True:
-        # Capture frame-by-frame
-        ret, frame = cap.read()
-        blob = cv.dnn.blobFromImage(frame, inScaleFactor, (inWidth, inHeight), (meanVal, meanVal, meanVal), swapRB)
-        net.setInput(blob)
-        detections = net.forward()
-
-        cols = frame.shape[1]
-        rows = frame.shape[0]
-
-        if cols / float(rows) > WHRatio:
-            cropSize = (int(rows * WHRatio), rows)
-        else:
-            cropSize = (cols, int(cols / WHRatio))
-
-        y1 = int((rows - cropSize[1]) / 2)
-        y2 = y1 + cropSize[1]
-        x1 = int((cols - cropSize[0]) / 2)
-        x2 = x1 + cropSize[0]
-        frame = frame[y1:y2, x1:x2]
-
-        cols = frame.shape[1]
-        rows = frame.shape[0]
-
-        for i in range(detections.shape[2]):
-            confidence = detections[0, 0, i, 2]
-            if confidence > args.thr:
-                class_id = int(detections[0, 0, i, 1])
-
-                xLeftBottom = int(detections[0, 0, i, 3] * cols)
-                yLeftBottom = int(detections[0, 0, i, 4] * rows)
-                xRightTop   = int(detections[0, 0, i, 5] * cols)
-                yRightTop   = int(detections[0, 0, i, 6] * rows)
-
-                cv.rectangle(frame, (xLeftBottom, yLeftBottom), (xRightTop, yRightTop),
-                              (0, 255, 0))
-                if class_id in classNames:
-                    label = classNames[class_id] + ": " + str(confidence)
-                    labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1)
-
-                    yLeftBottom = max(yLeftBottom, labelSize[1])
-                    cv.rectangle(frame, (xLeftBottom, yLeftBottom - labelSize[1]),
-                                         (xLeftBottom + labelSize[0], yLeftBottom + baseLine),
-                                         (255, 255, 255), cv.FILLED)
-                    cv.putText(frame, label, (xLeftBottom, yLeftBottom),
-                                cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
-
-        cv.imshow("detections", frame)
-        if cv.waitKey(1) >= 0:
-            break
diff --git a/samples/dnn/object_detection.cpp b/samples/dnn/object_detection.cpp
new file mode 100644 (file)
index 0000000..8f9369b
--- /dev/null
@@ -0,0 +1,255 @@
+#include <opencv2/opencv.hpp>
+#include <fstream>
+#include <iostream>
+#include <sstream>
+
+const char* keys =
+    "{ help  h     | | Print help message. }"
+    "{ input i     | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
+    "{ model m     | | Path to a binary file of model contains trained weights. "
+                      "It could be a file with extensions .caffemodel (Caffe), "
+                      ".pb (TensorFlow), .t7 or .net (Torch), .weights (Darknet) }"
+    "{ config c    | | Path to a text file of model contains network configuration. "
+                      "It could be a file with extensions .prototxt (Caffe), .pbtxt (TensorFlow), .cfg (Darknet) }"
+    "{ framework f | | Optional name of an origin framework of the model. Detect it automatically if it does not set. }"
+    "{ classes     | | Optional path to a text file with names of classes to label detected objects. }"
+    "{ mean        | | Preprocess input image by subtracting mean values. Mean values should be in BGR order and delimited by spaces. }"
+    "{ scale       |  1 | Preprocess input image by multiplying on a scale factor. }"
+    "{ width       | -1 | Preprocess input image by resizing to a specific width. }"
+    "{ height      | -1 | Preprocess input image by resizing to a specific height. }"
+    "{ rgb         |    | Indicate that model works with RGB input images instead BGR ones. }"
+    "{ thr         | .5 | Confidence threshold. }"
+    "{ opencl      |    | Enable OpenCL }";
+
+using namespace cv;
+using namespace dnn;
+
+float confThreshold;
+std::vector<std::string> classes;
+
+void loadClasses(const std::string& file);
+
+Net readNet(const std::string& model, const std::string& config = "", const std::string& framework = "");
+
+void postprocess(Mat& frame, const Mat& out, Net& net);
+
+void drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame);
+
+void callback(int pos, void* userdata);
+
+int main(int argc, char** argv)
+{
+    CommandLineParser parser(argc, argv, keys);
+    parser.about("Use this script to run object detection deep learning networks using OpenCV.");
+    if (argc == 1 || parser.has("help"))
+    {
+        parser.printMessage();
+        return 0;
+    }
+
+    confThreshold = parser.get<float>("thr");
+    float scale = parser.get<float>("scale");
+    bool swapRB = parser.get<bool>("rgb");
+    int inpWidth = parser.get<int>("width");
+    int inpHeight = parser.get<int>("height");
+
+    // Parse mean values.
+    Scalar mean;
+    if (parser.has("mean"))
+    {
+        std::istringstream meanStr(parser.get<String>("mean"));
+        std::vector<float> meanValues;
+        float val;
+        while (meanStr >> val)
+            meanValues.push_back(val);
+        CV_Assert(meanValues.size() == 3);
+        mean = Scalar(meanValues[0], meanValues[1], meanValues[2]);
+    }
+
+    // Open file with classes names.
+    if (parser.has("classes"))
+    {
+        std::string file = parser.get<String>("classes");
+        std::ifstream ifs(file.c_str());
+        if (!ifs.is_open())
+            CV_Error(Error::StsError, "File " + file + " not found");
+        std::string line;
+        while (ifs >> line)
+        {
+            classes.push_back(line);
+        }
+    }
+
+    // Load a model.
+    CV_Assert(parser.has("model"));
+    Net net = readNet(parser.get<String>("model"), parser.get<String>("config"), parser.get<String>("framework"));
+
+    if (parser.get<bool>("opencl"))
+    {
+        net.setPreferableTarget(DNN_TARGET_OPENCL);
+    }
+
+    // Create a window
+    static const std::string kWinName = "Deep learning object detection in OpenCV";
+    namedWindow(kWinName, WINDOW_NORMAL);
+    int initialConf = confThreshold * 100;
+    createTrackbar("Confidence threshold", kWinName, &initialConf, 99, callback);
+
+    // Open a video file or an image file or a camera stream.
+    VideoCapture cap;
+    if (parser.has("input"))
+        cap.open(parser.get<String>("input"));
+    else
+        cap.open(0);
+
+    // Process frames.
+    Mat frame, blob;
+    while (waitKey(1) < 0)
+    {
+        cap >> frame;
+        if (frame.empty())
+        {
+            waitKey();
+            break;
+        }
+
+        // Create a 4D blob from a frame.
+        Size inpSize(inpWidth > 0 ? inpWidth : frame.cols,
+                     inpHeight > 0 ? inpHeight : frame.rows);
+        blobFromImage(frame, blob, scale, inpSize, mean, swapRB, false);
+
+        // Run a model.
+        net.setInput(blob);
+        if (net.getLayer(0)->outputNameToIndex("im_info") != -1)  // Faster-RCNN or R-FCN
+        {
+            resize(frame, frame, inpSize);
+            Mat imInfo = (Mat_<float>(1, 3) << inpSize.height, inpSize.width, 1.6f);
+            net.setInput(imInfo, "im_info");
+        }
+        Mat out = net.forward();
+
+        postprocess(frame, out, net);
+
+        // Put efficiency information.
+        std::vector<double> layersTimes;
+        double t = net.getPerfProfile(layersTimes);
+        std::string label = format("Inference time: %.2f", t * 1000 / getTickFrequency());
+        putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar());
+
+        imshow(kWinName, frame);
+    }
+    return 0;
+}
+
+void postprocess(Mat& frame, const Mat& out, Net& net)
+{
+    static std::vector<int> outLayers = net.getUnconnectedOutLayers();
+    static std::string outLayerType = net.getLayer(outLayers[0])->type;
+
+    float* data = (float*)out.data;
+    if (net.getLayer(0)->outputNameToIndex("im_info") != -1)  // Faster-RCNN or R-FCN
+    {
+        // Network produces output blob with a shape 1x1xNx7 where N is a number of
+        // detections and an every detection is a vector of values
+        // [batchId, classId, confidence, left, top, right, bottom]
+        for (size_t i = 0; i < out.total(); i += 7)
+        {
+            float confidence = data[i + 2];
+            if (confidence > confThreshold)
+            {
+                int left = data[i + 3];
+                int top = data[i + 4];
+                int right = data[i + 5];
+                int bottom = data[i + 6];
+                int classId = (int)(data[i + 1]) - 1;  // Skip 0th background class id.
+                drawPred(classId, confidence, left, top, right, bottom, frame);
+            }
+        }
+    }
+    else if (outLayerType == "DetectionOutput")
+    {
+        // Network produces output blob with a shape 1x1xNx7 where N is a number of
+        // detections and an every detection is a vector of values
+        // [batchId, classId, confidence, left, top, right, bottom]
+        for (size_t i = 0; i < out.total(); i += 7)
+        {
+            float confidence = data[i + 2];
+            if (confidence > confThreshold)
+            {
+                int left = (int)(data[i + 3] * frame.cols);
+                int top = (int)(data[i + 4] * frame.rows);
+                int right = (int)(data[i + 5] * frame.cols);
+                int bottom = (int)(data[i + 6] * frame.rows);
+                int classId = (int)(data[i + 1]) - 1;  // Skip 0th background class id.
+                drawPred(classId, confidence, left, top, right, bottom, frame);
+            }
+        }
+    }
+    else if (outLayerType == "Region")
+    {
+        // Network produces output blob with a shape NxC where N is a number of
+        // detected objects and C is a number of classes + 4 where the first 4
+        // numbers are [center_x, center_y, width, height]
+        for (int i = 0; i < out.rows; ++i, data += out.cols)
+        {
+            Mat confidences = out.row(i).colRange(5, out.cols);
+            Point classIdPoint;
+            double confidence;
+            minMaxLoc(confidences, 0, &confidence, 0, &classIdPoint);
+            if (confidence > confThreshold)
+            {
+                int classId = classIdPoint.x;
+                int centerX = (int)(data[0] * frame.cols);
+                int centerY = (int)(data[1] * frame.rows);
+                int width = (int)(data[2] * frame.cols);
+                int height = (int)(data[3] * frame.rows);
+                int left = centerX - width / 2;
+                int top = centerY - height / 2;
+                drawPred(classId, confidence, left, top, left + width, top + height, frame);
+            }
+        }
+    }
+    else
+        CV_Error(Error::StsNotImplemented, "Unknown output layer type: " + outLayerType);
+}
+
+void drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame)
+{
+    rectangle(frame, Point(left, top), Point(right, bottom), Scalar(0, 255, 0));
+
+    std::string label = format("%.2f", conf);
+    if (!classes.empty())
+    {
+        CV_Assert(classId < (int)classes.size());
+        label = classes[classId] + ": " + label;
+    }
+
+    int baseLine;
+    Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
+
+    top = max(top, labelSize.height);
+    rectangle(frame, Point(left, top - labelSize.height),
+              Point(left + labelSize.width, top + baseLine), Scalar::all(255), FILLED);
+    putText(frame, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.5, Scalar());
+}
+
+void callback(int pos, void*)
+{
+    confThreshold = pos * 0.01;
+}
+
+Net readNet(const std::string& model, const std::string& config, const std::string& framework)
+{
+    std::string modelExt = model.substr(model.find('.'));
+    if (framework == "caffe" || modelExt == ".caffemodel")
+        return readNetFromCaffe(config, model);
+    else if (framework == "tensorflow" || modelExt == ".pb")
+        return readNetFromTensorflow(model, config);
+    else if (framework == "torch" || modelExt == ".t7" || modelExt == ".net")
+        return readNetFromTorch(model);
+    else if (framework == "darknet" || modelExt == ".weights")
+        return readNetFromDarknet(config, model);
+    else
+        CV_Error(Error::StsError, "Cannot determine an origin framework of model from file " + model);
+    return Net();
+}
diff --git a/samples/dnn/object_detection.py b/samples/dnn/object_detection.py
new file mode 100644 (file)
index 0000000..76d6e5a
--- /dev/null
@@ -0,0 +1,161 @@
+import cv2 as cv
+import argparse
+import sys
+import numpy as np
+
+parser = argparse.ArgumentParser(description='Use this script to run object detection deep learning networks using OpenCV.')
+parser.add_argument('--input', help='Path to input image or video file. Skip this argument to capture frames from a camera.')
+parser.add_argument('--model', required=True,
+                    help='Path to a binary file of model contains trained weights. '
+                         'It could be a file with extensions .caffemodel (Caffe), '
+                         '.pb (TensorFlow), .t7 or .net (Torch), .weights (Darknet)')
+parser.add_argument('--config',
+                    help='Path to a text file of model contains network configuration. '
+                         'It could be a file with extensions .prototxt (Caffe), .pbtxt (TensorFlow), .cfg (Darknet)')
+parser.add_argument('--framework', choices=['caffe', 'tensorflow', 'torch', 'darknet'],
+                    help='Optional name of an origin framework of the model. '
+                         'Detect it automatically if it does not set.')
+parser.add_argument('--classes', help='Optional path to a text file with names of classes to label detected objects.')
+parser.add_argument('--mean', nargs='+', type=float, default=[0, 0, 0],
+                    help='Preprocess input image by subtracting mean values. '
+                         'Mean values should be in BGR order.')
+parser.add_argument('--scale', type=float, default=1.0,
+                    help='Preprocess input image by multiplying on a scale factor.')
+parser.add_argument('--width', type=int,
+                    help='Preprocess input image by resizing to a specific width.')
+parser.add_argument('--height', type=int,
+                    help='Preprocess input image by resizing to a specific height.')
+parser.add_argument('--rgb', action='store_true',
+                    help='Indicate that model works with RGB input images instead BGR ones.')
+parser.add_argument('--thr', type=float, default=0.5, help='Confidence threshold')
+args = parser.parse_args()
+
+# Load names of classes
+classes = None
+if args.classes:
+    with open(args.classes, 'rt') as f:
+        classes = f.read().rstrip('\n').split('\n')
+
+# Load a network
+modelExt = args.model[args.model.find('.'):]
+if args.framework == 'caffe' or modelExt == '.caffemodel':
+    net = cv.dnn.readNetFromCaffe(args.config, args.model)
+elif args.framework == 'tensorflow' or modelExt == '.pb':
+    net = cv.dnn.readNetFromTensorflow(args.model, args.config)
+elif args.framework == 'torch' or modelExt in ['.t7', '.net']:
+    net = cv.dnn.readNetFromTorch(args.model)
+elif args.framework == 'darknet' or modelExt == '.weights':
+    net = cv.dnn.readNetFromDarknet(args.config, args.model)
+else:
+    print('Cannot determine an origin framework of model from file %s' % args.model)
+    sys.exit(0)
+
+confThreshold = args.thr
+
+def postprocess(frame, out):
+    frameHeight = frame.shape[0]
+    frameWidth = frame.shape[1]
+
+    def drawPred(classId, conf, left, top, right, bottom):
+        # Draw a bounding box.
+        cv.rectangle(frame, (left, top), (right, bottom), (0, 255, 0))
+
+        label = '%.2f' % confidence
+
+        # Print a label of class.
+        if classes:
+            assert(classId < len(classes))
+            label = '%s: %s' % (classes[classId], label)
+
+        labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1)
+        top = max(top, labelSize[1])
+        cv.rectangle(frame, (left, top - labelSize[1]), (left + labelSize[0], top + baseLine), (255, 255, 255), cv.FILLED)
+        cv.putText(frame, label, (left, top), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
+
+    layerNames = net.getLayerNames()
+    lastLayerId = net.getLayerId(layerNames[-1])
+    lastLayer = net.getLayer(lastLayerId)
+
+    if net.getLayer(0).outputNameToIndex('im_info') != -1:  # Faster-RCNN or R-FCN
+        # Network produces output blob with a shape 1x1xNx7 where N is a number of
+        # detections and an every detection is a vector of values
+        # [batchId, classId, confidence, left, top, right, bottom]
+        for detection in out[0, 0]:
+            confidence = detection[2]
+            if confidence > confThreshold:
+                left = int(detection[3])
+                top = int(detection[4])
+                right = int(detection[5])
+                bottom = int(detection[6])
+                classId = int(detection[1]) - 1  # Skip background label
+                drawPred(classId, confidence, left, top, right, bottom)
+    elif lastLayer.type == 'DetectionOutput':
+        # Network produces output blob with a shape 1x1xNx7 where N is a number of
+        # detections and an every detection is a vector of values
+        # [batchId, classId, confidence, left, top, right, bottom]
+        for detection in out[0, 0]:
+            confidence = detection[2]
+            if confidence > confThreshold:
+                left = int(detection[3] * frameWidth)
+                top = int(detection[4] * frameHeight)
+                right = int(detection[5] * frameWidth)
+                bottom = int(detection[6] * frameHeight)
+                classId = int(detection[1]) - 1  # Skip background label
+                drawPred(classId, confidence, left, top, right, bottom)
+    elif lastLayer.type == 'Region':
+        # Network produces output blob with a shape NxC where N is a number of
+        # detected objects and C is a number of classes + 4 where the first 4
+        # numbers are [center_x, center_y, width, height]
+        for detection in out:
+            confidences = detection[5:]
+            classId = np.argmax(confidences)
+            confidence = confidences[classId]
+            if confidence > confThreshold:
+                center_x = int(detection[0] * frameWidth)
+                center_y = int(detection[1] * frameHeight)
+                width = int(detection[2] * frameWidth)
+                height = int(detection[3] * frameHeight)
+                left = center_x - width / 2
+                top = center_y - height / 2
+                drawPred(classId, confidence, left, top, left + width, top + height)
+
+# Process inputs
+winName = 'Deep learning object detection in OpenCV'
+cv.namedWindow(winName, cv.WINDOW_NORMAL)
+
+def callback(pos):
+    global confThreshold
+    confThreshold = pos / 100.0
+
+cv.createTrackbar('Confidence threshold, %', winName, int(confThreshold * 100), 99, callback)
+
+cap = cv.VideoCapture(args.input if args.input else 0)
+while cv.waitKey(1) < 0:
+    hasFrame, frame = cap.read()
+    if not hasFrame:
+        cv.waitKey()
+        break
+
+    frameHeight = frame.shape[0]
+    frameWidth = frame.shape[1]
+
+    # Create a 4D blob from a frame.
+    inpWidth = args.width if args.width else frameWidth
+    inpHeight = args.height if args.height else frameHeight
+    blob = cv.dnn.blobFromImage(frame, args.scale, (inpWidth, inpHeight), args.mean, args.rgb, crop=False)
+
+    # Run a model
+    net.setInput(blob)
+    if net.getLayer(0).outputNameToIndex('im_info') != -1:  # Faster-RCNN or R-FCN
+        frame = cv.resize(frame, (inpWidth, inpHeight))
+        net.setInput(np.array([inpHeight, inpWidth, 1.6], dtype=np.float32), 'im_info');
+    out = net.forward()
+
+    postprocess(frame, out)
+
+    # Put efficiency information.
+    t, _ = net.getPerfProfile()
+    label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency())
+    cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
+
+    cv.imshow(winName, frame)
diff --git a/samples/dnn/object_detection_classes_coco.txt b/samples/dnn/object_detection_classes_coco.txt
new file mode 100644 (file)
index 0000000..75aa546
--- /dev/null
@@ -0,0 +1,90 @@
+person
+bicycle
+car
+motorcycle
+airplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+
+backpack
+umbrella
+
+
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+couch
+potted plant
+bed
+
+dining table
+
+
+toilet
+
+tv
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
diff --git a/samples/dnn/object_detection_classes_pascal_voc.txt b/samples/dnn/object_detection_classes_pascal_voc.txt
new file mode 100644 (file)
index 0000000..8420ab3
--- /dev/null
@@ -0,0 +1,20 @@
+aeroplane
+bicycle
+bird
+boat
+bottle
+bus
+car
+cat
+chair
+cow
+diningtable
+dog
+horse
+motorbike
+person
+pottedplant
+sheep
+sofa
+train
+tvmonitor
diff --git a/samples/dnn/resnet_ssd_face.cpp b/samples/dnn/resnet_ssd_face.cpp
deleted file mode 100644 (file)
index b8227d7..0000000
+++ /dev/null
@@ -1,164 +0,0 @@
-#include <opencv2/dnn.hpp>
-#include <opencv2/imgproc.hpp>
-#include <opencv2/highgui.hpp>
-#include <iostream>
-
-using namespace cv;
-using namespace std;
-using namespace cv::dnn;
-
-const size_t inWidth = 300;
-const size_t inHeight = 300;
-const double inScaleFactor = 1.0;
-const Scalar meanVal(104.0, 177.0, 123.0);
-
-const char* about = "This sample uses Single-Shot Detector "
-                    "(https://arxiv.org/abs/1512.02325) "
-                    "with ResNet-10 architecture to detect faces on camera/video/image.\n"
-                    "More information about the training is available here: "
-                    "<OPENCV_SRC_DIR>/samples/dnn/face_detector/how_to_train_face_detector.txt\n"
-                    ".caffemodel model's file is available here: "
-                    "<OPENCV_SRC_DIR>/samples/dnn/face_detector/res10_300x300_ssd_iter_140000.caffemodel\n"
-                    ".prototxt file is available here: "
-                    "<OPENCV_SRC_DIR>/samples/dnn/face_detector/deploy.prototxt\n";
-
-const char* params
-    = "{ help           | false | print usage          }"
-      "{ proto          |       | model configuration (deploy.prototxt) }"
-      "{ model          |       | model weights (res10_300x300_ssd_iter_140000.caffemodel) }"
-      "{ camera_device  | 0     | camera device number }"
-      "{ video          |       | video or image for detection }"
-      "{ opencl         | false | enable OpenCL }"
-      "{ min_confidence | 0.5   | min confidence       }";
-
-int main(int argc, char** argv)
-{
-    CommandLineParser parser(argc, argv, params);
-
-    if (parser.get<bool>("help"))
-    {
-        cout << about << endl;
-        parser.printMessage();
-        return 0;
-    }
-
-    String modelConfiguration = parser.get<string>("proto");
-    String modelBinary = parser.get<string>("model");
-
-    //! [Initialize network]
-    dnn::Net net = readNetFromCaffe(modelConfiguration, modelBinary);
-    //! [Initialize network]
-
-    if (net.empty())
-    {
-        cerr << "Can't load network by using the following files: " << endl;
-        cerr << "prototxt:   " << modelConfiguration << endl;
-        cerr << "caffemodel: " << modelBinary << endl;
-        cerr << "Models are available here:" << endl;
-        cerr << "<OPENCV_SRC_DIR>/samples/dnn/face_detector" << endl;
-        cerr << "or here:" << endl;
-        cerr << "https://github.com/opencv/opencv/tree/master/samples/dnn/face_detector" << endl;
-        exit(-1);
-    }
-
-    if (parser.get<bool>("opencl"))
-    {
-        net.setPreferableTarget(DNN_TARGET_OPENCL);
-    }
-
-    VideoCapture cap;
-    if (parser.get<String>("video").empty())
-    {
-        int cameraDevice = parser.get<int>("camera_device");
-        cap = VideoCapture(cameraDevice);
-        if(!cap.isOpened())
-        {
-            cout << "Couldn't find camera: " << cameraDevice << endl;
-            return -1;
-        }
-    }
-    else
-    {
-        cap.open(parser.get<String>("video"));
-        if(!cap.isOpened())
-        {
-            cout << "Couldn't open image or video: " << parser.get<String>("video") << endl;
-            return -1;
-        }
-    }
-
-    for(;;)
-    {
-        Mat frame;
-        cap >> frame; // get a new frame from camera/video or read image
-
-        if (frame.empty())
-        {
-            waitKey();
-            break;
-        }
-
-        if (frame.channels() == 4)
-            cvtColor(frame, frame, COLOR_BGRA2BGR);
-
-        //! [Prepare blob]
-        Mat inputBlob = blobFromImage(frame, inScaleFactor,
-                                      Size(inWidth, inHeight), meanVal, false, false); //Convert Mat to batch of images
-        //! [Prepare blob]
-
-        //! [Set input blob]
-        net.setInput(inputBlob, "data"); //set the network input
-        //! [Set input blob]
-
-        //! [Make forward pass]
-        Mat detection = net.forward("detection_out"); //compute output
-        //! [Make forward pass]
-
-        vector<double> layersTimings;
-        double freq = getTickFrequency() / 1000;
-        double time = net.getPerfProfile(layersTimings) / freq;
-
-        Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());
-
-        ostringstream ss;
-        ss << "FPS: " << 1000/time << " ; time: " << time << " ms";
-        putText(frame, ss.str(), Point(20,20), 0, 0.5, Scalar(0,0,255));
-
-        float confidenceThreshold = parser.get<float>("min_confidence");
-        for(int i = 0; i < detectionMat.rows; i++)
-        {
-            float confidence = detectionMat.at<float>(i, 2);
-
-            if(confidence > confidenceThreshold)
-            {
-                int xLeftBottom = static_cast<int>(detectionMat.at<float>(i, 3) * frame.cols);
-                int yLeftBottom = static_cast<int>(detectionMat.at<float>(i, 4) * frame.rows);
-                int xRightTop = static_cast<int>(detectionMat.at<float>(i, 5) * frame.cols);
-                int yRightTop = static_cast<int>(detectionMat.at<float>(i, 6) * frame.rows);
-
-                Rect object((int)xLeftBottom, (int)yLeftBottom,
-                            (int)(xRightTop - xLeftBottom),
-                            (int)(yRightTop - yLeftBottom));
-
-                rectangle(frame, object, Scalar(0, 255, 0));
-
-                ss.str("");
-                ss << confidence;
-                String conf(ss.str());
-                String label = "Face: " + conf;
-                int baseLine = 0;
-                Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
-                rectangle(frame, Rect(Point(xLeftBottom, yLeftBottom - labelSize.height),
-                                      Size(labelSize.width, labelSize.height + baseLine)),
-                          Scalar(255, 255, 255), FILLED);
-                putText(frame, label, Point(xLeftBottom, yLeftBottom),
-                        FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0,0,0));
-            }
-        }
-
-        imshow("detections", frame);
-        if (waitKey(1) >= 0) break;
-    }
-
-    return 0;
-} // main
diff --git a/samples/dnn/resnet_ssd_face_python.py b/samples/dnn/resnet_ssd_face_python.py
deleted file mode 100644 (file)
index 3f040a8..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-import numpy as np
-import argparse
-import cv2 as cv
-try:
-    import cv2 as cv
-except ImportError:
-    raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
-                      'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')
-
-from cv2 import dnn
-
-inWidth = 300
-inHeight = 300
-confThreshold = 0.5
-
-prototxt = 'face_detector/deploy.prototxt'
-caffemodel = 'face_detector/res10_300x300_ssd_iter_140000.caffemodel'
-
-if __name__ == '__main__':
-    net = dnn.readNetFromCaffe(prototxt, caffemodel)
-    cap = cv.VideoCapture(0)
-    while True:
-        ret, frame = cap.read()
-        cols = frame.shape[1]
-        rows = frame.shape[0]
-
-        net.setInput(dnn.blobFromImage(frame, 1.0, (inWidth, inHeight), (104.0, 177.0, 123.0), False, False))
-        detections = net.forward()
-
-        perf_stats = net.getPerfProfile()
-
-        print('Inference time, ms: %.2f' % (perf_stats[0] / cv.getTickFrequency() * 1000))
-
-        for i in range(detections.shape[2]):
-            confidence = detections[0, 0, i, 2]
-            if confidence > confThreshold:
-                xLeftBottom = int(detections[0, 0, i, 3] * cols)
-                yLeftBottom = int(detections[0, 0, i, 4] * rows)
-                xRightTop = int(detections[0, 0, i, 5] * cols)
-                yRightTop = int(detections[0, 0, i, 6] * rows)
-
-                cv.rectangle(frame, (xLeftBottom, yLeftBottom), (xRightTop, yRightTop),
-                             (0, 255, 0))
-                label = "face: %.4f" % confidence
-                labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1)
-
-                cv.rectangle(frame, (xLeftBottom, yLeftBottom - labelSize[1]),
-                                    (xLeftBottom + labelSize[0], yLeftBottom + baseLine),
-                                    (255, 255, 255), cv.FILLED)
-                cv.putText(frame, label, (xLeftBottom, yLeftBottom),
-                           cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
-
-        cv.imshow("detections", frame)
-        if cv.waitKey(1) != -1:
-            break
diff --git a/samples/dnn/ssd_mobilenet_object_detection.cpp b/samples/dnn/ssd_mobilenet_object_detection.cpp
deleted file mode 100644 (file)
index 889f66d..0000000
+++ /dev/null
@@ -1,187 +0,0 @@
-#include <opencv2/dnn.hpp>
-#include <opencv2/dnn/shape_utils.hpp>
-#include <opencv2/imgproc.hpp>
-#include <opencv2/highgui.hpp>
-#include <iostream>
-
-using namespace cv;
-using namespace std;
-using namespace cv::dnn;
-
-const size_t inWidth = 300;
-const size_t inHeight = 300;
-const float inScaleFactor = 0.007843f;
-const float meanVal = 127.5;
-const char* classNames[] = {"background",
-                            "aeroplane", "bicycle", "bird", "boat",
-                            "bottle", "bus", "car", "cat", "chair",
-                            "cow", "diningtable", "dog", "horse",
-                            "motorbike", "person", "pottedplant",
-                            "sheep", "sofa", "train", "tvmonitor"};
-
-const String keys
-    = "{ help           | false | print usage         }"
-      "{ proto          | MobileNetSSD_deploy.prototxt   | model configuration }"
-      "{ model          | MobileNetSSD_deploy.caffemodel | model weights }"
-      "{ camera_device  | 0     | camera device number }"
-      "{ camera_width   | 640   | camera device width  }"
-      "{ camera_height  | 480   | camera device height }"
-      "{ video          |       | video or image for detection}"
-      "{ out            |       | path to output video file}"
-      "{ min_confidence | 0.2   | min confidence      }"
-      "{ opencl         | false | enable OpenCL }"
-;
-
-int main(int argc, char** argv)
-{
-    CommandLineParser parser(argc, argv, keys);
-    parser.about("This sample uses MobileNet Single-Shot Detector "
-                 "(https://arxiv.org/abs/1704.04861) "
-                 "to detect objects on camera/video/image.\n"
-                 ".caffemodel model's file is available here: "
-                 "https://github.com/chuanqi305/MobileNet-SSD\n"
-                 "Default network is 300x300 and 20-classes VOC.\n");
-
-    if (parser.get<bool>("help"))
-    {
-        parser.printMessage();
-        return 0;
-    }
-
-    String modelConfiguration = parser.get<String>("proto");
-    String modelBinary = parser.get<String>("model");
-    CV_Assert(!modelConfiguration.empty() && !modelBinary.empty());
-
-    //! [Initialize network]
-    dnn::Net net = readNetFromCaffe(modelConfiguration, modelBinary);
-    //! [Initialize network]
-
-    if (parser.get<bool>("opencl"))
-    {
-        net.setPreferableTarget(DNN_TARGET_OPENCL);
-    }
-
-    if (net.empty())
-    {
-        cerr << "Can't load network by using the following files: " << endl;
-        cerr << "prototxt:   " << modelConfiguration << endl;
-        cerr << "caffemodel: " << modelBinary << endl;
-        cerr << "Models can be downloaded here:" << endl;
-        cerr << "https://github.com/chuanqi305/MobileNet-SSD" << endl;
-        exit(-1);
-    }
-
-    VideoCapture cap;
-    if (!parser.has("video"))
-    {
-        int cameraDevice = parser.get<int>("camera_device");
-        cap = VideoCapture(cameraDevice);
-        if(!cap.isOpened())
-        {
-            cout << "Couldn't find camera: " << cameraDevice << endl;
-            return -1;
-        }
-
-        cap.set(CAP_PROP_FRAME_WIDTH, parser.get<int>("camera_width"));
-        cap.set(CAP_PROP_FRAME_HEIGHT, parser.get<int>("camera_height"));
-    }
-    else
-    {
-        cap.open(parser.get<String>("video"));
-        if(!cap.isOpened())
-        {
-            cout << "Couldn't open image or video: " << parser.get<String>("video") << endl;
-            return -1;
-        }
-    }
-
-    //Acquire input size
-    Size inVideoSize((int) cap.get(CAP_PROP_FRAME_WIDTH),
-                     (int) cap.get(CAP_PROP_FRAME_HEIGHT));
-
-    double fps = cap.get(CAP_PROP_FPS);
-    int fourcc = static_cast<int>(cap.get(CAP_PROP_FOURCC));
-    VideoWriter outputVideo;
-    outputVideo.open(parser.get<String>("out") ,
-                     (fourcc != 0 ? fourcc : VideoWriter::fourcc('M','J','P','G')),
-                     (fps != 0 ? fps : 10.0), inVideoSize, true);
-
-    for(;;)
-    {
-        Mat frame;
-        cap >> frame; // get a new frame from camera/video or read image
-
-        if (frame.empty())
-        {
-            waitKey();
-            break;
-        }
-
-        if (frame.channels() == 4)
-            cvtColor(frame, frame, COLOR_BGRA2BGR);
-
-        //! [Prepare blob]
-        Mat inputBlob = blobFromImage(frame, inScaleFactor,
-                                      Size(inWidth, inHeight),
-                                      Scalar(meanVal, meanVal, meanVal),
-                                      false, false); //Convert Mat to batch of images
-        //! [Prepare blob]
-
-        //! [Set input blob]
-        net.setInput(inputBlob); //set the network input
-        //! [Set input blob]
-
-        //! [Make forward pass]
-        Mat detection = net.forward(); //compute output
-        //! [Make forward pass]
-
-        vector<double> layersTimings;
-        double freq = getTickFrequency() / 1000;
-        double time = net.getPerfProfile(layersTimings) / freq;
-
-        Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());
-
-        if (!outputVideo.isOpened())
-        {
-            putText(frame, format("FPS: %.2f ; time: %.2f ms", 1000.f/time, time),
-                    Point(20,20), 0, 0.5, Scalar(0,0,255));
-        }
-        else
-            cout << "Inference time, ms: " << time << endl;
-
-        float confidenceThreshold = parser.get<float>("min_confidence");
-        for(int i = 0; i < detectionMat.rows; i++)
-        {
-            float confidence = detectionMat.at<float>(i, 2);
-
-            if(confidence > confidenceThreshold)
-            {
-                size_t objectClass = (size_t)(detectionMat.at<float>(i, 1));
-
-                int left = static_cast<int>(detectionMat.at<float>(i, 3) * frame.cols);
-                int top = static_cast<int>(detectionMat.at<float>(i, 4) * frame.rows);
-                int right = static_cast<int>(detectionMat.at<float>(i, 5) * frame.cols);
-                int bottom = static_cast<int>(detectionMat.at<float>(i, 6) * frame.rows);
-
-                rectangle(frame, Point(left, top), Point(right, bottom), Scalar(0, 255, 0));
-                String label = format("%s: %.2f", classNames[objectClass], confidence);
-                int baseLine = 0;
-                Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
-                top = max(top, labelSize.height);
-                rectangle(frame, Point(left, top - labelSize.height),
-                          Point(left + labelSize.width, top + baseLine),
-                          Scalar(255, 255, 255), FILLED);
-                putText(frame, label, Point(left, top),
-                        FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0,0,0));
-            }
-        }
-
-        if (outputVideo.isOpened())
-            outputVideo << frame;
-
-        imshow("detections", frame);
-        if (waitKey(1) >= 0) break;
-    }
-
-    return 0;
-} // main
diff --git a/samples/dnn/ssd_object_detection.cpp b/samples/dnn/ssd_object_detection.cpp
deleted file mode 100644 (file)
index 10792db..0000000
+++ /dev/null
@@ -1,156 +0,0 @@
-#include <opencv2/dnn.hpp>
-#include <opencv2/dnn/shape_utils.hpp>
-#include <opencv2/imgproc.hpp>
-#include <opencv2/highgui.hpp>
-#include <iostream>
-
-using namespace cv;
-using namespace std;
-using namespace cv::dnn;
-
-const char* classNames[] = {"background",
-                            "aeroplane", "bicycle", "bird", "boat",
-                            "bottle", "bus", "car", "cat", "chair",
-                            "cow", "diningtable", "dog", "horse",
-                            "motorbike", "person", "pottedplant",
-                            "sheep", "sofa", "train", "tvmonitor"};
-
-const char* about = "This sample uses Single-Shot Detector "
-                    "(https://arxiv.org/abs/1512.02325) "
-                    "to detect objects on camera/video/image.\n"
-                    ".caffemodel model's file is available here: "
-                    "https://github.com/weiliu89/caffe/tree/ssd#models\n"
-                    "Default network is 300x300 and 20-classes VOC.\n";
-
-const char* params
-    = "{ help           | false | print usage         }"
-      "{ proto          |       | model configuration }"
-      "{ model          |       | model weights       }"
-      "{ camera_device  | 0     | camera device number}"
-      "{ video          |       | video or image for detection}"
-      "{ min_confidence | 0.5   | min confidence      }";
-
-int main(int argc, char** argv)
-{
-    cv::CommandLineParser parser(argc, argv, params);
-
-    if (parser.get<bool>("help"))
-    {
-        cout << about << endl;
-        parser.printMessage();
-        return 0;
-    }
-
-    String modelConfiguration = parser.get<string>("proto");
-    String modelBinary = parser.get<string>("model");
-
-    //! [Initialize network]
-    dnn::Net net = readNetFromCaffe(modelConfiguration, modelBinary);
-    //! [Initialize network]
-
-    if (net.empty())
-    {
-        cerr << "Can't load network by using the following files: " << endl;
-        cerr << "prototxt:   " << modelConfiguration << endl;
-        cerr << "caffemodel: " << modelBinary << endl;
-        cerr << "Models can be downloaded here:" << endl;
-        cerr << "https://github.com/weiliu89/caffe/tree/ssd#models" << endl;
-        exit(-1);
-    }
-
-    VideoCapture cap;
-    if (parser.get<String>("video").empty())
-    {
-        int cameraDevice = parser.get<int>("camera_device");
-        cap = VideoCapture(cameraDevice);
-        if(!cap.isOpened())
-        {
-            cout << "Couldn't find camera: " << cameraDevice << endl;
-            return -1;
-        }
-    }
-    else
-    {
-        cap.open(parser.get<String>("video"));
-        if(!cap.isOpened())
-        {
-            cout << "Couldn't open image or video: " << parser.get<String>("video") << endl;
-            return -1;
-        }
-    }
-
-    for (;;)
-    {
-        cv::Mat frame;
-        cap >> frame; // get a new frame from camera/video or read image
-
-        if (frame.empty())
-        {
-            waitKey();
-            break;
-        }
-
-        if (frame.channels() == 4)
-            cvtColor(frame, frame, COLOR_BGRA2BGR);
-
-        //! [Prepare blob]
-        Mat inputBlob = blobFromImage(frame, 1.0f, Size(300, 300), Scalar(104, 117, 123), false, false); //Convert Mat to batch of images
-        //! [Prepare blob]
-
-        //! [Set input blob]
-        net.setInput(inputBlob, "data"); //set the network input
-        //! [Set input blob]
-
-        //! [Make forward pass]
-        Mat detection = net.forward("detection_out"); //compute output
-        //! [Make forward pass]
-
-        vector<double> layersTimings;
-        double freq = getTickFrequency() / 1000;
-        double time = net.getPerfProfile(layersTimings) / freq;
-        ostringstream ss;
-        ss << "FPS: " << 1000/time << " ; time: " << time << " ms";
-        putText(frame, ss.str(), Point(20,20), 0, 0.5, Scalar(0,0,255));
-
-        Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());
-
-        float confidenceThreshold = parser.get<float>("min_confidence");
-        for(int i = 0; i < detectionMat.rows; i++)
-        {
-            float confidence = detectionMat.at<float>(i, 2);
-
-            if(confidence > confidenceThreshold)
-            {
-                size_t objectClass = (size_t)(detectionMat.at<float>(i, 1));
-
-                int xLeftBottom = static_cast<int>(detectionMat.at<float>(i, 3) * frame.cols);
-                int yLeftBottom = static_cast<int>(detectionMat.at<float>(i, 4) * frame.rows);
-                int xRightTop = static_cast<int>(detectionMat.at<float>(i, 5) * frame.cols);
-                int yRightTop = static_cast<int>(detectionMat.at<float>(i, 6) * frame.rows);
-
-                ss.str("");
-                ss << confidence;
-                String conf(ss.str());
-
-                Rect object(xLeftBottom, yLeftBottom,
-                            xRightTop - xLeftBottom,
-                            yRightTop - yLeftBottom);
-
-                rectangle(frame, object, Scalar(0, 255, 0));
-                String label = String(classNames[objectClass]) + ": " + conf;
-                int baseLine = 0;
-                Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
-                rectangle(frame, Rect(Point(xLeftBottom, yLeftBottom - labelSize.height),
-                                      Size(labelSize.width, labelSize.height + baseLine)),
-                          Scalar(255, 255, 255), FILLED);
-                putText(frame, label, Point(xLeftBottom, yLeftBottom),
-                        FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0,0,0));
-            }
-        }
-
-        imshow("detections", frame);
-        if (waitKey(1) >= 0) break;
-    }
-
-    return 0;
-} // main
diff --git a/samples/dnn/yolo_object_detection.cpp b/samples/dnn/yolo_object_detection.cpp
deleted file mode 100644 (file)
index bd4ad5c..0000000
+++ /dev/null
@@ -1,185 +0,0 @@
-// Brief Sample of using OpenCV dnn module in real time with device capture, video and image.
-// VIDEO DEMO: https://www.youtube.com/watch?v=NHtRlndE2cg
-
-#include <opencv2/dnn.hpp>
-#include <opencv2/dnn/shape_utils.hpp>
-#include <opencv2/imgproc.hpp>
-#include <opencv2/highgui.hpp>
-#include <fstream>
-#include <iostream>
-
-using namespace std;
-using namespace cv;
-using namespace cv::dnn;
-
-static const char* about =
-"This sample uses You only look once (YOLO)-Detector (https://arxiv.org/abs/1612.08242) to detect objects on camera/video/image.\n"
-"Models can be downloaded here: https://pjreddie.com/darknet/yolo/\n"
-"Default network is 416x416.\n"
-"Class names can be downloaded here: https://github.com/pjreddie/darknet/tree/master/data\n";
-
-static const char* params =
-"{ help           | false | print usage         }"
-"{ cfg            |       | model configuration }"
-"{ model          |       | model weights       }"
-"{ camera_device  | 0     | camera device number}"
-"{ source         |       | video or image for detection}"
-"{ out            |       | path to output video file}"
-"{ fps            | 3     | frame per second }"
-"{ style          | box   | box or line style draw }"
-"{ min_confidence | 0.24  | min confidence      }"
-"{ class_names    |       | File with class names, [PATH-TO-DARKNET]/data/coco.names }";
-
-int main(int argc, char** argv)
-{
-    CommandLineParser parser(argc, argv, params);
-
-    if (parser.get<bool>("help"))
-    {
-        cout << about << endl;
-        parser.printMessage();
-        return 0;
-    }
-
-    String modelConfiguration = parser.get<String>("cfg");
-    String modelBinary = parser.get<String>("model");
-
-    //! [Initialize network]
-    dnn::Net net = readNetFromDarknet(modelConfiguration, modelBinary);
-    //! [Initialize network]
-
-    if (net.empty())
-    {
-        cerr << "Can't load network by using the following files: " << endl;
-        cerr << "cfg-file:     " << modelConfiguration << endl;
-        cerr << "weights-file: " << modelBinary << endl;
-        cerr << "Models can be downloaded here:" << endl;
-        cerr << "https://pjreddie.com/darknet/yolo/" << endl;
-        exit(-1);
-    }
-
-    VideoCapture cap;
-    VideoWriter writer;
-    int codec = CV_FOURCC('M', 'J', 'P', 'G');
-    double fps = parser.get<float>("fps");
-    if (parser.get<String>("source").empty())
-    {
-        int cameraDevice = parser.get<int>("camera_device");
-        cap = VideoCapture(cameraDevice);
-        if(!cap.isOpened())
-        {
-            cout << "Couldn't find camera: " << cameraDevice << endl;
-            return -1;
-        }
-    }
-    else
-    {
-        cap.open(parser.get<String>("source"));
-        if(!cap.isOpened())
-        {
-            cout << "Couldn't open image or video: " << parser.get<String>("video") << endl;
-            return -1;
-        }
-    }
-
-    if(!parser.get<String>("out").empty())
-    {
-        writer.open(parser.get<String>("out"), codec, fps, Size((int)cap.get(CAP_PROP_FRAME_WIDTH),(int)cap.get(CAP_PROP_FRAME_HEIGHT)), 1);
-    }
-
-    vector<String> classNamesVec;
-    ifstream classNamesFile(parser.get<String>("class_names").c_str());
-    if (classNamesFile.is_open())
-    {
-        string className = "";
-        while (std::getline(classNamesFile, className))
-            classNamesVec.push_back(className);
-    }
-
-    String object_roi_style = parser.get<String>("style");
-
-    for(;;)
-    {
-        Mat frame;
-        cap >> frame; // get a new frame from camera/video or read image
-
-        if (frame.empty())
-        {
-            waitKey();
-            break;
-        }
-
-        if (frame.channels() == 4)
-            cvtColor(frame, frame, COLOR_BGRA2BGR);
-
-        //! [Prepare blob]
-        Mat inputBlob = blobFromImage(frame, 1 / 255.F, Size(416, 416), Scalar(), true, false); //Convert Mat to batch of images
-        //! [Prepare blob]
-
-        //! [Set input blob]
-        net.setInput(inputBlob, "data");                   //set the network input
-        //! [Set input blob]
-
-        //! [Make forward pass]
-        Mat detectionMat = net.forward("detection_out");   //compute output
-        //! [Make forward pass]
-
-        vector<double> layersTimings;
-        double tick_freq = getTickFrequency();
-        double time_ms = net.getPerfProfile(layersTimings) / tick_freq * 1000;
-        putText(frame, format("FPS: %.2f ; time: %.2f ms", 1000.f / time_ms, time_ms),
-                Point(20, 20), 0, 0.5, Scalar(0, 0, 255));
-
-        float confidenceThreshold = parser.get<float>("min_confidence");
-        for (int i = 0; i < detectionMat.rows; i++)
-        {
-            const int probability_index = 5;
-            const int probability_size = detectionMat.cols - probability_index;
-            float *prob_array_ptr = &detectionMat.at<float>(i, probability_index);
-
-            size_t objectClass = max_element(prob_array_ptr, prob_array_ptr + probability_size) - prob_array_ptr;
-            float confidence = detectionMat.at<float>(i, (int)objectClass + probability_index);
-
-            if (confidence > confidenceThreshold)
-            {
-                float x_center = detectionMat.at<float>(i, 0) * frame.cols;
-                float y_center = detectionMat.at<float>(i, 1) * frame.rows;
-                float width = detectionMat.at<float>(i, 2) * frame.cols;
-                float height = detectionMat.at<float>(i, 3) * frame.rows;
-                Point p1(cvRound(x_center - width / 2), cvRound(y_center - height / 2));
-                Point p2(cvRound(x_center + width / 2), cvRound(y_center + height / 2));
-                Rect object(p1, p2);
-
-                Scalar object_roi_color(0, 255, 0);
-
-                if (object_roi_style == "box")
-                {
-                    rectangle(frame, object, object_roi_color);
-                }
-                else
-                {
-                    Point p_center(cvRound(x_center), cvRound(y_center));
-                    line(frame, object.tl(), p_center, object_roi_color, 1);
-                }
-
-                String className = objectClass < classNamesVec.size() ? classNamesVec[objectClass] : cv::format("unknown(%d)", objectClass);
-                String label = format("%s: %.2f", className.c_str(), confidence);
-                int baseLine = 0;
-                Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
-                rectangle(frame, Rect(p1, Size(labelSize.width, labelSize.height + baseLine)),
-                          object_roi_color, FILLED);
-                putText(frame, label, p1 + Point(0, labelSize.height),
-                        FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0,0,0));
-            }
-        }
-        if(writer.isOpened())
-        {
-            writer.write(frame);
-        }
-
-        imshow("YOLO: Detections", frame);
-        if (waitKey(1) >= 0) break;
-    }
-
-    return 0;
-} // main