Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / dalgona / src / Utils.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Utils.h"
18 #include "StringUtils.h"
19
20 #include <luci_interpreter/core/Tensor.h>
21 #include <luci/IR/CircleOpcode.h>
22 #include <luci/IR/CircleNodeDecl.h>
23
24 #include <pybind11/numpy.h>
25 #include <stdexcept>
26 #include <vector>
27
28 using Tensor = luci_interpreter::Tensor;
29
30 namespace py = pybind11;
31 using namespace py::literals;
32
33 #define THROW_UNLESS(COND, MSG) \
34   if (not(COND))                \
35     throw std::runtime_error(MSG);
36
37 namespace
38 {
39
40 py::array numpyArray(const Tensor *tensor)
41 {
42   assert(tensor != nullptr); // FIX_CALLER_UNLESS
43
44   const auto tensor_shape = tensor->shape();
45
46   uint32_t size = 1;
47   std::vector<uint32_t> shape(tensor_shape.num_dims());
48   for (int i = 0; i < tensor_shape.num_dims(); i++)
49   {
50     THROW_UNLESS(tensor_shape.dim(i) >= 0, "Negative dimension detected in " + tensor->name());
51
52     shape[i] = tensor_shape.dim(i);
53     size *= shape[i];
54   }
55
56   if (size == 0)
57     return py::none();
58
59   switch (tensor->element_type())
60   {
61     case loco::DataType::FLOAT32:
62       return py::array_t<float, py::array::c_style>(shape, tensor->data<float>());
63     case loco::DataType::S16:
64       return py::array_t<int16_t, py::array::c_style>(shape, tensor->data<int16_t>());
65     case loco::DataType::S32:
66       return py::array_t<int32_t, py::array::c_style>(shape, tensor->data<int32_t>());
67     case loco::DataType::S64:
68       return py::array_t<int64_t, py::array::c_style>(shape, tensor->data<int64_t>());
69     case loco::DataType::U8:
70       return py::array_t<uint8_t, py::array::c_style>(shape, tensor->data<uint8_t>());
71     default:
72       throw std::runtime_error("Unsupported data type");
73   }
74 }
75
76 py::dict quantparam(const Tensor *tensor)
77 {
78   assert(tensor != nullptr); // FIX_CALLER_UNLESS
79
80   auto scale = tensor->scales();
81   auto zp = tensor->zero_points();
82
83   py::list py_scale;
84   for (auto s : scale)
85   {
86     py_scale.append(s);
87   }
88
89   py::list py_zp;
90   for (auto z : zp)
91   {
92     py_zp.append(z);
93   }
94
95   auto quantparam = py::dict("scale"_a = py_scale, "zero_point"_a = py_zp,
96                              "quantized_dimension"_a = tensor->quantized_dimension());
97   return quantparam;
98 }
99
100 } // namespace
101
102 namespace dalgona
103 {
104
105 py::object none() { return py::none(); }
106
107 std::vector<py::dict> inputsPyArray(const luci::CircleNode *node,
108                                     luci_interpreter::Interpreter *interpreter)
109 {
110   assert(node != nullptr);        // FIX_CALLER_UNLESS
111   assert(interpreter != nullptr); // FIX_CALLER_UNLESS
112
113   std::vector<py::dict> inputs;
114   for (uint32_t i = 0; i < node->arity(); ++i)
115   {
116     const auto input_tensor = interpreter->getTensor(node->arg(i));
117     auto circle_node = static_cast<luci::CircleNode *>(node->arg(i));
118
119     // skip invalid inputs (e.g., non-existing bias in TCONV)
120     if (circle_node->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
121       continue;
122
123     auto py_input =
124       py::dict("name"_a = circle_node->name(), "data"_a = numpyArray(input_tensor),
125                "quantparam"_a = quantparam(input_tensor),
126                "is_const"_a = circle_node->opcode() == luci::CircleOpcode::CIRCLECONST);
127     inputs.push_back(py_input);
128   }
129   return inputs;
130 }
131
132 std::vector<py::dict> outputsPyArray(const luci::CircleNode *node,
133                                      luci_interpreter::Interpreter *interpreter)
134 {
135   std::vector<py::dict> outputs;
136   for (auto succ : loco::succs(node))
137   {
138     const auto output_tensor = interpreter->getTensor(succ);
139     auto circle_node = static_cast<luci::CircleNode *>(succ);
140
141     auto opcode_str = toString(circle_node->opcode());
142     // Check if node is a multi-output node
143     // Assumption: Multi-output virtual nodes have 'Out' prefix
144     // TODO Fix this if the assumption changes
145     THROW_UNLESS(opcode_str.substr(opcode_str.length() - 3) == "Out",
146                  "Invalid output detected in " + node->name());
147
148     auto py_output =
149       py::dict("name"_a = circle_node->name(), "data"_a = numpyArray(output_tensor),
150                "quantparam"_a = quantparam(output_tensor),
151                "is_const"_a = circle_node->opcode() == luci::CircleOpcode::CIRCLECONST);
152     outputs.push_back(py_output);
153   }
154   return outputs;
155 }
156
157 // Note: Only returns 1 output
158 py::dict outputPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
159 {
160   assert(node != nullptr);        // FIX_CALLER_UNLESS
161   assert(interpreter != nullptr); // FIX_CALLER_UNLESS
162
163   const auto tensor = interpreter->getTensor(node);
164
165   THROW_UNLESS(tensor != nullptr, "Null tensor detected in " + node->name());
166
167   auto py_output = py::dict("name"_a = node->name(), "data"_a = numpyArray(tensor),
168                             "quantparam"_a = quantparam(tensor),
169                             "is_const"_a = node->opcode() == luci::CircleOpcode::CIRCLECONST);
170   return py_output;
171 }
172
173 bool multi_out_node(const luci::CircleNode *node)
174 {
175   switch (node->opcode())
176   {
177     // TODO Update this list when new Op is added
178     // Tip: grep "public GraphBuilderMultiOutput" in luci/import
179     case luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM:
180     case luci::CircleOpcode::CUSTOM:
181     case luci::CircleOpcode::IF:
182     case luci::CircleOpcode::NON_MAX_SUPPRESSION_V4:
183     case luci::CircleOpcode::NON_MAX_SUPPRESSION_V5:
184     case luci::CircleOpcode::SPLIT:
185     case luci::CircleOpcode::SPLIT_V:
186     case luci::CircleOpcode::TOPK_V2:
187     case luci::CircleOpcode::UNIQUE:
188     case luci::CircleOpcode::UNPACK:
189       return true;
190     default:
191       return false;
192   }
193 }
194
195 } // namespace dalgona
196
197 #undef THROW_UNLESS