Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / StaticShapeInferer.cc
index 25747d9..a25b326 100644 (file)
@@ -99,10 +99,10 @@ void StaticShapeInferer::infer()
   }
 }
 
-bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
+bool StaticShapeInferer::checkDynamicInput(const ir::IOperation &op)
 {
   const auto &operands = _lowered_subg->graph().operands();
-  for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
+  for (auto &&input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
   {
     if (operands.at(input_idx).info().isDynamic())
     {
@@ -113,10 +113,10 @@ bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
   return false;
 }
 
-bool StaticShapeInferer::checkDynamicOutput(const ir::Operation &op)
+bool StaticShapeInferer::checkDynamicOutput(const ir::IOperation &op)
 {
   auto &operands = _lowered_subg->graph().operands();
-  for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
+  for (auto &&output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
   {
     if (operands.at(output_idx).info().isDynamic())
     {
@@ -126,10 +126,10 @@ bool StaticShapeInferer::checkDynamicOutput(const ir::Operation &op)
   return false;
 }
 
-void StaticShapeInferer::setDynamicOutput(const ir::Operation &op)
+void StaticShapeInferer::setDynamicOutput(const ir::IOperation &op)
 {
   auto &operands = _lowered_subg->graph().operands();
-  for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
+  for (auto &&output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
   {
     operands.at(output_idx).info().setDynamic();
   }
@@ -192,7 +192,7 @@ void StaticShapeInferer::dump()
 
 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
 StaticShapeInferer::createStaticShapeInferers(
-  const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs)
+  const std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> &lowered_subgs)
 {
   // Allocate StaticShapeInferer per each subgraph
   std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers;
@@ -200,7 +200,7 @@ StaticShapeInferer::createStaticShapeInferers(
   {
     const auto &subg_index = pair.first;
     auto &lowered_subg = pair.second;
-    inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg.get());
+    inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg);
   }
 
   // Append observers in all StaticShapeInferers
@@ -211,7 +211,7 @@ StaticShapeInferer::createStaticShapeInferers(
 
     // TODO: Change this iteration for all to controlflow iteration
     lowered_subg->graph().operations().iterate(
-      [&](const ir::OperationIndex &, const ir::Operation &op) {
+      [&](const ir::OperationIndex &, const ir::IOperation &op) {
         // A Function to append child inferers. These make it possible for a StaticShapeInferer to
         // call StaticShapeInferes of child subgraphs recursively
         auto appendChildInferer = [&](const ir::SubgraphIndex &child_subg_idx) {
@@ -251,7 +251,9 @@ StaticShapeInferer::createStaticShapeInferers(
         // Append Observers in a StaticShapeInferer
         if (op.opcode() == ir::OpCode::If)
         {
-          const auto &if_op = nnfw::misc::polymorphic_downcast<const ir::operation::If &>(op);
+          // TODO Remove dynamic_cast
+          // An virtual base class cannot be downcasted by static_cast
+          const auto &if_op = dynamic_cast<const ir::operation::If &>(op);
 
           appendChildInferer(if_op.param().then_subg_index);
           appendChildInferer(if_op.param().else_subg_index);
@@ -263,7 +265,8 @@ StaticShapeInferer::createStaticShapeInferers(
         }
         else if (op.opcode() == ir::OpCode::While)
         {
-          const auto &while_op = nnfw::misc::polymorphic_downcast<const ir::operation::While &>(op);
+          // TODO Remove dynamic_cast
+          const auto &while_op = dynamic_cast<const ir::operation::While &>(op);
 
           appendChildInferer(while_op.param().cond_subg_index);
           appendChildInferer(while_op.param().body_subg_index);
@@ -602,6 +605,13 @@ void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT));
 }
 
+void StaticShapeInferer::visit(const ir::operation::Loss &)
+{
+  // TODO Consider SparseCategoricalCrossentropy case
+
+  // TODO Consider output shape in case of reduction option
+}
+
 void StaticShapeInferer::visit(const ir::operation::LSTM &op)
 {
   auto &operands = _lowered_subg->graph().operands();
@@ -1119,7 +1129,7 @@ void StaticShapeInferer::visit(const ir::operation::Split &op)
   auto outputs = op.getOutputs();
   if (!axis.isConstant())
   {
-    for (auto output_idx : outputs)
+    for (auto &&output_idx : outputs)
     {
       ir::Operand &output = operands.at(output_idx);
       output.info().setDynamic();
@@ -1137,7 +1147,7 @@ void StaticShapeInferer::visit(const ir::operation::Split &op)
 
   ir::Shape new_shape =
     shape_inference::inferSplitShape(input.info().shape(), axis_value, num_splits);
-  for (auto output_idx : outputs)
+  for (auto &&output_idx : outputs)
   {
     ir::Operand &output = operands.at(output_idx);
     output.info().shape(new_shape);