2 * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 * http://www.apache.org/licenses/LICENSE-2.0
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
16 * @brief This is collection of math functions
17 * @see https://github.com/nnstreamer/nntrainer
18 * @author Jijoong Moon <jijoong.moon@samsung.com>
19 * @bug No known bugs except for NYI items
24 #define MAX_PATH_LENGTH 1024
31 #include <nntrainer_log.h>
32 #include <util_func.h>
36 static auto rng = [] {
41 static std::uniform_real_distribution<float> dist(-0.5, 0.5);
43 unsigned int getSeed() { return 0; }
45 float sqrtFloat(float x) { return sqrt(x); };
47 double sqrtDouble(double x) { return sqrt(x); };
49 float logFloat(float x) { return log(x + 1.0e-20); }
51 float exp_util(float x) { return exp(x); }
53 Tensor rotate_180(Tensor in) {
54 Tensor output(in.getDim());
56 for (unsigned int i = 0; i < in.batch(); ++i) {
57 for (unsigned int j = 0; j < in.channel(); ++j) {
58 for (unsigned int k = 0; k < in.height(); ++k) {
59 for (unsigned int l = 0; l < in.width(); ++l) {
62 in.getValue(i, j, (in.height() - k - 1), (in.width() - l - 1)));
70 bool isFileExist(std::string file_name) {
71 std::ifstream infile(file_name);
76 static void checkFile(const T &file, const char *error_msg) {
77 if (file.bad() | file.eof() | !file.good() | file.fail()) {
78 throw std::runtime_error(error_msg);
82 void checkedRead(std::ifstream &file, char *array, std::streamsize size,
83 const char *error_msg) {
84 file.read(array, size);
86 checkFile(file, error_msg);
89 void checkedWrite(std::ostream &file, const char *array, std::streamsize size,
90 const char *error_msg) {
91 file.write(array, size);
93 checkFile(file, error_msg);
96 std::string readString(std::ifstream &file, const char *error_msg) {
100 checkedRead(file, (char *)&size, sizeof(size), error_msg);
102 std::streamsize sz = static_cast<std::streamsize>(size);
103 NNTR_THROW_IF(sz < 0, std::invalid_argument)
104 << "read string size: " << sz
105 << " is too big. It cannot be represented by std::streamsize";
108 checkedRead(file, (char *)&str[0], sz, error_msg);
113 void writeString(std::ofstream &file, const std::string &str,
114 const char *error_msg) {
115 size_t size = str.size();
117 checkedWrite(file, (char *)&size, sizeof(size), error_msg);
119 std::streamsize sz = static_cast<std::streamsize>(size);
120 NNTR_THROW_IF(sz < 0, std::invalid_argument)
121 << "write string size: " << size
122 << " is too big. It cannot be represented by std::streamsize";
124 checkedWrite(file, (char *)&str[0], sz, error_msg);
127 bool endswith(const std::string &target, const std::string &suffix) {
128 if (target.size() < suffix.size()) {
131 size_t spos = target.size() - suffix.size();
132 return target.substr(spos) == suffix;
135 int getKeyValue(const std::string &input_str, std::string &key,
136 std::string &value) {
137 int status = ML_ERROR_NONE;
138 auto input_trimmed = input_str;
140 std::vector<std::string> list;
141 static const std::regex words_regex("[^\\s=]+");
143 std::remove(input_trimmed.begin(), input_trimmed.end(), ' '),
144 input_trimmed.end());
145 auto words_begin = std::sregex_iterator(input_trimmed.begin(),
146 input_trimmed.end(), words_regex);
147 auto words_end = std::sregex_iterator();
148 int nwords = std::distance(words_begin, words_end);
151 ml_loge("Error: input string must be 'key = value' format "
152 "(e.g.{\"key1=value1\",\"key2=value2\"}), \"%s\" given",
153 input_trimmed.c_str());
154 return ML_ERROR_INVALID_PARAMETER;
157 for (std::sregex_iterator i = words_begin; i != words_end; ++i) {
158 list.push_back((*i).str());
167 int getValues(int n_str, std::string str, int *value) {
168 int status = ML_ERROR_NONE;
169 static const std::regex words_regex("[^\\s.,:;!?]+");
170 str.erase(std::remove(str.begin(), str.end(), ' '), str.end());
171 auto words_begin = std::sregex_iterator(str.begin(), str.end(), words_regex);
172 auto words_end = std::sregex_iterator();
174 int num = std::distance(words_begin, words_end);
176 ml_loge("Number of Data is not match");
177 return ML_ERROR_INVALID_PARAMETER;
180 for (std::sregex_iterator i = words_begin; i != words_end; ++i) {
181 value[cn] = std::stoi((*i).str());
187 std::vector<std::string> split(const std::string &s, const std::regex ®) {
188 std::vector<std::string> out;
189 const int NUM_SKIP_CHAR = 3;
190 char char_to_remove[NUM_SKIP_CHAR] = {' ', '[', ']'};
192 for (unsigned int i = 0; i < NUM_SKIP_CHAR; ++i) {
193 str.erase(std::remove(str.begin(), str.end(), char_to_remove[i]),
196 std::regex_token_iterator<std::string::iterator> end;
197 std::regex_token_iterator<std::string::iterator> iter(str.begin(), str.end(),
200 while (iter != end) {
201 out.push_back(*iter);
207 bool istrequal(const std::string &a, const std::string &b) {
208 if (a.size() != b.size())
211 return std::equal(a.begin(), a.end(), b.begin(), [](char a_, char b_) {
212 return tolower(a_) == tolower(b_);
216 char *getRealpath(const char *name, char *resolved) {
218 return _fullpath(resolved, name, MAX_PATH_LENGTH);
220 return realpath(name, resolved);
224 tm *getLocaltime(tm *tp) {
230 return localtime_r(&t, tp);
234 } // namespace nntrainer