Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / include / ir / Model.h
index c3c0d87..950d288 100644 (file)
 #include <memory>
 #include <unordered_map>
 
+#include "ir/IGraph.h"
 #include "ir/Index.h"
 #include "util/ObjectManager.h"
 
 namespace onert
 {
-namespace ir
+namespace backend
+{
+namespace custom
 {
+class IKernelBuilder;
+} // namespace custom
+} // namespace backend
+} // namespace onert
 
-class Graph;
+namespace onert
+{
+namespace ir
+{
 
 class Model
 {
@@ -47,7 +57,7 @@ public:
    * @param[in] index Index of subgraph to be pushed
    * @return Created
    */
-  void push(SubgraphIndex index, const std::shared_ptr<Graph> &subg) { _subgraphs[index] = subg; }
+  void push(SubgraphIndex index, const std::shared_ptr<IGraph> &subg) { _subgraphs[index] = subg; }
 
   /**
    * @brief Remove the subgraph that is associated with the given index
@@ -61,9 +71,9 @@ public:
    * @brief Get the subgraph that is associated with the given index
    *
    * @param[in] index Index of the subgraph to be returned
-   * @return Graph
+   * @return IGraph
    */
-  const std::shared_ptr<Graph> &at(const SubgraphIndex &index) const
+  const std::shared_ptr<IGraph> &at(const SubgraphIndex &index) const
   {
     return _subgraphs.at(index);
   }
@@ -71,9 +81,9 @@ public:
    * @brief Get the subgraph that is associated with the given index
    *
    * @param[in] index Index of the subgraph to be returned
-   * @return Graph
+   * @return IGraph
    */
-  std::shared_ptr<Graph> &at(const SubgraphIndex &index) { return _subgraphs.at(index); }
+  std::shared_ptr<IGraph> &at(const SubgraphIndex &index) { return _subgraphs.at(index); }
 
   /**
    * @brief Get the subgraph that is associated with the given index
@@ -93,7 +103,7 @@ public:
    * @param[in] fn Function to be run for every container entry
    * @return N/A
    */
-  void iterate(const std::function<void(const SubgraphIndex &, const Graph &)> &fn) const
+  void iterate(const std::function<void(const SubgraphIndex &, const IGraph &)> &fn) const
   {
     for (const auto &e : _subgraphs)
     {
@@ -107,7 +117,7 @@ public:
    * @param[in] fn Function to be run for every container entry
    * @return N/A
    */
-  void iterate(const std::function<void(const SubgraphIndex &, Graph &)> &fn)
+  void iterate(const std::function<void(const SubgraphIndex &, IGraph &)> &fn)
   {
     for (const auto &e : _subgraphs)
     {
@@ -125,12 +135,46 @@ public:
   /**
    * @brief Return the primary subgraph
    *
-   * @return std::shared_ptr<Graph> Primary subgraph
+   * @return std::shared_ptr<IGraph> Primary subgraph
    */
-  std::shared_ptr<Graph> primary_subgraph() const { return _subgraphs.at(SubgraphIndex{0}); }
+  std::shared_ptr<IGraph> primary_subgraph() const { return _subgraphs.at(SubgraphIndex{0}); }
+
+  /**
+   * @brief Return whether the model has only typename Graph
+   *
+   * @tparam Graph Type that inherits from IGraph
+   *
+   * @return true if the model has only typename Graph, otherwise false
+   */
+  template <typename Graph, std::enable_if_t<std::is_base_of<IGraph, Graph>::value, bool> = true>
+  bool hasOnly()
+  {
+    for (const auto &e : _subgraphs)
+    {
+      if (std::dynamic_pointer_cast<Graph>(e.second) == nullptr)
+        return false;
+    }
+    return true;
+  }
+
+private:
+  std::unordered_map<SubgraphIndex, std::shared_ptr<IGraph>> _subgraphs;
+
+  // Custom operations support
+public:
+  void
+  bindKernelBuilder(const std::shared_ptr<onert::backend::custom::IKernelBuilder> &kernel_builder)
+  {
+    _kernel_builder = kernel_builder;
+  }
+
+  const std::shared_ptr<backend::custom::IKernelBuilder> &getKernelBuilder() const
+  {
+    return _kernel_builder;
+  }
 
 private:
-  std::unordered_map<SubgraphIndex, std::shared_ptr<Graph>> _subgraphs;
+  std::shared_ptr<backend::custom::IKernelBuilder> _kernel_builder;
 };
 
 } // namespace ir