Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / tests / nnfw_api / src / fixtures.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __NNFW_API_TEST_FIXTURES_H__
18 #define __NNFW_API_TEST_FIXTURES_H__
19
20 #include <array>
21 #include <gtest/gtest.h>
22 #include <nnfw_experimental.h>
23
24 #include "NNPackages.h"
25
26 #define NNFW_ENSURE_SUCCESS(EXPR) ASSERT_EQ((EXPR), NNFW_STATUS_NO_ERROR)
27
28 inline uint64_t num_elems(const nnfw_tensorinfo *ti)
29 {
30   uint64_t n = 1;
31   for (uint32_t i = 0; i < ti->rank; ++i)
32   {
33     n *= ti->dims[i];
34   }
35   return n;
36 }
37
38 struct SessionObject
39 {
40   nnfw_session *session = nullptr;
41   std::vector<std::vector<float>> inputs;
42   std::vector<std::vector<float>> outputs;
43 };
44
45 class ValidationTest : public ::testing::Test
46 {
47 protected:
48   void SetUp() override {}
49 };
50
51 class RegressionTest : public ::testing::Test
52 {
53 protected:
54   void SetUp() override {}
55 };
56
57 class ValidationTestSingleSession : public ValidationTest
58 {
59 protected:
60   nnfw_session *_session = nullptr;
61 };
62
63 class ValidationTestSessionCreated : public ValidationTestSingleSession
64 {
65 protected:
66   void SetUp() override
67   {
68     ValidationTestSingleSession::SetUp();
69     ASSERT_EQ(nnfw_create_session(&_session), NNFW_STATUS_NO_ERROR);
70   }
71
72   void TearDown() override
73   {
74     ASSERT_EQ(nnfw_close_session(_session), NNFW_STATUS_NO_ERROR);
75     ValidationTestSingleSession::TearDown();
76   }
77 };
78
79 template <int PackageNo> class ValidationTestModelLoaded : public ValidationTestSessionCreated
80 {
81 protected:
82   void SetUp() override
83   {
84     ValidationTestSessionCreated::SetUp();
85     ASSERT_EQ(nnfw_load_model_from_file(_session,
86                                         NNPackages::get().getModelAbsolutePath(PackageNo).c_str()),
87               NNFW_STATUS_NO_ERROR);
88     ASSERT_NE(_session, nullptr);
89   }
90
91   void TearDown() override { ValidationTestSessionCreated::TearDown(); }
92 };
93
94 template <int PackageNo>
95 class ValidationTestSessionPrepared : public ValidationTestModelLoaded<PackageNo>
96 {
97 protected:
98   using ValidationTestSingleSession::_session;
99
100   void SetUp() override
101   {
102     ValidationTestModelLoaded<PackageNo>::SetUp();
103     nnfw_prepare(_session);
104   }
105
106   void TearDown() override { ValidationTestModelLoaded<PackageNo>::TearDown(); }
107
108   void SetInOutBuffers()
109   {
110     nnfw_tensorinfo ti_input;
111     ASSERT_EQ(nnfw_input_tensorinfo(_session, 0, &ti_input), NNFW_STATUS_NO_ERROR);
112     uint64_t input_elements = num_elems(&ti_input);
113     EXPECT_EQ(input_elements, 1);
114     _input.resize(input_elements);
115     ASSERT_EQ(
116         nnfw_set_input(_session, 0, ti_input.dtype, _input.data(), sizeof(float) * input_elements),
117         NNFW_STATUS_NO_ERROR);
118
119     nnfw_tensorinfo ti_output;
120     ASSERT_EQ(nnfw_output_tensorinfo(_session, 0, &ti_output), NNFW_STATUS_NO_ERROR);
121     uint64_t output_elements = num_elems(&ti_output);
122     EXPECT_EQ(output_elements, 1);
123     _output.resize(output_elements);
124     ASSERT_EQ(nnfw_set_output(_session, 0, ti_output.dtype, _output.data(),
125                               sizeof(float) * output_elements),
126               NNFW_STATUS_NO_ERROR);
127   }
128
129 protected:
130   std::vector<float> _input;
131   std::vector<float> _output;
132 };
133
134 template <int PackageNo> class ValidationTestFourModelsSetInput : public ValidationTest
135 {
136 protected:
137   static const uint32_t NUM_SESSIONS = 4;
138
139   void SetUp() override
140   {
141     ValidationTest::SetUp();
142
143     auto model_path = NNPackages::get().getModelAbsolutePath(NNPackages::ADD);
144     for (auto &obj : _objects)
145     {
146       ASSERT_EQ(nnfw_create_session(&obj.session), NNFW_STATUS_NO_ERROR);
147       ASSERT_EQ(nnfw_load_model_from_file(obj.session, model_path.c_str()), NNFW_STATUS_NO_ERROR);
148       ASSERT_EQ(nnfw_prepare(obj.session), NNFW_STATUS_NO_ERROR);
149
150       uint32_t num_inputs;
151       ASSERT_EQ(nnfw_input_size(obj.session, &num_inputs), NNFW_STATUS_NO_ERROR);
152       obj.inputs.resize(num_inputs);
153       for (uint32_t ind = 0; ind < obj.inputs.size(); ind++)
154       {
155         nnfw_tensorinfo ti;
156         ASSERT_EQ(nnfw_input_tensorinfo(obj.session, ind, &ti), NNFW_STATUS_NO_ERROR);
157         uint64_t input_elements = num_elems(&ti);
158         obj.inputs[ind].resize(input_elements);
159         ASSERT_EQ(nnfw_set_input(obj.session, ind, ti.dtype, obj.inputs[ind].data(),
160                                  sizeof(float) * input_elements),
161                   NNFW_STATUS_NO_ERROR);
162       }
163
164       uint32_t num_outputs;
165       ASSERT_EQ(nnfw_output_size(obj.session, &num_outputs), NNFW_STATUS_NO_ERROR);
166       obj.outputs.resize(num_outputs);
167       for (uint32_t ind = 0; ind < obj.outputs.size(); ind++)
168       {
169         nnfw_tensorinfo ti;
170         ASSERT_EQ(nnfw_output_tensorinfo(obj.session, ind, &ti), NNFW_STATUS_NO_ERROR);
171         uint64_t output_elements = num_elems(&ti);
172         obj.outputs[ind].resize(output_elements);
173         ASSERT_EQ(nnfw_set_output(obj.session, ind, ti.dtype, obj.outputs[ind].data(),
174                                   sizeof(float) * output_elements),
175                   NNFW_STATUS_NO_ERROR);
176       }
177     }
178   }
179
180   void TearDown() override
181   {
182     for (auto &obj : _objects)
183     {
184       ASSERT_EQ(nnfw_close_session(obj.session), NNFW_STATUS_NO_ERROR);
185     }
186     ValidationTest::TearDown();
187   }
188
189 protected:
190   std::array<SessionObject, NUM_SESSIONS> _objects;
191 };
192
193 #endif // __NNFW_API_TEST_FIXTURES_H__