// 2. inputs: input data (type: std::vector of numpy array)
// 3. output: output data (type: numpy array)
#define POST_OPERATOR_HOOK_PROLOGUE(OP_NAME) \
+ assert(not multi_out_node(node)); \
if (!py::hasattr(_analysis, #OP_NAME "Post")) \
{ \
visit(loco::must_cast<const luci::CircleNode *>(node)); \
// Multi-output version of POST_OPERATOR_HOOK_PROLOGUE
#define POST_OPERATOR_HOOK_PROLOGUE_MULTI_OUTS(OP_NAME) \
+ assert(multi_out_node(node)); \
if (!py::hasattr(_analysis, #OP_NAME "Post")) \
{ \
visit(loco::must_cast<const luci::CircleNode *>(node)); \
py::object hook = _analysis.attr("DefaultOpPost");
auto inputs = inputsPyArray(node, _interpreter);
- auto output = outputPyArray(node, _interpreter);
py::list input_list;
for (uint32_t i = 0; i < inputs.size(); i++)
input_list.append(inputs[i]);
}
+ py::list output_list;
+ if (multi_out_node(node))
+ {
+ auto outputs = outputsPyArray(node, _interpreter);
+ for (uint32_t i = 0; i < outputs.size(); i++)
+ {
+ output_list.append(outputs[i]);
+ }
+ }
+ else
+ {
+ auto output = outputPyArray(node, _interpreter);
+ output_list.append(output);
+ }
+
pySafeCall(hook,
node->name(), // name
toString(node->opcode()), // opcode
input_list, // list of inputs
- output // output
+ output_list // list of outputs
);
}