Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / ctc_decoder.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/frontend/frontend.hpp>
6
7 #include <vector>
8 #include <unordered_set>
9 #include <memory>
10
11 namespace vpu {
12
13 namespace {
14
15 class CTCDecoderStage final : public StageNode {
16 private:
17     StagePtr cloneImpl() const override {
18         return std::make_shared<CTCDecoderStage>(*this);
19     }
20
21     DataMap<float> propagateScaleFactorsImpl(
22             const DataMap<float>&,
23             ScalePropagationStep) override {
24         IE_ASSERT(_inputEdges.size() == 2);
25         IE_ASSERT(_outputEdges.size() == 1);
26
27         auto input0 = _inputEdges[0]->input();
28         auto input1 = _inputEdges[1]->input();
29         auto output = _outputEdges[0]->output();
30
31         DataMap<float> out;
32
33         out[input0] = 1.0f;
34         out[input1] = 1.0f;
35         out[output] = 1.0f;
36
37         return out;
38     }
39
40     DataMap<DimsOrder> propagateDataOrderImpl() const override {
41         IE_ASSERT(_inputEdges.size() == 2);
42         IE_ASSERT(_outputEdges.size() == 1);
43
44         auto input = _inputEdges[0]->input();
45         auto output = _outputEdges[0]->output();
46
47         DataMap<DimsOrder> out;
48
49         auto cInd = input->desc().dimsOrder().dimInd(Dim::C);
50         out[output] = output->desc().dimsOrder().createMovedDim(Dim::C, cInd);
51
52         return out;
53     }
54
55     DataMap<StridesRequirement> getDataStridesRequirementsImpl() const override {
56         IE_ASSERT(_inputEdges.size() == 2);
57         IE_ASSERT(_outputEdges.size() == 1);
58
59         auto input = _inputEdges[0]->input();
60         auto output = _outputEdges[0]->output();
61
62         DataMap<StridesRequirement> out;
63
64         out[input] = StridesRequirement::compact();
65         out[output] = StridesRequirement::compact();
66
67         return out;
68     }
69
70     void finalizeDataLayoutImpl() override {
71     }
72
73     DataMap<BatchSupport> getBatchSupportInfoImpl() const override {
74         IE_ASSERT(_inputEdges.size() == 2);
75         IE_ASSERT(_outputEdges.size() == 1);
76
77         auto input = _inputEdges[0]->input();
78         auto output = _outputEdges[0]->output();
79
80         DataMap<BatchSupport> out;
81
82         out[input] = BatchSupport::Split;
83         out[output] = BatchSupport::Split;
84
85         return out;
86     }
87
88     StageSHAVEsRequirements getSHAVEsRequirementsImpl() const override {
89         return StageSHAVEsRequirements::OnlyOne;
90     }
91
92     void finalCheckImpl() const override {
93     }
94
95     void serializeParamsImpl(BlobSerializer&) const override {
96     }
97
98     void serializeDataImpl(BlobSerializer& serializer) const override {
99         IE_ASSERT(_inputEdges.size() == 2);
100         IE_ASSERT(_outputEdges.size() == 1);
101         IE_ASSERT(_tempBufferEdges.empty());
102
103         auto input0 = _inputEdges[0]->input();
104         auto input1 = _inputEdges[1]->input();
105         auto output = _outputEdges[0]->output();
106
107         input0->serializeOldBuffer(handle_from_this(), serializer);
108         input1->serializeOldBuffer(handle_from_this(), serializer);
109         output->serializeOldBuffer(handle_from_this(), serializer);
110     }
111 };
112
113 }  // namespace
114
115 void FrontEnd::parseCTCDecoder(
116         const Model::Ptr& model,
117         const ie::CNNLayerPtr& layer,
118         const DataVector& inputs,
119         const DataVector& outputs) {
120     IE_ASSERT(inputs.size() == 2);
121     IE_ASSERT(outputs.size() == 1);
122
123     auto ctc_merge_repeated_ = layer->GetParamAsInt("ctc_merge_repeated", 1);
124     if (ctc_merge_repeated_ != 1) {
125         VPU_THROW_EXCEPTION
126             << layer->name <<  " [" << layer->type
127             << "] has incorrect ctc_merge_repeated param value."
128             << " Kernel support case when ctc_merge_repeated_ == 1 only";
129     }
130
131     model->addNewStage<CTCDecoderStage>(
132         layer->name,
133         StageType::CTCDecoder,
134         layer,
135         inputs,
136         outputs);
137 }
138
139 }  // namespace vpu