2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "../ImageTensorGenerator/ImageTensorGenerator.hpp"
7 #include "../InferenceTest.hpp"
8 #include "ModelAccuracyChecker.hpp"
9 #include "armnnDeserializer/IDeserializer.hpp"
11 #include <boost/filesystem.hpp>
12 #include <boost/program_options/variables_map.hpp>
13 #include <boost/range/iterator_range.hpp>
16 using namespace armnn::test;
18 /** Load image names and ground-truth labels from the image directory and the ground truth label file
20 * @pre \p validationLabelPath exists and is valid regular file
21 * @pre \p imageDirectoryPath exists and is valid directory
22 * @pre labels in validation file correspond to images which are in lexicographical order with the image name
23 * @pre image index starts at 1
24 * @pre \p begIndex and \p endIndex are end-inclusive
26 * @param[in] validationLabelPath Path to validation label file
27 * @param[in] imageDirectoryPath Path to directory containing validation images
28 * @param[in] begIndex Begin index of images to be loaded. Inclusive
29 * @param[in] endIndex End index of images to be loaded. Inclusive
30 * @param[in] blacklistPath Path to blacklist file
31 * @return A map mapping image file names to their corresponding ground-truth labels
33 map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
34 const string& imageDirectoryPath,
37 const string& blacklistPath = "");
39 /** Load model output labels from file
41 * @pre \p modelOutputLabelsPath exists and is a regular file
43 * @param[in] modelOutputLabelsPath path to model output labels file
44 * @return A vector of labels, which in turn is described by a list of category names
46 std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath);
48 int main(int argc, char* argv[])
52 using namespace boost::filesystem;
53 armnn::LogSeverity level = armnn::LogSeverity::Debug;
54 armnn::ConfigureLogging(true, true, level);
56 // Set-up program Options
57 namespace po = boost::program_options;
59 std::vector<armnn::BackendId> computeDevice;
60 std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
61 std::string modelPath;
62 std::string modelFormat;
64 std::string inputName;
65 std::string inputLayout;
66 std::string outputName;
67 std::string modelOutputLabelsPath;
68 std::string validationLabelPath;
69 std::string validationRange;
70 std::string blacklistPath;
72 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
73 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
75 po::options_description desc("Options");
78 // Adds generic options needed to run Accuracy Tool.
80 ("help,h", "Display help messages")
81 ("model-path,m", po::value<std::string>(&modelPath)->required(), "Path to armnn format model file")
82 ("model-format,f", po::value<std::string>(&modelFormat)->required(),
83 "The model format. Supported values: caffe, tensorflow, tflite")
84 ("input-name,i", po::value<std::string>(&inputName)->required(),
85 "Identifier of the input tensors in the network separated by comma.")
86 ("output-name,o", po::value<std::string>(&outputName)->required(),
87 "Identifier of the output tensors in the network separated by comma.")
88 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
89 "Path to directory containing the ImageNet test data")
90 ("model-output-labels,p", po::value<std::string>(&modelOutputLabelsPath)->required(),
91 "Path to model output labels file.")
92 ("validation-labels-path,v", po::value<std::string>(&validationLabelPath)->required(),
93 "Path to ImageNet Validation Label file")
94 ("data-layout,l", po::value<std::string>(&inputLayout)->default_value("NHWC"),
95 "Data layout. Supported value: NHWC, NCHW. Default: NHWC")
96 ("compute,c", po::value<std::vector<armnn::BackendId>>(&computeDevice)->default_value(defaultBackends),
97 backendsMessage.c_str())
98 ("validation-range,r", po::value<std::string>(&validationRange)->default_value("1:0"),
99 "The range of the images to be evaluated. Specified in the form <begin index>:<end index>."
100 "The index starts at 1 and the range is inclusive."
101 "By default the evaluation will be performed on all images.")
102 ("blacklist-path,b", po::value<std::string>(&blacklistPath)->default_value(""),
103 "Path to a blacklist file where each line denotes the index of an image to be "
104 "excluded from evaluation.");
106 catch (const std::exception& e)
108 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
109 // and that desc.add_options() can throw boost::io::too_few_args.
110 // They really won't in any of these cases.
111 ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
112 std::cerr << "Fatal internal error: " << e.what() << std::endl;
116 po::variables_map vm;
119 po::store(po::parse_command_line(argc, argv, desc), vm);
121 if (vm.count("help"))
123 std::cout << desc << std::endl;
130 std::cerr << e.what() << std::endl << std::endl;
131 std::cerr << desc << std::endl;
135 // Check if the requested backend are all valid
136 std::string invalidBackends;
137 if (!CheckRequestedBackendsAreValid(computeDevice, armnn::Optional<std::string&>(invalidBackends)))
139 ARMNN_LOG(fatal) << "The list of preferred devices contains invalid backend IDs: "
143 armnn::Status status;
146 armnn::IRuntime::CreationOptions options;
147 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
148 std::ifstream file(modelPath);
151 using IParser = armnnDeserializer::IDeserializer;
152 auto armnnparser(IParser::Create());
155 armnn::INetworkPtr network = armnnparser->CreateNetworkFromBinary(file);
157 // Optimizes the network.
158 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
161 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
163 catch (const armnn::Exception& e)
165 std::stringstream message;
166 message << "armnn::Exception (" << e.what() << ") caught from optimize.";
167 ARMNN_LOG(fatal) << message.str();
171 // Loads the network into the runtime.
172 armnn::NetworkId networkId;
173 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
174 if (status == armnn::Status::Failure)
176 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to load network";
181 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
183 const armnnDeserializer::BindingPointInfo&
184 inputBindingInfo = armnnparser->GetNetworkInputBindingInfo(0, inputName);
186 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
187 m_InputBindingInfo(inputBindingInfo.m_BindingId, inputBindingInfo.m_TensorInfo);
188 std::vector<BindingPointInfo> inputBindings = { m_InputBindingInfo };
190 const armnnDeserializer::BindingPointInfo&
191 outputBindingInfo = armnnparser->GetNetworkOutputBindingInfo(0, outputName);
193 std::pair<armnn::LayerBindingId, armnn::TensorInfo>
194 m_OutputBindingInfo(outputBindingInfo.m_BindingId, outputBindingInfo.m_TensorInfo);
195 std::vector<BindingPointInfo> outputBindings = { m_OutputBindingInfo };
197 // Load model output labels
198 if (modelOutputLabelsPath.empty() || !boost::filesystem::exists(modelOutputLabelsPath) ||
199 !boost::filesystem::is_regular_file(modelOutputLabelsPath))
201 ARMNN_LOG(fatal) << "Invalid model output labels path at " << modelOutputLabelsPath;
203 const std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels =
204 LoadModelOutputLabels(modelOutputLabelsPath);
206 // Parse begin and end image indices
207 std::vector<std::string> imageIndexStrs = armnnUtils::SplitBy(validationRange, ":");
208 size_t imageBegIndex;
209 size_t imageEndIndex;
210 if (imageIndexStrs.size() != 2)
212 ARMNN_LOG(fatal) << "Invalid validation range specification: Invalid format " << validationRange;
217 imageBegIndex = std::stoul(imageIndexStrs[0]);
218 imageEndIndex = std::stoul(imageIndexStrs[1]);
220 catch (const std::exception& e)
222 ARMNN_LOG(fatal) << "Invalid validation range specification: " << validationRange;
226 // Validate blacklist file if it's specified
227 if (!blacklistPath.empty() &&
228 !(boost::filesystem::exists(blacklistPath) && boost::filesystem::is_regular_file(blacklistPath)))
230 ARMNN_LOG(fatal) << "Invalid path to blacklist file at " << blacklistPath;
234 path pathToDataDir(dataDir);
235 const map<std::string, std::string> imageNameToLabel = LoadValidationImageFilenamesAndLabels(
236 validationLabelPath, pathToDataDir.string(), imageBegIndex, imageEndIndex, blacklistPath);
237 armnnUtils::ModelAccuracyChecker checker(imageNameToLabel, modelOutputLabels);
238 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<uint8_t>>;
240 if (ValidateDirectory(dataDir))
242 InferenceModel<armnnDeserializer::IDeserializer, float>::Params params;
243 params.m_ModelPath = modelPath;
244 params.m_IsModelBinary = true;
245 params.m_ComputeDevices = computeDevice;
246 params.m_InputBindings.push_back(inputName);
247 params.m_OutputBindings.push_back(outputName);
249 using TParser = armnnDeserializer::IDeserializer;
250 InferenceModel<TParser, float> model(params, false);
251 // Get input tensor information
252 const armnn::TensorInfo& inputTensorInfo = model.GetInputBindingInfo().second;
253 const armnn::TensorShape& inputTensorShape = inputTensorInfo.GetShape();
254 const armnn::DataType& inputTensorDataType = inputTensorInfo.GetDataType();
255 armnn::DataLayout inputTensorDataLayout;
256 if (inputLayout == "NCHW")
258 inputTensorDataLayout = armnn::DataLayout::NCHW;
260 else if (inputLayout == "NHWC")
262 inputTensorDataLayout = armnn::DataLayout::NHWC;
266 ARMNN_LOG(fatal) << "Invalid Data layout: " << inputLayout;
269 const unsigned int inputTensorWidth =
270 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[3] : inputTensorShape[2];
271 const unsigned int inputTensorHeight =
272 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[2] : inputTensorShape[1];
273 // Get output tensor info
274 const unsigned int outputNumElements = model.GetOutputSize();
275 // Check output tensor shape is valid
276 if (modelOutputLabels.size() != outputNumElements)
278 ARMNN_LOG(fatal) << "Number of output elements: " << outputNumElements
279 << " , mismatches the number of output labels: " << modelOutputLabels.size();
283 const unsigned int batchSize = 1;
284 // Get normalisation parameters
285 SupportedFrontend modelFrontend;
286 if (modelFormat == "caffe")
288 modelFrontend = SupportedFrontend::Caffe;
290 else if (modelFormat == "tensorflow")
292 modelFrontend = SupportedFrontend::TensorFlow;
294 else if (modelFormat == "tflite")
296 modelFrontend = SupportedFrontend::TFLite;
300 ARMNN_LOG(fatal) << "Unsupported frontend: " << modelFormat;
303 const NormalizationParameters& normParams = GetNormalizationParameters(modelFrontend, inputTensorDataType);
304 for (const auto& imageEntry : imageNameToLabel)
306 const std::string imageName = imageEntry.first;
307 std::cout << "Processing image: " << imageName << "\n";
309 vector<TContainer> inputDataContainers;
310 vector<TContainer> outputDataContainers;
312 auto imagePath = pathToDataDir / boost::filesystem::path(imageName);
313 switch (inputTensorDataType)
315 case armnn::DataType::Signed32:
316 inputDataContainers.push_back(
317 PrepareImageTensor<int>(imagePath.string(),
318 inputTensorWidth, inputTensorHeight,
321 inputTensorDataLayout));
322 outputDataContainers = { vector<int>(outputNumElements) };
324 case armnn::DataType::QAsymmU8:
325 inputDataContainers.push_back(
326 PrepareImageTensor<uint8_t>(imagePath.string(),
327 inputTensorWidth, inputTensorHeight,
330 inputTensorDataLayout));
331 outputDataContainers = { vector<uint8_t>(outputNumElements) };
333 case armnn::DataType::Float32:
335 inputDataContainers.push_back(
336 PrepareImageTensor<float>(imagePath.string(),
337 inputTensorWidth, inputTensorHeight,
340 inputTensorDataLayout));
341 outputDataContainers = { vector<float>(outputNumElements) };
345 status = runtime->EnqueueWorkload(networkId,
346 armnnUtils::MakeInputTensors(inputBindings, inputDataContainers),
347 armnnUtils::MakeOutputTensors(outputBindings, outputDataContainers));
349 if (status == armnn::Status::Failure)
351 ARMNN_LOG(fatal) << "armnn::IRuntime: Failed to enqueue workload for image: " << imageName;
354 checker.AddImageResult<TContainer>(imageName, outputDataContainers);
362 for(unsigned int i = 1; i <= 5; ++i)
364 std::cout << "Top " << i << " Accuracy: " << checker.GetAccuracy(i) << "%" << "\n";
367 ARMNN_LOG(info) << "Accuracy Tool ran successfully!";
370 catch (const armnn::Exception& e)
372 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
373 // exception of type std::length_error.
374 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
375 std::cerr << "Armnn Error: " << e.what() << std::endl;
378 catch (const std::exception& e)
380 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
381 std::cerr << "WARNING: ModelAccuracyTool-Armnn: An error has occurred when running the "
382 "Accuracy Tool: " << e.what() << std::endl;
387 map<std::string, std::string> LoadValidationImageFilenamesAndLabels(const string& validationLabelPath,
388 const string& imageDirectoryPath,
391 const string& blacklistPath)
393 // Populate imageFilenames with names of all .JPEG, .PNG images
394 std::vector<std::string> imageFilenames;
395 for (const auto& imageEntry :
396 boost::make_iterator_range(boost::filesystem::directory_iterator(boost::filesystem::path(imageDirectoryPath))))
398 boost::filesystem::path imagePath = imageEntry.path();
399 std::string imageExtension = boost::to_upper_copy<std::string>(imagePath.extension().string());
400 if (boost::filesystem::is_regular_file(imagePath) && (imageExtension == ".JPEG" || imageExtension == ".PNG"))
402 imageFilenames.push_back(imagePath.filename().string());
405 if (imageFilenames.empty())
407 throw armnn::Exception("No image file (JPEG, PNG) found at " + imageDirectoryPath);
410 // Sort the image filenames lexicographically
411 std::sort(imageFilenames.begin(), imageFilenames.end());
413 std::cout << imageFilenames.size() << " images found at " << imageDirectoryPath << std::endl;
415 // Get default end index
416 if (begIndex < 1 || endIndex > imageFilenames.size())
418 throw armnn::Exception("Invalid image index range");
420 endIndex = endIndex == 0 ? imageFilenames.size() : endIndex;
421 if (begIndex > endIndex)
423 throw armnn::Exception("Invalid image index range");
426 // Load blacklist if there is one
427 std::vector<unsigned int> blacklist;
428 if (!blacklistPath.empty())
430 std::ifstream blacklistFile(blacklistPath);
432 while (blacklistFile >> index)
434 blacklist.push_back(index);
438 // Load ground truth labels and pair them with corresponding image names
439 std::string classification;
440 map<std::string, std::string> imageNameToLabel;
441 ifstream infile(validationLabelPath);
442 size_t imageIndex = begIndex;
443 size_t blacklistIndexCount = 0;
444 while (std::getline(infile, classification))
446 if (imageIndex > endIndex)
450 // If current imageIndex is included in blacklist, skip the current image
451 if (blacklistIndexCount < blacklist.size() && imageIndex == blacklist[blacklistIndexCount])
454 ++blacklistIndexCount;
457 imageNameToLabel.insert(std::pair<std::string, std::string>(imageFilenames[imageIndex - 1], classification));
460 std::cout << blacklistIndexCount << " images blacklisted" << std::endl;
461 std::cout << imageIndex - begIndex - blacklistIndexCount << " images to be loaded" << std::endl;
462 return imageNameToLabel;
465 std::vector<armnnUtils::LabelCategoryNames> LoadModelOutputLabels(const std::string& modelOutputLabelsPath)
467 std::vector<armnnUtils::LabelCategoryNames> modelOutputLabels;
468 ifstream modelOutputLablesFile(modelOutputLabelsPath);
470 while (std::getline(modelOutputLablesFile, line))
472 armnnUtils::LabelCategoryNames tokens = armnnUtils::SplitBy(line, ":");
473 armnnUtils::LabelCategoryNames predictionCategoryNames = armnnUtils::SplitBy(tokens.back(), ",");
474 std::transform(predictionCategoryNames.begin(), predictionCategoryNames.end(), predictionCategoryNames.begin(),
475 [](const std::string& category) { return armnnUtils::Strip(category); });
476 modelOutputLabels.push_back(predictionCategoryNames);
478 return modelOutputLabels;