Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / StaticShapeInferer.cc
index 5849a98..f2fee2c 100644 (file)
@@ -1302,6 +1302,30 @@ void StaticShapeInferer::visit(const ir::operation::While &op)
   }
 }
 
+void StaticShapeInferer::visit(const ir::operation::DetectionPostProcess &op)
+{
+  // TODO: NMS supports very limited input/output size.
+  ir::operation::DetectionPostProcess::Param param = op.param();
+
+  const int num_detected_boxes = param.max_detections * param.max_classes_per_detection;
+
+  const auto output_idx1 = op.getOutputs().at(0);
+  auto &output1 = _operands.at(output_idx1);
+  output1.info().shape({1, num_detected_boxes, 4});
+
+  const auto output_idx2 = op.getOutputs().at(1);
+  auto &output2 = _operands.at(output_idx2);
+  output2.info().shape({1, num_detected_boxes});
+
+  const auto output_idx3 = op.getOutputs().at(2);
+  auto &output3 = _operands.at(output_idx3);
+  output3.info().shape({1, num_detected_boxes});
+
+  const auto output_idx4 = op.getOutputs().at(3);
+  auto &output4 = _operands.at(output_idx4);
+  output4.info().shape({1});
+}
+
 } // namespace compiler
 
 } // namespace onert