Set zero confidences in case of no detections
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 30 Oct 2017 07:17:57 +0000 (10:17 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 30 Oct 2017 07:17:57 +0000 (10:17 +0300)
modules/dnn/src/layers/detection_output_layer.cpp
modules/dnn/test/test_caffe_importer.cpp

index 2e381b2..ae4774a 100644 (file)
@@ -240,6 +240,9 @@ public:
 
         if (numKept == 0)
         {
+            // Set confidences to zeros.
+            Range ranges[] = {Range::all(), Range::all(), Range::all(), Range(2, 3)};
+            outputs[0](ranges).setTo(0);
             return;
         }
         int outputShape[] = {1, 1, (int)numKept, 7};
index 02ed416..a1837ae 100644 (file)
@@ -139,6 +139,35 @@ TEST(Reproducibility_SSD, Accuracy)
     normAssert(ref, out);
 }
 
+TEST(Reproducibility_MobileNet_SSD, Accuracy)
+{
+    const string proto = findDataFile("dnn/MobileNetSSD_deploy.prototxt", false);
+    const string model = findDataFile("dnn/MobileNetSSD_deploy.caffemodel", false);
+    Net net = readNetFromCaffe(proto, model);
+
+    Mat sample = imread(_tf("street.png"));
+
+    Mat inp = blobFromImage(sample, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), false);
+    net.setInput(inp);
+    Mat out = net.forward();
+
+    Mat ref = blobFromNPY(_tf("mobilenet_ssd_caffe_out.npy"));
+    normAssert(ref, out);
+
+    // Check that detections aren't preserved.
+    inp.setTo(0.0f);
+    net.setInput(inp);
+    out = net.forward();
+
+    const int numDetections = out.size[2];
+    ASSERT_NE(numDetections, 0);
+    for (int i = 0; i < numDetections; ++i)
+    {
+        float confidence = out.ptr<float>(0, 0, i)[2];
+        ASSERT_EQ(confidence, 0);
+    }
+}
+
 TEST(Reproducibility_ResNet50, Accuracy)
 {
     Net net = readNetFromCaffe(findDataFile("dnn/ResNet-50-deploy.prototxt", false),