2 * Copyright (C) 2019 Samsung Electronics Co., Ltd. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 * http://www.apache.org/licenses/LICENSE-2.0
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
15 * @file nntrainer_test_util.cpp
17 * @brief This is util functions for test
18 * @see https://github.com/nnstreamer/nntrainer
19 * @author Jijoong Moon <jijoong.moon@samsung.com>
20 * @bug No known bugs except for NYI items
24 #include "nntrainer_test_util.h"
25 #include <app_context.h>
28 #include <layer_node.h>
29 #include <multiout_realizer.h>
30 #include <nntrainer_error.h>
37 #define feature_size 62720
39 static std::mt19937 rng(0);
42 * @brief replace string and save it in file
43 * @param[in] from string to be replaced
44 * @param[in] to string to be replaced to
45 * @param[in] file file to perform the action on
46 * @param[in] init_config file contents to be initialized with if file not found
48 void replaceString(const std::string &from, const std::string &to,
49 const std::string file, std::string init_config) {
52 std::ifstream file_stream(file.c_str(), std::ifstream::in);
53 if (file_stream.good()) {
54 s.assign((std::istreambuf_iterator<char>(file_stream)),
55 std::istreambuf_iterator<char>());
60 while ((start_pos = s.find(from, start_pos)) != std::string::npos) {
61 s.replace(start_pos, from.size(), to);
62 start_pos += to.size();
65 std::ofstream data_file(file.c_str());
71 * @brief load data at specific position of file
72 * @param[in] F ifstream (input file)
74 * @param[out] outLabel
75 * @param[in] id th data to get
76 * @retval true/false false : end of data
78 static bool getData(std::ifstream &F, float *outVec, float *outLabel,
81 F.seekg(0, std::ios_base::end);
82 uint64_t file_length = F.tellg();
85 (uint64_t)((feature_size + num_class) * (uint64_t)id * sizeof(float));
87 if (position > file_length) {
90 F.seekg(position, std::ios::beg);
91 F.read((char *)outVec, sizeof(float) * feature_size);
92 F.read((char *)outLabel, sizeof(float) * num_class);
97 DataInformation::DataInformation(unsigned int num_samples,
98 const std::string &filename) :
100 num_samples(num_samples),
101 file(filename, std::ios::in | std::ios::binary),
103 std::iota(idxes.begin(), idxes.end(), 0);
104 std::shuffle(idxes.begin(), idxes.end(), rng);
107 throw std::invalid_argument("given file is not good, filename: " +
112 static auto getDataSize = [](const std::string &file_name) {
113 std::ifstream f(file_name, std::ios::in | std::ios::binary);
114 NNTR_THROW_IF(!f.good(), std::invalid_argument)
115 << "cannot find " << file_name;
116 f.seekg(0, std::ios::end);
117 long file_size = f.tellg();
118 return static_cast<unsigned int>(
119 file_size / ((num_class + feature_size) * sizeof(float)));
122 std::string train_filename = getResPath("trainingSet.dat", {"test"});
123 std::string valid_filename = getResPath("trainingSet.dat", {"test"});
125 DataInformation createTrainData() {
126 return DataInformation(getDataSize(train_filename), train_filename);
129 DataInformation createValidData() {
130 return DataInformation(getDataSize(valid_filename), valid_filename);
134 * @brief get data which size is batch for train
136 * @param[out] outLabel
137 * @param[out] last if the data is finished
138 * @param[in] user_data private data for the callback
139 * @retval status for handling error
141 int getSample(float **outVec, float **outLabel, bool *last, void *user_data) {
142 auto data = reinterpret_cast<DataInformation *>(user_data);
144 getData(data->file, *outVec, *outLabel, data->idxes.at(data->count));
146 if (data->count < data->num_samples) {
151 std::shuffle(data->idxes.begin(), data->idxes.end(), data->rng);
154 return ML_ERROR_NONE;
158 * @brief return a tensor filled with contant value with dimension
160 nntrainer::Tensor constant(float value, unsigned int batch,
161 unsigned int channel, unsigned int height,
162 unsigned int width) {
163 nntrainer::Tensor t(batch, channel, height, width);
169 nntrainer::Tensor ranged(unsigned int batch, unsigned int channel,
170 unsigned int height, unsigned int width,
171 nntrainer::Tformat fm, nntrainer::DataType d_type) {
172 nntrainer::Tensor t(batch, channel, height, width, fm, d_type);
174 if (d_type == nntrainer::DataType::FP32)
175 return t.apply([&](float in) { return i++; });
176 else if (d_type == nntrainer::DataType::FP16)
177 return t.apply([&](__fp16 in) { return i++; });
180 nntrainer::Tensor randUniform(unsigned int batch, unsigned int channel,
181 unsigned int height, unsigned int width,
182 float min, float max) {
183 nntrainer::Tensor t(batch, channel, height, width);
184 t.setRandUniform(min, max);
189 getResPath(const std::string &filename,
190 const std::initializer_list<const char *> fallback_base) {
191 static const char *prefix = std::getenv("NNTRAINER_RESOURCE_PATH");
192 static const char *fallback_prefix = "./res";
194 std::stringstream ss;
195 if (prefix != nullptr) {
196 ss << prefix << '/' << filename;
200 ss << fallback_prefix;
201 for (auto &folder : fallback_base) {
205 ss << '/' << filename;
210 nntrainer::GraphRepresentation
211 makeGraph(const std::vector<LayerRepresentation> &layer_reps) {
212 static auto &ac = nntrainer::AppContext::Global();
213 nntrainer::GraphRepresentation graph_rep;
215 for (const auto &layer_representation : layer_reps) {
216 /// @todo Use unique_ptr here
217 std::shared_ptr<nntrainer::LayerNode> layer = nntrainer::createLayerNode(
218 ac.createObject<nntrainer::Layer>(layer_representation.first),
219 layer_representation.second);
220 graph_rep.push_back(layer);
226 nntrainer::GraphRepresentation makeCompiledGraph(
227 const std::vector<LayerRepresentation> &layer_reps,
228 std::vector<std::unique_ptr<nntrainer::GraphRealizer>> &realizers,
229 const std::string &loss_layer) {
230 static auto &ac = nntrainer::AppContext::Global();
232 nntrainer::GraphRepresentation graph_rep;
233 auto model_graph = nntrainer::NetworkGraph();
235 for (auto &layer_representation : layer_reps) {
236 std::shared_ptr<nntrainer::LayerNode> layer = nntrainer::createLayerNode(
237 ac.createObject<nntrainer::Layer>(layer_representation.first),
238 layer_representation.second);
239 graph_rep.push_back(layer);
242 for (auto &realizer : realizers) {
243 graph_rep = realizer->realize(graph_rep);
246 for (auto &layer : graph_rep) {
247 model_graph.addLayer(layer);
250 model_graph.compile(loss_layer);
253 for (auto &node : model_graph.getLayerNodes()) {
254 graph_rep.push_back(node);
260 void sizeCheckedReadTensor(nntrainer::Tensor &t, std::ifstream &file,
261 const std::string &error_msg) {
263 nntrainer::checkedRead(file, (char *)&sz, sizeof(unsigned));
264 NNTR_THROW_IF(t.getDim().getDataLen() != sz, std::invalid_argument)
265 << "[ReadFail] dimension does not match at " << error_msg << " sz: " << sz
266 << " dimsize: " << t.getDim().getDataLen() << '\n';