IVGCVSW-2436 Modify MobileNet SSD inference test
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Fri, 15 Feb 2019 17:34:51 +0000 (17:34 +0000)
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Wed, 20 Feb 2019 11:06:06 +0000 (11:06 +0000)
 * change MobileNet SSD input to uint8
 * get quantization scale and offset from the model
 * change data layout to NHWC as TensorFlow lite layout
 * update expected output as result from TfLite with quantized data

Change-Id: I07104d56286893935779169356234de53f1c9492
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
tests/MobileNetSsdDatabase.hpp
tests/MobileNetSsdInferenceTest.hpp

index e3a28d1..cac5587 100644 (file)
@@ -11,6 +11,7 @@
 #include <vector>
 
 #include <armnn/TypesUtils.hpp>
+#include <backendsCommon/test/QuantizeHelper.hpp>
 
 #include <boost/log/trivial.hpp>
 #include <boost/numeric/conversion/cast.hpp>
@@ -26,25 +27,27 @@ namespace
 struct MobileNetSsdTestCaseData
 {
     MobileNetSsdTestCaseData(
-        std::vector<float> inputData,
+        std::vector<uint8_t> inputData,
         std::vector<DetectedObject> expectedOutput)
         : m_InputData(std::move(inputData))
         , m_ExpectedOutput(std::move(expectedOutput))
     {}
 
-    std::vector<float>          m_InputData;
+    std::vector<uint8_t>        m_InputData;
     std::vector<DetectedObject> m_ExpectedOutput;
 };
 
 class MobileNetSsdDatabase
 {
 public:
-    explicit MobileNetSsdDatabase(const std::string& imageDir);
+    explicit MobileNetSsdDatabase(const std::string& imageDir, float scale, int offset);
 
     std::unique_ptr<MobileNetSsdTestCaseData> GetTestCaseData(unsigned int testCaseId);
 
 private:
     std::string m_ImageDir;
+    float m_Scale;
+    int m_Offset;
 };
 
 constexpr unsigned int k_MobileNetSsdImageWidth  = 300u;
@@ -56,12 +59,14 @@ const std::array<ObjectDetectionInput, 1> g_PerTestCaseInput =
     ObjectDetectionInput
     {
         "Cat.jpg",
-        DetectedObject(16, BoundingBox(0.21678525f, 0.0859828f, 0.9271242f, 0.9453231f), 0.79296875f)
+        DetectedObject(16, BoundingBox(0.208961248f, 0.0852333307f, 0.92757535f, 0.940263629f), 0.79296875f)
     }
 };
 
-MobileNetSsdDatabase::MobileNetSsdDatabase(const std::string& imageDir)
+MobileNetSsdDatabase::MobileNetSsdDatabase(const std::string& imageDir, float scale, int offset)
     : m_ImageDir(imageDir)
+    , m_Scale(scale)
+    , m_Offset(offset)
 {}
 
 std::unique_ptr<MobileNetSsdTestCaseData> MobileNetSsdDatabase::GetTestCaseData(unsigned int testCaseId)
@@ -72,7 +77,7 @@ std::unique_ptr<MobileNetSsdTestCaseData> MobileNetSsdDatabase::GetTestCaseData(
 
     // Load test case input
     const std::string imagePath = m_ImageDir + testCaseInput.first;
-    std::vector<float> imageData;
+    std::vector<uint8_t> imageData;
     try
     {
         InferenceTestImage image(imagePath.c_str());
@@ -86,7 +91,8 @@ std::unique_ptr<MobileNetSsdTestCaseData> MobileNetSsdDatabase::GetTestCaseData(
         }
 
         // Get image data as a vector of floats
-        imageData = GetImageDataInArmNnLayoutAsNormalizedFloats(ImageChannelLayout::Rgb, image);
+        std::vector<float> floatImageData = GetImageDataAsNormalizedFloats(ImageChannelLayout::Rgb, image);
+        imageData = QuantizedVector<uint8_t>(m_Scale, m_Offset, floatImageData);
     }
     catch (const InferenceTestImageException& e)
     {
index 0091009..10ee1dc 100644 (file)
@@ -126,7 +126,7 @@ public:
     }
 
 private:
-    static constexpr unsigned int k_NumDetections = 10u;
+    static constexpr unsigned int k_NumDetections = 1u;
 
     static constexpr unsigned int k_OutputSize1 = k_NumDetections * 4u;
     static constexpr unsigned int k_OutputSize2 = k_NumDetections;
@@ -169,8 +169,8 @@ public:
         {
             return false;
         }
-
-        m_Database = std::make_unique<MobileNetSsdDatabase>(m_DataDir.c_str());
+        std::pair<float, int32_t> qParams = m_Model->GetQuantizationParams();
+        m_Database = std::make_unique<MobileNetSsdDatabase>(m_DataDir.c_str(), qParams.first, qParams.second);
         if (!m_Database)
         {
             return false;