2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __DALGONA_POST_OPERATOR_HOOK_H__
18 #define __DALGONA_POST_OPERATOR_HOOK_H__
21 #include "StringUtils.h"
23 #include <loco/IR/Node.h>
24 #include <luci_interpreter/Interpreter.h>
25 #include <luci/IR/CircleNodeVisitor.h>
27 #include <pybind11/embed.h>
30 namespace py = pybind11;
31 using namespace py::literals;
36 // Invoke a user-written Python hook after an operator is executed
37 class PostOperatorHook final : public luci::CircleNodeVisitor<void>
40 // This macro creates three variables used for post-operator hooks.
41 // 1. hook: Python function to be invoked (type: py::object)
42 // 2. inputs: input data (type: std::vector of numpy array)
43 // 3. output: output data (type: numpy array)
44 #define POST_OPERATOR_HOOK_PROLOGUE(OP_NAME) \
45 assert(not multi_out_node(node)); \
46 if (!py::hasattr(_analysis, #OP_NAME "Post")) \
48 visit(loco::must_cast<const luci::CircleNode *>(node)); \
51 py::object hook = _analysis.attr(#OP_NAME "Post"); \
52 auto inputs = inputsPyArray(node, _interpreter); \
53 auto output = outputPyArray(node, _interpreter);
55 // Multi-output version of POST_OPERATOR_HOOK_PROLOGUE
56 #define POST_OPERATOR_HOOK_PROLOGUE_MULTI_OUTS(OP_NAME) \
57 assert(multi_out_node(node)); \
58 if (!py::hasattr(_analysis, #OP_NAME "Post")) \
60 visit(loco::must_cast<const luci::CircleNode *>(node)); \
63 py::object hook = _analysis.attr(#OP_NAME "Post"); \
64 auto inputs = inputsPyArray(node, _interpreter); \
65 auto outputs = outputsPyArray(node, _interpreter);
69 luci_interpreter::Interpreter *_interpreter{nullptr};
72 explicit PostOperatorHook(py::object analysis, luci_interpreter::Interpreter *interpreter)
73 : _analysis(analysis), _interpreter(interpreter)
79 void visit(const luci::CircleNode *node)
81 if (not py::hasattr(_analysis, "DefaultOpPost"))
84 py::object hook = _analysis.attr("DefaultOpPost");
85 auto inputs = inputsPyArray(node, _interpreter);
88 for (uint32_t i = 0; i < inputs.size(); i++)
90 input_list.append(inputs[i]);
94 if (multi_out_node(node))
96 auto outputs = outputsPyArray(node, _interpreter);
97 for (uint32_t i = 0; i < outputs.size(); i++)
99 output_list.append(outputs[i]);
104 auto output = outputPyArray(node, _interpreter);
105 output_list.append(output);
109 node->name(), // name
110 toString(node->opcode()), // opcode
111 input_list, // list of inputs
112 output_list // list of outputs
116 void visit(const luci::CircleConv2D *node)
118 POST_OPERATOR_HOOK_PROLOGUE(Conv2D)
120 auto padding = node->padding();
121 auto stride = node->stride();
122 auto dilation = node->dilation();
124 auto py_stride = py::dict("w"_a = stride->w(), "h"_a = stride->h());
125 auto py_dilation = py::dict("w"_a = dilation->w(), "h"_a = dilation->h());
127 auto fused_act = node->fusedActivationFunction();
130 node->name(), // name
134 padding == luci::Padding::SAME ? "SAME" : "VALID", // padding
136 py_dilation, // dilation
138 toString(fused_act) // fused activation
142 void visit(const luci::CircleDepthwiseConv2D *node)
144 POST_OPERATOR_HOOK_PROLOGUE(DepthwiseConv2D)
146 auto padding = node->padding();
147 auto stride = node->stride();
148 auto dilation = node->dilation();
149 auto depthMultiplier = node->depthMultiplier();
151 auto py_stride = py::dict("w"_a = stride->w(), "h"_a = stride->h());
152 auto py_dilation = py::dict("w"_a = dilation->w(), "h"_a = dilation->h());
154 auto fused_act = node->fusedActivationFunction();
157 node->name(), // name
161 padding == luci::Padding::SAME ? "SAME" : "VALID", // padding
163 depthMultiplier, // depthMultiplier
164 py_dilation, // dilation
166 toString(fused_act) // fused activation
170 void visit(const luci::CircleAdd *node)
172 POST_OPERATOR_HOOK_PROLOGUE(Add)
174 auto fused_act = node->fusedActivationFunction();
177 node->name(), // name
181 toString(fused_act) // fused activation
185 void visit(const luci::CircleFullyConnected *node)
187 POST_OPERATOR_HOOK_PROLOGUE(FullyConnected)
189 auto fused_act = node->fusedActivationFunction();
192 node->name(), // name
194 inputs[1], // weights
197 toString(fused_act) // fused activation
201 void visit(const luci::CircleTransposeConv *node)
203 POST_OPERATOR_HOOK_PROLOGUE(TransposeConv)
205 auto padding = node->padding();
206 auto stride = node->stride();
208 auto py_stride = py::dict("w"_a = stride->w(), "h"_a = stride->h());
211 node->name(), // name
214 inputs[0], // output shape
215 inputs.size() == 4 ? inputs[3] : none(), // bias
216 padding == luci::Padding::SAME ? "SAME" : "VALID", // padding
222 void visit(const luci::CircleInstanceNorm *node)
224 POST_OPERATOR_HOOK_PROLOGUE(InstanceNorm)
226 auto epsilon = node->epsilon();
228 auto fused_act = node->fusedActivationFunction();
231 node->name(), // name
237 toString(fused_act) // fused activation
241 void visit(const luci::CircleSplit *node)
243 POST_OPERATOR_HOOK_PROLOGUE_MULTI_OUTS(Split)
245 py::list output_list;
246 for (uint32_t i = 0; i < outputs.size(); i++)
248 output_list.append(outputs[i]);
251 auto num_split = node->num_split();
254 node->name(), // name
255 inputs[0], // split_dim
257 num_split, // num_split
258 output_list // list of outputs
262 #undef POST_OPERATOR_HOOK_PROLOGUE_MULTI_OUTS
265 } // namespace dalgona
267 #endif // __DALGONA_POST_OPERATOR_HOOK_H__