IVGCVSW-2834 Fix Input TensorInfo in Quantization DataSet
[platform/upstream/armnn.git] / src / armnnQuantizer / QuantizationDataSet.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "QuantizationDataSet.hpp"
7 #include "CsvReader.hpp"
8
9 #define BOOST_FILESYSTEM_NO_DEPRECATED
10
11 #include <boost/filesystem/operations.hpp>
12 #include <boost/filesystem/path.hpp>
13
14 namespace armnnQuantizer
15 {
16
17 QuantizationDataSet::QuantizationDataSet()
18 {
19 }
20
21 QuantizationDataSet::QuantizationDataSet(const std::string csvFilePath):
22     m_QuantizationInputs(),
23     m_CsvFilePath(csvFilePath)
24 {
25     ParseCsvFile();
26 }
27
28 void AddInputData(unsigned int passId,
29                   armnn::LayerBindingId bindingId,
30                   const std::string& inputFilePath,
31                   std::map<unsigned int, QuantizationInput>& passIdToQuantizationInput)
32 {
33     auto iterator = passIdToQuantizationInput.find(passId);
34     if (iterator == passIdToQuantizationInput.end())
35     {
36         QuantizationInput input(passId, bindingId, inputFilePath);
37         passIdToQuantizationInput.emplace(passId, input);
38     }
39     else
40     {
41         auto existingQuantizationInput = iterator->second;
42         existingQuantizationInput.AddEntry(bindingId, inputFilePath);
43     }
44 }
45
46 QuantizationDataSet::~QuantizationDataSet()
47 {
48 }
49
50 void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
51                                         armnn::LayerBindingId id,
52                                         const char* name)
53 {
54     m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
55 }
56
57 armnn::TensorInfo InputLayerVisitor::GetTensorInfo(armnn::LayerBindingId layerBindingId)
58 {
59     auto iterator = m_TensorInfos.find(layerBindingId);
60     if (iterator != m_TensorInfos.end())
61     {
62         return m_TensorInfos.at(layerBindingId);
63     }
64     else
65     {
66         throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
67     }
68 }
69
70
71 unsigned int GetPassIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
72 {
73     unsigned int passId;
74     try
75     {
76         passId = static_cast<unsigned int>(std::stoi(csvRows[rowIndex].values[0]));
77     }
78     catch (std::invalid_argument)
79     {
80         throw armnn::ParseException("Pass ID [" + csvRows[rowIndex].values[0] + "]" +
81                                     " is not correct format on CSV row " + std::to_string(rowIndex));
82     }
83     return passId;
84 }
85
86 armnn::LayerBindingId GetBindingIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
87 {
88     armnn::LayerBindingId bindingId;
89     try
90     {
91         bindingId = std::stoi(csvRows[rowIndex].values[1]);
92     }
93     catch (std::invalid_argument)
94     {
95         throw armnn::ParseException("Binding ID [" + csvRows[rowIndex].values[0] + "]" +
96                                     " is not correct format on CSV row " + std::to_string(rowIndex));
97     }
98     return bindingId;
99 }
100
101 std::string GetFileNameFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
102 {
103     std::string fileName = csvRows[rowIndex].values[2];
104
105     if (!boost::filesystem::exists(fileName))
106     {
107         throw armnn::ParseException("File [ " + fileName + "] provided on CSV row " + std::to_string(rowIndex) +
108                                     " does not exist.");
109     }
110
111     if (fileName.empty())
112     {
113         throw armnn::ParseException("Filename cannot be empty on CSV row " + std::to_string(rowIndex));
114     }
115     return fileName;
116 }
117
118
119 void QuantizationDataSet::ParseCsvFile()
120 {
121     std::map<unsigned int, QuantizationInput> passIdToQuantizationInput;
122     armnnUtils::CsvReader reader;
123
124     if (m_CsvFilePath == "")
125     {
126         throw armnn::Exception("CSV file not specified.");
127     }
128
129     // Parse CSV file and extract data
130     std::vector<armnnUtils::CsvRow> csvRows = reader.ParseFile(m_CsvFilePath);
131     if (csvRows.empty())
132     {
133         throw armnn::Exception("CSV file [" + m_CsvFilePath + "] is empty.");
134     }
135
136     for (unsigned int i = 0; i < csvRows.size(); ++i)
137     {
138         if (csvRows[i].values.size() != 3)
139         {
140             throw armnn::Exception("CSV file [" + m_CsvFilePath + "] does not have correct number of entries " +
141                                    "on line " + std::to_string(i) + ". Expected 3 entries " +
142                                    "but was " + std::to_string(csvRows[i].values.size()));
143         }
144
145         unsigned int passId = GetPassIdFromCsvRow(csvRows, i);
146         armnn::LayerBindingId bindingId = GetBindingIdFromCsvRow(csvRows, i);
147         std::string rawFileName = GetFileNameFromCsvRow(csvRows, i);
148
149         AddInputData(passId, bindingId, rawFileName, passIdToQuantizationInput);
150     }
151
152     if (passIdToQuantizationInput.empty())
153     {
154         throw armnn::Exception("Could not parse CSV file.");
155     }
156
157     // Once all entries in CSV file are parsed successfully and QuantizationInput map is populated, populate
158     // QuantizationInputs iterator for easier access and clear the map
159     for (auto itr = passIdToQuantizationInput.begin(); itr != passIdToQuantizationInput.end(); ++itr)
160     {
161         m_QuantizationInputs.emplace_back(itr->second);
162     }
163 }
164
165 }