Imported Upstream version 1.19.0
[platform/core/ml/nnfw.git] / compiler / luci / partition / src / PartitionPGroups.cpp
index e0b4e8e..0080873 100644 (file)
@@ -35,6 +35,17 @@ class IsVirtualNode final : public luci::CircleNodeVisitor<bool>
 public:
   bool visit(const luci::CircleInput *) final { return true; }
   bool visit(const luci::CircleOutput *) final { return true; }
+  // For multiple outputs
+  bool visit(const luci::CircleCustomOut *) final { return true; }
+  bool visit(const luci::CircleIfOut *) final { return true; }
+  bool visit(const luci::CircleNonMaxSuppressionV4Out *) final { return true; }
+  bool visit(const luci::CircleNonMaxSuppressionV5Out *) final { return true; }
+  bool visit(const luci::CircleSplitOut *) final { return true; }
+  bool visit(const luci::CircleSplitVOut *) final { return true; }
+  bool visit(const luci::CircleTopKV2Out *) final { return true; }
+  bool visit(const luci::CircleUniqueOut *) final { return true; }
+  bool visit(const luci::CircleUnpackOut *) final { return true; }
+  bool visit(const luci::CircleWhileOut *) final { return true; }
   // TODO add all virtual nodes
 
   // default is false
@@ -58,6 +69,91 @@ bool check_allocate_partition(const luci::CircleNode *node)
   return true;
 }
 
+class FindGroupToFollow final : public luci::CircleNodeVisitor<const std::string &>
+{
+public:
+  FindGroupToFollow(const luci::PartitionTable &partition, luci::PGroups *pgroups)
+    : _partition(partition), _pgroups(pgroups)
+  {
+    // NOTHING TODO
+  }
+
+private:
+  const std::string &groupof(const luci::CircleNode *input) const
+  {
+    auto group = _pgroups->node2group[input];
+    assert(not group.empty());
+    if (group.empty())
+      return _partition.default_group;
+    return _pgroups->node2group[input];
+  }
+
+public:
+#define IMPLEMENT(CLASS)                                             \
+  const std::string &visit(const luci::CLASS *node) final            \
+  {                                                                  \
+    auto input = loco::must_cast<luci::CircleNode *>(node->input()); \
+    return groupof(input);                                           \
+  }
+
+  IMPLEMENT(CircleCustomOut);
+  IMPLEMENT(CircleIfOut);
+  IMPLEMENT(CircleNonMaxSuppressionV4Out);
+  IMPLEMENT(CircleNonMaxSuppressionV5Out);
+  IMPLEMENT(CircleSplitOut);
+  IMPLEMENT(CircleSplitVOut);
+  IMPLEMENT(CircleTopKV2Out);
+  IMPLEMENT(CircleUniqueOut);
+  IMPLEMENT(CircleUnpackOut);
+  IMPLEMENT(CircleWhileOut);
+
+#undef IMPLEMENT
+
+  // return empty for nothing to do
+  const std::string &visit(const luci::CircleNode *) final { return _empty_str; }
+
+private:
+  const luci::PartitionTable &_partition;
+  luci::PGroups *_pgroups = nullptr;
+  std::string _empty_str;
+};
+
+} // namespace
+
+namespace
+{
+
+void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &group, uint32_t idx)
+{
+  auto pgroup = std::make_unique<luci::PGroup>();
+  pgroup->group = group;
+  pgroup->id = idx + 1;
+
+  auto pnode = std::make_unique<luci::PNode>();
+  pnode->node = node;
+  pnode->group = group;
+  pnode->pgroup = pgroup.get();
+
+  pgroup->pnodes.push_back(std::move(pnode));
+
+  // Set input of PGroup
+  for (uint32_t in = 0; in < node->arity(); ++in)
+  {
+    auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
+    // this input maybe CircleInput in source graph
+    // --> not confident this is safe
+    pgroup->inputs.push_back(input);
+  }
+  // Set output of PGroup: node itself or multiple virtual outputs
+  // TODO support multiple virtual outputs
+  pgroup->outputs.push_back(node);
+
+  pgroups->node2group[node] = group;
+  pgroups->id2pgroup[pgroup->id] = pgroup.get();
+
+  pgroups->pgroups.push_back(std::move(pgroup));
+}
+
 } // namespace
 
 namespace luci
@@ -120,6 +216,8 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
       INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
               << std::endl;
 
+      append(node, pgroups.get(), group, idx);
+#if 0
       auto pgroup = std::make_unique<luci::PGroup>();
       pgroup->group = group;
       pgroup->id = idx + 1;
@@ -147,6 +245,7 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
       pgroups->id2pgroup[pgroup->id] = pgroup.get();
 
       pgroups->pgroups.push_back(std::move(pgroup));
+#endif
     }
     else
     {
@@ -156,6 +255,22 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
     }
   }
 
+  // handle for virtual nodes like multiple outputs
+  // these nodes should follow group of the input
+  for (uint32_t idx = 0; idx < nodes->size(); ++idx)
+  {
+    auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
+
+    // for virtual nodes like CircleUnpackOut should follow it's input (owner)
+    // or just set to default
+    FindGroupToFollow query(partition, pgroups.get());
+    const auto &group = node->accept(&query);
+    if (not group.empty())
+    {
+      append(node, pgroups.get(), group, idx);
+    }
+  }
+
   return std::move(pgroups);
 }