Fix ENet test
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 19 Oct 2018 14:43:26 +0000 (17:43 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 19 Oct 2018 14:43:26 +0000 (17:43 +0300)
modules/dnn/test/test_torch_importer.cpp

index dd7d975..0b84445 100644 (file)
@@ -287,6 +287,46 @@ TEST_P(Test_Torch_nets, OpenFace_accuracy)
     normAssert(out, outRef, "", default_l1, default_lInf);
 }
 
+static Mat getSegmMask(const Mat& scores)
+{
+    const int rows = scores.size[2];
+    const int cols = scores.size[3];
+    const int numClasses = scores.size[1];
+
+    Mat maxCl = Mat::zeros(rows, cols, CV_8UC1);
+    Mat maxVal(rows, cols, CV_32FC1, Scalar(0));
+    for (int ch = 0; ch < numClasses; ch++)
+    {
+        for (int row = 0; row < rows; row++)
+        {
+            const float *ptrScore = scores.ptr<float>(0, ch, row);
+            uint8_t *ptrMaxCl = maxCl.ptr<uint8_t>(row);
+            float *ptrMaxVal = maxVal.ptr<float>(row);
+            for (int col = 0; col < cols; col++)
+            {
+                if (ptrScore[col] > ptrMaxVal[col])
+                {
+                    ptrMaxVal[col] = ptrScore[col];
+                    ptrMaxCl[col] = (uchar)ch;
+                }
+            }
+        }
+    }
+    return maxCl;
+}
+
+// Computer per-class intersection over union metric.
+static void normAssertSegmentation(const Mat& ref, const Mat& test)
+{
+    CV_Assert_N(ref.dims == 4, test.dims == 4);
+    const int numClasses = ref.size[1];
+    CV_Assert(numClasses == test.size[1]);
+
+    Mat refMask = getSegmMask(ref);
+    Mat testMask = getSegmMask(test);
+    EXPECT_EQ(countNonZero(refMask != testMask), 0);
+}
+
 TEST_P(Test_Torch_nets, ENet_accuracy)
 {
     checkBackend();
@@ -313,14 +353,16 @@ TEST_P(Test_Torch_nets, ENet_accuracy)
     // Due to numerical instability in Pooling-Unpooling layers (indexes jittering)
     // thresholds for ENet must be changed. Accuracy of results was checked on
     // Cityscapes dataset and difference in mIOU with Torch is 10E-4%
-    normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.5);
+    normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.552);
+    normAssertSegmentation(ref, out);
 
     const int N = 3;
     for (int i = 0; i < N; i++)
     {
         net.setInput(inputBlob, "");
         Mat out = net.forward();
-        normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.5);
+        normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.552);
+        normAssertSegmentation(ref, out);
     }
 }