Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / ctc_decoder.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/frontend/frontend.hpp>
6
7 #include <vector>
8 #include <unordered_set>
9 #include <memory>
10
11 namespace vpu {
12
13 namespace {
14
15 class CTCDecoderStage final : public StageNode {
16 private:
17     StagePtr cloneImpl() const override {
18         return std::make_shared<CTCDecoderStage>(*this);
19     }
20
21     void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
22         auto input = inputEdge(0)->input();
23         auto output = outputEdge(0)->output();
24
25         auto cInd = input->desc().dimsOrder().dimInd(Dim::C);
26         orderInfo.setOutput(outputEdge(0), output->desc().dimsOrder().createMovedDim(Dim::C, cInd));
27     }
28
29     void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
30         stridesInfo.setInput(inputEdge(0), StridesRequirement::compact());
31         stridesInfo.setOutput(outputEdge(0), StridesRequirement::compact());
32     }
33
34     void finalizeDataLayoutImpl() override {
35     }
36
37     void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
38         batchInfo.setInput(inputEdge(0), BatchSupport::Split);
39         batchInfo.setOutput(outputEdge(0), BatchSupport::Split);
40     }
41
42     StageSHAVEsRequirements getSHAVEsRequirementsImpl() const override {
43         return StageSHAVEsRequirements::OnlyOne;
44     }
45
46     void initialCheckImpl() const override {
47         assertInputsOutputsTypes(this, {{DataType::FP16}, {DataType::FP16}}, {{DataType::FP16}});
48     }
49
50     void serializeParamsImpl(BlobSerializer&) const override {
51     }
52
53     void serializeDataImpl(BlobSerializer& serializer) const override {
54         auto input0 = inputEdge(0)->input();
55         auto input1 = inputEdge(1)->input();
56         auto output = outputEdge(0)->output();
57
58         input0->serializeOldBuffer(handle_from_this(), serializer);
59         input1->serializeOldBuffer(handle_from_this(), serializer);
60         output->serializeOldBuffer(handle_from_this(), serializer);
61     }
62 };
63
64 }  // namespace
65
66 void FrontEnd::parseCTCDecoder(
67         const Model::Ptr& model,
68         const ie::CNNLayerPtr& layer,
69         const DataVector& inputs,
70         const DataVector& outputs) {
71     IE_ASSERT(inputs.size() == 2);
72     IE_ASSERT(outputs.size() == 1);
73
74     auto ctc_merge_repeated_ = layer->GetParamAsInt("ctc_merge_repeated", 1);
75     if (ctc_merge_repeated_ != 1) {
76         VPU_THROW_EXCEPTION
77             << layer->name <<  " [" << layer->type
78             << "] has incorrect ctc_merge_repeated param value."
79             << " Kernel support case when ctc_merge_repeated_ == 1 only";
80     }
81
82     model->addNewStage<CTCDecoderStage>(
83         layer->name,
84         StageType::CTCDecoder,
85         layer,
86         inputs,
87         outputs);
88 }
89
90 }  // namespace vpu