void loadIf(const Operator *op, ir::Graph &subg);
void loadLeakyRelu(const Operator *op, ir::Graph &subg);
void loadLogSoftmax(const Operator *op, ir::Graph &subg);
+ void loadDetectionPostProcess(const Operator *op, ir::Graph &subg);
void loadOneHot(const Operator *op, ir::Graph &subg);
void loadPack(const Operator *op, ir::Graph &subg);
void loadPool2D(const Operator *op, ir::Graph &subg, ir::operation::Pool2D::PoolType op_type);
}
template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadDetectionPostProcess(const Operator *op, ir::Graph &subg)
+{
+ const flexbuffers::Map &m =
+ flexbuffers::GetRoot(op->custom_options()->data(), op->custom_options()->size()).AsMap();
+
+ ir::operation::DetectionPostProcess::Param param;
+
+ param.max_detections = m["max_detections"].AsInt32();
+
+ // TODO fixme
+ param.max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
+ if (m["detections_per_class"].IsNull())
+ param.max_boxes_per_class = 100;
+ else
+ param.max_boxes_per_class = m["detections_per_class"].AsInt32();
+
+ if (m["use_regular_nms"].IsNull())
+ param.do_fast_eval = true;
+ else
+ param.do_fast_eval = !m["use_regular_nms"].AsBool();
+
+ param.score_threshold = m["nms_score_threshold"].AsFloat();
+ param.iou_threshold = m["nms_iou_threshold"].AsFloat();
+
+ // TODO add num classes support
+ param.num_classes = m["num_classes"].AsInt32();
+
+ param.scale.y_scale = m["y_scale"].AsFloat();
+ param.scale.x_scale = m["x_scale"].AsFloat();
+ param.scale.h_scale = m["h_scale"].AsFloat();
+ param.scale.w_scale = m["w_scale"].AsFloat();
+
+ // TODO depends on input model framework
+ param.center_size_boxes = true;
+
+ loadOperationTo<ir::operation::DetectionPostProcess>(op, subg, param);
+}
+
+template <typename LoaderDomain>
void BaseLoader<LoaderDomain>::loadBatchMatMul(const Operator *op, ir::Graph &subg)
{
ir::operation::BatchMatMul::Param param;
BroadcastTo,
FusedBatchNorm,
StatelessRandomUniform,
- Erf
+ Erf,
+ DetectionPostProcess
};
// Mapping from custom op name string to BuiltinOP enum
{"BroadcastTo", BuiltinOP::BroadcastTo},
{"StatelessRandomUniform", BuiltinOP::StatelessRandomUniform},
{"Erf", BuiltinOP::Erf},
+ {"TFLite_Detection_PostProcess", BuiltinOP::DetectionPostProcess},
};
try
case BuiltinOP::Erf:
loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ERF);
break;
+ case BuiltinOP::DetectionPostProcess:
+ loadDetectionPostProcess(op, subg);
+ break;
default:
throw std::runtime_error{
"Loader: Custom OP map is defined but operation loader function is not defined"};