Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / dalgona / src / PostOperatorHook.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __DALGONA_POST_OPERATOR_HOOK_H__
18 #define __DALGONA_POST_OPERATOR_HOOK_H__
19
20 #include "Utils.h"
21 #include "StringUtils.h"
22
23 #include <loco/IR/Node.h>
24 #include <luci_interpreter/Interpreter.h>
25 #include <luci/IR/CircleNodeVisitor.h>
26
27 #include <pybind11/embed.h>
28 #include <vector>
29
30 namespace py = pybind11;
31 using namespace py::literals;
32
33 namespace dalgona
34 {
35
36 // Invoke a user-written Python hook after an operator is executed
37 class PostOperatorHook final : public luci::CircleNodeVisitor<void>
38 {
39
40 // This macro creates three variables used for post-operator hooks.
41 // 1. hook: Python function to be invoked (type: py::object)
42 // 2. inputs: input data (type: std::vector of numpy array)
43 // 3. output: output data (type: numpy array)
44 #define POST_OPERATOR_HOOK_PROLOGUE(OP_NAME)                \
45   assert(not multi_out_node(node));                         \
46   if (!py::hasattr(_analysis, #OP_NAME "Post"))             \
47   {                                                         \
48     visit(loco::must_cast<const luci::CircleNode *>(node)); \
49     return;                                                 \
50   }                                                         \
51   py::object hook = _analysis.attr(#OP_NAME "Post");        \
52   auto inputs = inputsPyArray(node, _interpreter);          \
53   auto output = outputPyArray(node, _interpreter);
54
55 // Multi-output version of POST_OPERATOR_HOOK_PROLOGUE
56 #define POST_OPERATOR_HOOK_PROLOGUE_MULTI_OUTS(OP_NAME)     \
57   assert(multi_out_node(node));                             \
58   if (!py::hasattr(_analysis, #OP_NAME "Post"))             \
59   {                                                         \
60     visit(loco::must_cast<const luci::CircleNode *>(node)); \
61     return;                                                 \
62   }                                                         \
63   py::object hook = _analysis.attr(#OP_NAME "Post");        \
64   auto inputs = inputsPyArray(node, _interpreter);          \
65   auto outputs = outputsPyArray(node, _interpreter);
66
67 private:
68   py::object _analysis;
69   luci_interpreter::Interpreter *_interpreter{nullptr};
70
71 public:
72   explicit PostOperatorHook(py::object analysis, luci_interpreter::Interpreter *interpreter)
73     : _analysis(analysis), _interpreter(interpreter)
74   {
75     // Do nothing
76   }
77
78   // default
79   void visit(const luci::CircleNode *node)
80   {
81     if (not py::hasattr(_analysis, "DefaultOpPost"))
82       return;
83
84     py::object hook = _analysis.attr("DefaultOpPost");
85     auto inputs = inputsPyArray(node, _interpreter);
86
87     py::list input_list;
88     for (uint32_t i = 0; i < inputs.size(); i++)
89     {
90       input_list.append(inputs[i]);
91     }
92
93     py::list output_list;
94     if (multi_out_node(node))
95     {
96       auto outputs = outputsPyArray(node, _interpreter);
97       for (uint32_t i = 0; i < outputs.size(); i++)
98       {
99         output_list.append(outputs[i]);
100       }
101     }
102     else
103     {
104       auto output = outputPyArray(node, _interpreter);
105       output_list.append(output);
106     }
107
108     pySafeCall(hook,
109                node->name(),             // name
110                toString(node->opcode()), // opcode
111                input_list,               // list of inputs
112                output_list               // list of outputs
113     );
114   }
115
116   void visit(const luci::CircleConv2D *node)
117   {
118     POST_OPERATOR_HOOK_PROLOGUE(Conv2D)
119
120     auto padding = node->padding();
121     auto stride = node->stride();
122     auto dilation = node->dilation();
123
124     auto py_stride = py::dict("w"_a = stride->w(), "h"_a = stride->h());
125     auto py_dilation = py::dict("w"_a = dilation->w(), "h"_a = dilation->h());
126
127     auto fused_act = node->fusedActivationFunction();
128
129     pySafeCall(hook,
130                node->name(),                                      // name
131                inputs[0],                                         // input
132                inputs[1],                                         // filter
133                inputs[2],                                         // bias
134                padding == luci::Padding::SAME ? "SAME" : "VALID", // padding
135                py_stride,                                         // stride
136                py_dilation,                                       // dilation
137                output,                                            // output
138                toString(fused_act)                                // fused activation
139     );
140   }
141
142   void visit(const luci::CircleDepthwiseConv2D *node)
143   {
144     POST_OPERATOR_HOOK_PROLOGUE(DepthwiseConv2D)
145
146     auto padding = node->padding();
147     auto stride = node->stride();
148     auto dilation = node->dilation();
149     auto depthMultiplier = node->depthMultiplier();
150
151     auto py_stride = py::dict("w"_a = stride->w(), "h"_a = stride->h());
152     auto py_dilation = py::dict("w"_a = dilation->w(), "h"_a = dilation->h());
153
154     auto fused_act = node->fusedActivationFunction();
155
156     pySafeCall(hook,
157                node->name(),                                      // name
158                inputs[0],                                         // input
159                inputs[1],                                         // filter
160                inputs[2],                                         // bias
161                padding == luci::Padding::SAME ? "SAME" : "VALID", // padding
162                py_stride,                                         // stride
163                depthMultiplier,                                   // depthMultiplier
164                py_dilation,                                       // dilation
165                output,                                            // output
166                toString(fused_act)                                // fused activation
167     );
168   }
169
170   void visit(const luci::CircleAdd *node)
171   {
172     POST_OPERATOR_HOOK_PROLOGUE(Add)
173
174     auto fused_act = node->fusedActivationFunction();
175
176     pySafeCall(hook,
177                node->name(),       // name
178                inputs[0],          // x
179                inputs[1],          // y
180                output,             // output
181                toString(fused_act) // fused activation
182     );
183   }
184
185   void visit(const luci::CircleFullyConnected *node)
186   {
187     POST_OPERATOR_HOOK_PROLOGUE(FullyConnected)
188
189     auto fused_act = node->fusedActivationFunction();
190
191     pySafeCall(hook,
192                node->name(),       // name
193                inputs[0],          // input
194                inputs[1],          // weights
195                inputs[2],          // bias
196                output,             // output
197                toString(fused_act) // fused activation
198     );
199   }
200
201   void visit(const luci::CircleTransposeConv *node)
202   {
203     POST_OPERATOR_HOOK_PROLOGUE(TransposeConv)
204
205     auto padding = node->padding();
206     auto stride = node->stride();
207
208     auto py_stride = py::dict("w"_a = stride->w(), "h"_a = stride->h());
209
210     pySafeCall(hook,
211                node->name(),                                      // name
212                inputs[2],                                         // input
213                inputs[1],                                         // filter
214                inputs[0],                                         // output shape
215                inputs.size() == 4 ? inputs[3] : none(),           // bias
216                padding == luci::Padding::SAME ? "SAME" : "VALID", // padding
217                py_stride,                                         // stride
218                output                                             // output
219     );
220   }
221
222   void visit(const luci::CircleInstanceNorm *node)
223   {
224     POST_OPERATOR_HOOK_PROLOGUE(InstanceNorm)
225
226     auto epsilon = node->epsilon();
227
228     auto fused_act = node->fusedActivationFunction();
229
230     pySafeCall(hook,
231                node->name(),       // name
232                inputs[0],          // input
233                inputs[1],          // gamma
234                inputs[2],          // beta
235                epsilon,            // epsilon
236                output,             // output
237                toString(fused_act) // fused activation
238     );
239   }
240
241   void visit(const luci::CircleSplit *node)
242   {
243     POST_OPERATOR_HOOK_PROLOGUE_MULTI_OUTS(Split)
244
245     py::list output_list;
246     for (uint32_t i = 0; i < outputs.size(); i++)
247     {
248       output_list.append(outputs[i]);
249     }
250
251     auto num_split = node->num_split();
252
253     pySafeCall(hook,
254                node->name(), // name
255                inputs[0],    // split_dim
256                inputs[1],    // input
257                num_split,    // num_split
258                output_list   // list of outputs
259     );
260   }
261
262 #undef POST_OPERATOR_HOOK_PROLOGUE_MULTI_OUTS
263 };
264
265 } // namespace dalgona
266
267 #endif // __DALGONA_POST_OPERATOR_HOOK_H__