1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <builders/ie_detection_output_layer.hpp>
6 #include <ie_cnn_layer_builder.h>
12 using namespace InferenceEngine;
14 Builder::DetectionOutputLayer::DetectionOutputLayer(const std::string& name): LayerDecorator("DetectionOutput", name) {
15 getLayer()->getOutputPorts().resize(1);
16 getLayer()->getInputPorts().resize(2);
17 setBackgroudLabelId(-1);
20 Builder::DetectionOutputLayer::DetectionOutputLayer(const Layer::Ptr& layer): LayerDecorator(layer) {
21 checkType("DetectionOutput");
24 Builder::DetectionOutputLayer::DetectionOutputLayer(const Layer::CPtr& layer): LayerDecorator(layer) {
25 checkType("DetectionOutput");
28 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setName(const std::string& name) {
29 getLayer()->setName(name);
33 const std::vector<Port>& Builder::DetectionOutputLayer::getInputPorts() const {
34 return getLayer()->getInputPorts();
37 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setInputPorts(const std::vector<Port> &ports) {
38 if (ports.size() != 3)
39 THROW_IE_EXCEPTION << "Incorrect number of inputs for DetectionOutput getLayer().";
40 getLayer()->getInputPorts() = ports;
44 const Port& Builder::DetectionOutputLayer::getOutputPort() const {
45 return getLayer()->getOutputPorts()[0];
48 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setOutputPort(const Port &port) {
49 getLayer()->getOutputPorts()[0] = port;
53 size_t Builder::DetectionOutputLayer::getNumClasses() const {
54 return getLayer()->getParameters().at("num_classes");
56 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setNumClasses(size_t num) {
57 getLayer()->getParameters()["num_classes"] = num;
60 int Builder::DetectionOutputLayer::getBackgroudLabelId() const {
61 return getLayer()->getParameters().at("background_label_id");
63 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setBackgroudLabelId(int labelId) {
64 getLayer()->getParameters()["background_label_id"] = labelId;
67 int Builder::DetectionOutputLayer::getTopK() const {
68 return getLayer()->getParameters().at("top_k");
70 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setTopK(int topK) {
71 getLayer()->getParameters()["top_k"] = topK;
74 int Builder::DetectionOutputLayer::getKeepTopK() const {
75 return getLayer()->getParameters().at("keep_top_k");
77 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setKeepTopK(int topK) {
78 getLayer()->getParameters()["keep_top_k"] = topK;
81 int Builder::DetectionOutputLayer::getNumOrientClasses() const {
82 return getLayer()->getParameters().at("num_orient_classes");
84 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setNumOrientClasses(int numClasses) {
85 getLayer()->getParameters()["num_orient_classes"] = numClasses;
88 std::string Builder::DetectionOutputLayer::getCodeType() const {
89 return getLayer()->getParameters().at("code_type");
91 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setCodeType(std::string type) {
92 getLayer()->getParameters()["code_type"] = type;
95 int Builder::DetectionOutputLayer::getInterpolateOrientation() const {
96 return getLayer()->getParameters().at("interpolate_orientation");
98 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setInterpolateOrientation(int orient) {
99 getLayer()->getParameters()["interpolate_orientation"] = orient;
102 float Builder::DetectionOutputLayer::getNMSThreshold() const {
103 return getLayer()->getParameters().at("nms_threshold");
105 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setNMSThreshold(float threshold) {
106 getLayer()->getParameters()["nms_threshold"] = threshold;
109 float Builder::DetectionOutputLayer::getConfidenceThreshold() const {
110 return getLayer()->getParameters().at("confidence_threshold");
112 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setConfidenceThreshold(float threshold) {
113 getLayer()->getParameters()["confidence_threshold"] = threshold;
116 bool Builder::DetectionOutputLayer::getShareLocation() const {
117 return getLayer()->getParameters().at("share_location");
119 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setShareLocation(bool flag) {
120 getLayer()->getParameters()["share_location"] = flag;
123 bool Builder::DetectionOutputLayer::getVariantEncodedInTarget() const {
124 return getLayer()->getParameters().at("variance_encoded_in_target");
126 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setVariantEncodedInTarget(bool flag) {
127 getLayer()->getParameters()["variance_encoded_in_target"] = flag;
131 REG_VALIDATOR_FOR(DetectionOutput, [](const InferenceEngine::Builder::Layer::CPtr& input_layer, bool partial) {
132 Builder::DetectionOutputLayer layer(input_layer);
133 if (layer.getNumClasses() == 0) {
134 THROW_IE_EXCEPTION << "NumClasses parameter is wrong in layer " << layer.getName() <<
135 ". It should be > 0.";
137 if (layer.getCodeType() != "caffe.PriorBoxParameter.CENTER_SIZE" &&
138 layer.getCodeType() != "caffe.PriorBoxParameter.CORNER") {
139 THROW_IE_EXCEPTION << "CodeType parameter is wrong in layer " << layer.getName() <<
140 ". It should be equal to 'caffe.PriorBoxParameter.CORNER' or 'caffe.PriorBoxParameter.CENTER_SIZE'";
142 if (layer.getBackgroudLabelId() < -1) {
143 THROW_IE_EXCEPTION << "BackgroundLabelId parameter is wrong in layer " << layer.getName() <<
144 ". It should be >= 0 if this one is an Id of existing label else it should be equal to -1";
146 if (layer.getNMSThreshold() <= 0) {
147 THROW_IE_EXCEPTION << "NMSThreshold parameter is wrong in layer " << layer.getName() <<
148 ". It should be > 0.";
150 if (layer.getConfidenceThreshold() <= 0) {
151 THROW_IE_EXCEPTION << "ConfidenceThreshold parameter is wrong in layer " << layer.getName() <<
152 ". It should be > 0.";
156 REG_CONVERTER_FOR(DetectionOutput, [](const CNNLayerPtr& cnnLayer, Builder::Layer& layer) {
157 layer.getParameters()["num_classes"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("num_classes"));
158 layer.getParameters()["background_label_id"] = cnnLayer->GetParamAsInt("background_label_id", 0);
159 layer.getParameters()["top_k"] = cnnLayer->GetParamAsInt("top_k", -1);
160 layer.getParameters()["keep_top_k"] = cnnLayer->GetParamAsInt("keep_top_k", -1);
161 layer.getParameters()["num_orient_classes"] = cnnLayer->GetParamAsInt("num_orient_classes", 0);
162 layer.getParameters()["code_type"] = cnnLayer->GetParamAsString("code_type", "caffe.PriorBoxParameter.CORNER");
163 layer.getParameters()["interpolate_orientation"] = cnnLayer->GetParamAsInt("interpolate_orientation", 1);
164 layer.getParameters()["nms_threshold"] = cnnLayer->GetParamAsFloat("nms_threshold");
165 layer.getParameters()["confidence_threshold"] = cnnLayer->GetParamAsFloat("confidence_threshold", -FLT_MAX);
166 layer.getParameters()["share_location"] = cnnLayer->GetParamsAsBool("share_location", true);
167 layer.getParameters()["variance_encoded_in_target"] = cnnLayer->GetParamsAsBool("variance_encoded_in_target", false);