Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / tests / tools / onert_train / test / rawdataloader.test.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <nnfw.h>
18
19 #include <gtest/gtest.h>
20 #include <algorithm>
21 #include <numeric>
22
23 #include "../src/rawdataloader.h"
24 #include "../src/nnfw_util.h"
25
26 namespace
27 {
28 using namespace onert_train;
29
30 class DataFileGenerator
31 {
32 public:
33   DataFileGenerator(uint32_t data_length)
34     : _data_length{data_length}, _input_file{"input.bin"}, _expected_file{"expected.bin"}
35   {
36   }
37   ~DataFileGenerator()
38   {
39     try
40     {
41       if (std::remove(_input_file.c_str()) != 0)
42       {
43         std::cerr << "Failed to remove " << _input_file << std::endl;
44       }
45       if (std::remove(_expected_file.c_str()) != 0)
46       {
47         std::cerr << "Failed to remove " << _expected_file << std::endl;
48       }
49     }
50     catch (const std::exception &e)
51     {
52       std::cerr << "Exception: " << e.what() << std::endl;
53     }
54   }
55
56   template <typename T>
57   const std::string &generateInputData(const std::vector<std::vector<T>> &data)
58   {
59     generateData(_input_file, data);
60     return _input_file;
61   }
62
63   template <typename T>
64   const std::string &generateExpectedData(const std::vector<std::vector<T>> &data)
65   {
66     generateData(_expected_file, data);
67     return _expected_file;
68   }
69
70 private:
71   template <typename T>
72   void generateData(const std::string &name, const std::vector<std::vector<T>> &data)
73   {
74     try
75     {
76       std::ofstream file(name, std::ios::binary);
77       for (uint32_t i = 0; i < data.size(); ++i)
78       {
79         for (uint32_t j = 0; j < _data_length; ++j)
80         {
81           for (uint32_t k = 0; k < data[i].size(); ++k)
82           {
83             file.write(reinterpret_cast<const char *>(&data[i][k]), sizeof(data[i][k]));
84           }
85         }
86       }
87     }
88     catch (const std::exception &e)
89     {
90       std::cerr << "Exception: " << e.what() << std::endl;
91     }
92   }
93
94 private:
95   uint32_t _data_length;
96   std::string _input_file;
97   std::string _expected_file;
98 };
99
100 class RawDataLoaderTest : public testing::Test
101 {
102 protected:
103   void SetUp() override { nnfw_create_session(&_session); }
104
105   void TearDown() override { nnfw_close_session(_session); }
106
107   nnfw_session *_session = nullptr;
108 };
109
110 TEST_F(RawDataLoaderTest, loadDatas_1)
111 {
112   const uint32_t data_length = 100;
113   const uint32_t num_input = 1;
114   const uint32_t num_expected = 1;
115   const uint32_t batch_size = 16;
116
117   // Set data tensor info
118   nnfw_tensorinfo in_info = {
119     .dtype = NNFW_TYPE_TENSOR_INT32,
120     .rank = 4,
121     .dims = {batch_size, 2, 2, 2},
122   };
123   std::vector<nnfw_tensorinfo> in_infos{in_info};
124
125   nnfw_tensorinfo expected_info = {
126     .dtype = NNFW_TYPE_TENSOR_INT32,
127     .rank = 4,
128     .dims = {batch_size, 1, 1, 1},
129   };
130   std::vector<nnfw_tensorinfo> expected_infos{expected_info};
131
132   // Generate test data
133   std::vector<std::vector<uint32_t>> in(num_input);
134   for (uint32_t i = 0; i < num_input; ++i)
135   {
136     in[i].resize(num_elems(&in_infos[i]) / batch_size);
137     std::generate(in[i].begin(), in[i].end(), [this] {
138       static uint32_t i = 0;
139       return i++;
140     });
141   }
142
143   std::vector<std::vector<uint32_t>> expected(num_expected);
144   for (uint32_t i = 0; i < num_expected; ++i)
145   {
146     expected[i].resize(num_elems(&expected_infos[i]) / batch_size);
147     std::generate(expected[i].begin(), expected[i].end(), [in, i] {
148       auto sum = std::accumulate(in[i].begin(), in[i].end(), 0);
149       return sum;
150     });
151   }
152
153   // Generate test data file
154   DataFileGenerator file_gen(data_length);
155   auto &input_file = file_gen.generateInputData<uint32_t>(in);
156   auto &expected_file = file_gen.generateExpectedData<uint32_t>(expected);
157
158   // Set expected datas
159   std::vector<std::vector<uint32_t>> expected_in(num_input);
160   std::vector<std::vector<uint32_t>> expected_ex(num_expected);
161   for (uint32_t i = 0; i < num_input; ++i)
162   {
163     for (uint32_t j = 0; j < batch_size; ++j)
164     {
165       expected_in[i].insert(expected_in[i].end(), in[i].begin(), in[i].end());
166     }
167   }
168   for (uint32_t i = 0; i < num_expected; ++i)
169   {
170     for (uint32_t j = 0; j < batch_size; ++j)
171     {
172       expected_ex[i].insert(expected_ex[i].end(), expected[i].begin(), expected[i].end());
173     }
174   }
175
176   // Load test datas
177   RawDataLoader loader;
178   Generator generator =
179     loader.loadData(input_file, expected_file, in_infos, expected_infos, data_length, batch_size);
180
181   // Allocate inputs and expecteds data memory
182   std::vector<Allocation> inputs(num_input);
183   for (uint32_t i = 0; i < num_input; ++i)
184   {
185     inputs[i].alloc(bufsize_for(&in_infos[i]));
186   }
187   std::vector<Allocation> expecteds(num_expected);
188   for (uint32_t i = 0; i < num_expected; ++i)
189   {
190     expecteds[i].alloc(bufsize_for(&expected_infos[i]));
191   }
192
193   uint32_t num_sample = data_length / batch_size;
194   for (uint32_t i = 0; i < num_sample; ++i)
195   {
196     auto data = generator(i, inputs, expecteds);
197
198     std::vector<std::vector<uint32_t>> gen_in(num_input);
199     for (uint32_t h = 0; h < num_input; ++h)
200     {
201       auto num_elem = num_elems(&in_infos[h]);
202       for (uint32_t k = 0; k < num_elem; ++k)
203       {
204         auto inbufs = reinterpret_cast<uint32_t *>(inputs[h].data()) + k;
205         gen_in[h].emplace_back(*inbufs);
206       }
207     }
208     std::vector<std::vector<uint32_t>> gen_ex(num_expected);
209     for (uint32_t h = 0; h < num_expected; ++h)
210     {
211       auto num_elem = num_elems(&expected_infos[h]);
212       for (uint32_t k = 0; k < num_elem; ++k)
213       {
214         auto exbufs = reinterpret_cast<uint32_t *>(expecteds[h].data()) + k;
215         gen_ex[h].emplace_back(*exbufs);
216       }
217     }
218
219     EXPECT_EQ(gen_in, expected_in);
220     EXPECT_EQ(gen_ex, expected_ex);
221   }
222 }
223
224 } // namespace