MLCE-103 Remove hardcoded output shape in ModelAccuracyTool
[platform/upstream/armnn.git] / tests / DeepSpeechV1Database.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "LstmCommon.hpp"
8
9 #include <memory>
10 #include <string>
11 #include <vector>
12
13 #include <armnn/TypesUtils.hpp>
14 #include <backendsCommon/test/QuantizeHelper.hpp>
15
16 #include <boost/log/trivial.hpp>
17 #include <boost/numeric/conversion/cast.hpp>
18
19 #include <array>
20 #include <string>
21
22 #include "InferenceTestImage.hpp"
23
24 namespace
25 {
26
27 template<typename T, typename TParseElementFunc>
28 std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
29 {
30     std::vector<T> result;
31     // Processes line-by-line.
32     std::string line;
33     while (std::getline(stream, line))
34     {
35         std::vector<std::string> tokens;
36         try
37         {
38             // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
39             boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on);
40         }
41         catch (const std::exception& e)
42         {
43             BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
44             continue;
45         }
46         for (const std::string& token : tokens)
47         {
48             if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
49             {
50                 try
51                 {
52                     result.push_back(parseElementFunc(token));
53                 }
54                 catch (const std::exception&)
55                 {
56                     BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
57                 }
58             }
59         }
60     }
61
62     return result;
63 }
64
65 template<armnn::DataType NonQuantizedType>
66 auto ParseDataArray(std::istream & stream);
67
68 template<armnn::DataType QuantizedType>
69 auto ParseDataArray(std::istream& stream,
70                     const float& quantizationScale,
71                     const int32_t& quantizationOffset);
72
73 // NOTE: declaring the template specialisations inline to prevent them
74 //       being flagged as unused functions when -Werror=unused-function is in effect
75 template<>
76 inline auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
77 {
78     return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
79 }
80
81 template<>
82 inline auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
83 {
84     return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
85 }
86
87 template<>
88 inline auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream,
89                                                       const float& quantizationScale,
90                                                       const int32_t& quantizationOffset)
91 {
92     return ParseArrayImpl<uint8_t>(stream,
93                                    [&quantizationScale, &quantizationOffset](const std::string & s)
94                                    {
95                                        return boost::numeric_cast<uint8_t>(
96                                                armnn::Quantize<u_int8_t>(std::stof(s),
97                                                                          quantizationScale,
98                                                                          quantizationOffset));
99                                    });
100 }
101
102 struct DeepSpeechV1TestCaseData
103 {
104     DeepSpeechV1TestCaseData(
105         const LstmInput& inputData,
106         const LstmInput& expectedOutputData)
107         : m_InputData(inputData)
108         , m_ExpectedOutputData(expectedOutputData)
109     {}
110
111     LstmInput m_InputData;
112     LstmInput m_ExpectedOutputData;
113 };
114
115 class DeepSpeechV1Database
116 {
117 public:
118     explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
119                                   const std::string& prevStateCDir, const std::string& logitsDir,
120                                   const std::string& newStateHDir, const std::string& newStateCDir);
121
122     std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId);
123
124 private:
125     std::string m_InputSeqDir;
126     std::string m_PrevStateHDir;
127     std::string m_PrevStateCDir;
128     std::string m_LogitsDir;
129     std::string m_NewStateHDir;
130     std::string m_NewStateCDir;
131 };
132
133 DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
134                                            const std::string& prevStateCDir, const std::string& logitsDir,
135                                            const std::string& newStateHDir, const std::string& newStateCDir)
136     : m_InputSeqDir(inputSeqDir)
137     , m_PrevStateHDir(prevStateHDir)
138     , m_PrevStateCDir(prevStateCDir)
139     , m_LogitsDir(logitsDir)
140     , m_NewStateHDir(newStateHDir)
141     , m_NewStateCDir(newStateCDir)
142 {}
143
144 std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId)
145 {
146     // Load test case input
147     const std::string inputSeqPath   = m_InputSeqDir + "input_node_0_flat.txt";
148     const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt";
149     const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt";
150
151     std::vector<float> inputSeqData;
152     std::vector<float> prevStateCData;
153     std::vector<float> prevStateHData;
154
155     std::ifstream inputSeqFile(inputSeqPath);
156     std::ifstream prevStateCTensorFile(prevStateCPath);
157     std::ifstream prevStateHTensorFile(prevStateHPath);
158
159     try
160     {
161         inputSeqData   = ParseDataArray<armnn::DataType::Float32>(inputSeqFile);
162         prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile);
163         prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile);
164     }
165     catch (const InferenceTestImageException& e)
166     {
167         BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
168         return nullptr;
169     }
170
171     // Prepare test case expected output
172     const std::string logitsPath   = m_LogitsDir + "logits.txt";
173     const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt";
174     const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt";
175
176     std::vector<float> logitsData;
177     std::vector<float> expectedNewStateCData;
178     std::vector<float> expectedNewStateHData;
179
180     std::ifstream logitsTensorFile(logitsPath);
181     std::ifstream newStateCTensorFile(newStateCPath);
182     std::ifstream newStateHTensorFile(newStateHPath);
183
184     try
185     {
186         logitsData     = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile);
187         expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile);
188         expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile);
189     }
190     catch (const InferenceTestImageException& e)
191     {
192         BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
193         return nullptr;
194     }
195
196     // use the struct for representing input and output data
197     LstmInput inputDataSingleTest(inputSeqData, prevStateHData, prevStateCData);
198
199     LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateHData, expectedNewStateCData);
200
201     return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest);
202 }
203
204 } // anonymous namespace
205