21be22f343f83c33eaf3e68e3de165cfa576d064
[platform/core/ml/nnfw.git] / tests / nnfw_api / src / fixtures.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __NNFW_API_TEST_FIXTURES_H__
18 #define __NNFW_API_TEST_FIXTURES_H__
19
20 #include <array>
21 #include <gtest/gtest.h>
22 #include <nnfw_experimental.h>
23 #include <nnfw_internal.h>
24
25 #include "NNPackages.h"
26
27 #define NNFW_ENSURE_SUCCESS(EXPR) ASSERT_EQ((EXPR), NNFW_STATUS_NO_ERROR)
28
29 inline uint64_t num_elems(const nnfw_tensorinfo *ti)
30 {
31   uint64_t n = 1;
32   for (uint32_t i = 0; i < ti->rank; ++i)
33   {
34     n *= ti->dims[i];
35   }
36   return n;
37 }
38
39 struct SessionObject
40 {
41   nnfw_session *session = nullptr;
42   std::vector<std::vector<float>> inputs;
43   std::vector<std::vector<float>> outputs;
44 };
45
46 class ValidationTest : public ::testing::Test
47 {
48 protected:
49   void SetUp() override {}
50 };
51
52 class RegressionTest : public ::testing::Test
53 {
54 protected:
55   void SetUp() override {}
56 };
57
58 class ValidationTestSingleSession : public ValidationTest
59 {
60 protected:
61   nnfw_session *_session = nullptr;
62 };
63
64 class ValidationTestSessionCreated : public ValidationTestSingleSession
65 {
66 protected:
67   void SetUp() override
68   {
69     ValidationTestSingleSession::SetUp();
70     ASSERT_EQ(nnfw_create_session(&_session), NNFW_STATUS_NO_ERROR);
71   }
72
73   void TearDown() override
74   {
75     ASSERT_EQ(nnfw_close_session(_session), NNFW_STATUS_NO_ERROR);
76     ValidationTestSingleSession::TearDown();
77   }
78 };
79
80 template <int PackageNo> class ValidationTestModelLoaded : public ValidationTestSessionCreated
81 {
82 protected:
83   void SetUp() override
84   {
85     ValidationTestSessionCreated::SetUp();
86     ASSERT_EQ(nnfw_load_model_from_file(_session,
87                                         NNPackages::get().getModelAbsolutePath(PackageNo).c_str()),
88               NNFW_STATUS_NO_ERROR);
89     ASSERT_NE(_session, nullptr);
90   }
91
92   void TearDown() override { ValidationTestSessionCreated::TearDown(); }
93 };
94
95 template <int PackageNo>
96 class ValidationTestSessionPrepared : public ValidationTestModelLoaded<PackageNo>
97 {
98 protected:
99   using ValidationTestSingleSession::_session;
100
101   void SetUp() override
102   {
103     ValidationTestModelLoaded<PackageNo>::SetUp();
104     nnfw_prepare(_session);
105   }
106
107   void TearDown() override { ValidationTestModelLoaded<PackageNo>::TearDown(); }
108
109   void SetInOutBuffers()
110   {
111     nnfw_tensorinfo ti_input;
112     ASSERT_EQ(nnfw_input_tensorinfo(_session, 0, &ti_input), NNFW_STATUS_NO_ERROR);
113     uint64_t input_elements = num_elems(&ti_input);
114     EXPECT_EQ(input_elements, 1);
115     _input.resize(input_elements);
116     ASSERT_EQ(
117         nnfw_set_input(_session, 0, ti_input.dtype, _input.data(), sizeof(float) * input_elements),
118         NNFW_STATUS_NO_ERROR);
119
120     nnfw_tensorinfo ti_output;
121     ASSERT_EQ(nnfw_output_tensorinfo(_session, 0, &ti_output), NNFW_STATUS_NO_ERROR);
122     uint64_t output_elements = num_elems(&ti_output);
123     EXPECT_EQ(output_elements, 1);
124     _output.resize(output_elements);
125     ASSERT_EQ(nnfw_set_output(_session, 0, ti_output.dtype, _output.data(),
126                               sizeof(float) * output_elements),
127               NNFW_STATUS_NO_ERROR);
128   }
129
130   void SetInOutBuffersDynamic(const nnfw_tensorinfo *ti_input)
131   {
132     NNFW_ENSURE_SUCCESS(nnfw_set_input_tensorinfo(_session, 0, ti_input));
133     uint64_t input_elements = num_elems(ti_input);
134     _input.resize(input_elements);
135     ASSERT_EQ(
136         nnfw_set_input(_session, 0, ti_input->dtype, _input.data(), sizeof(float) * input_elements),
137         NNFW_STATUS_NO_ERROR);
138
139     _output.resize(40000); // Give sufficient size for the output
140     ASSERT_EQ(nnfw_set_output(_session, 0, ti_input->dtype, _output.data(),
141                               sizeof(float) * _output.size()),
142               NNFW_STATUS_NO_ERROR);
143   }
144
145 protected:
146   std::vector<float> _input;
147   std::vector<float> _output;
148 };
149
150 template <int PackageNo> class ValidationTestFourModelsSetInput : public ValidationTest
151 {
152 protected:
153   static const uint32_t NUM_SESSIONS = 4;
154
155   void SetUp() override
156   {
157     ValidationTest::SetUp();
158
159     auto model_path = NNPackages::get().getModelAbsolutePath(NNPackages::ADD);
160     for (auto &obj : _objects)
161     {
162       ASSERT_EQ(nnfw_create_session(&obj.session), NNFW_STATUS_NO_ERROR);
163       ASSERT_EQ(nnfw_load_model_from_file(obj.session, model_path.c_str()), NNFW_STATUS_NO_ERROR);
164       ASSERT_EQ(nnfw_prepare(obj.session), NNFW_STATUS_NO_ERROR);
165
166       uint32_t num_inputs;
167       ASSERT_EQ(nnfw_input_size(obj.session, &num_inputs), NNFW_STATUS_NO_ERROR);
168       obj.inputs.resize(num_inputs);
169       for (uint32_t ind = 0; ind < obj.inputs.size(); ind++)
170       {
171         nnfw_tensorinfo ti;
172         ASSERT_EQ(nnfw_input_tensorinfo(obj.session, ind, &ti), NNFW_STATUS_NO_ERROR);
173         uint64_t input_elements = num_elems(&ti);
174         obj.inputs[ind].resize(input_elements);
175         ASSERT_EQ(nnfw_set_input(obj.session, ind, ti.dtype, obj.inputs[ind].data(),
176                                  sizeof(float) * input_elements),
177                   NNFW_STATUS_NO_ERROR);
178       }
179
180       uint32_t num_outputs;
181       ASSERT_EQ(nnfw_output_size(obj.session, &num_outputs), NNFW_STATUS_NO_ERROR);
182       obj.outputs.resize(num_outputs);
183       for (uint32_t ind = 0; ind < obj.outputs.size(); ind++)
184       {
185         nnfw_tensorinfo ti;
186         ASSERT_EQ(nnfw_output_tensorinfo(obj.session, ind, &ti), NNFW_STATUS_NO_ERROR);
187         uint64_t output_elements = num_elems(&ti);
188         obj.outputs[ind].resize(output_elements);
189         ASSERT_EQ(nnfw_set_output(obj.session, ind, ti.dtype, obj.outputs[ind].data(),
190                                   sizeof(float) * output_elements),
191                   NNFW_STATUS_NO_ERROR);
192       }
193     }
194   }
195
196   void TearDown() override
197   {
198     for (auto &obj : _objects)
199     {
200       ASSERT_EQ(nnfw_close_session(obj.session), NNFW_STATUS_NO_ERROR);
201     }
202     ValidationTest::TearDown();
203   }
204
205 protected:
206   std::array<SessionObject, NUM_SESSIONS> _objects;
207 };
208
209 #endif // __NNFW_API_TEST_FIXTURES_H__