Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / dalgona / src / PostOperatorHook.h
index daf32f6..00c5d46 100644 (file)
@@ -42,6 +42,7 @@ class PostOperatorHook final : public luci::CircleNodeVisitor<void>
 // 2. inputs: input data (type: std::vector of numpy array)
 // 3. output: output data (type: numpy array)
 #define POST_OPERATOR_HOOK_PROLOGUE(OP_NAME)                \
+  assert(not multi_out_node(node));                         \
   if (!py::hasattr(_analysis, #OP_NAME "Post"))             \
   {                                                         \
     visit(loco::must_cast<const luci::CircleNode *>(node)); \
@@ -53,6 +54,7 @@ class PostOperatorHook final : public luci::CircleNodeVisitor<void>
 
 // Multi-output version of POST_OPERATOR_HOOK_PROLOGUE
 #define POST_OPERATOR_HOOK_PROLOGUE_MULTI_OUTS(OP_NAME)     \
+  assert(multi_out_node(node));                             \
   if (!py::hasattr(_analysis, #OP_NAME "Post"))             \
   {                                                         \
     visit(loco::must_cast<const luci::CircleNode *>(node)); \
@@ -81,7 +83,6 @@ public:
 
     py::object hook = _analysis.attr("DefaultOpPost");
     auto inputs = inputsPyArray(node, _interpreter);
-    auto output = outputPyArray(node, _interpreter);
 
     py::list input_list;
     for (uint32_t i = 0; i < inputs.size(); i++)
@@ -89,11 +90,26 @@ public:
       input_list.append(inputs[i]);
     }
 
+    py::list output_list;
+    if (multi_out_node(node))
+    {
+      auto outputs = outputsPyArray(node, _interpreter);
+      for (uint32_t i = 0; i < outputs.size(); i++)
+      {
+        output_list.append(outputs[i]);
+      }
+    }
+    else
+    {
+      auto output = outputPyArray(node, _interpreter);
+      output_list.append(output);
+    }
+
     pySafeCall(hook,
                node->name(),             // name
                toString(node->opcode()), // opcode
                input_list,               // list of inputs
-               output                    // output
+               output_list               // list of outputs
     );
   }