Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / builders / ie_detection_output_layer.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <builders/ie_detection_output_layer.hpp>
6 #include <ie_cnn_layer_builder.h>
7
8 #include <cfloat>
9 #include <vector>
10 #include <string>
11
12 using namespace InferenceEngine;
13
14 Builder::DetectionOutputLayer::DetectionOutputLayer(const std::string& name): LayerDecorator("DetectionOutput", name) {
15     getLayer()->getOutputPorts().resize(1);
16     getLayer()->getInputPorts().resize(2);
17     setBackgroudLabelId(-1);
18 }
19
20 Builder::DetectionOutputLayer::DetectionOutputLayer(const Layer::Ptr& layer): LayerDecorator(layer) {
21     checkType("DetectionOutput");
22 }
23
24 Builder::DetectionOutputLayer::DetectionOutputLayer(const Layer::CPtr& layer): LayerDecorator(layer) {
25     checkType("DetectionOutput");
26 }
27
28 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setName(const std::string& name) {
29     getLayer()->setName(name);
30     return *this;
31 }
32
33 const std::vector<Port>& Builder::DetectionOutputLayer::getInputPorts() const {
34     return getLayer()->getInputPorts();
35 }
36
37 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setInputPorts(const std::vector<Port> &ports) {
38     if (ports.size() != 3)
39         THROW_IE_EXCEPTION << "Incorrect number of inputs for DetectionOutput getLayer().";
40     getLayer()->getInputPorts() = ports;
41     return *this;
42 }
43
44 const Port& Builder::DetectionOutputLayer::getOutputPort() const {
45     return getLayer()->getOutputPorts()[0];
46 }
47
48 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setOutputPort(const Port &port) {
49     getLayer()->getOutputPorts()[0] = port;
50     return *this;
51 }
52
53 size_t Builder::DetectionOutputLayer::getNumClasses() const {
54     return getLayer()->getParameters().at("num_classes");
55 }
56 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setNumClasses(size_t num) {
57     getLayer()->getParameters()["num_classes"] = num;
58     return *this;
59 }
60 int Builder::DetectionOutputLayer::getBackgroudLabelId() const {
61     return getLayer()->getParameters().at("background_label_id");
62 }
63 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setBackgroudLabelId(int labelId) {
64     getLayer()->getParameters()["background_label_id"] = labelId;
65     return *this;
66 }
67 int Builder::DetectionOutputLayer::getTopK() const {
68     return getLayer()->getParameters().at("top_k");
69 }
70 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setTopK(int topK) {
71     getLayer()->getParameters()["top_k"] = topK;
72     return *this;
73 }
74 int Builder::DetectionOutputLayer::getKeepTopK() const {
75     return getLayer()->getParameters().at("keep_top_k");
76 }
77 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setKeepTopK(int topK) {
78     getLayer()->getParameters()["keep_top_k"] = topK;
79     return *this;
80 }
81 int Builder::DetectionOutputLayer::getNumOrientClasses() const {
82     return getLayer()->getParameters().at("num_orient_classes");
83 }
84 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setNumOrientClasses(int numClasses) {
85     getLayer()->getParameters()["num_orient_classes"] = numClasses;
86     return *this;
87 }
88 std::string Builder::DetectionOutputLayer::getCodeType() const {
89     return getLayer()->getParameters().at("code_type");
90 }
91 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setCodeType(std::string type) {
92     getLayer()->getParameters()["code_type"] = type;
93     return *this;
94 }
95 int Builder::DetectionOutputLayer::getInterpolateOrientation() const {
96     return getLayer()->getParameters().at("interpolate_orientation");
97 }
98 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setInterpolateOrientation(int orient) {
99     getLayer()->getParameters()["interpolate_orientation"] = orient;
100     return *this;
101 }
102 float Builder::DetectionOutputLayer::getNMSThreshold() const {
103     return getLayer()->getParameters().at("nms_threshold");
104 }
105 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setNMSThreshold(float threshold) {
106     getLayer()->getParameters()["nms_threshold"] = threshold;
107     return *this;
108 }
109 float Builder::DetectionOutputLayer::getConfidenceThreshold() const {
110     return getLayer()->getParameters().at("confidence_threshold");
111 }
112 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setConfidenceThreshold(float threshold) {
113     getLayer()->getParameters()["confidence_threshold"] = threshold;
114     return *this;
115 }
116 bool Builder::DetectionOutputLayer::getShareLocation() const {
117     return getLayer()->getParameters().at("share_location");
118 }
119 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setShareLocation(bool flag) {
120     getLayer()->getParameters()["share_location"] = flag;
121     return *this;
122 }
123 bool Builder::DetectionOutputLayer::getVariantEncodedInTarget() const {
124     return getLayer()->getParameters().at("variance_encoded_in_target");
125 }
126 Builder::DetectionOutputLayer& Builder::DetectionOutputLayer::setVariantEncodedInTarget(bool flag) {
127     getLayer()->getParameters()["variance_encoded_in_target"] = flag;
128     return *this;
129 }
130
131 REG_VALIDATOR_FOR(DetectionOutput, [](const InferenceEngine::Builder::Layer::CPtr& input_layer, bool partial) {
132     Builder::DetectionOutputLayer layer(input_layer);
133     if (layer.getNumClasses() == 0) {
134         THROW_IE_EXCEPTION << "NumClasses parameter is wrong in layer " << layer.getName() <<
135                            ". It should be > 0.";
136     }
137     if (layer.getCodeType() != "caffe.PriorBoxParameter.CENTER_SIZE" &&
138         layer.getCodeType() != "caffe.PriorBoxParameter.CORNER") {
139         THROW_IE_EXCEPTION << "CodeType parameter is wrong in layer " << layer.getName() <<
140                            ". It should be equal to 'caffe.PriorBoxParameter.CORNER' or 'caffe.PriorBoxParameter.CENTER_SIZE'";
141     }
142     if (layer.getBackgroudLabelId() < -1) {
143         THROW_IE_EXCEPTION << "BackgroundLabelId parameter is wrong in layer " << layer.getName() <<
144                            ". It should be >= 0 if this one is an Id of existing label else it should be equal to -1";
145     }
146     if (layer.getNMSThreshold() <= 0) {
147         THROW_IE_EXCEPTION << "NMSThreshold parameter is wrong in layer " << layer.getName() <<
148                            ". It should be > 0.";
149     }
150     if (layer.getConfidenceThreshold() <= 0) {
151         THROW_IE_EXCEPTION << "ConfidenceThreshold parameter is wrong in layer " << layer.getName() <<
152                            ". It should be > 0.";
153     }
154 });
155
156 REG_CONVERTER_FOR(DetectionOutput, [](const CNNLayerPtr& cnnLayer, Builder::Layer& layer) {
157     layer.getParameters()["num_classes"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("num_classes"));
158     layer.getParameters()["background_label_id"] = cnnLayer->GetParamAsInt("background_label_id", 0);
159     layer.getParameters()["top_k"] = cnnLayer->GetParamAsInt("top_k", -1);
160     layer.getParameters()["keep_top_k"] = cnnLayer->GetParamAsInt("keep_top_k", -1);
161     layer.getParameters()["num_orient_classes"] = cnnLayer->GetParamAsInt("num_orient_classes", 0);
162     layer.getParameters()["code_type"] = cnnLayer->GetParamAsString("code_type", "caffe.PriorBoxParameter.CORNER");
163     layer.getParameters()["interpolate_orientation"] = cnnLayer->GetParamAsInt("interpolate_orientation", 1);
164     layer.getParameters()["nms_threshold"] = cnnLayer->GetParamAsFloat("nms_threshold");
165     layer.getParameters()["confidence_threshold"] = cnnLayer->GetParamAsFloat("confidence_threshold", -FLT_MAX);
166     layer.getParameters()["share_location"] = cnnLayer->GetParamsAsBool("share_location", true);
167     layer.getParameters()["variance_encoded_in_target"] = cnnLayer->GetParamsAsBool("variance_encoded_in_target", false);
168 });