IVGCVSW-4661 Add include Assert to GatordMockService.cpp
[platform/upstream/armnn.git] / tests / DeepSpeechV1Database.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "LstmCommon.hpp"
8
9 #include <memory>
10 #include <string>
11 #include <vector>
12
13 #include <armnn/TypesUtils.hpp>
14
15 #include <boost/numeric/conversion/cast.hpp>
16
17 #include <array>
18 #include <string>
19
20 #include "InferenceTestImage.hpp"
21
22 namespace
23 {
24
25 template<typename T, typename TParseElementFunc>
26 std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
27 {
28     std::vector<T> result;
29     // Processes line-by-line.
30     std::string line;
31     while (std::getline(stream, line))
32     {
33         std::vector<std::string> tokens;
34         try
35         {
36             // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
37             boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on);
38         }
39         catch (const std::exception& e)
40         {
41             ARMNN_LOG(error) << "An error occurred when splitting tokens: " << e.what();
42             continue;
43         }
44         for (const std::string& token : tokens)
45         {
46             if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
47             {
48                 try
49                 {
50                     result.push_back(parseElementFunc(token));
51                 }
52                 catch (const std::exception&)
53                 {
54                     ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
55                 }
56             }
57         }
58     }
59
60     return result;
61 }
62
63 template<armnn::DataType NonQuantizedType>
64 auto ParseDataArray(std::istream & stream);
65
66 template<armnn::DataType QuantizedType>
67 auto ParseDataArray(std::istream& stream,
68                     const float& quantizationScale,
69                     const int32_t& quantizationOffset);
70
71 // NOTE: declaring the template specialisations inline to prevent them
72 //       being flagged as unused functions when -Werror=unused-function is in effect
73 template<>
74 inline auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
75 {
76     return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
77 }
78
79 template<>
80 inline auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
81 {
82     return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
83 }
84
85 template<>
86 inline auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream,
87                                                       const float& quantizationScale,
88                                                       const int32_t& quantizationOffset)
89 {
90     return ParseArrayImpl<uint8_t>(stream,
91                                    [&quantizationScale, &quantizationOffset](const std::string & s)
92                                    {
93                                        return boost::numeric_cast<uint8_t>(
94                                                armnn::Quantize<u_int8_t>(std::stof(s),
95                                                                          quantizationScale,
96                                                                          quantizationOffset));
97                                    });
98 }
99
100 struct DeepSpeechV1TestCaseData
101 {
102     DeepSpeechV1TestCaseData(
103         const LstmInput& inputData,
104         const LstmInput& expectedOutputData)
105         : m_InputData(inputData)
106         , m_ExpectedOutputData(expectedOutputData)
107     {}
108
109     LstmInput m_InputData;
110     LstmInput m_ExpectedOutputData;
111 };
112
113 class DeepSpeechV1Database
114 {
115 public:
116     explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
117                                   const std::string& prevStateCDir, const std::string& logitsDir,
118                                   const std::string& newStateHDir, const std::string& newStateCDir);
119
120     std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId);
121
122 private:
123     std::string m_InputSeqDir;
124     std::string m_PrevStateHDir;
125     std::string m_PrevStateCDir;
126     std::string m_LogitsDir;
127     std::string m_NewStateHDir;
128     std::string m_NewStateCDir;
129 };
130
131 DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
132                                            const std::string& prevStateCDir, const std::string& logitsDir,
133                                            const std::string& newStateHDir, const std::string& newStateCDir)
134     : m_InputSeqDir(inputSeqDir)
135     , m_PrevStateHDir(prevStateHDir)
136     , m_PrevStateCDir(prevStateCDir)
137     , m_LogitsDir(logitsDir)
138     , m_NewStateHDir(newStateHDir)
139     , m_NewStateCDir(newStateCDir)
140 {}
141
142 std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId)
143 {
144     // Load test case input
145     const std::string inputSeqPath   = m_InputSeqDir + "input_node_0_flat.txt";
146     const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt";
147     const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt";
148
149     std::vector<float> inputSeqData;
150     std::vector<float> prevStateCData;
151     std::vector<float> prevStateHData;
152
153     std::ifstream inputSeqFile(inputSeqPath);
154     std::ifstream prevStateCTensorFile(prevStateCPath);
155     std::ifstream prevStateHTensorFile(prevStateHPath);
156
157     try
158     {
159         inputSeqData   = ParseDataArray<armnn::DataType::Float32>(inputSeqFile);
160         prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile);
161         prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile);
162     }
163     catch (const InferenceTestImageException& e)
164     {
165         ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
166         return nullptr;
167     }
168
169     // Prepare test case expected output
170     const std::string logitsPath   = m_LogitsDir + "logits.txt";
171     const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt";
172     const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt";
173
174     std::vector<float> logitsData;
175     std::vector<float> expectedNewStateCData;
176     std::vector<float> expectedNewStateHData;
177
178     std::ifstream logitsTensorFile(logitsPath);
179     std::ifstream newStateCTensorFile(newStateCPath);
180     std::ifstream newStateHTensorFile(newStateHPath);
181
182     try
183     {
184         logitsData     = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile);
185         expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile);
186         expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile);
187     }
188     catch (const InferenceTestImageException& e)
189     {
190         ARMNN_LOG(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
191         return nullptr;
192     }
193
194     // use the struct for representing input and output data
195     LstmInput inputDataSingleTest(inputSeqData, prevStateHData, prevStateCData);
196
197     LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateHData, expectedNewStateCData);
198
199     return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest);
200 }
201
202 } // anonymous namespace