2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "rawdataloader.h"
18 #include "nnfw_util.h"
27 Generator RawDataLoader::loadData(const std::string &input_file, const std::string &expected_file,
28 const std::vector<nnfw_tensorinfo> &input_infos,
29 const std::vector<nnfw_tensorinfo> &expected_infos,
30 const uint32_t data_length, const uint32_t batch_size)
32 std::vector<uint32_t> input_origins(input_infos.size());
34 for (uint32_t i = 0; i < input_infos.size(); ++i)
36 input_origins.at(i) = start;
37 start += (bufsize_for(&input_infos[i]) / batch_size * data_length);
40 std::vector<uint32_t> expected_origins(expected_infos.size());
42 for (uint32_t i = 0; i < expected_infos.size(); ++i)
44 expected_origins.at(i) = start;
45 start += (bufsize_for(&expected_infos[i]) / batch_size * data_length);
50 _input_file = std::ifstream(input_file, std::ios::ate | std::ios::binary);
51 _expected_file = std::ifstream(expected_file, std::ios::ate | std::ios::binary);
53 catch (const std::exception &e)
55 std::cerr << e.what() << std::endl;
59 return [input_origins, expected_origins, &input_infos, &expected_infos,
60 this](uint32_t idx, std::vector<Allocation> &inputs, std::vector<Allocation> &expecteds) {
61 for (uint32_t i = 0; i < input_infos.size(); ++i)
63 auto bufsz = bufsize_for(&input_infos[i]);
64 _input_file.seekg(input_origins[i] + idx * bufsz, std::ios::beg);
65 _input_file.read(reinterpret_cast<char *>(inputs[i].data()), bufsz);
67 for (uint32_t i = 0; i < expected_infos.size(); ++i)
69 auto bufsz = bufsize_for(&expected_infos[i]);
70 _expected_file.seekg(expected_origins[i] + idx * bufsz, std::ios::beg);
71 _expected_file.read(reinterpret_cast<char *>(expecteds[i].data()), bufsz);
77 } // namespace onert_train