Update C++ MobileNet-SSD object detection sample
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 1 Jan 2018 20:01:23 +0000 (23:01 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 1 Jan 2018 20:01:23 +0000 (23:01 +0300)
samples/dnn/ssd_mobilenet_object_detection.cpp

index d7b2cbf..e04f1c3 100644 (file)
@@ -13,7 +13,6 @@ using namespace std;
 
 const size_t inWidth = 300;
 const size_t inHeight = 300;
-const float WHRatio = inWidth / (float)inHeight;
 const float inScaleFactor = 0.007843f;
 const float meanVal = 127.5;
 const char* classNames[] = {"background",
@@ -23,13 +22,6 @@ const char* classNames[] = {"background",
                             "motorbike", "person", "pottedplant",
                             "sheep", "sofa", "train", "tvmonitor"};
 
-const char* about = "This sample uses MobileNet Single-Shot Detector "
-                    "(https://arxiv.org/abs/1704.04861) "
-                    "to detect objects on camera/video/image.\n"
-                    ".caffemodel model's file is available here: "
-                    "https://github.com/chuanqi305/MobileNet-SSD\n"
-                    "Default network is 300x300 and 20-classes VOC.\n";
-
 const char* params
     = "{ help           | false | print usage         }"
       "{ proto          | MobileNetSSD_deploy.prototxt | model configuration }"
@@ -44,16 +36,22 @@ const char* params
 int main(int argc, char** argv)
 {
     CommandLineParser parser(argc, argv, params);
-
-    if (parser.get<bool>("help"))
+    parser.about("This sample uses MobileNet Single-Shot Detector "
+                 "(https://arxiv.org/abs/1704.04861) "
+                 "to detect objects on camera/video/image.\n"
+                 ".caffemodel model's file is available here: "
+                 "https://github.com/chuanqi305/MobileNet-SSD\n"
+                 "Default network is 300x300 and 20-classes VOC.\n");
+
+    if (parser.get<bool>("help") || argc == 1)
     {
-        cout << about << endl;
         parser.printMessage();
         return 0;
     }
 
     String modelConfiguration = parser.get<string>("proto");
     String modelBinary = parser.get<string>("model");
+    CV_Assert(!modelConfiguration.empty() && !modelBinary.empty());
 
     //! [Initialize network]
     dnn::Net net = readNetFromCaffe(modelConfiguration, modelBinary);
@@ -75,7 +73,7 @@ int main(int argc, char** argv)
     }
 
     VideoCapture cap;
-    if (parser.get<String>("video").empty())
+    if (!parser.has("video"))
     {
         int cameraDevice = parser.get<int>("camera_device");
         cap = VideoCapture(cameraDevice);
@@ -95,32 +93,16 @@ int main(int argc, char** argv)
         }
     }
 
-    Size inVideoSize;
-    inVideoSize = Size((int) cap.get(CV_CAP_PROP_FRAME_WIDTH),    //Acquire input size
-                       (int) cap.get(CV_CAP_PROP_FRAME_HEIGHT));
-
-    Size cropSize;
-    if (inVideoSize.width / (float)inVideoSize.height > WHRatio)
-    {
-        cropSize = Size(static_cast<int>(inVideoSize.height * WHRatio),
-                        inVideoSize.height);
-    }
-    else
-    {
-        cropSize = Size(inVideoSize.width,
-                        static_cast<int>(inVideoSize.width / WHRatio));
-    }
-
-    Rect crop(Point((inVideoSize.width - cropSize.width) / 2,
-                    (inVideoSize.height - cropSize.height) / 2),
-              cropSize);
+    //Acquire input size
+    Size inVideoSize((int) cap.get(CV_CAP_PROP_FRAME_WIDTH),
+                     (int) cap.get(CV_CAP_PROP_FRAME_HEIGHT));
 
     double fps = cap.get(CV_CAP_PROP_FPS);
     int fourcc = static_cast<int>(cap.get(CV_CAP_PROP_FOURCC));
     VideoWriter outputVideo;
     outputVideo.open(parser.get<String>("out") ,
                      (fourcc != 0 ? fourcc : VideoWriter::fourcc('M','J','P','G')),
-                     (fps != 0 ? fps : 10.0), cropSize, true);
+                     (fps != 0 ? fps : 10.0), inVideoSize, true);
 
     for(;;)
     {
@@ -138,15 +120,17 @@ int main(int argc, char** argv)
 
         //! [Prepare blob]
         Mat inputBlob = blobFromImage(frame, inScaleFactor,
-                                      Size(inWidth, inHeight), meanVal, false); //Convert Mat to batch of images
+                                      Size(inWidth, inHeight),
+                                      Scalar(meanVal, meanVal, meanVal),
+                                      false, false); //Convert Mat to batch of images
         //! [Prepare blob]
 
         //! [Set input blob]
-        net.setInput(inputBlob, "data"); //set the network input
+        net.setInput(inputBlob); //set the network input
         //! [Set input blob]
 
         //! [Make forward pass]
-        Mat detection = net.forward("detection_out"); //compute output
+        Mat detection = net.forward(); //compute output
         //! [Make forward pass]
 
         vector<double> layersTimings;
@@ -155,13 +139,10 @@ int main(int argc, char** argv)
 
         Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());
 
-        frame = frame(crop);
-
-        ostringstream ss;
         if (!outputVideo.isOpened())
         {
-            ss << "FPS: " << 1000/time << " ; time: " << time << " ms";
-            putText(frame, ss.str(), Point(20,20), 0, 0.5, Scalar(0,0,255));
+            putText(frame, format("FPS: %.2f ; time: %.2f ms", 1000.f/time, time),
+                    Point(20,20), 0, 0.5, Scalar(0,0,255));
         }
         else
             cout << "Inference time, ms: " << time << endl;
@@ -175,27 +156,20 @@ int main(int argc, char** argv)
             {
                 size_t objectClass = (size_t)(detectionMat.at<float>(i, 1));
 
-                int xLeftBottom = static_cast<int>(detectionMat.at<float>(i, 3) * frame.cols);
-                int yLeftBottom = static_cast<int>(detectionMat.at<float>(i, 4) * frame.rows);
-                int xRightTop = static_cast<int>(detectionMat.at<float>(i, 5) * frame.cols);
-                int yRightTop = static_cast<int>(detectionMat.at<float>(i, 6) * frame.rows);
-
-                ss.str("");
-                ss << confidence;
-                String conf(ss.str());
-
-                Rect object((int)xLeftBottom, (int)yLeftBottom,
-                            (int)(xRightTop - xLeftBottom),
-                            (int)(yRightTop - yLeftBottom));
+                int left = static_cast<int>(detectionMat.at<float>(i, 3) * frame.cols);
+                int top = static_cast<int>(detectionMat.at<float>(i, 4) * frame.rows);
+                int right = static_cast<int>(detectionMat.at<float>(i, 5) * frame.cols);
+                int bottom = static_cast<int>(detectionMat.at<float>(i, 6) * frame.rows);
 
-                rectangle(frame, object, Scalar(0, 255, 0));
-                String label = String(classNames[objectClass]) + ": " + conf;
+                rectangle(frame, Point(left, top), Point(right, bottom), Scalar(0, 255, 0));
+                String label = format("%s: %.2f", classNames[objectClass], confidence);
                 int baseLine = 0;
                 Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
-                rectangle(frame, Rect(Point(xLeftBottom, yLeftBottom - labelSize.height),
-                                      Size(labelSize.width, labelSize.height + baseLine)),
+                top = max(top, labelSize.height);
+                rectangle(frame, Point(left, top - labelSize.height),
+                          Point(left + labelSize.width, top + baseLine),
                           Scalar(255, 255, 255), CV_FILLED);
-                putText(frame, label, Point(xLeftBottom, yLeftBottom),
+                putText(frame, label, Point(left, top),
                         FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0,0,0));
             }
         }