2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include "StringUtils.h"
20 #include <luci_interpreter/core/Tensor.h>
21 #include <luci/IR/CircleOpcode.h>
22 #include <luci/IR/CircleNodeDecl.h>
24 #include <pybind11/numpy.h>
28 using Tensor = luci_interpreter::Tensor;
30 namespace py = pybind11;
31 using namespace py::literals;
33 #define THROW_UNLESS(COND, MSG) \
35 throw std::runtime_error(MSG);
40 py::array numpyArray(const Tensor *tensor)
42 assert(tensor != nullptr); // FIX_CALLER_UNLESS
44 const auto tensor_shape = tensor->shape();
47 std::vector<uint32_t> shape(tensor_shape.num_dims());
48 for (int i = 0; i < tensor_shape.num_dims(); i++)
50 THROW_UNLESS(tensor_shape.dim(i) >= 0, "Negative dimension detected in " + tensor->name());
52 shape[i] = tensor_shape.dim(i);
59 switch (tensor->element_type())
61 case loco::DataType::FLOAT32:
62 return py::array_t<float, py::array::c_style>(shape, tensor->data<float>());
63 case loco::DataType::S16:
64 return py::array_t<int16_t, py::array::c_style>(shape, tensor->data<int16_t>());
65 case loco::DataType::S32:
66 return py::array_t<int32_t, py::array::c_style>(shape, tensor->data<int32_t>());
67 case loco::DataType::S64:
68 return py::array_t<int64_t, py::array::c_style>(shape, tensor->data<int64_t>());
69 case loco::DataType::U8:
70 return py::array_t<uint8_t, py::array::c_style>(shape, tensor->data<uint8_t>());
72 throw std::runtime_error("Unsupported data type");
76 py::dict quantparam(const Tensor *tensor)
78 assert(tensor != nullptr); // FIX_CALLER_UNLESS
80 auto scale = tensor->scales();
81 auto zp = tensor->zero_points();
95 auto quantparam = py::dict("scale"_a = py_scale, "zero_point"_a = py_zp,
96 "quantized_dimension"_a = tensor->quantized_dimension());
105 py::object none() { return py::none(); }
107 std::vector<py::dict> inputsPyArray(const luci::CircleNode *node,
108 luci_interpreter::Interpreter *interpreter)
110 assert(node != nullptr); // FIX_CALLER_UNLESS
111 assert(interpreter != nullptr); // FIX_CALLER_UNLESS
113 std::vector<py::dict> inputs;
114 for (uint32_t i = 0; i < node->arity(); ++i)
116 const auto input_tensor = interpreter->getTensor(node->arg(i));
117 auto circle_node = static_cast<luci::CircleNode *>(node->arg(i));
119 // skip invalid inputs (e.g., non-existing bias in TCONV)
120 if (circle_node->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
124 py::dict("name"_a = circle_node->name(), "data"_a = numpyArray(input_tensor),
125 "quantparam"_a = quantparam(input_tensor),
126 "is_const"_a = circle_node->opcode() == luci::CircleOpcode::CIRCLECONST);
127 inputs.push_back(py_input);
132 std::vector<py::dict> outputsPyArray(const luci::CircleNode *node,
133 luci_interpreter::Interpreter *interpreter)
135 std::vector<py::dict> outputs;
136 for (auto succ : loco::succs(node))
138 const auto output_tensor = interpreter->getTensor(succ);
139 auto circle_node = static_cast<luci::CircleNode *>(succ);
141 auto opcode_str = toString(circle_node->opcode());
142 // Check if node is a multi-output node
143 // Assumption: Multi-output virtual nodes have 'Out' prefix
144 // TODO Fix this if the assumption changes
145 THROW_UNLESS(opcode_str.substr(opcode_str.length() - 3) == "Out",
146 "Invalid output detected in " + node->name());
149 py::dict("name"_a = circle_node->name(), "data"_a = numpyArray(output_tensor),
150 "quantparam"_a = quantparam(output_tensor),
151 "is_const"_a = circle_node->opcode() == luci::CircleOpcode::CIRCLECONST);
152 outputs.push_back(py_output);
157 // Note: Only returns 1 output
158 py::dict outputPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
160 assert(node != nullptr); // FIX_CALLER_UNLESS
161 assert(interpreter != nullptr); // FIX_CALLER_UNLESS
163 const auto tensor = interpreter->getTensor(node);
165 THROW_UNLESS(tensor != nullptr, "Null tensor detected in " + node->name());
167 auto py_output = py::dict("name"_a = node->name(), "data"_a = numpyArray(tensor),
168 "quantparam"_a = quantparam(tensor),
169 "is_const"_a = node->opcode() == luci::CircleOpcode::CIRCLECONST);
173 bool multi_out_node(const luci::CircleNode *node)
175 switch (node->opcode())
177 // TODO Update this list when new Op is added
178 // Tip: grep "public GraphBuilderMultiOutput" in luci/import
179 case luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM:
180 case luci::CircleOpcode::CUSTOM:
181 case luci::CircleOpcode::IF:
182 case luci::CircleOpcode::NON_MAX_SUPPRESSION_V4:
183 case luci::CircleOpcode::NON_MAX_SUPPRESSION_V5:
184 case luci::CircleOpcode::SPLIT:
185 case luci::CircleOpcode::SPLIT_V:
186 case luci::CircleOpcode::TOPK_V2:
187 case luci::CircleOpcode::UNIQUE:
188 case luci::CircleOpcode::UNPACK:
195 } // namespace dalgona