2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include <gtest/gtest.h>
23 #include "../src/rawdataloader.h"
24 #include "../src/nnfw_util.h"
28 using namespace onert_train;
30 class DataFileGenerator
33 DataFileGenerator(uint32_t data_length)
34 : _data_length{data_length}, _input_file{"input.bin"}, _expected_file{"expected.bin"}
41 if (std::remove(_input_file.c_str()) != 0)
43 std::cerr << "Failed to remove " << _input_file << std::endl;
45 if (std::remove(_expected_file.c_str()) != 0)
47 std::cerr << "Failed to remove " << _expected_file << std::endl;
50 catch (const std::exception &e)
52 std::cerr << "Exception: " << e.what() << std::endl;
57 const std::string &generateInputData(const std::vector<std::vector<T>> &data)
59 generateData(_input_file, data);
64 const std::string &generateExpectedData(const std::vector<std::vector<T>> &data)
66 generateData(_expected_file, data);
67 return _expected_file;
72 void generateData(const std::string &name, const std::vector<std::vector<T>> &data)
76 std::ofstream file(name, std::ios::binary);
77 for (uint32_t i = 0; i < data.size(); ++i)
79 for (uint32_t j = 0; j < _data_length; ++j)
81 for (uint32_t k = 0; k < data[i].size(); ++k)
83 file.write(reinterpret_cast<const char *>(&data[i][k]), sizeof(data[i][k]));
88 catch (const std::exception &e)
90 std::cerr << "Exception: " << e.what() << std::endl;
95 uint32_t _data_length;
96 std::string _input_file;
97 std::string _expected_file;
100 class RawDataLoaderTest : public testing::Test
103 void SetUp() override { nnfw_create_session(&_session); }
105 void TearDown() override { nnfw_close_session(_session); }
107 nnfw_session *_session = nullptr;
110 TEST_F(RawDataLoaderTest, loadDatas_1)
112 const uint32_t data_length = 100;
113 const uint32_t num_input = 1;
114 const uint32_t num_expected = 1;
115 const uint32_t batch_size = 16;
117 // Set data tensor info
118 nnfw_tensorinfo in_info = {
119 .dtype = NNFW_TYPE_TENSOR_INT32,
121 .dims = {batch_size, 2, 2, 2},
123 std::vector<nnfw_tensorinfo> in_infos{in_info};
125 nnfw_tensorinfo expected_info = {
126 .dtype = NNFW_TYPE_TENSOR_INT32,
128 .dims = {batch_size, 1, 1, 1},
130 std::vector<nnfw_tensorinfo> expected_infos{expected_info};
132 // Generate test data
133 std::vector<std::vector<uint32_t>> in(num_input);
134 for (uint32_t i = 0; i < num_input; ++i)
136 in[i].resize(num_elems(&in_infos[i]) / batch_size);
137 std::generate(in[i].begin(), in[i].end(), [this] {
138 static uint32_t i = 0;
143 std::vector<std::vector<uint32_t>> expected(num_expected);
144 for (uint32_t i = 0; i < num_expected; ++i)
146 expected[i].resize(num_elems(&expected_infos[i]) / batch_size);
147 std::generate(expected[i].begin(), expected[i].end(), [in, i] {
148 auto sum = std::accumulate(in[i].begin(), in[i].end(), 0);
153 // Generate test data file
154 DataFileGenerator file_gen(data_length);
155 auto &input_file = file_gen.generateInputData<uint32_t>(in);
156 auto &expected_file = file_gen.generateExpectedData<uint32_t>(expected);
158 // Set expected datas
159 std::vector<std::vector<uint32_t>> expected_in(num_input);
160 std::vector<std::vector<uint32_t>> expected_ex(num_expected);
161 for (uint32_t i = 0; i < num_input; ++i)
163 for (uint32_t j = 0; j < batch_size; ++j)
165 expected_in[i].insert(expected_in[i].end(), in[i].begin(), in[i].end());
168 for (uint32_t i = 0; i < num_expected; ++i)
170 for (uint32_t j = 0; j < batch_size; ++j)
172 expected_ex[i].insert(expected_ex[i].end(), expected[i].begin(), expected[i].end());
177 RawDataLoader loader;
178 Generator generator =
179 loader.loadData(input_file, expected_file, in_infos, expected_infos, data_length, batch_size);
181 // Allocate inputs and expecteds data memory
182 std::vector<Allocation> inputs(num_input);
183 for (uint32_t i = 0; i < num_input; ++i)
185 inputs[i].alloc(bufsize_for(&in_infos[i]));
187 std::vector<Allocation> expecteds(num_expected);
188 for (uint32_t i = 0; i < num_expected; ++i)
190 expecteds[i].alloc(bufsize_for(&expected_infos[i]));
193 uint32_t num_sample = data_length / batch_size;
194 for (uint32_t i = 0; i < num_sample; ++i)
196 auto data = generator(i, inputs, expecteds);
198 std::vector<std::vector<uint32_t>> gen_in(num_input);
199 for (uint32_t h = 0; h < num_input; ++h)
201 auto num_elem = num_elems(&in_infos[h]);
202 for (uint32_t k = 0; k < num_elem; ++k)
204 auto inbufs = reinterpret_cast<uint32_t *>(inputs[h].data()) + k;
205 gen_in[h].emplace_back(*inbufs);
208 std::vector<std::vector<uint32_t>> gen_ex(num_expected);
209 for (uint32_t h = 0; h < num_expected; ++h)
211 auto num_elem = num_elems(&expected_infos[h]);
212 for (uint32_t k = 0; k < num_elem; ++k)
214 auto exbufs = reinterpret_cast<uint32_t *>(expecteds[h].data()) + k;
215 gen_ex[h].emplace_back(*exbufs);
219 EXPECT_EQ(gen_in, expected_in);
220 EXPECT_EQ(gen_ex, expected_ex);