Update tutorials. A new cv::dnn::readNet function
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Sat, 3 Mar 2018 16:29:37 +0000 (19:29 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Sun, 4 Mar 2018 17:30:22 +0000 (20:30 +0300)
doc/tutorials/dnn/dnn_googlenet/dnn_googlenet.markdown
doc/tutorials/dnn/dnn_halide/dnn_halide.markdown
modules/dnn/include/opencv2/dnn/dnn.hpp
modules/dnn/src/dnn.cpp
modules/dnn/test/test_misc.cpp
samples/dnn/README.md
samples/dnn/classification.cpp
samples/dnn/classification.py
samples/dnn/object_detection.cpp
samples/dnn/object_detection.py

index feb1e09..50946b1 100644 (file)
@@ -13,50 +13,53 @@ We will demonstrate results of this example on the following picture.
 Source Code
 -----------
 
-We will be using snippets from the example application, that can be downloaded [here](https://github.com/opencv/opencv/blob/master/samples/dnn/caffe_googlenet.cpp).
+We will be using snippets from the example application, that can be downloaded [here](https://github.com/opencv/opencv/blob/master/samples/dnn/classification.cpp).
 
-@include dnn/caffe_googlenet.cpp
+@include dnn/classification.cpp
 
 Explanation
 -----------
 
 -# Firstly, download GoogLeNet model files:
-   [bvlc_googlenet.prototxt  ](https://raw.githubusercontent.com/opencv/opencv/master/samples/data/dnn/bvlc_googlenet.prototxt) and
+   [bvlc_googlenet.prototxt  ](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/bvlc_googlenet.prototxt) and
    [bvlc_googlenet.caffemodel](http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel)
 
    Also you need file with names of [ILSVRC2012](http://image-net.org/challenges/LSVRC/2012/browse-synsets) classes:
-   [synset_words.txt](https://raw.githubusercontent.com/opencv/opencv/master/samples/data/dnn/synset_words.txt).
+   [classification_classes_ILSVRC2012.txt](https://github.com/opencv/opencv/tree/master/samples/dnn/classification_classes_ILSVRC2012.txt).
 
    Put these files into working dir of this program example.
 
 -# Read and initialize network using path to .prototxt and .caffemodel files
-   @snippet dnn/caffe_googlenet.cpp Read and initialize network
+   @snippet dnn/classification.cpp Read and initialize network
 
--# Check that network was read successfully
-   @snippet dnn/caffe_googlenet.cpp Check that network was read successfully
+   You can skip an argument `framework` if one of the files `model` or `config` has an
+   extension `.caffemodel` or `.prototxt`.
+   This way function cv::dnn::readNet can automatically detects a model's format.
 
 -# Read input image and convert to the blob, acceptable by GoogleNet
-   @snippet dnn/caffe_googlenet.cpp Prepare blob
-   We convert the image to a 4-dimensional blob (so-called batch) with 1x3x224x224 shape after applying necessary pre-processing like resizing and mean subtraction using cv::dnn::blobFromImage constructor.
+   @snippet dnn/classification.cpp Open a video file or an image file or a camera stream
 
--# Pass the blob to the network
-   @snippet dnn/caffe_googlenet.cpp Set input blob
-   In bvlc_googlenet.prototxt the network input blob named as "data", therefore this blob labeled as ".data" in opencv_dnn API.
+   cv::VideoCapture can load both images and videos.
+
+   @snippet dnn/classification.cpp Create a 4D blob from a frame
+   We convert the image to a 4-dimensional blob (so-called batch) with `1x3x224x224` shape
+   after applying necessary pre-processing like resizing and mean subtraction
+   `(-104, -117, -123)` for each blue, green and red channels correspondingly using cv::dnn::blobFromImage function.
 
-   Other blobs labeled as "name_of_layer.name_of_layer_output".
+-# Pass the blob to the network
+   @snippet dnn/classification.cpp Set input blob
 
 -# Make forward pass
-   @snippet dnn/caffe_googlenet.cpp Make forward pass
-   During the forward pass output of each network layer is computed, but in this example we need output from "prob" layer only.
+   @snippet dnn/classification.cpp Make forward pass
+   During the forward pass output of each network layer is computed, but in this example we need output from the last layer only.
 
 -# Determine the best class
-   @snippet dnn/caffe_googlenet.cpp Gather output
-   We put the output of "prob" layer, which contain probabilities for each of 1000 ILSVRC2012 image classes, to the `prob` blob.
-   And find the index of element with maximal value in this one. This index correspond to the class of the image.
-
--# Print results
-   @snippet dnn/caffe_googlenet.cpp Print results
-   For our image we get:
-> Best class: #812 'space shuttle'
->
-> Probability: 99.6378%
+   @snippet dnn/classification.cpp Get a class with a highest score
+   We put the output of network, which contain probabilities for each of 1000 ILSVRC2012 image classes, to the `prob` blob.
+   And find the index of element with maximal value in this one. This index corresponds to the class of the image.
+
+-# Run an example from command line
+   @code
+   ./example_dnn_classification --model=bvlc_googlenet.caffemodel --config=bvlc_googlenet.prototxt --width=224 --height=224 --classes=classification_classes_ILSVRC2012.txt --input=space_shuttle.jpg --mean="104 117 123"
+   @endcode
+   For our image we get prediction of class `space shuttle` with more than 99% sureness.
index 1b811ce..7271d08 100644 (file)
@@ -74,46 +74,7 @@ When you build OpenCV add the following configuration flags:
 
 - `HALIDE_ROOT_DIR` - path to Halide build directory
 
-## Sample
-
-@include dnn/squeezenet_halide.cpp
-
-## Explanation
-Download Caffe model from SqueezeNet repository: [train_val.prototxt](https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/train_val.prototxt) and [squeezenet_v1.1.caffemodel](https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/squeezenet_v1.1.caffemodel).
-
-Also you need file with names of [ILSVRC2012](http://image-net.org/challenges/LSVRC/2012/browse-synsets) classes:
-[synset_words.txt](https://raw.githubusercontent.com/opencv/opencv/master/samples/data/dnn/synset_words.txt).
-
-Put these files into working dir of this program example.
-
--# Read and initialize network using path to .prototxt and .caffemodel files
-@snippet dnn/squeezenet_halide.cpp Read and initialize network
-
--# Check that network was read successfully
-@snippet dnn/squeezenet_halide.cpp Check that network was read successfully
-
--# Read input image and convert to the 4-dimensional blob, acceptable by SqueezeNet v1.1
-@snippet dnn/squeezenet_halide.cpp Prepare blob
-
--# Pass the blob to the network
-@snippet dnn/squeezenet_halide.cpp Set input blob
-
--# Enable Halide backend for layers where it is implemented
-@snippet dnn/squeezenet_halide.cpp Enable Halide backend
-
--# Make forward pass
-@snippet dnn/squeezenet_halide.cpp Make forward pass
-Remember that the first forward pass after initialization require quite more
-time that the next ones. It's because of runtime compilation of Halide pipelines
-at the first invocation.
-
--# Determine the best class
-@snippet dnn/squeezenet_halide.cpp Determine the best class
-
--# Print results
-@snippet dnn/squeezenet_halide.cpp Print results
-For our image we get:
-
-> Best class: #812 'space shuttle'
->
-> Probability: 97.9812%
+## Set Halide as a preferable backend
+@code
+net.setPreferableBackend(DNN_BACKEND_HALIDE);
+@endcode
index 5495276..5605cc0 100644 (file)
@@ -683,6 +683,29 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
      */
      CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true);
 
+     /**
+      * @brief Read deep learning network represented in one of the supported formats.
+      * @param[in] model Binary file contains trained weights. The following file
+      *                  extensions are expected for models from different frameworks:
+      *                  * `*.caffemodel` (Caffe, http://caffe.berkeleyvision.org/)
+      *                  * `*.pb` (TensorFlow, https://www.tensorflow.org/)
+      *                  * `*.t7` | `*.net` (Torch, http://torch.ch/)
+      *                  * `*.weights` (Darknet, https://pjreddie.com/darknet/)
+      * @param[in] config Text file contains network configuration. It could be a
+      *                   file with the following extensions:
+      *                  * `*.prototxt` (Caffe, http://caffe.berkeleyvision.org/)
+      *                  * `*.pbtxt` (TensorFlow, https://www.tensorflow.org/)
+      *                  * `*.cfg` (Darknet, https://pjreddie.com/darknet/)
+      * @param[in] framework Explicit framework name tag to determine a format.
+      * @returns Net object.
+      *
+      * This function automatically detects an origin framework of trained model
+      * and calls an appropriate function such @ref readNetFromCaffe, @ref readNetFromTensorflow,
+      * @ref readNetFromTorch or @ref readNetFromDarknet. An order of @p model and @p config
+      * arguments does not matter.
+      */
+     CV_EXPORTS_W Net readNet(String model, String config = "", String framework = "");
+
     /** @brief Loads blob which was serialized as torch.Tensor object of Torch7 framework.
      *  @warning This function has the same limitations as readNetFromTorch().
      */
index 194648c..63fa22d 100644 (file)
@@ -2805,5 +2805,41 @@ BackendWrapper::BackendWrapper(const Ptr<BackendWrapper>& base, const MatShape&
 
 BackendWrapper::~BackendWrapper() {}
 
+Net readNet(String model, String config, String framework)
+{
+    framework = framework.toLowerCase();
+    const std::string modelExt = model.substr(model.rfind('.') + 1);
+    const std::string configExt = config.substr(config.rfind('.') + 1);
+    if (framework == "caffe" || modelExt == "caffemodel" || configExt == "caffemodel" ||
+                                modelExt == "prototxt" || configExt == "prototxt")
+    {
+        if (modelExt == "prototxt" || configExt == "caffemodel")
+            std::swap(model, config);
+        return readNetFromCaffe(config, model);
+    }
+    if (framework == "tensorflow" || modelExt == "pb" || configExt == "pb" ||
+                                     modelExt == "pbtxt" || configExt == "pbtxt")
+    {
+        if (modelExt == "pbtxt" || configExt == "pb")
+            std::swap(model, config);
+        return readNetFromTensorflow(model, config);
+    }
+    if (framework == "torch" || modelExt == "t7" || modelExt == "net" ||
+                                configExt == "t7" || configExt == "net")
+    {
+        return readNetFromTorch(model.empty() ? config : model);
+    }
+    if (framework == "darknet" || modelExt == "weights" || configExt == "weights" ||
+                                  modelExt == "cfg" || configExt == "cfg")
+    {
+        if (modelExt == "cfg" || configExt == "weights")
+            std::swap(model, config);
+        return readNetFromDarknet(config, model);
+    }
+    CV_Error(Error::StsError, "Cannot determine an origin framework of files: " +
+                              model + (config.empty() ? "" : ", " + config));
+    return Net();
+}
+
 CV__DNN_EXPERIMENTAL_NS_END
 }} // namespace
index 57aa7a9..2e92504 100644 (file)
@@ -57,4 +57,22 @@ TEST(imagesFromBlob, Regression)
     }
 }
 
+TEST(readNet, Regression)
+{
+    Net net = readNet(findDataFile("dnn/squeezenet_v1.1.prototxt", false),
+                      findDataFile("dnn/squeezenet_v1.1.caffemodel", false));
+    EXPECT_FALSE(net.empty());
+    net = readNet(findDataFile("dnn/opencv_face_detector.caffemodel", false),
+                  findDataFile("dnn/opencv_face_detector.prototxt", false));
+    EXPECT_FALSE(net.empty());
+    net = readNet(findDataFile("dnn/openface_nn4.small2.v1.t7", false));
+    EXPECT_FALSE(net.empty());
+    net = readNet(findDataFile("dnn/tiny-yolo-voc.cfg", false),
+                  findDataFile("dnn/tiny-yolo-voc.weights", false));
+    EXPECT_FALSE(net.empty());
+    net = readNet(findDataFile("dnn/ssd_mobilenet_v1_coco.pbtxt", false),
+                  findDataFile("dnn/ssd_mobilenet_v1_coco.pb", false));
+    EXPECT_FALSE(net.empty());
+}
+
 }} // namespace
index 07c04ec..fea2025 100644 (file)
 | [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801, 115.9465, 122.7717` | BGR |
 | [R-FCN](https://github.com/YuwenXiong/py-R-FCN) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR |
 
-
 ### Classification
 |    Model | Scale |   Size WxH|   Mean subtraction | Channels order |
 |---------------|-------|-----------|--------------------|-------|
 | GoogLeNet | `1.0` | `224x224` | `104 117 123` | BGR |
 | [SqueezeNet](https://github.com/DeepScale/SqueezeNet) | `1.0` | `227x227` | `0 0 0` | BGR |
 
-
 ## References
 * [Models downloading script](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/download_models.py)
 * [Configuration files adopted for OpenCV](https://github.com/opencv/opencv_extra/tree/master/testdata/dnn)
index 74b72a4..d3ae08e 100644 (file)
@@ -2,8 +2,9 @@
 #include <iostream>
 #include <sstream>
 
-#include <opencv2/opencv.hpp>
 #include <opencv2/dnn.hpp>
+#include <opencv2/imgproc.hpp>
+#include <opencv2/highgui.hpp>
 
 const char* keys =
     "{ help  h     | | Print help message. }"
@@ -33,8 +34,6 @@ using namespace dnn;
 
 std::vector<std::string> classes;
 
-Net readNet(const std::string& model, const std::string& config = "", const std::string& framework = "");
-
 int main(int argc, char** argv)
 {
     CommandLineParser parser(argc, argv, keys);
@@ -49,6 +48,11 @@ int main(int argc, char** argv)
     bool swapRB = parser.get<bool>("rgb");
     int inpWidth = parser.get<int>("width");
     int inpHeight = parser.get<int>("height");
+    String model = parser.get<String>("model");
+    String config = parser.get<String>("config");
+    String framework = parser.get<String>("framework");
+    int backendId = parser.get<int>("backend");
+    int targetId = parser.get<int>("target");
 
     // Parse mean values.
     Scalar mean;
@@ -77,22 +81,24 @@ int main(int argc, char** argv)
         }
     }
 
-    // Load a model.
     CV_Assert(parser.has("model"));
-    Net net = readNet(parser.get<String>("model"), parser.get<String>("config"), parser.get<String>("framework"));
-    net.setPreferableBackend(parser.get<int>("backend"));
-    net.setPreferableTarget(parser.get<int>("target"));
+    //! [Read and initialize network]
+    Net net = readNet(model, config, framework);
+    net.setPreferableBackend(backendId);
+    net.setPreferableTarget(targetId);
+    //! [Read and initialize network]
 
     // Create a window
     static const std::string kWinName = "Deep learning image classification in OpenCV";
     namedWindow(kWinName, WINDOW_NORMAL);
 
-    // Open a video file or an image file or a camera stream.
+    //! [Open a video file or an image file or a camera stream]
     VideoCapture cap;
     if (parser.has("input"))
         cap.open(parser.get<String>("input"));
     else
         cap.open(0);
+    //! [Open a video file or an image file or a camera stream]
 
     // Process frames.
     Mat frame, blob;
@@ -105,24 +111,29 @@ int main(int argc, char** argv)
             break;
         }
 
-        // Create a 4D blob from a frame.
+        //! [Create a 4D blob from a frame]
         blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, false);
+        //! [Create a 4D blob from a frame]
 
-        // Run a model.
+        //! [Set input blob]
         net.setInput(blob);
-        Mat out = net.forward();
-        out = out.reshape(1, 1);
+        //! [Set input blob]
+        //! [Make forward pass]
+        Mat prob = net.forward();
+        //! [Make forward pass]
 
-        // Get a class with a highest score.
+        //! [Get a class with a highest score]
         Point classIdPoint;
         double confidence;
-        minMaxLoc(out, 0, &confidence, 0, &classIdPoint);
+        minMaxLoc(prob.reshape(1, 1), 0, &confidence, 0, &classIdPoint);
         int classId = classIdPoint.x;
+        //! [Get a class with a highest score]
 
         // Put efficiency information.
         std::vector<double> layersTimes;
-        double t = net.getPerfProfile(layersTimes);
-        std::string label = format("Inference time: %.2f", t * 1000 / getTickFrequency());
+        double freq = getTickFrequency() / 1000;
+        double t = net.getPerfProfile(layersTimes) / freq;
+        std::string label = format("Inference time: %.2f ms", t);
         putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
 
         // Print predicted class.
@@ -135,19 +146,3 @@ int main(int argc, char** argv)
     }
     return 0;
 }
-
-Net readNet(const std::string& model, const std::string& config, const std::string& framework)
-{
-    std::string modelExt = model.substr(model.rfind('.'));
-    if (framework == "caffe" || modelExt == ".caffemodel")
-        return readNetFromCaffe(config, model);
-    else if (framework == "tensorflow" || modelExt == ".pb")
-        return readNetFromTensorflow(model, config);
-    else if (framework == "torch" || modelExt == ".t7" || modelExt == ".net")
-        return readNetFromTorch(model);
-    else if (framework == "darknet" || modelExt == ".weights")
-        return readNetFromDarknet(config, model);
-    else
-        CV_Error(Error::StsError, "Cannot determine an origin framework of model from file " + model);
-    return Net();
-}
index 446c9b0..2628195 100644 (file)
@@ -48,19 +48,7 @@ if args.classes:
         classes = f.read().rstrip('\n').split('\n')
 
 # Load a network
-modelExt = args.model[args.model.rfind('.'):]
-if args.framework == 'caffe' or modelExt == '.caffemodel':
-    net = cv.dnn.readNetFromCaffe(args.config, args.model)
-elif args.framework == 'tensorflow' or modelExt == '.pb':
-    net = cv.dnn.readNetFromTensorflow(args.model, args.config)
-elif args.framework == 'torch' or modelExt in ['.t7', '.net']:
-    net = cv.dnn.readNetFromTorch(args.model)
-elif args.framework == 'darknet' or modelExt == '.weights':
-    net = cv.dnn.readNetFromDarknet(args.config, args.model)
-else:
-    print('Cannot determine an origin framework of model from file %s' % args.model)
-    sys.exit(0)
-
+net = cv.dnn.readNet(args.model, args.config, args.framework)
 net.setPreferableBackend(args.backend)
 net.setPreferableTarget(args.target)
 
index bb6f6f0..81575d2 100644 (file)
@@ -2,8 +2,9 @@
 #include <iostream>
 #include <sstream>
 
-#include <opencv2/opencv.hpp>
 #include <opencv2/dnn.hpp>
+#include <opencv2/imgproc.hpp>
+#include <opencv2/highgui.hpp>
 
 const char* keys =
     "{ help  h     | | Print help message. }"
@@ -35,8 +36,6 @@ using namespace dnn;
 float confThreshold;
 std::vector<std::string> classes;
 
-Net readNet(const std::string& model, const std::string& config = "", const std::string& framework = "");
-
 void postprocess(Mat& frame, const Mat& out, Net& net);
 
 void drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame);
@@ -95,7 +94,7 @@ int main(int argc, char** argv)
     // Create a window
     static const std::string kWinName = "Deep learning object detection in OpenCV";
     namedWindow(kWinName, WINDOW_NORMAL);
-    int initialConf = confThreshold * 100;
+    int initialConf = (int)(confThreshold * 100);
     createTrackbar("Confidence threshold, %", kWinName, &initialConf, 99, callback);
 
     // Open a video file or an image file or a camera stream.
@@ -135,8 +134,9 @@ int main(int argc, char** argv)
 
         // Put efficiency information.
         std::vector<double> layersTimes;
-        double t = net.getPerfProfile(layersTimes);
-        std::string label = format("Inference time: %.2f", t * 1000 / getTickFrequency());
+        double freq = getTickFrequency() / 1000;
+        double t = net.getPerfProfile(layersTimes) / freq;
+        std::string label = format("Inference time: %.2f ms", t);
         putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
 
         imshow(kWinName, frame);
@@ -160,10 +160,10 @@ void postprocess(Mat& frame, const Mat& out, Net& net)
             float confidence = data[i + 2];
             if (confidence > confThreshold)
             {
-                int left = data[i + 3];
-                int top = data[i + 4];
-                int right = data[i + 5];
-                int bottom = data[i + 6];
+                int left = (int)data[i + 3];
+                int top = (int)data[i + 4];
+                int right = (int)data[i + 5];
+                int bottom = (int)data[i + 6];
                 int classId = (int)(data[i + 1]) - 1;  // Skip 0th background class id.
                 drawPred(classId, confidence, left, top, right, bottom, frame);
             }
@@ -208,7 +208,7 @@ void postprocess(Mat& frame, const Mat& out, Net& net)
                 int height = (int)(data[3] * frame.rows);
                 int left = centerX - width / 2;
                 int top = centerY - height / 2;
-                drawPred(classId, confidence, left, top, left + width, top + height, frame);
+                drawPred(classId, (float)confidence, left, top, left + width, top + height, frame);
             }
         }
     }
@@ -238,21 +238,5 @@ void drawPred(int classId, float conf, int left, int top, int right, int bottom,
 
 void callback(int pos, void*)
 {
-    confThreshold = pos * 0.01;
-}
-
-Net readNet(const std::string& model, const std::string& config, const std::string& framework)
-{
-    std::string modelExt = model.substr(model.rfind('.'));
-    if (framework == "caffe" || modelExt == ".caffemodel")
-        return readNetFromCaffe(config, model);
-    else if (framework == "tensorflow" || modelExt == ".pb")
-        return readNetFromTensorflow(model, config);
-    else if (framework == "torch" || modelExt == ".t7" || modelExt == ".net")
-        return readNetFromTorch(model);
-    else if (framework == "darknet" || modelExt == ".weights")
-        return readNetFromDarknet(config, model);
-    else
-        CV_Error(Error::StsError, "Cannot determine an origin framework of model from file " + model);
-    return Net();
+    confThreshold = pos * 0.01f;
 }
index 661395f..c54f5d3 100644 (file)
@@ -49,19 +49,7 @@ if args.classes:
         classes = f.read().rstrip('\n').split('\n')
 
 # Load a network
-modelExt = args.model[args.model.rfind('.'):]
-if args.framework == 'caffe' or modelExt == '.caffemodel':
-    net = cv.dnn.readNetFromCaffe(args.config, args.model)
-elif args.framework == 'tensorflow' or modelExt == '.pb':
-    net = cv.dnn.readNetFromTensorflow(args.model, args.config)
-elif args.framework == 'torch' or modelExt in ['.t7', '.net']:
-    net = cv.dnn.readNetFromTorch(args.model)
-elif args.framework == 'darknet' or modelExt == '.weights':
-    net = cv.dnn.readNetFromDarknet(args.config, args.model)
-else:
-    print('Cannot determine an origin framework of model from file %s' % args.model)
-    sys.exit(0)
-
+net = cv.dnn.readNet(args.model, args.config, args.framework)
 net.setPreferableBackend(args.backend)
 net.setPreferableTarget(args.target)