1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <vpu/frontend/frontend.hpp>
8 #include <unordered_set>
15 class CTCDecoderStage final : public StageNode {
17 StagePtr cloneImpl() const override {
18 return std::make_shared<CTCDecoderStage>(*this);
21 DataMap<float> propagateScaleFactorsImpl(
22 const DataMap<float>&,
23 ScalePropagationStep) override {
24 IE_ASSERT(_inputEdges.size() == 2);
25 IE_ASSERT(_outputEdges.size() == 1);
27 auto input0 = _inputEdges[0]->input();
28 auto input1 = _inputEdges[1]->input();
29 auto output = _outputEdges[0]->output();
40 DataMap<DimsOrder> propagateDataOrderImpl() const override {
41 IE_ASSERT(_inputEdges.size() == 2);
42 IE_ASSERT(_outputEdges.size() == 1);
44 auto input = _inputEdges[0]->input();
45 auto output = _outputEdges[0]->output();
47 DataMap<DimsOrder> out;
49 auto cInd = input->desc().dimsOrder().dimInd(Dim::C);
50 out[output] = output->desc().dimsOrder().createMovedDim(Dim::C, cInd);
55 DataMap<StridesRequirement> getDataStridesRequirementsImpl() const override {
56 IE_ASSERT(_inputEdges.size() == 2);
57 IE_ASSERT(_outputEdges.size() == 1);
59 auto input = _inputEdges[0]->input();
60 auto output = _outputEdges[0]->output();
62 DataMap<StridesRequirement> out;
64 out[input] = StridesRequirement::compact();
65 out[output] = StridesRequirement::compact();
70 void finalizeDataLayoutImpl() override {
73 DataMap<BatchSupport> getBatchSupportInfoImpl() const override {
74 IE_ASSERT(_inputEdges.size() == 2);
75 IE_ASSERT(_outputEdges.size() == 1);
77 auto input = _inputEdges[0]->input();
78 auto output = _outputEdges[0]->output();
80 DataMap<BatchSupport> out;
82 out[input] = BatchSupport::Split;
83 out[output] = BatchSupport::Split;
88 StageSHAVEsRequirements getSHAVEsRequirementsImpl() const override {
89 return StageSHAVEsRequirements::OnlyOne;
92 void finalCheckImpl() const override {
95 void serializeParamsImpl(BlobSerializer&) const override {
98 void serializeDataImpl(BlobSerializer& serializer) const override {
99 IE_ASSERT(_inputEdges.size() == 2);
100 IE_ASSERT(_outputEdges.size() == 1);
101 IE_ASSERT(_tempBufferEdges.empty());
103 auto input0 = _inputEdges[0]->input();
104 auto input1 = _inputEdges[1]->input();
105 auto output = _outputEdges[0]->output();
107 input0->serializeOldBuffer(handle_from_this(), serializer);
108 input1->serializeOldBuffer(handle_from_this(), serializer);
109 output->serializeOldBuffer(handle_from_this(), serializer);
115 void FrontEnd::parseCTCDecoder(
116 const Model::Ptr& model,
117 const ie::CNNLayerPtr& layer,
118 const DataVector& inputs,
119 const DataVector& outputs) {
120 IE_ASSERT(inputs.size() == 2);
121 IE_ASSERT(outputs.size() == 1);
123 auto ctc_merge_repeated_ = layer->GetParamAsInt("ctc_merge_repeated", 1);
124 if (ctc_merge_repeated_ != 1) {
126 << layer->name << " [" << layer->type
127 << "] has incorrect ctc_merge_repeated param value."
128 << " Kernel support case when ctc_merge_repeated_ == 1 only";
131 model->addNewStage<CTCDecoderStage>(
133 StageType::CTCDecoder,