Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / CircleOptimizerUtils.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "CircleOptimizerUtils.h"
18
19 namespace luci
20 {
21
22 bool in_array(const std::string &str, const std::vector<std::string> &array)
23 {
24   return std::find(array.begin(), array.end(), str) != array.end();
25 }
26
27 std::string to_string(const std::vector<std::string> &strings)
28 {
29   assert(!strings.empty());
30
31   std::string res;
32   for (unsigned int i = 0; i < strings.size() - 1; i++)
33     res += strings[i] + ", ";
34
35   res += strings[strings.size() - 1];
36   return res;
37 }
38
39 std::string to_lower_case(std::string s)
40 {
41   std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
42   return s;
43 }
44
45 loco::DataType str_to_dtype(const std::string &str)
46 {
47   if (to_lower_case(str).compare("uint8") == 0)
48     return loco::DataType::U8;
49   if (to_lower_case(str).compare("uint16") == 0)
50     return loco::DataType::U16;
51   if (to_lower_case(str).compare("uint32") == 0)
52     return loco::DataType::U32;
53   if (to_lower_case(str).compare("uint64") == 0)
54     return loco::DataType::U64;
55
56   if (to_lower_case(str).compare("int8") == 0)
57     return loco::DataType::S8;
58   if (to_lower_case(str).compare("int16") == 0)
59     return loco::DataType::S16;
60   if (to_lower_case(str).compare("int32") == 0)
61     return loco::DataType::S32;
62   if (to_lower_case(str).compare("int64") == 0)
63     return loco::DataType::S64;
64
65   if (to_lower_case(str).compare("float16") == 0)
66     return loco::DataType::FLOAT16;
67   if (to_lower_case(str).compare("float32") == 0)
68     return loco::DataType::FLOAT32;
69   if (to_lower_case(str).compare("float64") == 0)
70     return loco::DataType::FLOAT64;
71
72   if (to_lower_case(str).compare("bool") == 0)
73     return loco::DataType::BOOL;
74
75   return loco::DataType::Unknown;
76 }
77
78 QuantizationGranularity str_to_granularity(const std::string &str)
79 {
80   if (to_lower_case(str).compare("layer") == 0)
81     return QuantizationGranularity::LayerWise;
82
83   if (to_lower_case(str).compare("channel") == 0)
84     return QuantizationGranularity::ChannelWise;
85
86   throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'");
87 }
88
89 } // namespace luci