NNXSW-1853 Change SubgraphViewSelector algorithm
[platform/upstream/armnn.git] / tests / DeepSpeechV1Database.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "LstmCommon.hpp"
8
9 #include <memory>
10 #include <string>
11 #include <vector>
12
13 #include <armnn/TypesUtils.hpp>
14
15 #include <boost/log/trivial.hpp>
16 #include <boost/numeric/conversion/cast.hpp>
17
18 #include <array>
19 #include <string>
20
21 #include "InferenceTestImage.hpp"
22
23 namespace
24 {
25
26 template<typename T, typename TParseElementFunc>
27 std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
28 {
29     std::vector<T> result;
30     // Processes line-by-line.
31     std::string line;
32     while (std::getline(stream, line))
33     {
34         std::vector<std::string> tokens;
35         try
36         {
37             // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
38             boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on);
39         }
40         catch (const std::exception& e)
41         {
42             BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
43             continue;
44         }
45         for (const std::string& token : tokens)
46         {
47             if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
48             {
49                 try
50                 {
51                     result.push_back(parseElementFunc(token));
52                 }
53                 catch (const std::exception&)
54                 {
55                     BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
56                 }
57             }
58         }
59     }
60
61     return result;
62 }
63
64 template<armnn::DataType NonQuantizedType>
65 auto ParseDataArray(std::istream & stream);
66
67 template<armnn::DataType QuantizedType>
68 auto ParseDataArray(std::istream& stream,
69                     const float& quantizationScale,
70                     const int32_t& quantizationOffset);
71
72 // NOTE: declaring the template specialisations inline to prevent them
73 //       being flagged as unused functions when -Werror=unused-function is in effect
74 template<>
75 inline auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
76 {
77     return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
78 }
79
80 template<>
81 inline auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
82 {
83     return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
84 }
85
86 template<>
87 inline auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream,
88                                                       const float& quantizationScale,
89                                                       const int32_t& quantizationOffset)
90 {
91     return ParseArrayImpl<uint8_t>(stream,
92                                    [&quantizationScale, &quantizationOffset](const std::string & s)
93                                    {
94                                        return boost::numeric_cast<uint8_t>(
95                                                armnn::Quantize<u_int8_t>(std::stof(s),
96                                                                          quantizationScale,
97                                                                          quantizationOffset));
98                                    });
99 }
100
101 struct DeepSpeechV1TestCaseData
102 {
103     DeepSpeechV1TestCaseData(
104         const LstmInput& inputData,
105         const LstmInput& expectedOutputData)
106         : m_InputData(inputData)
107         , m_ExpectedOutputData(expectedOutputData)
108     {}
109
110     LstmInput m_InputData;
111     LstmInput m_ExpectedOutputData;
112 };
113
114 class DeepSpeechV1Database
115 {
116 public:
117     explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
118                                   const std::string& prevStateCDir, const std::string& logitsDir,
119                                   const std::string& newStateHDir, const std::string& newStateCDir);
120
121     std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId);
122
123 private:
124     std::string m_InputSeqDir;
125     std::string m_PrevStateHDir;
126     std::string m_PrevStateCDir;
127     std::string m_LogitsDir;
128     std::string m_NewStateHDir;
129     std::string m_NewStateCDir;
130 };
131
132 DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
133                                            const std::string& prevStateCDir, const std::string& logitsDir,
134                                            const std::string& newStateHDir, const std::string& newStateCDir)
135     : m_InputSeqDir(inputSeqDir)
136     , m_PrevStateHDir(prevStateHDir)
137     , m_PrevStateCDir(prevStateCDir)
138     , m_LogitsDir(logitsDir)
139     , m_NewStateHDir(newStateHDir)
140     , m_NewStateCDir(newStateCDir)
141 {}
142
143 std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId)
144 {
145     // Load test case input
146     const std::string inputSeqPath   = m_InputSeqDir + "input_node_0_flat.txt";
147     const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt";
148     const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt";
149
150     std::vector<float> inputSeqData;
151     std::vector<float> prevStateCData;
152     std::vector<float> prevStateHData;
153
154     std::ifstream inputSeqFile(inputSeqPath);
155     std::ifstream prevStateCTensorFile(prevStateCPath);
156     std::ifstream prevStateHTensorFile(prevStateHPath);
157
158     try
159     {
160         inputSeqData   = ParseDataArray<armnn::DataType::Float32>(inputSeqFile);
161         prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile);
162         prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile);
163     }
164     catch (const InferenceTestImageException& e)
165     {
166         BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
167         return nullptr;
168     }
169
170     // Prepare test case expected output
171     const std::string logitsPath   = m_LogitsDir + "logits.txt";
172     const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt";
173     const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt";
174
175     std::vector<float> logitsData;
176     std::vector<float> expectedNewStateCData;
177     std::vector<float> expectedNewStateHData;
178
179     std::ifstream logitsTensorFile(logitsPath);
180     std::ifstream newStateCTensorFile(newStateCPath);
181     std::ifstream newStateHTensorFile(newStateHPath);
182
183     try
184     {
185         logitsData     = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile);
186         expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile);
187         expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile);
188     }
189     catch (const InferenceTestImageException& e)
190     {
191         BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
192         return nullptr;
193     }
194
195     // use the struct for representing input and output data
196     LstmInput inputDataSingleTest(inputSeqData, prevStateHData, prevStateCData);
197
198     LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateHData, expectedNewStateCData);
199
200     return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest);
201 }
202
203 } // anonymous namespace