2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __NNFW_API_TEST_FIXTURES_H__
18 #define __NNFW_API_TEST_FIXTURES_H__
21 #include <gtest/gtest.h>
22 #include <nnfw_experimental.h>
23 #include <nnfw_internal.h>
25 #include "NNPackages.h"
27 #define NNFW_ENSURE_SUCCESS(EXPR) ASSERT_EQ((EXPR), NNFW_STATUS_NO_ERROR)
29 inline uint64_t num_elems(const nnfw_tensorinfo *ti)
32 for (uint32_t i = 0; i < ti->rank; ++i)
41 nnfw_session *session = nullptr;
42 std::vector<std::vector<float>> inputs;
43 std::vector<std::vector<float>> outputs;
46 class ValidationTest : public ::testing::Test
49 void SetUp() override {}
52 class RegressionTest : public ::testing::Test
55 void SetUp() override {}
58 class ValidationTestSingleSession : public ValidationTest
61 nnfw_session *_session = nullptr;
64 class ValidationTestSessionCreated : public ValidationTestSingleSession
69 ValidationTestSingleSession::SetUp();
70 ASSERT_EQ(nnfw_create_session(&_session), NNFW_STATUS_NO_ERROR);
73 void TearDown() override
75 ASSERT_EQ(nnfw_close_session(_session), NNFW_STATUS_NO_ERROR);
76 ValidationTestSingleSession::TearDown();
80 template <int PackageNo> class ValidationTestModelLoaded : public ValidationTestSessionCreated
85 ValidationTestSessionCreated::SetUp();
86 ASSERT_EQ(nnfw_load_model_from_file(_session,
87 NNPackages::get().getModelAbsolutePath(PackageNo).c_str()),
88 NNFW_STATUS_NO_ERROR);
89 ASSERT_NE(_session, nullptr);
92 void TearDown() override { ValidationTestSessionCreated::TearDown(); }
95 template <int PackageNo>
96 class ValidationTestSessionPrepared : public ValidationTestModelLoaded<PackageNo>
99 using ValidationTestSingleSession::_session;
101 void SetUp() override
103 ValidationTestModelLoaded<PackageNo>::SetUp();
104 nnfw_prepare(_session);
107 void TearDown() override { ValidationTestModelLoaded<PackageNo>::TearDown(); }
109 void SetInOutBuffers()
111 nnfw_tensorinfo ti_input;
112 ASSERT_EQ(nnfw_input_tensorinfo(_session, 0, &ti_input), NNFW_STATUS_NO_ERROR);
113 uint64_t input_elements = num_elems(&ti_input);
114 EXPECT_EQ(input_elements, 1);
115 _input.resize(input_elements);
117 nnfw_set_input(_session, 0, ti_input.dtype, _input.data(), sizeof(float) * input_elements),
118 NNFW_STATUS_NO_ERROR);
120 nnfw_tensorinfo ti_output;
121 ASSERT_EQ(nnfw_output_tensorinfo(_session, 0, &ti_output), NNFW_STATUS_NO_ERROR);
122 uint64_t output_elements = num_elems(&ti_output);
123 EXPECT_EQ(output_elements, 1);
124 _output.resize(output_elements);
125 ASSERT_EQ(nnfw_set_output(_session, 0, ti_output.dtype, _output.data(),
126 sizeof(float) * output_elements),
127 NNFW_STATUS_NO_ERROR);
130 void SetInOutBuffersDynamic(const nnfw_tensorinfo *ti_input)
132 NNFW_ENSURE_SUCCESS(nnfw_set_input_tensorinfo(_session, 0, ti_input));
133 uint64_t input_elements = num_elems(ti_input);
134 _input.resize(input_elements);
136 nnfw_set_input(_session, 0, ti_input->dtype, _input.data(), sizeof(float) * input_elements),
137 NNFW_STATUS_NO_ERROR);
139 _output.resize(40000); // Give sufficient size for the output
140 ASSERT_EQ(nnfw_set_output(_session, 0, ti_input->dtype, _output.data(),
141 sizeof(float) * _output.size()),
142 NNFW_STATUS_NO_ERROR);
146 std::vector<float> _input;
147 std::vector<float> _output;
150 template <int PackageNo> class ValidationTestFourModelsSetInput : public ValidationTest
153 static const uint32_t NUM_SESSIONS = 4;
155 void SetUp() override
157 ValidationTest::SetUp();
159 auto model_path = NNPackages::get().getModelAbsolutePath(NNPackages::ADD);
160 for (auto &obj : _objects)
162 ASSERT_EQ(nnfw_create_session(&obj.session), NNFW_STATUS_NO_ERROR);
163 ASSERT_EQ(nnfw_load_model_from_file(obj.session, model_path.c_str()), NNFW_STATUS_NO_ERROR);
164 ASSERT_EQ(nnfw_prepare(obj.session), NNFW_STATUS_NO_ERROR);
167 ASSERT_EQ(nnfw_input_size(obj.session, &num_inputs), NNFW_STATUS_NO_ERROR);
168 obj.inputs.resize(num_inputs);
169 for (uint32_t ind = 0; ind < obj.inputs.size(); ind++)
172 ASSERT_EQ(nnfw_input_tensorinfo(obj.session, ind, &ti), NNFW_STATUS_NO_ERROR);
173 uint64_t input_elements = num_elems(&ti);
174 obj.inputs[ind].resize(input_elements);
175 ASSERT_EQ(nnfw_set_input(obj.session, ind, ti.dtype, obj.inputs[ind].data(),
176 sizeof(float) * input_elements),
177 NNFW_STATUS_NO_ERROR);
180 uint32_t num_outputs;
181 ASSERT_EQ(nnfw_output_size(obj.session, &num_outputs), NNFW_STATUS_NO_ERROR);
182 obj.outputs.resize(num_outputs);
183 for (uint32_t ind = 0; ind < obj.outputs.size(); ind++)
186 ASSERT_EQ(nnfw_output_tensorinfo(obj.session, ind, &ti), NNFW_STATUS_NO_ERROR);
187 uint64_t output_elements = num_elems(&ti);
188 obj.outputs[ind].resize(output_elements);
189 ASSERT_EQ(nnfw_set_output(obj.session, ind, ti.dtype, obj.outputs[ind].data(),
190 sizeof(float) * output_elements),
191 NNFW_STATUS_NO_ERROR);
196 void TearDown() override
198 for (auto &obj : _objects)
200 ASSERT_EQ(nnfw_close_session(obj.session), NNFW_STATUS_NO_ERROR);
202 ValidationTest::TearDown();
206 std::array<SessionObject, NUM_SESSIONS> _objects;
209 #endif // __NNFW_API_TEST_FIXTURES_H__