MLCE-347 'REDUCE_MIN, REDUCE_MAX, REDUCE_SUM Support'
[platform/upstream/armnn.git] / tests / DeepSpeechV1InferenceTest.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "InferenceTest.hpp"
8 #include "DeepSpeechV1Database.hpp"
9
10 #include <armnn/utility/Assert.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12 #include <armnnUtils/FloatingPointComparison.hpp>
13
14 #include <vector>
15
16 namespace
17 {
18
19 template<typename Model>
20 class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
21 {
22 public:
23     DeepSpeechV1TestCase(Model& model,
24                          unsigned int testCaseId,
25                          const DeepSpeechV1TestCaseData& testCaseData)
26         : InferenceModelTestCase<Model>(model,
27                                         testCaseId,
28                                         { testCaseData.m_InputData.m_InputSeq,
29                                           testCaseData.m_InputData.m_StateH,
30                                           testCaseData.m_InputData.m_StateC},
31                                         { k_OutputSize1, k_OutputSize2, k_OutputSize3 })
32         , m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
33                              testCaseData.m_ExpectedOutputData.m_StateC})
34     {}
35
36     TestCaseResult ProcessResult(const InferenceTestOptions& options) override
37     {
38         armnn::IgnoreUnused(options);
39         const std::vector<float>& output1 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]); // logits
40         ARMNN_ASSERT(output1.size() == k_OutputSize1);
41
42         const std::vector<float>& output2 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
43         ARMNN_ASSERT(output2.size() == k_OutputSize2);
44
45         const std::vector<float>& output3 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
46         ARMNN_ASSERT(output3.size() == k_OutputSize3);
47
48         // Check each output to see whether it is the expected value
49         for (unsigned int j = 0u; j < output1.size(); j++)
50         {
51             if(!armnnUtils::within_percentage_tolerance(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
52             {
53                 ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
54                                          " is incorrect at" << j;
55                 return TestCaseResult::Failed;
56             }
57         }
58
59         for (unsigned int j = 0u; j < output2.size(); j++)
60         {
61             if(!armnnUtils::within_percentage_tolerance(output2[j], m_ExpectedOutputs.m_StateH[j]))
62             {
63                 ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() <<
64                                          " is incorrect";
65                 return TestCaseResult::Failed;
66             }
67         }
68
69         for (unsigned int j = 0u; j < output3.size(); j++)
70         {
71             if(!armnnUtils::within_percentage_tolerance(output3[j], m_ExpectedOutputs.m_StateC[j]))
72             {
73                 ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() <<
74                                          " is incorrect";
75                 return TestCaseResult::Failed;
76             }
77         }
78         return TestCaseResult::Ok;
79     }
80
81 private:
82
83     static constexpr unsigned int k_OutputSize1 = 464u;
84     static constexpr unsigned int k_OutputSize2 = 2048u;
85     static constexpr unsigned int k_OutputSize3 = 2048u;
86
87     LstmInput m_ExpectedOutputs;
88 };
89
90 template <typename Model>
91 class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
92 {
93 public:
94     template <typename TConstructModelCallable>
95     explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
96         : m_ConstructModel(constructModel)
97     {}
98
99     virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override
100     {
101         options
102             .allow_unrecognised_options()
103             .add_options()
104                 ("s,input-seq-dir", "Path to directory containing test data for m_InputSeq",
105                  cxxopts::value<std::string>(m_InputSeqDir))
106                 ("h,prev-state-h-dir", "Path to directory containing test data for m_PrevStateH",
107                  cxxopts::value<std::string>(m_PrevStateHDir))
108                 ("c,prev-state-c-dir", "Path to directory containing test data for m_PrevStateC",
109                  cxxopts::value<std::string>(m_PrevStateCDir))
110                 ("l,logits-dir", "Path to directory containing test data for m_Logits",
111                  cxxopts::value<std::string>(m_LogitsDir))
112                 ("H,new-state-h-dir", "Path to directory containing test data for m_NewStateH",
113                  cxxopts::value<std::string>(m_NewStateHDir))
114                 ("C,new-state-c-dir", "Path to directory containing test data for m_NewStateC",
115                  cxxopts::value<std::string>(m_NewStateCDir));
116
117         required.insert(required.end(), {"input-seq-dir", "prev-state-h-dir", "prev-state-c-dir", "logits-dir",
118                                          "new-state-h-dir", "new-state-c-dir"});
119
120         Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
121     }
122
123     virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
124     {
125         if (!ValidateDirectory(m_InputSeqDir))
126         {
127             return false;
128         }
129
130         if (!ValidateDirectory(m_PrevStateCDir))
131         {
132             return false;
133         }
134
135         if (!ValidateDirectory(m_PrevStateHDir))
136         {
137             return false;
138         }
139
140         if (!ValidateDirectory(m_LogitsDir))
141         {
142             return false;
143         }
144
145         if (!ValidateDirectory(m_NewStateCDir))
146         {
147             return false;
148         }
149
150         if (!ValidateDirectory(m_NewStateHDir))
151         {
152             return false;
153         }
154
155         m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
156         if (!m_Model)
157         {
158             return false;
159         }
160         m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
161                                                             m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
162                                                             m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
163         if (!m_Database)
164         {
165             return false;
166         }
167
168         return true;
169     }
170
171     std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
172     {
173         std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
174         if (!testCaseData)
175         {
176             return nullptr;
177         }
178
179         return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
180     }
181
182 private:
183     typename Model::CommandLineOptions m_ModelCommandLineOptions;
184     std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
185                                          typename Model::CommandLineOptions)> m_ConstructModel;
186     std::unique_ptr<Model> m_Model;
187
188     std::string m_InputSeqDir;
189     std::string m_PrevStateCDir;
190     std::string m_PrevStateHDir;
191     std::string m_LogitsDir;
192     std::string m_NewStateCDir;
193     std::string m_NewStateHDir;
194
195     std::unique_ptr<DeepSpeechV1Database> m_Database;
196 };
197
198 } // anonymous namespace