2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <gtest/gtest.h>
18 #include <nnfw_internal.h>
23 #include "CircleGen.h"
29 * @brief A vector of input buffers
31 * @todo support other types as well as float
33 std::vector<std::vector<float>> inputs;
35 * @brief A vector of output buffers
37 * @todo support other types as well as float
39 std::vector<std::vector<float>> outputs;
42 class GenModelTestContext
45 GenModelTestContext(CircleBuffer &&cbuf) : _cbuf{std::move(cbuf)}, _backends{"cpu"} {}
48 * @brief Return circle buffer
50 * @return CircleBuffer& the circle buffer
52 const CircleBuffer &cbuf() const { return _cbuf; }
55 * @brief Return test cases
57 * @return std::vector<TestCaseData>& the test cases
59 const std::vector<TestCaseData> &test_cases() const { return _test_cases; }
62 * @brief Return backends
64 * @return const std::vector<std::string>& the backends to be tested
66 const std::vector<std::string> &backends() const { return _backends; }
69 * @brief Return test is defined to fail on compile
71 * @return bool test is defined to fail on compile
73 const bool fail_compile() const { return _fail_compile; }
76 * @brief Add a test case
78 * @param tc the test case to be added
80 void addTestCase(const TestCaseData &tc) { _test_cases.emplace_back(tc); }
83 * @brief Add a test case
85 * @param tc the test case to be added
87 void setBackends(const std::vector<std::string> &backends)
91 for (auto backend : backends)
93 #ifdef TEST_ACL_BACKEND
94 if (backend == "acl_cl" || backend == "acl_neon")
96 _backends.push_back(backend);
101 _backends.push_back(backend);
107 * @brief Set the Test Fail
109 void setCompileFail() { _fail_compile = true; }
113 std::vector<TestCaseData> _test_cases;
114 std::vector<std::string> _backends;
115 bool _fail_compile{false};
119 * @brief Generated Model test fixture for a one time inference
121 * This fixture is for one-time inference test with variety of generated models.
122 * It is the test maker's responsiblity to create @c _context which contains
123 * test body, which are generated circle buffer, model input data and output data and
124 * backend list to be tested.
125 * The rest(calling API functions for execution) is done by @c Setup and @c TearDown .
128 class GenModelTest : public ::testing::Test
131 void SetUp() override
135 void TearDown() override
137 for (std::string backend : _context->backends())
139 // NOTE If we can prepare many times for one model loading on same session,
140 // we can move nnfw_create_session to SetUp and
141 // nnfw_load_circle_from_buffer to outside forloop
142 NNFW_ENSURE_SUCCESS(nnfw_create_session(&_so.session));
143 auto &cbuf = _context->cbuf();
144 NNFW_ENSURE_SUCCESS(nnfw_load_circle_from_buffer(_so.session, cbuf.buffer(), cbuf.size()));
145 NNFW_ENSURE_SUCCESS(nnfw_set_available_backends(_so.session, backend.data()));
147 if (_context->fail_compile())
149 ASSERT_EQ(nnfw_prepare(_so.session), NNFW_STATUS_ERROR);
151 NNFW_ENSURE_SUCCESS(nnfw_close_session(_so.session));
154 NNFW_ENSURE_SUCCESS(nnfw_prepare(_so.session));
156 // In/Out buffer settings
158 NNFW_ENSURE_SUCCESS(nnfw_input_size(_so.session, &num_inputs));
159 _so.inputs.resize(num_inputs);
160 for (uint32_t ind = 0; ind < _so.inputs.size(); ind++)
163 NNFW_ENSURE_SUCCESS(nnfw_input_tensorinfo(_so.session, ind, &ti));
164 uint64_t input_elements = num_elems(&ti);
165 _so.inputs[ind].resize(input_elements);
167 ASSERT_EQ(nnfw_set_input(_so.session, ind, ti.dtype, _so.inputs[ind].data(),
168 sizeof(float) * input_elements),
169 NNFW_STATUS_NO_ERROR);
172 uint32_t num_outputs;
173 NNFW_ENSURE_SUCCESS(nnfw_output_size(_so.session, &num_outputs));
174 _so.outputs.resize(num_outputs);
175 for (uint32_t ind = 0; ind < _so.outputs.size(); ind++)
178 NNFW_ENSURE_SUCCESS(nnfw_output_tensorinfo(_so.session, ind, &ti));
179 uint64_t output_elements = num_elems(&ti);
180 _so.outputs[ind].resize(output_elements);
181 ASSERT_EQ(nnfw_set_output(_so.session, ind, ti.dtype, _so.outputs[ind].data(),
182 sizeof(float) * output_elements),
183 NNFW_STATUS_NO_ERROR);
186 // Set input values, run, and check output values
187 for (auto &test_case : _context->test_cases())
189 auto &ref_inputs = test_case.inputs;
190 auto &ref_outputs = test_case.outputs;
191 ASSERT_EQ(_so.inputs.size(), ref_inputs.size());
192 for (uint32_t i = 0; i < _so.inputs.size(); i++)
195 ASSERT_EQ(_so.inputs[i].size(), ref_inputs[i].size());
196 memcpy(_so.inputs[i].data(), ref_inputs[i].data(), _so.inputs[i].size() * sizeof(float));
199 NNFW_ENSURE_SUCCESS(nnfw_run(_so.session));
201 ASSERT_EQ(_so.outputs.size(), ref_outputs.size());
202 for (uint32_t i = 0; i < _so.outputs.size(); i++)
204 // Check output tensor values
205 auto &ref_output = ref_outputs[i];
206 auto &output = _so.outputs[i];
207 ASSERT_EQ(output.size(), ref_output.size());
208 for (uint32_t e = 0; e < ref_output.size(); e++)
209 EXPECT_NEAR(ref_output[e], output[e], 0.001); // TODO better way for handling FP error?
213 NNFW_ENSURE_SUCCESS(nnfw_close_session(_so.session));
219 std::unique_ptr<GenModelTestContext> _context;