2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __NNFW_API_TEST_FIXTURES_H__
18 #define __NNFW_API_TEST_FIXTURES_H__
21 #include <gtest/gtest.h>
22 #include <nnfw_experimental.h>
23 #include <nnfw_internal.h>
25 #include "NNPackages.h"
26 #include "CircleGen.h"
28 #define NNFW_ENSURE_SUCCESS(EXPR) ASSERT_EQ((EXPR), NNFW_STATUS_NO_ERROR)
30 inline uint64_t num_elems(const nnfw_tensorinfo *ti)
33 for (uint32_t i = 0; i < ti->rank; ++i)
42 nnfw_session *session = nullptr;
43 std::vector<std::vector<float>> inputs;
44 std::vector<std::vector<float>> outputs;
47 class ValidationTest : public ::testing::Test
50 void SetUp() override {}
53 class RegressionTest : public ::testing::Test
56 void SetUp() override {}
59 class ValidationTestSingleSession : public ValidationTest
62 nnfw_session *_session = nullptr;
65 class ValidationTestSessionCreated : public ValidationTestSingleSession
70 ValidationTestSingleSession::SetUp();
71 ASSERT_EQ(nnfw_create_session(&_session), NNFW_STATUS_NO_ERROR);
72 ASSERT_NE(_session, nullptr);
75 void TearDown() override
77 ASSERT_EQ(nnfw_close_session(_session), NNFW_STATUS_NO_ERROR);
78 ValidationTestSingleSession::TearDown();
82 inline CircleBuffer genAddModel()
85 std::vector<float> rhs_data{2};
86 uint32_t rhs_buf = cgen.addBuffer(rhs_data);
87 int lhs = cgen.addTensor({{1}, circle::TensorType::TensorType_FLOAT32, 0, "X_input"});
88 int rhs = cgen.addTensor({{1}, circle::TensorType::TensorType_FLOAT32, rhs_buf, "y_var"});
89 int out = cgen.addTensor({{1}, circle::TensorType::TensorType_FLOAT32, 0, "ADD_TOP"});
90 cgen.addOperatorAdd({{lhs, rhs}, {out}}, circle::ActivationFunctionType_NONE);
91 cgen.setInputsAndOutputs({lhs}, {out});
95 template <int PackageNo> class ValidationTestModelLoaded : public ValidationTestSessionCreated
100 ValidationTestSessionCreated::SetUp();
101 if (PackageNo == NNPackages::ADD)
103 auto cbuf = genAddModel();
104 NNFW_ENSURE_SUCCESS(nnfw_load_circle_from_buffer(_session, cbuf.buffer(), cbuf.size()));
108 // TODO Eventually, downloaded model tests are removed.
109 NNFW_ENSURE_SUCCESS(nnfw_load_model_from_file(
110 _session, NNPackages::get().getModelAbsolutePath(PackageNo).c_str()));
114 void TearDown() override { ValidationTestSessionCreated::TearDown(); }
117 template <int PackageNo>
118 class ValidationTestSessionPrepared : public ValidationTestModelLoaded<PackageNo>
121 using ValidationTestSingleSession::_session;
123 void SetUp() override
125 ValidationTestModelLoaded<PackageNo>::SetUp();
126 nnfw_prepare(_session);
129 void TearDown() override { ValidationTestModelLoaded<PackageNo>::TearDown(); }
131 void SetInOutBuffers()
133 nnfw_tensorinfo ti_input;
134 ASSERT_EQ(nnfw_input_tensorinfo(_session, 0, &ti_input), NNFW_STATUS_NO_ERROR);
135 uint64_t input_elements = num_elems(&ti_input);
136 EXPECT_EQ(input_elements, 1);
137 _input.resize(input_elements);
139 nnfw_set_input(_session, 0, ti_input.dtype, _input.data(), sizeof(float) * input_elements),
140 NNFW_STATUS_NO_ERROR);
142 nnfw_tensorinfo ti_output;
143 ASSERT_EQ(nnfw_output_tensorinfo(_session, 0, &ti_output), NNFW_STATUS_NO_ERROR);
144 uint64_t output_elements = num_elems(&ti_output);
145 EXPECT_EQ(output_elements, 1);
146 _output.resize(output_elements);
147 ASSERT_EQ(nnfw_set_output(_session, 0, ti_output.dtype, _output.data(),
148 sizeof(float) * output_elements),
149 NNFW_STATUS_NO_ERROR);
152 void SetInOutBuffersDynamic(const nnfw_tensorinfo *ti_input)
154 NNFW_ENSURE_SUCCESS(nnfw_set_input_tensorinfo(_session, 0, ti_input));
155 uint64_t input_elements = num_elems(ti_input);
156 _input.resize(input_elements);
158 nnfw_set_input(_session, 0, ti_input->dtype, _input.data(), sizeof(float) * input_elements),
159 NNFW_STATUS_NO_ERROR);
161 _output.resize(40000); // Give sufficient size for the output
163 nnfw_set_output(_session, 0, ti_input->dtype, _output.data(), sizeof(float) * _output.size()),
164 NNFW_STATUS_NO_ERROR);
168 std::vector<float> _input;
169 std::vector<float> _output;
172 template <int PackageNo> class ValidationTestFourModelsSetInput : public ValidationTest
175 static const uint32_t NUM_SESSIONS = 4;
177 void SetUp() override
179 ValidationTest::SetUp();
181 for (auto &obj : _objects)
183 ASSERT_EQ(nnfw_create_session(&obj.session), NNFW_STATUS_NO_ERROR);
185 auto cbuf = genAddModel();
186 NNFW_ENSURE_SUCCESS(nnfw_load_circle_from_buffer(obj.session, cbuf.buffer(), cbuf.size()));
187 ASSERT_EQ(nnfw_prepare(obj.session), NNFW_STATUS_NO_ERROR);
190 ASSERT_EQ(nnfw_input_size(obj.session, &num_inputs), NNFW_STATUS_NO_ERROR);
191 obj.inputs.resize(num_inputs);
192 for (uint32_t ind = 0; ind < obj.inputs.size(); ind++)
195 ASSERT_EQ(nnfw_input_tensorinfo(obj.session, ind, &ti), NNFW_STATUS_NO_ERROR);
196 uint64_t input_elements = num_elems(&ti);
197 obj.inputs[ind].resize(input_elements);
198 ASSERT_EQ(nnfw_set_input(obj.session, ind, ti.dtype, obj.inputs[ind].data(),
199 sizeof(float) * input_elements),
200 NNFW_STATUS_NO_ERROR);
203 uint32_t num_outputs;
204 ASSERT_EQ(nnfw_output_size(obj.session, &num_outputs), NNFW_STATUS_NO_ERROR);
205 obj.outputs.resize(num_outputs);
206 for (uint32_t ind = 0; ind < obj.outputs.size(); ind++)
209 ASSERT_EQ(nnfw_output_tensorinfo(obj.session, ind, &ti), NNFW_STATUS_NO_ERROR);
210 uint64_t output_elements = num_elems(&ti);
211 obj.outputs[ind].resize(output_elements);
212 ASSERT_EQ(nnfw_set_output(obj.session, ind, ti.dtype, obj.outputs[ind].data(),
213 sizeof(float) * output_elements),
214 NNFW_STATUS_NO_ERROR);
219 void TearDown() override
221 for (auto &obj : _objects)
223 ASSERT_EQ(nnfw_close_session(obj.session), NNFW_STATUS_NO_ERROR);
225 ValidationTest::TearDown();
229 std::array<SessionObject, NUM_SESSIONS> _objects;
232 class ValidationTestTwoSessions : public ValidationTest
235 nnfw_session *_session1 = nullptr;
236 nnfw_session *_session2 = nullptr;
239 class ValidationTestTwoSessionsCreated : public ValidationTestTwoSessions
242 void SetUp() override
244 ValidationTestTwoSessions::SetUp();
245 ASSERT_EQ(nnfw_create_session(&_session1), NNFW_STATUS_NO_ERROR);
246 ASSERT_EQ(nnfw_create_session(&_session2), NNFW_STATUS_NO_ERROR);
247 ASSERT_NE(_session1, nullptr);
248 ASSERT_NE(_session2, nullptr);
251 void TearDown() override
253 ASSERT_EQ(nnfw_close_session(_session1), NNFW_STATUS_NO_ERROR);
254 ASSERT_EQ(nnfw_close_session(_session2), NNFW_STATUS_NO_ERROR);
255 ValidationTestTwoSessions::TearDown();
259 #endif // __NNFW_API_TEST_FIXTURES_H__