Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / builders / ie_ctc_greedy_decoder_layer.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <builders/ie_ctc_greedy_decoder_layer.hpp>
6 #include <ie_cnn_layer_builder.h>
7 #include <vector>
8 #include <string>
9
10 using namespace InferenceEngine;
11
12 Builder::CTCGreedyDecoderLayer::CTCGreedyDecoderLayer(const std::string& name): LayerDecorator("CTCGreedyDecoder", name) {
13     getLayer()->getOutputPorts().resize(1);
14 }
15
16 Builder::CTCGreedyDecoderLayer::CTCGreedyDecoderLayer(const Layer::Ptr& layer): LayerDecorator(layer) {
17     checkType("CTCGreedyDecoder");
18 }
19
20 Builder::CTCGreedyDecoderLayer::CTCGreedyDecoderLayer(const Layer::CPtr& layer): LayerDecorator(layer) {
21     checkType("CTCGreedyDecoder");
22 }
23
24 Builder::CTCGreedyDecoderLayer& Builder::CTCGreedyDecoderLayer::setName(const std::string& name) {
25     getLayer()->setName(name);
26     return *this;
27 }
28 const std::vector<Port>& Builder::CTCGreedyDecoderLayer::getInputPorts() const {
29     return getLayer()->getInputPorts();
30 }
31 Builder::CTCGreedyDecoderLayer& Builder::CTCGreedyDecoderLayer::setInputPorts(const std::vector<Port>& ports) {
32     getLayer()->getInputPorts() = ports;
33     return *this;
34 }
35 const Port& Builder::CTCGreedyDecoderLayer::getOutputPort() const {
36     return getLayer()->getOutputPorts()[0];
37 }
38 Builder::CTCGreedyDecoderLayer& Builder::CTCGreedyDecoderLayer::setOutputPort(const Port& port) {
39     getLayer()->getOutputPorts()[0] = port;
40     return *this;
41 }
42 bool Builder::CTCGreedyDecoderLayer::getCTCMergeRepeated() const {
43     return getLayer()->getParameters().at("ctc_merge_repeated");
44 }
45 Builder::CTCGreedyDecoderLayer& Builder::CTCGreedyDecoderLayer::setCTCMergeRepeated(bool flag) {
46     getLayer()->getParameters()["ctc_merge_repeated"] = flag;
47     return *this;
48 }
49
50 REG_VALIDATOR_FOR(CTCGreedyDecoder, [](const InferenceEngine::Builder::Layer::CPtr& input_layer, bool partial) {
51     Builder::CTCGreedyDecoderLayer layer(input_layer);
52
53     if (layer.getInputPorts().empty() || layer.getInputPorts().size() > 2) {
54         THROW_IE_EXCEPTION << "Input ports are wrong in layer " << layer.getName() <<
55                            ". There are should be 1 or 2 input ports";
56     }
57 });
58
59 REG_CONVERTER_FOR(CTCGreedyDecoder, [](const CNNLayerPtr& cnnLayer, Builder::Layer& layer) {
60     layer.getParameters()["ctc_merge_repeated"] = cnnLayer->GetParamsAsBool("ctc_merge_repeated", false);
61 });