Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / tests / tools / onert_train / src / rawdataloader.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "rawdataloader.h"
18 #include "nnfw_util.h"
19
20 #include <iostream>
21 #include <stdexcept>
22 #include <numeric>
23
24 namespace onert_train
25 {
26
27 Generator RawDataLoader::loadData(const std::string &input_file, const std::string &expected_file,
28                                   const std::vector<nnfw_tensorinfo> &input_infos,
29                                   const std::vector<nnfw_tensorinfo> &expected_infos,
30                                   const uint32_t data_length, const uint32_t batch_size)
31 {
32   std::vector<uint32_t> input_origins(input_infos.size());
33   uint32_t start = 0;
34   for (uint32_t i = 0; i < input_infos.size(); ++i)
35   {
36     input_origins.at(i) = start;
37     start += (bufsize_for(&input_infos[i]) / batch_size * data_length);
38   }
39
40   std::vector<uint32_t> expected_origins(expected_infos.size());
41   start = 0;
42   for (uint32_t i = 0; i < expected_infos.size(); ++i)
43   {
44     expected_origins.at(i) = start;
45     start += (bufsize_for(&expected_infos[i]) / batch_size * data_length);
46   }
47
48   try
49   {
50     _input_file = std::ifstream(input_file, std::ios::ate | std::ios::binary);
51     _expected_file = std::ifstream(expected_file, std::ios::ate | std::ios::binary);
52   }
53   catch (const std::exception &e)
54   {
55     std::cerr << e.what() << std::endl;
56     std::exit(-1);
57   }
58
59   return [input_origins, expected_origins, &input_infos, &expected_infos,
60           this](uint32_t idx, std::vector<Allocation> &inputs, std::vector<Allocation> &expecteds) {
61     for (uint32_t i = 0; i < input_infos.size(); ++i)
62     {
63       auto bufsz = bufsize_for(&input_infos[i]);
64       _input_file.seekg(input_origins[i] + idx * bufsz, std::ios::beg);
65       _input_file.read(reinterpret_cast<char *>(inputs[i].data()), bufsz);
66     }
67     for (uint32_t i = 0; i < expected_infos.size(); ++i)
68     {
69       auto bufsz = bufsize_for(&expected_infos[i]);
70       _expected_file.seekg(expected_origins[i] + idx * bufsz, std::ios::beg);
71       _expected_file.read(reinterpret_cast<char *>(expecteds[i].data()), bufsz);
72     }
73     return true;
74   };
75 }
76
77 } // namespace onert_train