Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / runtime / onert / frontend / base_loader / include / base_loader.h
index c444e73..6ba7ee9 100644 (file)
@@ -142,6 +142,7 @@ private:
   void loadIf(const Operator *op, ir::Graph &subg);
   void loadLeakyRelu(const Operator *op, ir::Graph &subg);
   void loadLogSoftmax(const Operator *op, ir::Graph &subg);
+  void loadDetectionPostProcess(const Operator *op, ir::Graph &subg);
   void loadOneHot(const Operator *op, ir::Graph &subg);
   void loadPack(const Operator *op, ir::Graph &subg);
   void loadPool2D(const Operator *op, ir::Graph &subg, ir::operation::Pool2D::PoolType op_type);
@@ -928,6 +929,45 @@ void BaseLoader<LoaderDomain>::loadGather(const Operator *op, ir::Graph &subg)
 }
 
 template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadDetectionPostProcess(const Operator *op, ir::Graph &subg)
+{
+  const flexbuffers::Map &m =
+    flexbuffers::GetRoot(op->custom_options()->data(), op->custom_options()->size()).AsMap();
+
+  ir::operation::DetectionPostProcess::Param param;
+
+  param.max_detections = m["max_detections"].AsInt32();
+
+  // TODO fixme
+  param.max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
+  if (m["detections_per_class"].IsNull())
+    param.max_boxes_per_class = 100;
+  else
+    param.max_boxes_per_class = m["detections_per_class"].AsInt32();
+
+  if (m["use_regular_nms"].IsNull())
+    param.do_fast_eval = true;
+  else
+    param.do_fast_eval = !m["use_regular_nms"].AsBool();
+
+  param.score_threshold = m["nms_score_threshold"].AsFloat();
+  param.iou_threshold = m["nms_iou_threshold"].AsFloat();
+
+  // TODO add num classes support
+  param.num_classes = m["num_classes"].AsInt32();
+
+  param.scale.y_scale = m["y_scale"].AsFloat();
+  param.scale.x_scale = m["x_scale"].AsFloat();
+  param.scale.h_scale = m["h_scale"].AsFloat();
+  param.scale.w_scale = m["w_scale"].AsFloat();
+
+  // TODO depends on input model framework
+  param.center_size_boxes = true;
+
+  loadOperationTo<ir::operation::DetectionPostProcess>(op, subg, param);
+}
+
+template <typename LoaderDomain>
 void BaseLoader<LoaderDomain>::loadBatchMatMul(const Operator *op, ir::Graph &subg)
 {
   ir::operation::BatchMatMul::Param param;
@@ -997,7 +1037,8 @@ void BaseLoader<LoaderDomain>::loadCustom(const Operator *op, ir::Graph &subg)
     BroadcastTo,
     FusedBatchNorm,
     StatelessRandomUniform,
-    Erf
+    Erf,
+    DetectionPostProcess
   };
 
   // Mapping from custom op name string to BuiltinOP enum
@@ -1011,6 +1052,7 @@ void BaseLoader<LoaderDomain>::loadCustom(const Operator *op, ir::Graph &subg)
     {"BroadcastTo", BuiltinOP::BroadcastTo},
     {"StatelessRandomUniform", BuiltinOP::StatelessRandomUniform},
     {"Erf", BuiltinOP::Erf},
+    {"TFLite_Detection_PostProcess", BuiltinOP::DetectionPostProcess},
   };
 
   try
@@ -1046,6 +1088,9 @@ void BaseLoader<LoaderDomain>::loadCustom(const Operator *op, ir::Graph &subg)
       case BuiltinOP::Erf:
         loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ERF);
         break;
+      case BuiltinOP::DetectionPostProcess:
+        loadDetectionPostProcess(op, subg);
+        break;
       default:
         throw std::runtime_error{
           "Loader: Custom OP map is defined but operation loader function is not defined"};