1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
9 #include "ClassificationProcessor.hpp"
10 #include "Processor.hpp"
12 using InferenceEngine::details::InferenceEngineException;
14 ClassificationProcessor::ClassificationProcessor(const std::string& flags_m, const std::string& flags_d, const std::string& flags_i, int flags_b,
15 InferencePlugin plugin, CsvDumper& dumper, const std::string& flags_l,
16 PreprocessingOptions preprocessingOptions, bool zeroBackground)
17 : Processor(flags_m, flags_d, flags_i, flags_b, plugin, dumper, "Classification network", preprocessingOptions), zeroBackground(zeroBackground) {
19 // Change path to labels file if necessary
20 if (flags_l.empty()) {
21 labelFileName = fileNameNoExt(modelFileName) + ".labels";
23 labelFileName = flags_l;
27 ClassificationProcessor::ClassificationProcessor(const std::string& flags_m, const std::string& flags_d, const std::string& flags_i, int flags_b,
28 InferencePlugin plugin, CsvDumper& dumper, const std::string& flags_l, bool zeroBackground)
29 : ClassificationProcessor(flags_m, flags_d, flags_i, flags_b, plugin, dumper, flags_l,
30 PreprocessingOptions(false, ResizeCropPolicy::ResizeThenCrop, 256, 256), zeroBackground) {
33 std::shared_ptr<Processor::InferenceMetrics> ClassificationProcessor::Process(bool stream_output) {
34 slog::info << "Collecting labels" << slog::endl;
35 ClassificationSetGenerator generator;
37 generator.readLabels(labelFileName);
38 } catch (InferenceEngine::details::InferenceEngineException& ex) {
39 slog::warn << "Can't read labels file " << labelFileName << slog::endl;
40 slog::warn << "Error: " << ex.what() << slog::endl;
43 auto validationMap = generator.getValidationMap(imagesPath);
46 // ----------------------------Do inference-------------------------------------------------------------
47 slog::info << "Starting inference" << slog::endl;
49 std::vector<int> expected(batch);
50 std::vector<std::string> files(batch);
52 ConsoleProgress progress(validationMap.size(), stream_output);
54 ClassificationInferenceMetrics im;
56 std::string firstInputName = this->inputInfo.begin()->first;
57 std::string firstOutputName = this->outInfo.begin()->first;
58 auto firstInputBlob = inferRequest.GetBlob(firstInputName);
59 auto firstOutputBlob = inferRequest.GetBlob(firstOutputName);
61 auto iter = validationMap.begin();
62 while (iter != validationMap.end()) {
65 for (; b < batch && iter != validationMap.end(); b++, iter++, filesWatched++) {
66 expected[b] = iter->first;
68 decoder.insertIntoBlob(iter->second, b, *firstInputBlob, preprocessingOptions);
69 files[b] = iter->second;
70 } catch (const InferenceEngineException& iex) {
71 slog::warn << "Can't read file " << iter->second << slog::endl;
72 slog::warn << "Error: " << iex.what() << slog::endl;
73 // Could be some non-image file in directory
79 Infer(progress, filesWatched, im);
81 std::vector<unsigned> results;
82 auto firstOutputData = firstOutputBlob->buffer().as<PrecisionTrait<Precision::FP32>::value_type*>();
83 InferenceEngine::TopResults(TOP_COUNT, *firstOutputBlob, results);
85 for (size_t i = 0; i < b; i++) {
86 int expc = expected[i];
87 if (zeroBackground) expc++;
89 bool top1Scored = (static_cast<int>(results[0 + TOP_COUNT * i]) == expc);
90 dumper << "\"" + files[i] + "\"" << top1Scored;
91 if (top1Scored) im.top1Result++;
92 for (int j = 0; j < TOP_COUNT; j++) {
93 unsigned classId = results[j + TOP_COUNT * i];
94 if (static_cast<int>(classId) == expc) {
97 dumper << classId << firstOutputData[classId + i * (firstOutputBlob->size() / batch)];
105 return std::shared_ptr<Processor::InferenceMetrics>(new ClassificationInferenceMetrics(im));
108 void ClassificationProcessor::Report(const Processor::InferenceMetrics& im) {
109 Processor::Report(im);
111 const ClassificationInferenceMetrics& cim = dynamic_cast<const ClassificationInferenceMetrics&>(im);
113 cout << "Top1 accuracy: " << OUTPUT_FLOATING(100.0 * cim.top1Result / cim.total) << "% (" << cim.top1Result << " of "
114 << cim.total << " images were detected correctly, top class is correct)" << "\n";
115 cout << "Top5 accuracy: " << OUTPUT_FLOATING(100.0 * cim.topCountResult / cim.total) << "% (" << cim.topCountResult << " of "
116 << cim.total << " images were detected correctly, top five classes contain required class)" << "\n";