[ Mixed ] fix apply using casted function
[platform/core/ml/nntrainer.git] / nntrainer / utils / util_func.cpp
1 /**
2  * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *   http://www.apache.org/licenses/LICENSE-2.0
8  * Unless required by applicable law or agreed to in writing, software
9  * distributed under the License is distributed on an "AS IS" BASIS,
10  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  * See the License for the specific language governing permissions and
12  * limitations under the License.
13  *
14  * @file        util_func.cpp
15  * @date        08 April 2020
16  * @brief       This is collection of math functions
17  * @see         https://github.com/nnstreamer/nntrainer
18  * @author      Jijoong Moon <jijoong.moon@samsung.com>
19  * @bug         No known bugs except for NYI items
20  *
21  */
22
23 #ifdef _WIN32
24 #define MAX_PATH_LENGTH 1024
25 #endif
26
27 #include <cmath>
28 #include <fstream>
29 #include <random>
30
31 #include <nntrainer_log.h>
32 #include <util_func.h>
33
34 namespace nntrainer {
35
36 static std::uniform_real_distribution<float> dist(-0.5, 0.5);
37
38 float sqrtFloat(float x) { return sqrt(x); };
39
40 double sqrtDouble(double x) { return sqrt(x); };
41
42 float logFloat(float x) { return log(x + 1.0e-20); }
43 1103
44
45 float exp_util(float x) { return exp(x); }
46
47 bool isFileExist(std::string file_name) {
48   std::ifstream infile(file_name);
49   return infile.good();
50 }
51
52 template <typename T>
53 static void checkFile(const T &file, const char *error_msg) {
54   if (file.bad() | file.eof() | !file.good() | file.fail()) {
55     throw std::runtime_error(error_msg);
56   }
57 }
58
59 void checkedRead(std::ifstream &file, char *array, std::streamsize size,
60                  const char *error_msg) {
61   file.read(array, size);
62
63   checkFile(file, error_msg);
64 }
65
66 void checkedWrite(std::ostream &file, const char *array, std::streamsize size,
67                   const char *error_msg) {
68   file.write(array, size);
69
70   checkFile(file, error_msg);
71 }
72
73 std::string readString(std::ifstream &file, const char *error_msg) {
74   std::string str;
75   size_t size;
76
77   checkedRead(file, (char *)&size, sizeof(size), error_msg);
78
79   std::streamsize sz = static_cast<std::streamsize>(size);
80   NNTR_THROW_IF(sz < 0, std::invalid_argument)
81     << "read string size: " << sz
82     << " is too big. It cannot be represented by std::streamsize";
83
84   str.resize(size);
85   checkedRead(file, (char *)&str[0], sz, error_msg);
86
87   return str;
88 }
89
90 void writeString(std::ofstream &file, const std::string &str,
91                  const char *error_msg) {
92   size_t size = str.size();
93
94   checkedWrite(file, (char *)&size, sizeof(size), error_msg);
95
96   std::streamsize sz = static_cast<std::streamsize>(size);
97   NNTR_THROW_IF(sz < 0, std::invalid_argument)
98     << "write string size: " << size
99     << " is too big. It cannot be represented by std::streamsize";
100
101   checkedWrite(file, (char *)&str[0], sz, error_msg);
102 }
103
104 bool endswith(const std::string &target, const std::string &suffix) {
105   if (target.size() < suffix.size()) {
106     return false;
107   }
108   size_t spos = target.size() - suffix.size();
109   return target.substr(spos) == suffix;
110 }
111
112 int getKeyValue(const std::string &input_str, std::string &key,
113                 std::string &value) {
114   int status = ML_ERROR_NONE;
115   auto input_trimmed = input_str;
116
117   std::vector<std::string> list;
118   static const std::regex words_regex("[^\\s=]+");
119   input_trimmed.erase(
120     std::remove(input_trimmed.begin(), input_trimmed.end(), ' '),
121     input_trimmed.end());
122   auto words_begin = std::sregex_iterator(input_trimmed.begin(),
123                                           input_trimmed.end(), words_regex);
124   auto words_end = std::sregex_iterator();
125   int nwords = std::distance(words_begin, words_end);
126
127   if (nwords != 2) {
128     ml_loge("Error: input string must be 'key = value' format "
129             "(e.g.{\"key1=value1\",\"key2=value2\"}), \"%s\" given",
130             input_trimmed.c_str());
131     return ML_ERROR_INVALID_PARAMETER;
132   }
133
134   for (std::sregex_iterator i = words_begin; i != words_end; ++i) {
135     list.push_back((*i).str());
136   }
137
138   key = list[0];
139   value = list[1];
140
141   return status;
142 }
143
144 int getValues(int n_str, std::string str, int *value) {
145   int status = ML_ERROR_NONE;
146   static const std::regex words_regex("[^\\s.,:;!?]+");
147   str.erase(std::remove(str.begin(), str.end(), ' '), str.end());
148   auto words_begin = std::sregex_iterator(str.begin(), str.end(), words_regex);
149   auto words_end = std::sregex_iterator();
150
151   int num = std::distance(words_begin, words_end);
152   if (num != n_str) {
153     ml_loge("Number of Data is not match");
154     return ML_ERROR_INVALID_PARAMETER;
155   }
156   int cn = 0;
157   for (std::sregex_iterator i = words_begin; i != words_end; ++i) {
158     value[cn] = std::stoi((*i).str());
159     cn++;
160   }
161   return status;
162 }
163
164 std::vector<std::string> split(const std::string &s, const std::regex &reg) {
165   std::vector<std::string> out;
166   const int NUM_SKIP_CHAR = 3;
167   char char_to_remove[NUM_SKIP_CHAR] = {' ', '[', ']'};
168   std::string str = s;
169   for (unsigned int i = 0; i < NUM_SKIP_CHAR; ++i) {
170     str.erase(std::remove(str.begin(), str.end(), char_to_remove[i]),
171               str.end());
172   }
173   std::regex_token_iterator<std::string::iterator> end;
174   std::regex_token_iterator<std::string::iterator> iter(str.begin(), str.end(),
175                                                         reg, -1);
176
177   while (iter != end) {
178     out.push_back(*iter);
179     ++iter;
180   }
181   return out;
182 }
183
184 bool istrequal(const std::string &a, const std::string &b) {
185   if (a.size() != b.size())
186     return false;
187
188   return std::equal(a.begin(), a.end(), b.begin(), [](char a_, char b_) {
189     return tolower(a_) == tolower(b_);
190   });
191 }
192
193 char *getRealpath(const char *name, char *resolved) {
194 #ifdef _WIN32
195   return _fullpath(resolved, name, MAX_PATH_LENGTH);
196 #else
197   return realpath(name, resolved);
198 #endif
199 }
200
201 tm *getLocaltime(tm *tp) {
202   time_t t = time(0);
203 #ifdef _WIN32
204   localtime_s(tp, &t);
205   return tp;
206 #else
207   return localtime_r(&t, tp);
208 #endif
209 }
210
211 } // namespace nntrainer