Release 18.08
[platform/upstream/armnn.git] / tests / TfResNext_Quantized-Armnn / TfResNext_Quantized-Armnn.cpp
index 1e1ede3..5817e8b 100644 (file)
@@ -3,7 +3,7 @@
 // See LICENSE file in the project root for full license information.
 //
 #include "../InferenceTest.hpp"
-#include "../ImageNetDatabase.hpp"
+#include "../CaffePreprocessor.hpp"
 #include "armnnTfParser/ITfParser.hpp"
 
 int main(int argc, char* argv[])
@@ -20,11 +20,18 @@ int main(int argc, char* argv[])
 
         armnn::TensorShape inputTensorShape({ 1, 3, 224, 224 });
 
+        using DataType = float;
+        using DatabaseType = CaffePreprocessor;
+        using ParserType = armnnTfParser::ITfParser;
+        using ModelType = InferenceModel<ParserType, DataType>;
+
         // Coverity fix: ClassifierInferenceTestMain() may throw uncaught exceptions.
-        retVal = armnn::test::ClassifierInferenceTestMain<ImageNetDatabase, armnnTfParser::ITfParser>(
+        retVal = armnn::test::ClassifierInferenceTestMain<DatabaseType, ParserType>(
                      argc, argv, "resnext_TF_quantized_for_armnn_team.pb", true,
                      "inputs", "pool1", { 0, 1 },
-                     [&imageSet](const char* dataDir) { return ImageNetDatabase(dataDir, 224, 224, imageSet); },
+                     [&imageSet](const char* dataDir, const ModelType &) {
+                         return DatabaseType(dataDir, 224, 224, imageSet);
+                     },
                      &inputTensorShape);
     }
     catch (const std::exception& e)