1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
12 namespace InferenceEngine {
13 namespace Extensions {
16 class CTCGreedyDecoderImpl: public ExtLayerBase {
18 explicit CTCGreedyDecoderImpl(const CNNLayer* layer) {
20 if (layer->insData.empty() || layer->outData.size() != 1)
21 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
23 std::vector<DataConfigurator> inps;
24 inps.resize(layer->insData.size(), DataConfigurator(ConfLayout::PLN));
25 addConfig(layer, inps, {DataConfigurator(ConfLayout::PLN)});
26 } catch (InferenceEngine::details::InferenceEngineException &ex) {
31 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
32 ResponseDesc *resp) noexcept override {
33 if ((inputs.size() != 1 && inputs.size() != 2) || outputs.empty()) {
35 std::string errorMsg = "Incorrect number of input or output edges!";
36 errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
40 const float* probabilities = inputs[0]->buffer();
41 const float* sequence_indicators = inputs[1]->buffer();
42 float* output_sequences = outputs[0]->buffer();
44 size_t T_ = inputs[0]->getTensorDesc().getDims()[0];
45 size_t N_ = inputs[0]->getTensorDesc().getDims()[1];
46 size_t C_ = inputs[0]->getTensorDesc().getDims()[2];
48 // Fill output_sequences with -1
49 for (size_t ii = 0; ii < T_*N_; ii++) {
50 output_sequences[ii] = -1;
53 for (size_t n = 0; n < N_; ++n) {
54 int prev_class_idx = -1;
55 size_t output_index = n*T_;
57 for (int t = 0; /* check at end */; ++t) {
58 // get maximum probability and its index
59 int max_class_idx = 0;
61 const float* probs = probabilities + t*C_*N_ + n*C_;
62 float max_prob = probs[0];
65 for (size_t c = 1; c < C_; ++c, ++probs) {
66 if (*probs > max_prob) {
67 max_class_idx = static_cast<int>(c);
72 if (max_class_idx < static_cast<int>(C_) - 1 &&
73 max_class_idx != prev_class_idx) {
74 output_sequences[output_index] = static_cast<float>(max_class_idx);
78 prev_class_idx = max_class_idx;
80 if (t + 1 == static_cast<int>(T_) || sequence_indicators[(t + 1)*N_ + n] == 0) {
89 REG_FACTORY_FOR(ImplFactory<CTCGreedyDecoderImpl>, CTCGreedyDecoder);
92 } // namespace Extensions
93 } // namespace InferenceEngine