#include <unordered_set>
#include <vector>
-#include "ir/Graph.h"
#include "ir/Index.h"
#include "ir/Model.h"
/**
* @brief Get model input info
*/
- OperandInfo &inputInfo(uint32_t index) const
+ const OperandInfo &inputInfo(uint32_t index) const
{
if (_models.size() == 1)
{
/**
* @brief Get model output info
*/
- OperandInfo &outputInfo(uint32_t index) const
+ const OperandInfo &outputInfo(uint32_t index) const
{
if (_models.size() == 1)
{
return graph->operands().at(operand_index).info();
}
+ void changeInputShape(uint32_t index, const ir::Shape &new_shape)
+ {
+ if (_models.size() == 1)
+ {
+ auto graph = primary_model()->primary_subgraph();
+ auto const operand_index = graph->getInputs().at(index);
+ graph->changeShape(operand_index, new_shape);
+ return;
+ }
+
+ auto const &desc = input(index);
+ auto graph = model(std::get<ModelIndex>(desc))->primary_subgraph();
+ auto const operand_index = graph->getInputs().at(std::get<IOIndex>(desc).value());
+ graph->changeShape(operand_index, new_shape);
+ }
+
+ /**
+ * @brief Replace model
+ *
+ * @param[in] model Model to be replaced
+ *
+ * TODO: Support multiple models
+ */
+ void replaceModel(std::shared_ptr<Model> model) { _models[ModelIndex{0}] = model; }
+
// TODO: Add iterate() or getter for edges
private: