2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include "LstmCommon.hpp"
13 #include <armnn/TypesUtils.hpp>
15 #include <boost/log/trivial.hpp>
16 #include <boost/numeric/conversion/cast.hpp>
21 #include "InferenceTestImage.hpp"
26 template<typename T, typename TParseElementFunc>
27 std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
29 std::vector<T> result;
30 // Processes line-by-line.
32 while (std::getline(stream, line))
34 std::vector<std::string> tokens;
37 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
38 boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on);
40 catch (const std::exception& e)
42 BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
45 for (const std::string& token : tokens)
47 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
51 result.push_back(parseElementFunc(token));
53 catch (const std::exception&)
55 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
64 template<armnn::DataType NonQuantizedType>
65 auto ParseDataArray(std::istream & stream);
67 template<armnn::DataType QuantizedType>
68 auto ParseDataArray(std::istream& stream,
69 const float& quantizationScale,
70 const int32_t& quantizationOffset);
72 // NOTE: declaring the template specialisations inline to prevent them
73 // being flagged as unused functions when -Werror=unused-function is in effect
75 inline auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
77 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
81 inline auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
83 return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
87 inline auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream,
88 const float& quantizationScale,
89 const int32_t& quantizationOffset)
91 return ParseArrayImpl<uint8_t>(stream,
92 [&quantizationScale, &quantizationOffset](const std::string & s)
94 return boost::numeric_cast<uint8_t>(
95 armnn::Quantize<u_int8_t>(std::stof(s),
101 struct DeepSpeechV1TestCaseData
103 DeepSpeechV1TestCaseData(
104 const LstmInput& inputData,
105 const LstmInput& expectedOutputData)
106 : m_InputData(inputData)
107 , m_ExpectedOutputData(expectedOutputData)
110 LstmInput m_InputData;
111 LstmInput m_ExpectedOutputData;
114 class DeepSpeechV1Database
117 explicit DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
118 const std::string& prevStateCDir, const std::string& logitsDir,
119 const std::string& newStateHDir, const std::string& newStateCDir);
121 std::unique_ptr<DeepSpeechV1TestCaseData> GetTestCaseData(unsigned int testCaseId);
124 std::string m_InputSeqDir;
125 std::string m_PrevStateHDir;
126 std::string m_PrevStateCDir;
127 std::string m_LogitsDir;
128 std::string m_NewStateHDir;
129 std::string m_NewStateCDir;
132 DeepSpeechV1Database::DeepSpeechV1Database(const std::string& inputSeqDir, const std::string& prevStateHDir,
133 const std::string& prevStateCDir, const std::string& logitsDir,
134 const std::string& newStateHDir, const std::string& newStateCDir)
135 : m_InputSeqDir(inputSeqDir)
136 , m_PrevStateHDir(prevStateHDir)
137 , m_PrevStateCDir(prevStateCDir)
138 , m_LogitsDir(logitsDir)
139 , m_NewStateHDir(newStateHDir)
140 , m_NewStateCDir(newStateCDir)
143 std::unique_ptr<DeepSpeechV1TestCaseData> DeepSpeechV1Database::GetTestCaseData(unsigned int testCaseId)
145 // Load test case input
146 const std::string inputSeqPath = m_InputSeqDir + "input_node_0_flat.txt";
147 const std::string prevStateCPath = m_PrevStateCDir + "previous_state_c_0.txt";
148 const std::string prevStateHPath = m_PrevStateHDir + "previous_state_h_0.txt";
150 std::vector<float> inputSeqData;
151 std::vector<float> prevStateCData;
152 std::vector<float> prevStateHData;
154 std::ifstream inputSeqFile(inputSeqPath);
155 std::ifstream prevStateCTensorFile(prevStateCPath);
156 std::ifstream prevStateHTensorFile(prevStateHPath);
160 inputSeqData = ParseDataArray<armnn::DataType::Float32>(inputSeqFile);
161 prevStateCData = ParseDataArray<armnn::DataType::Float32>(prevStateCTensorFile);
162 prevStateHData = ParseDataArray<armnn::DataType::Float32>(prevStateHTensorFile);
164 catch (const InferenceTestImageException& e)
166 BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
170 // Prepare test case expected output
171 const std::string logitsPath = m_LogitsDir + "logits.txt";
172 const std::string newStateCPath = m_NewStateCDir + "new_state_c.txt";
173 const std::string newStateHPath = m_NewStateHDir + "new_state_h.txt";
175 std::vector<float> logitsData;
176 std::vector<float> expectedNewStateCData;
177 std::vector<float> expectedNewStateHData;
179 std::ifstream logitsTensorFile(logitsPath);
180 std::ifstream newStateCTensorFile(newStateCPath);
181 std::ifstream newStateHTensorFile(newStateHPath);
185 logitsData = ParseDataArray<armnn::DataType::Float32>(logitsTensorFile);
186 expectedNewStateCData = ParseDataArray<armnn::DataType::Float32>(newStateCTensorFile);
187 expectedNewStateHData = ParseDataArray<armnn::DataType::Float32>(newStateHTensorFile);
189 catch (const InferenceTestImageException& e)
191 BOOST_LOG_TRIVIAL(fatal) << "Failed to load image for test case " << testCaseId << ". Error: " << e.what();
195 // use the struct for representing input and output data
196 LstmInput inputDataSingleTest(inputSeqData, prevStateHData, prevStateCData);
198 LstmInput expectedOutputsSingleTest(logitsData, expectedNewStateHData, expectedNewStateCData);
200 return std::make_unique<DeepSpeechV1TestCaseData>(inputDataSingleTest, expectedOutputsSingleTest);
203 } // anonymous namespace