1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <vpu/frontend/frontend.hpp>
8 #include <unordered_set>
15 class CTCDecoderStage final : public StageNode {
17 StagePtr cloneImpl() const override {
18 return std::make_shared<CTCDecoderStage>(*this);
21 void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
22 auto input = inputEdge(0)->input();
23 auto output = outputEdge(0)->output();
25 auto cInd = input->desc().dimsOrder().dimInd(Dim::C);
26 orderInfo.setOutput(outputEdge(0), output->desc().dimsOrder().createMovedDim(Dim::C, cInd));
29 void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
30 stridesInfo.setInput(inputEdge(0), StridesRequirement::compact());
31 stridesInfo.setOutput(outputEdge(0), StridesRequirement::compact());
34 void finalizeDataLayoutImpl() override {
37 void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
38 batchInfo.setInput(inputEdge(0), BatchSupport::Split);
39 batchInfo.setOutput(outputEdge(0), BatchSupport::Split);
42 StageSHAVEsRequirements getSHAVEsRequirementsImpl() const override {
43 return StageSHAVEsRequirements::OnlyOne;
46 void initialCheckImpl() const override {
47 assertInputsOutputsTypes(this, {{DataType::FP16}, {DataType::FP16}}, {{DataType::FP16}});
50 void serializeParamsImpl(BlobSerializer&) const override {
53 void serializeDataImpl(BlobSerializer& serializer) const override {
54 auto input0 = inputEdge(0)->input();
55 auto input1 = inputEdge(1)->input();
56 auto output = outputEdge(0)->output();
58 input0->serializeOldBuffer(handle_from_this(), serializer);
59 input1->serializeOldBuffer(handle_from_this(), serializer);
60 output->serializeOldBuffer(handle_from_this(), serializer);
66 void FrontEnd::parseCTCDecoder(
67 const Model::Ptr& model,
68 const ie::CNNLayerPtr& layer,
69 const DataVector& inputs,
70 const DataVector& outputs) {
71 IE_ASSERT(inputs.size() == 2);
72 IE_ASSERT(outputs.size() == 1);
74 auto ctc_merge_repeated_ = layer->GetParamAsInt("ctc_merge_repeated", 1);
75 if (ctc_merge_repeated_ != 1) {
77 << layer->name << " [" << layer->type
78 << "] has incorrect ctc_merge_repeated param value."
79 << " Kernel support case when ctc_merge_repeated_ == 1 only";
82 model->addNewStage<CTCDecoderStage>(
84 StageType::CTCDecoder,