IVGCVSW-3193 Allow ExecuteNetwork to have qasymm8 input type
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Wed, 26 Jun 2019 14:10:46 +0000 (15:10 +0100)
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Wed, 26 Jun 2019 17:22:57 +0000 (17:22 +0000)
and add option to quantize float inputs to qasymm8

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I54b13b8b53c31c05658fe9c310ca5a66df759aa5

tests/ExecuteNetwork/ExecuteNetwork.cpp
tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp

index 60353db..a8f3b3d 100644 (file)
@@ -66,6 +66,10 @@ int main(int argc, const char* argv[])
             ("input-type,y",po::value(&inputTypes), "The type of the input tensors in the network separated by comma. "
              "If unset, defaults to \"float\" for all defined inputs. "
              "Accepted values (float, int or qasymm8)")
+            ("quantize-input,q",po::bool_switch()->default_value(false),
+             "If this option is enabled, all float inputs will be quantized to qasymm8. "
+             "If unset, default to not quantized. "
+             "Accepted values (true or false)")
             ("output-type,z",po::value(&outputTypes),
              "The type of the output tensors in the network separated by comma. "
              "If unset, defaults to \"float\" for all defined outputs. "
@@ -119,6 +123,7 @@ int main(int argc, const char* argv[])
     bool concurrent = vm["concurrent"].as<bool>();
     bool enableProfiling = vm["event-based-profiling"].as<bool>();
     bool enableFp16TurboMode = vm["fp16-turbo-mode"].as<bool>();
+    bool quantizeInput = vm["quantize-input"].as<bool>();
 
     // Check whether we have to load test cases from a file.
     if (CheckOption(vm, "test-cases"))
@@ -220,7 +225,7 @@ int main(int argc, const char* argv[])
         }
 
         return RunTest(modelFormat, inputTensorShapes, computeDevices, modelPath, inputNames,
-                       inputTensorDataFilePaths, inputTypes, outputTypes, outputNames,
+                       inputTensorDataFilePaths, inputTypes, quantizeInput, outputTypes, outputNames,
                        enableProfiling, enableFp16TurboMode, thresholdTime, subgraphId);
     }
 }
index 9d7e368..440dcf9 100644 (file)
@@ -146,6 +146,13 @@ auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
 }
 
 template<>
+auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream)
+{
+    return ParseArrayImpl<uint8_t>(stream,
+                                   [](const std::string& s) { return boost::numeric_cast<uint8_t>(std::stoi(s)); });
+}
+
+template<>
 auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream,
                                                       const float& quantizationScale,
                                                       const int32_t& quantizationOffset)
@@ -159,7 +166,6 @@ auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream,
                                                                      quantizationOffset));
                                    });
 }
-
 std::vector<unsigned int> ParseArray(std::istream& stream)
 {
     return ParseArrayImpl<unsigned int>(stream,
@@ -252,6 +258,7 @@ int MainImpl(const char* modelPath,
              const std::vector<std::unique_ptr<armnn::TensorShape>>& inputTensorShapes,
              const std::vector<string>& inputTensorDataFilePaths,
              const std::vector<string>& inputTypes,
+             bool quantizeInput,
              const std::vector<string>& outputTypes,
              const std::vector<string>& outputNames,
              bool enableProfiling,
@@ -297,8 +304,19 @@ int MainImpl(const char* modelPath,
 
             if (inputTypes[i].compare("float") == 0)
             {
-                inputDataContainers.push_back(
-                    ParseDataArray<armnn::DataType::Float32>(inputTensorFile));
+                if (quantizeInput)
+                {
+                    auto inputBinding = model.GetInputBindingInfo();
+                    inputDataContainers.push_back(
+                            ParseDataArray<armnn::DataType::QuantisedAsymm8>(inputTensorFile,
+                                                         inputBinding.second.GetQuantizationScale(),
+                                                         inputBinding.second.GetQuantizationOffset()));
+                }
+                else
+                {
+                    inputDataContainers.push_back(
+                            ParseDataArray<armnn::DataType::Float32>(inputTensorFile));
+                }
             }
             else if (inputTypes[i].compare("int") == 0)
             {
@@ -307,11 +325,8 @@ int MainImpl(const char* modelPath,
             }
             else if (inputTypes[i].compare("qasymm8") == 0)
             {
-                auto inputBinding = model.GetInputBindingInfo();
                 inputDataContainers.push_back(
-                    ParseDataArray<armnn::DataType::QuantisedAsymm8>(inputTensorFile,
-                                                                     inputBinding.second.GetQuantizationScale(),
-                                                                     inputBinding.second.GetQuantizationOffset()));
+                    ParseDataArray<armnn::DataType::QuantisedAsymm8>(inputTensorFile));
             }
             else
             {
@@ -396,6 +411,7 @@ int RunTest(const std::string& format,
             const std::string& inputNames,
             const std::string& inputTensorDataFilePaths,
             const std::string& inputTypes,
+            bool quantizeInput,
             const std::string& outputTypes,
             const std::string& outputNames,
             bool enableProfiling,
@@ -498,7 +514,7 @@ int RunTest(const std::string& format,
     return MainImpl<armnnDeserializer::IDeserializer, float>(
         modelPath.c_str(), isModelBinary, computeDevice,
         inputNamesVector, inputTensorShapes,
-        inputTensorDataFilePathsVector, inputTypesVector,
+        inputTensorDataFilePathsVector, inputTypesVector, quantizeInput,
         outputTypesVector, outputNamesVector, enableProfiling,
         enableFp16TurboMode, thresholdTime, subgraphId, runtime);
 #else
@@ -512,8 +528,9 @@ int RunTest(const std::string& format,
         return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
                                                                inputNamesVector, inputTensorShapes,
                                                                inputTensorDataFilePathsVector, inputTypesVector,
-                                                               outputTypesVector, outputNamesVector, enableProfiling,
-                                                               enableFp16TurboMode, thresholdTime, subgraphId, runtime);
+                                                               quantizeInput, outputTypesVector, outputNamesVector,
+                                                               enableProfiling, enableFp16TurboMode, thresholdTime,
+                                                               subgraphId, runtime);
 #else
         BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
         return EXIT_FAILURE;
@@ -525,8 +542,9 @@ int RunTest(const std::string& format,
     return MainImpl<armnnOnnxParser::IOnnxParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
                                                          inputNamesVector, inputTensorShapes,
                                                          inputTensorDataFilePathsVector, inputTypesVector,
-                                                         outputTypesVector, outputNamesVector, enableProfiling,
-                                                         enableFp16TurboMode, thresholdTime, subgraphId, runtime);
+                                                         quantizeInput, outputTypesVector, outputNamesVector,
+                                                         enableProfiling, enableFp16TurboMode, thresholdTime,
+                                                         subgraphId, runtime);
 #else
     BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support.";
     return EXIT_FAILURE;
@@ -538,8 +556,9 @@ int RunTest(const std::string& format,
         return MainImpl<armnnTfParser::ITfParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
                                                          inputNamesVector, inputTensorShapes,
                                                          inputTensorDataFilePathsVector, inputTypesVector,
-                                                         outputTypesVector, outputNamesVector, enableProfiling,
-                                                         enableFp16TurboMode, thresholdTime, subgraphId, runtime);
+                                                         quantizeInput, outputTypesVector, outputNamesVector,
+                                                         enableProfiling, enableFp16TurboMode, thresholdTime,
+                                                         subgraphId, runtime);
 #else
         BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
         return EXIT_FAILURE;
@@ -557,9 +576,9 @@ int RunTest(const std::string& format,
         return MainImpl<armnnTfLiteParser::ITfLiteParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
                                                                  inputNamesVector, inputTensorShapes,
                                                                  inputTensorDataFilePathsVector, inputTypesVector,
-                                                                 outputTypesVector, outputNamesVector, enableProfiling,
-                                                                 enableFp16TurboMode, thresholdTime, subgraphId,
-                                                                 runtime);
+                                                                 quantizeInput, outputTypesVector, outputNamesVector,
+                                                                 enableProfiling, enableFp16TurboMode, thresholdTime,
+                                                                 subgraphId, runtime);
 #else
         BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
             "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
@@ -616,6 +635,10 @@ int RunCsvTest(const armnnUtils::CsvRow &csvRow, const std::shared_ptr<armnn::IR
         ("input-type,y",po::value(&inputTypes), "The type of the input tensors in the network separated by comma. "
          "If unset, defaults to \"float\" for all defined inputs. "
          "Accepted values (float, int or qasymm8).")
+        ("quantize-input,q",po::bool_switch()->default_value(false),
+         "If this option is enabled, all float inputs will be quantized to qasymm8. "
+         "If unset, default to not quantized. "
+         "Accepted values (true or false)")
         ("output-type,z",po::value(&outputTypes), "The type of the output tensors in the network separated by comma. "
          "If unset, defaults to \"float\" for all defined outputs. "
          "Accepted values (float, int or qasymm8).")
@@ -655,6 +678,9 @@ int RunCsvTest(const armnnUtils::CsvRow &csvRow, const std::shared_ptr<armnn::IR
         return EXIT_FAILURE;
     }
 
+    // Get the value of the switch arguments.
+    bool quantizeInput = vm["quantize-input"].as<bool>();
+
     // Get the preferred order of compute devices.
     std::vector<armnn::BackendId> computeDevices = vm["compute"].as<std::vector<armnn::BackendId>>();
 
@@ -671,6 +697,6 @@ int RunCsvTest(const armnnUtils::CsvRow &csvRow, const std::shared_ptr<armnn::IR
     }
 
     return RunTest(modelFormat, inputTensorShapes, computeDevices, modelPath, inputNames,
-                   inputTensorDataFilePaths, inputTypes, outputTypes, outputNames,
+                   inputTensorDataFilePaths, inputTypes, quantizeInput, outputTypes, outputNames,
                    enableProfiling, enableFp16TurboMode, thresholdTime, subgraphId);
 }
\ No newline at end of file