IVGCVSW-4487 Remove boost::filesystem
[platform/upstream/armnn.git] / tests / DeepSpeechV1InferenceTest.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "InferenceTest.hpp"
8 #include "DeepSpeechV1Database.hpp"
9
10 #include <armnn/utility/Assert.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12
13 #include <boost/numeric/conversion/cast.hpp>
14 #include <boost/test/tools/floating_point_comparison.hpp>
15
16 #include <vector>
17
18 namespace
19 {
20
21 template<typename Model>
22 class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
23 {
24 public:
25     DeepSpeechV1TestCase(Model& model,
26                          unsigned int testCaseId,
27                          const DeepSpeechV1TestCaseData& testCaseData)
28         : InferenceModelTestCase<Model>(model,
29                                         testCaseId,
30                                         { testCaseData.m_InputData.m_InputSeq,
31                                           testCaseData.m_InputData.m_StateH,
32                                           testCaseData.m_InputData.m_StateC},
33                                         { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
34         , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
35         , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
36                              testCaseData.m_ExpectedOutputData.m_StateC})
37     {}
38
39     TestCaseResult ProcessResult(const InferenceTestOptions& options) override
40     {
41         armnn::IgnoreUnused(options);
42         const std::vector<float>& output1 = boost::get<std::vector<float>>(this->GetOutputs()[0]); // logits
43         ARMNN_ASSERT(output1.size() == k_OutputSize1);
44
45         const std::vector<float>& output2 = boost::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
46         ARMNN_ASSERT(output2.size() == k_OutputSize2);
47
48         const std::vector<float>& output3 = boost::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
49         ARMNN_ASSERT(output3.size() == k_OutputSize3);
50
51         // Check each output to see whether it is the expected value
52         for (unsigned int j = 0u; j < output1.size(); j++)
53         {
54             if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
55             {
56                 ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
57                                          " is incorrect at" << j;
58                 return TestCaseResult::Failed;
59             }
60         }
61
62         for (unsigned int j = 0u; j < output2.size(); j++)
63         {
64             if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateH[j]))
65             {
66                 ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() <<
67                                          " is incorrect";
68                 return TestCaseResult::Failed;
69             }
70         }
71
72         for (unsigned int j = 0u; j < output3.size(); j++)
73         {
74             if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateC[j]))
75             {
76                 ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() <<
77                                          " is incorrect";
78                 return TestCaseResult::Failed;
79             }
80         }
81         return TestCaseResult::Ok;
82     }
83
84 private:
85
86     static constexpr unsigned int k_OutputSize1 = 464u;
87     static constexpr unsigned int k_OutputSize2 = 2048u;
88     static constexpr unsigned int k_OutputSize3 = 2048u;
89
90     boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
91     LstmInput m_ExpectedOutputs;
92 };
93
94 template <typename Model>
95 class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
96 {
97 public:
98     template <typename TConstructModelCallable>
99     explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
100         : m_ConstructModel(constructModel)
101     {}
102
103     virtual void AddCommandLineOptions(boost::program_options::options_description& options) override
104     {
105         namespace po = boost::program_options;
106
107         options.add_options()
108                 ("input-seq-dir,s", po::value<std::string>(&m_InputSeqDir)->required(),
109                  "Path to directory containing test data for m_InputSeq");
110         options.add_options()
111                 ("prev-state-h-dir,h", po::value<std::string>(&m_PrevStateHDir)->required(),
112                  "Path to directory containing test data for m_PrevStateH");
113         options.add_options()
114                 ("prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(),
115                  "Path to directory containing test data for m_PrevStateC");
116         options.add_options()
117                 ("logits-dir,l", po::value<std::string>(&m_LogitsDir)->required(),
118                  "Path to directory containing test data for m_Logits");
119         options.add_options()
120                 ("new-state-h-dir,H", po::value<std::string>(&m_NewStateHDir)->required(),
121                  "Path to directory containing test data for m_NewStateH");
122         options.add_options()
123                 ("new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(),
124                  "Path to directory containing test data for m_NewStateC");
125
126
127         Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
128     }
129
130     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
131     {
132         if (!ValidateDirectory(m_InputSeqDir))
133         {
134             return false;
135         }
136
137         if (!ValidateDirectory(m_PrevStateCDir))
138         {
139             return false;
140         }
141
142         if (!ValidateDirectory(m_PrevStateHDir))
143         {
144             return false;
145         }
146
147         if (!ValidateDirectory(m_LogitsDir))
148         {
149             return false;
150         }
151
152         if (!ValidateDirectory(m_NewStateCDir))
153         {
154             return false;
155         }
156
157         if (!ValidateDirectory(m_NewStateHDir))
158         {
159             return false;
160         }
161
162         m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
163         if (!m_Model)
164         {
165             return false;
166         }
167         m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
168                                                             m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
169                                                             m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
170         if (!m_Database)
171         {
172             return false;
173         }
174
175         return true;
176     }
177
178     std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
179     {
180         std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
181         if (!testCaseData)
182         {
183             return nullptr;
184         }
185
186         return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
187     }
188
189 private:
190     typename Model::CommandLineOptions m_ModelCommandLineOptions;
191     std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
192                                          typename Model::CommandLineOptions)> m_ConstructModel;
193     std::unique_ptr<Model> m_Model;
194
195     std::string m_InputSeqDir;
196     std::string m_PrevStateCDir;
197     std::string m_PrevStateHDir;
198     std::string m_LogitsDir;
199     std::string m_NewStateCDir;
200     std::string m_NewStateHDir;
201
202     std::unique_ptr<DeepSpeechV1Database> m_Database;
203 };
204
205 } // anonymous namespace