2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __NNFW_API_TEST_FIXTURES_H__
18 #define __NNFW_API_TEST_FIXTURES_H__
21 #include <gtest/gtest.h>
24 #include "NNPackages.h"
26 #define NNFW_ENSURE_SUCCESS(EXPR) ASSERT_EQ((EXPR), NNFW_STATUS_NO_ERROR)
28 inline uint64_t num_elems(const nnfw_tensorinfo *ti)
31 for (uint32_t i = 0; i < ti->rank; ++i)
40 nnfw_session *session = nullptr;
41 std::vector<std::vector<float>> inputs;
42 std::vector<std::vector<float>> outputs;
45 class ValidationTest : public ::testing::Test
48 void SetUp() override {}
51 class RegressionTest : public ::testing::Test
54 void SetUp() override {}
57 class ValidationTestSingleSession : public ValidationTest
60 nnfw_session *_session = nullptr;
63 class ValidationTestSessionCreated : public ValidationTestSingleSession
68 ValidationTestSingleSession::SetUp();
69 ASSERT_EQ(nnfw_create_session(&_session), NNFW_STATUS_NO_ERROR);
72 void TearDown() override
74 ASSERT_EQ(nnfw_close_session(_session), NNFW_STATUS_NO_ERROR);
75 ValidationTestSingleSession::TearDown();
79 template <int PackageNo> class ValidationTestModelLoaded : public ValidationTestSessionCreated
84 ValidationTestSessionCreated::SetUp();
85 ASSERT_EQ(nnfw_load_model_from_file(_session,
86 NNPackages::get().getModelAbsolutePath(PackageNo).c_str()),
87 NNFW_STATUS_NO_ERROR);
88 ASSERT_NE(_session, nullptr);
91 void TearDown() override { ValidationTestSessionCreated::TearDown(); }
94 template <int PackageNo>
95 class ValidationTestSessionPrepared : public ValidationTestModelLoaded<PackageNo>
98 using ValidationTestSingleSession::_session;
100 void SetUp() override
102 ValidationTestModelLoaded<PackageNo>::SetUp();
103 nnfw_prepare(_session);
106 void TearDown() override { ValidationTestModelLoaded<PackageNo>::TearDown(); }
108 void SetInOutBuffers()
110 nnfw_tensorinfo ti_input;
111 ASSERT_EQ(nnfw_input_tensorinfo(_session, 0, &ti_input), NNFW_STATUS_NO_ERROR);
112 uint64_t input_elements = num_elems(&ti_input);
113 EXPECT_EQ(input_elements, 1);
114 _input.resize(input_elements);
116 nnfw_set_input(_session, 0, ti_input.dtype, _input.data(), sizeof(float) * input_elements),
117 NNFW_STATUS_NO_ERROR);
119 nnfw_tensorinfo ti_output;
120 ASSERT_EQ(nnfw_output_tensorinfo(_session, 0, &ti_output), NNFW_STATUS_NO_ERROR);
121 uint64_t output_elements = num_elems(&ti_output);
122 EXPECT_EQ(output_elements, 1);
123 _output.resize(output_elements);
124 ASSERT_EQ(nnfw_set_output(_session, 0, ti_output.dtype, _output.data(),
125 sizeof(float) * output_elements),
126 NNFW_STATUS_NO_ERROR);
130 std::vector<float> _input;
131 std::vector<float> _output;
134 template <int PackageNo> class ValidationTestFourModelsSetInput : public ValidationTest
137 static const uint32_t NUM_SESSIONS = 4;
139 void SetUp() override
141 ValidationTest::SetUp();
143 auto model_path = NNPackages::get().getModelAbsolutePath(NNPackages::ADD);
144 for (auto &obj : _objects)
146 ASSERT_EQ(nnfw_create_session(&obj.session), NNFW_STATUS_NO_ERROR);
147 ASSERT_EQ(nnfw_load_model_from_file(obj.session, model_path.c_str()), NNFW_STATUS_NO_ERROR);
148 ASSERT_EQ(nnfw_prepare(obj.session), NNFW_STATUS_NO_ERROR);
151 ASSERT_EQ(nnfw_input_size(obj.session, &num_inputs), NNFW_STATUS_NO_ERROR);
152 obj.inputs.resize(num_inputs);
153 for (uint32_t ind = 0; ind < obj.inputs.size(); ind++)
156 ASSERT_EQ(nnfw_input_tensorinfo(obj.session, ind, &ti), NNFW_STATUS_NO_ERROR);
157 uint64_t input_elements = num_elems(&ti);
158 obj.inputs[ind].resize(input_elements);
159 ASSERT_EQ(nnfw_set_input(obj.session, ind, ti.dtype, obj.inputs[ind].data(),
160 sizeof(float) * input_elements),
161 NNFW_STATUS_NO_ERROR);
164 uint32_t num_outputs;
165 ASSERT_EQ(nnfw_output_size(obj.session, &num_outputs), NNFW_STATUS_NO_ERROR);
166 obj.outputs.resize(num_outputs);
167 for (uint32_t ind = 0; ind < obj.outputs.size(); ind++)
170 ASSERT_EQ(nnfw_output_tensorinfo(obj.session, ind, &ti), NNFW_STATUS_NO_ERROR);
171 uint64_t output_elements = num_elems(&ti);
172 obj.outputs[ind].resize(output_elements);
173 ASSERT_EQ(nnfw_set_output(obj.session, ind, ti.dtype, obj.outputs[ind].data(),
174 sizeof(float) * output_elements),
175 NNFW_STATUS_NO_ERROR);
180 void TearDown() override
182 for (auto &obj : _objects)
184 ASSERT_EQ(nnfw_close_session(obj.session), NNFW_STATUS_NO_ERROR);
186 ValidationTest::TearDown();
190 std::array<SessionObject, NUM_SESSIONS> _objects;
193 #endif // __NNFW_API_TEST_FIXTURES_H__