NNXSW-1853 Change SubgraphViewSelector algorithm
[platform/upstream/armnn.git] / tests / DeepSpeechV1InferenceTest.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "InferenceTest.hpp"
8 #include "DeepSpeechV1Database.hpp"
9
10 #include <boost/assert.hpp>
11 #include <boost/log/trivial.hpp>
12 #include <boost/numeric/conversion/cast.hpp>
13 #include <boost/test/tools/floating_point_comparison.hpp>
14
15 #include <vector>
16
17 namespace
18 {
19
20 template<typename Model>
21 class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
22 {
23 public:
24     DeepSpeechV1TestCase(Model& model,
25                          unsigned int testCaseId,
26                          const DeepSpeechV1TestCaseData& testCaseData)
27         : InferenceModelTestCase<Model>(model,
28                                         testCaseId,
29                                         { testCaseData.m_InputData.m_InputSeq,
30                                           testCaseData.m_InputData.m_StateH,
31                                           testCaseData.m_InputData.m_StateC},
32                                         { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
33         , m_FloatComparer(boost::math::fpc::percent_tolerance(1.0f))
34         , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
35                              testCaseData.m_ExpectedOutputData.m_StateC})
36     {}
37
38     TestCaseResult ProcessResult(const InferenceTestOptions& options) override
39     {
40         const std::vector<float>& output1 = boost::get<std::vector<float>>(this->GetOutputs()[0]); // logits
41         BOOST_ASSERT(output1.size() == k_OutputSize1);
42
43         const std::vector<float>& output2 = boost::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
44         BOOST_ASSERT(output2.size() == k_OutputSize2);
45
46         const std::vector<float>& output3 = boost::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
47         BOOST_ASSERT(output3.size() == k_OutputSize3);
48
49         // Check each output to see whether it is the expected value
50         for (unsigned int j = 0u; j < output1.size(); j++)
51         {
52             if(!m_FloatComparer(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
53             {
54                 BOOST_LOG_TRIVIAL(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
55                                          " is incorrect at" << j;
56                 return TestCaseResult::Failed;
57             }
58         }
59
60         for (unsigned int j = 0u; j < output2.size(); j++)
61         {
62             if(!m_FloatComparer(output2[j], m_ExpectedOutputs.m_StateH[j]))
63             {
64                 BOOST_LOG_TRIVIAL(error) << "StateH for Lstm " << this->GetTestCaseId() <<
65                                          " is incorrect";
66                 return TestCaseResult::Failed;
67             }
68         }
69
70         for (unsigned int j = 0u; j < output3.size(); j++)
71         {
72             if(!m_FloatComparer(output3[j], m_ExpectedOutputs.m_StateC[j]))
73             {
74                 BOOST_LOG_TRIVIAL(error) << "StateC for Lstm " << this->GetTestCaseId() <<
75                                          " is incorrect";
76                 return TestCaseResult::Failed;
77             }
78         }
79         return TestCaseResult::Ok;
80     }
81
82 private:
83
84     static constexpr unsigned int k_OutputSize1 = 464u;
85     static constexpr unsigned int k_OutputSize2 = 2048u;
86     static constexpr unsigned int k_OutputSize3 = 2048u;
87
88     boost::math::fpc::close_at_tolerance<float> m_FloatComparer;
89     LstmInput m_ExpectedOutputs;
90 };
91
92 template <typename Model>
93 class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
94 {
95 public:
96     template <typename TConstructModelCallable>
97     explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
98         : m_ConstructModel(constructModel)
99     {}
100
101     virtual void AddCommandLineOptions(boost::program_options::options_description& options) override
102     {
103         namespace po = boost::program_options;
104
105         options.add_options()
106                 ("input-seq-dir,s", po::value<std::string>(&m_InputSeqDir)->required(),
107                  "Path to directory containing test data for m_InputSeq");
108         options.add_options()
109                 ("prev-state-h-dir,h", po::value<std::string>(&m_PrevStateHDir)->required(),
110                  "Path to directory containing test data for m_PrevStateH");
111         options.add_options()
112                 ("prev-state-c-dir,c", po::value<std::string>(&m_PrevStateCDir)->required(),
113                  "Path to directory containing test data for m_PrevStateC");
114         options.add_options()
115                 ("logits-dir,l", po::value<std::string>(&m_LogitsDir)->required(),
116                  "Path to directory containing test data for m_Logits");
117         options.add_options()
118                 ("new-state-h-dir,H", po::value<std::string>(&m_NewStateHDir)->required(),
119                  "Path to directory containing test data for m_NewStateH");
120         options.add_options()
121                 ("new-state-c-dir,C", po::value<std::string>(&m_NewStateCDir)->required(),
122                  "Path to directory containing test data for m_NewStateC");
123
124
125         Model::AddCommandLineOptions(options, m_ModelCommandLineOptions);
126     }
127
128     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
129     {
130         if (!ValidateDirectory(m_InputSeqDir))
131         {
132             return false;
133         }
134
135         if (!ValidateDirectory(m_PrevStateCDir))
136         {
137             return false;
138         }
139
140         if (!ValidateDirectory(m_PrevStateHDir))
141         {
142             return false;
143         }
144
145         if (!ValidateDirectory(m_LogitsDir))
146         {
147             return false;
148         }
149
150         if (!ValidateDirectory(m_NewStateCDir))
151         {
152             return false;
153         }
154
155         if (!ValidateDirectory(m_NewStateHDir))
156         {
157             return false;
158         }
159
160         m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
161         if (!m_Model)
162         {
163             return false;
164         }
165         m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
166                                                             m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
167                                                             m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
168         if (!m_Database)
169         {
170             return false;
171         }
172
173         return true;
174     }
175
176     std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
177     {
178         std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
179         if (!testCaseData)
180         {
181             return nullptr;
182         }
183
184         return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
185     }
186
187 private:
188     typename Model::CommandLineOptions m_ModelCommandLineOptions;
189     std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
190                                          typename Model::CommandLineOptions)> m_ConstructModel;
191     std::unique_ptr<Model> m_Model;
192
193     std::string m_InputSeqDir;
194     std::string m_PrevStateCDir;
195     std::string m_PrevStateHDir;
196     std::string m_LogitsDir;
197     std::string m_NewStateCDir;
198     std::string m_NewStateHDir;
199
200     std::unique_ptr<DeepSpeechV1Database> m_Database;
201 };
202
203 } // anonymous namespace