Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / include / ir / NNPkg.h
index b23745d..5df58bd 100644 (file)
@@ -21,7 +21,6 @@
 #include <unordered_set>
 #include <vector>
 
-#include "ir/Graph.h"
 #include "ir/Index.h"
 #include "ir/Model.h"
 
@@ -233,7 +232,7 @@ public:
   /**
    * @brief   Get model input info
    */
-  OperandInfo &inputInfo(uint32_t index) const
+  const OperandInfo &inputInfo(uint32_t index) const
   {
     if (_models.size() == 1)
     {
@@ -251,7 +250,7 @@ public:
   /**
    * @brief   Get model output info
    */
-  OperandInfo &outputInfo(uint32_t index) const
+  const OperandInfo &outputInfo(uint32_t index) const
   {
     if (_models.size() == 1)
     {
@@ -266,6 +265,31 @@ public:
     return graph->operands().at(operand_index).info();
   }
 
+  void changeInputShape(uint32_t index, const ir::Shape &new_shape)
+  {
+    if (_models.size() == 1)
+    {
+      auto graph = primary_model()->primary_subgraph();
+      auto const operand_index = graph->getInputs().at(index);
+      graph->changeShape(operand_index, new_shape);
+      return;
+    }
+
+    auto const &desc = input(index);
+    auto graph = model(std::get<ModelIndex>(desc))->primary_subgraph();
+    auto const operand_index = graph->getInputs().at(std::get<IOIndex>(desc).value());
+    graph->changeShape(operand_index, new_shape);
+  }
+
+  /**
+   * @brief Replace model
+   *
+   * @param[in] model Model to be replaced
+   *
+   * TODO:  Support multiple models
+   */
+  void replaceModel(std::shared_ptr<Model> model) { _models[ModelIndex{0}] = model; }
+
   // TODO: Add iterate() or getter for edges
 
 private: