2 * Copyright (C) 2019 Samsung Electronics Co., Ltd. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 * http://www.apache.org/licenses/LICENSE-2.0
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
15 * @file nntrainer_test_util.h
17 * @brief This is util functions for test
18 * @see https://github.com/nnstreamer/nntrainer
19 * @author Jijoong Moon <jijoong.moon@samsung.com>
20 * @bug No known bugs except for NYI items
24 #ifndef __NNTRAINER_TEST_UTIL_H__
25 #define __NNTRAINER_TEST_UTIL_H__
31 #include <unordered_map>
34 #include <compiler_fwd.h>
35 #include <ini_wrapper.h>
36 #include <neuralnet.h>
37 #include <nntrainer_error.h>
38 #include <nntrainer_log.h>
42 /** tolerance is reduced for packaging, but CI runs at full tolerance */
43 #ifdef REDUCE_TOLERANCE
44 #define tolerance 1.0e-4
46 #define tolerance 1.0e-5
49 /** Enum values to get model accuracy and loss. Sync with internal CAPI header
51 #define ML_TRAIN_SUMMARY_MODEL_TRAIN_LOSS 101
52 #define ML_TRAIN_SUMMARY_MODEL_VALID_LOSS 102
53 #define ML_TRAIN_SUMMARY_MODEL_VALID_ACCURACY 103
55 /** Gtest compatibility for parameterize google test API */
57 #define GTEST_PARAMETER_TEST INSTANTIATE_TEST_CASE_P
59 #define GTEST_PARAMETER_TEST INSTANTIATE_TEST_SUITE_P
63 * @brief This class wraps IniWrapper. This generates real ini file when
64 * construct, and remove real ini file when destroy
70 * @brief Construct a new Scoped Ini object
72 * @param ini_ ini wrapper
74 ScopedIni(const nntrainer::IniWrapper &ini_) : ini(ini_) { ini.save_ini(); }
77 * @brief Construct a new Scoped Ini object
80 * @param sections_ sequenes of sections to save
82 ScopedIni(const std::string &name_,
83 const nntrainer::IniWrapper::Sections §ions_) :
84 ini(name_, sections_) {
89 * @brief Get the Ini Name object
91 * @return std::string ini name
93 std::string getIniName() { return ini.getIniName(); }
96 * @brief Destroy the Scoped Ini object
99 ~ScopedIni() { ini.erase_ini(); }
102 nntrainer::IniWrapper ini;
105 #define GEN_TEST_INPUT(input, eqation_i_j_k_l) \
107 for (int i = 0; i < batch; ++i) { \
108 for (int j = 0; j < channel; ++j) { \
109 for (int k = 0; k < height; ++k) { \
110 for (int l = 0; l < width; ++l) { \
111 float val = eqation_i_j_k_l; \
112 input.setValue(i, j, k, l, val); \
120 * @brief return a tensor filled with contant value with dimension
122 nntrainer::Tensor constant(float value, unsigned int batch, unsigned channel,
123 unsigned height, unsigned width);
126 * @brief return a tensor filled with ranged value with given dimension
129 ranged(unsigned int batch, unsigned channel, unsigned height, unsigned width,
130 nntrainer::Tformat fm = nntrainer::Tformat::NCHW,
131 nntrainer::DataType d_type = nntrainer::DataType::FP32);
134 * @brief return a tensor filled with random value with given dimension
136 nntrainer::Tensor randUniform(unsigned int batch, unsigned channel,
137 unsigned height, unsigned width, float min = -1,
141 * @brief replace string and save in file
142 * @param[in] from string to be replaced
143 * @param[in] to string to repalce with
144 * @param[in] n file name to save
147 void replaceString(const std::string &from, const std::string &to,
148 const std::string n, std::string str);
151 * @brief UserData which stores information used to feed data from data callback
154 class DataInformation {
157 * @brief Construct a new Data Information object
159 * @param num_samples number of data
160 * @param filename file name to read from
162 DataInformation(unsigned int num_samples, const std::string &filename);
164 unsigned int num_samples;
166 std::vector<unsigned int> idxes;
171 * @brief Create a user data for training
173 * @return DataInformation
175 DataInformation createTrainData();
178 * @brief Create a user data for validataion
180 * @return DataInformation
182 DataInformation createValidData();
185 * @brief get data which size is batch
187 * @param[out] outLabel
188 * @param[out] last if the data is finished
189 * @param[in] user_data private data for the callback
190 * @retval status for handling error
192 int getSample(float **outVec, float **outLabel, bool *last, void *user_data);
195 * @brief Get the Res Path object
196 * @note if NNTRAINER_RESOURCE_PATH environment variable is given, @a
197 * fallback_base is ignored and NNTRINAER_RESOURCE_PATH is directly used as a
200 * @param filename filename if omitted, ${prefix}/${base} will be returned
201 * @param fallback_base list of base to attach when NNTRAINER_RESOURCE_PATH is
203 * @return const std::string path,
206 getResPath(const std::string &filename,
207 const std::initializer_list<const char *> fallback_base = {});
209 using LayerRepresentation = std::pair<std::string, std::vector<std::string>>;
212 * @brief make graph of a representation
214 * @param layer_reps layer representation (pair of type, properties)
215 * @return nntrainer::GraphRepresentation synthesized graph representation
217 nntrainer::GraphRepresentation
218 makeGraph(const std::vector<LayerRepresentation> &layer_reps);
221 * @brief make graph of a representation after compile
223 * @param layer_reps layer representation (pair of type, properties)
224 * @param realizers GraphRealizers to modify graph before compile
225 * @param loss_layer loss layer to compile with
226 * @return nntrainer::GraphRepresentation synthesized graph representation
228 nntrainer::GraphRepresentation makeCompiledGraph(
229 const std::vector<LayerRepresentation> &layer_reps,
230 std::vector<std::unique_ptr<nntrainer::GraphRealizer>> &realizers,
231 const std::string &loss_layer = "");
234 * @brief read tensor after reading tensor size
236 * @param t tensor to fill
237 * @param file file name
238 * @param error_msg error msg
240 void sizeCheckedReadTensor(nntrainer::Tensor &t, std::ifstream &file,
241 const std::string &error_msg = "");
243 #endif /* __cplusplus */
244 #endif /* __NNTRAINER_TEST_UTIL_H__ */