From 955cf35d5f890b85baa12b254a325b98880813d0 Mon Sep 17 00:00:00 2001 From: JIANG Yichen Date: Mon, 9 Aug 2021 13:46:11 +0800 Subject: [PATCH] Implement ctc prefix beam search decode for TextRecognitionModel. The algorithm is based on Hannun's paper: First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs --- .../dnn_text_spotting/dnn_text_spotting.markdown | 5 + modules/dnn/include/opencv2/dnn/dnn.hpp | 13 +- modules/dnn/src/math_utils.hpp | 83 +++++++ modules/dnn/src/model.cpp | 248 ++++++++++++++++++--- modules/dnn/test/test_model.cpp | 19 ++ 5 files changed, 332 insertions(+), 36 deletions(-) create mode 100644 modules/dnn/src/math_utils.hpp diff --git a/doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown b/doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown index 5f28b6c..b0be262 100644 --- a/doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown +++ b/doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown @@ -26,6 +26,11 @@ Before recognition, you should `setVocabulary` and `setDecodeType`. - `T` is the sequence length - `B` is the batch size (only support `B=1` in inference) - and `Dim` is the length of vocabulary +1('Blank' of CTC is at the index=0 of Dim). +- "CTC-prefix-beam-search", the output of the text recognition model should be a probability matrix same with "CTC-greedy". + - The algorithm is proposed at Hannun's [paper](https://arxiv.org/abs/1408.2873). + - `setDecodeOptsCTCPrefixBeamSearch` could be used to control the beam size in search step. + - To futher optimize for big vocabulary, a new option `vocPruneSize` is introduced to avoid iterate the whole vocbulary + but only the number of `vocPruneSize` tokens with top probabilty. @ref cv::dnn::TextRecognitionModel::recognize() is the main function for text recognition. - The input image should be a cropped text image or an image with `roiRects` diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index 255b41d..a498039 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -1373,7 +1373,9 @@ public: /** * @brief Set the decoding method of translating the network output into string - * @param[in] decodeType The decoding method of translating the network output into string: {'CTC-greedy': greedy decoding for the output of CTC-based methods} + * @param[in] decodeType The decoding method of translating the network output into string, currently supported type: + * - `"CTC-greedy"` greedy decoding for the output of CTC-based methods + * - `"CTC-prefix-beam-search"` Prefix beam search decoding for the output of CTC-based methods */ CV_WRAP TextRecognitionModel& setDecodeType(const std::string& decodeType); @@ -1386,6 +1388,15 @@ public: const std::string& getDecodeType() const; /** + * @brief Set the decoding method options for `"CTC-prefix-beam-search"` decode usage + * @param[in] beamSize Beam size for search + * @param[in] vocPruneSize Parameter to optimize big vocabulary search, + * only take top @p vocPruneSize tokens in each search step, @p vocPruneSize <= 0 stands for disable this prune. + */ + CV_WRAP + TextRecognitionModel& setDecodeOptsCTCPrefixBeamSearch(int beamSize, int vocPruneSize = 0); + + /** * @brief Set the vocabulary for recognition. * @param[in] vocabulary the associated vocabulary of the network. */ diff --git a/modules/dnn/src/math_utils.hpp b/modules/dnn/src/math_utils.hpp new file mode 100644 index 0000000..19ee474 --- /dev/null +++ b/modules/dnn/src/math_utils.hpp @@ -0,0 +1,83 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Code is borrowed from https://github.com/kaldi-asr/kaldi/blob/master/src/base/kaldi-math.h + +// base/kaldi-math.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian; +// Jan Silovsky; Saarland University +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef __OPENCV_DNN_MATH_UTILS_HPP__ +#define __OPENCV_DNN_MATH_UTILS_HPP__ + +#ifdef OS_QNX +#include +#else +#include +#endif + +#include + +#ifndef FLT_EPSILON +#define FLT_EPSILON 1.19209290e-7f +#endif + +namespace cv { namespace dnn { + +const float kNegativeInfinity = -std::numeric_limits::infinity(); + +const float kMinLogDiffFloat = std::log(FLT_EPSILON); + +#if !defined(_MSC_VER) || (_MSC_VER >= 1700) +inline float Log1p(float x) { return log1pf(x); } +#else +inline float Log1p(float x) { + const float cutoff = 1.0e-07; + if (x < cutoff) + return x - 2 * x * x; + else + return Log(1.0 + x); +} +#endif + +inline float Exp(float x) { return expf(x); } + +inline float LogAdd(float x, float y) { + float diff; + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffFloat) { + float res; + res = x + Log1p(Exp(diff)); + return res; + } else { + return x; // return the larger one. + } +} + +}} // namespace + +#endif // __OPENCV_DNN_MATH_UTILS_HPP__ diff --git a/modules/dnn/src/model.cpp b/modules/dnn/src/model.cpp index 0af8223..bc8709d 100644 --- a/modules/dnn/src/model.cpp +++ b/modules/dnn/src/model.cpp @@ -3,8 +3,10 @@ // of this distribution and at http://opencv.org/license.html. #include "precomp.hpp" +#include "math_utils.hpp" #include #include +#include #include #include @@ -552,6 +554,9 @@ struct TextRecognitionModel_Impl : public Model::Impl std::string decodeType; std::vector vocabulary; + int beamSize = 10; + int vocPruneSize = 0; + TextRecognitionModel_Impl() { CV_TRACE_FUNCTION(); @@ -575,6 +580,13 @@ struct TextRecognitionModel_Impl : public Model::Impl decodeType = type; } + inline + void setDecodeOptsCTCPrefixBeamSearch(int beam, int vocPrune) + { + beamSize = beam; + vocPruneSize = vocPrune; + } + virtual std::string decode(const Mat& prediction) { @@ -586,53 +598,213 @@ struct TextRecognitionModel_Impl : public Model::Impl CV_Error(Error::StsBadArg, "TextRecognitionModel: vocabulary is not specified"); std::string decodeSeq; - if (decodeType == "CTC-greedy") + if (decodeType == "CTC-greedy") { + decodeSeq = ctcGreedyDecode(prediction); + } else if (decodeType == "CTC-prefix-beam-search") { + decodeSeq = ctcPrefixBeamSearchDecode(prediction); + } else if (decodeType.length() == 0) { + CV_Error(Error::StsBadArg, "Please set decodeType"); + } else { + CV_Error_(Error::StsBadArg, ("Unsupported decodeType: %s", decodeType.c_str())); + } + + return decodeSeq; + } + + virtual + std::string ctcGreedyDecode(const Mat& prediction) + { + std::string decodeSeq; + CV_CheckEQ(prediction.dims, 3, ""); + CV_CheckType(prediction.type(), CV_32FC1, ""); + const int vocLength = (int)(vocabulary.size()); + CV_CheckLE(prediction.size[1], vocLength, ""); + bool ctcFlag = true; + int lastLoc = 0; + for (int i = 0; i < prediction.size[0]; i++) { - CV_CheckEQ(prediction.dims, 3, ""); - CV_CheckType(prediction.type(), CV_32FC1, ""); - const int vocLength = (int)(vocabulary.size()); - CV_CheckLE(prediction.size[1], vocLength, ""); - bool ctcFlag = true; - int lastLoc = 0; - for (int i = 0; i < prediction.size[0]; i++) + const float* pred = prediction.ptr(i); + int maxLoc = 0; + float maxScore = pred[0]; + for (int j = 1; j < vocLength + 1; j++) { - const float* pred = prediction.ptr(i); - int maxLoc = 0; - float maxScore = pred[0]; - for (int j = 1; j < vocLength + 1; j++) + float score = pred[j]; + if (maxScore < score) { - float score = pred[j]; - if (maxScore < score) - { - maxScore = score; - maxLoc = j; - } + maxScore = score; + maxLoc = j; } + } - if (maxLoc > 0) - { - std::string currentChar = vocabulary.at(maxLoc - 1); - if (maxLoc != lastLoc || ctcFlag) - { - lastLoc = maxLoc; - decodeSeq += currentChar; - ctcFlag = false; - } - } - else + if (maxLoc > 0) + { + std::string currentChar = vocabulary.at(maxLoc - 1); + if (maxLoc != lastLoc || ctcFlag) { - ctcFlag = true; + lastLoc = maxLoc; + decodeSeq += currentChar; + ctcFlag = false; } } - } else if (decodeType.length() == 0) { - CV_Error(Error::StsBadArg, "Please set decodeType"); - } else { - CV_Error_(Error::StsBadArg, ("Unsupported decodeType: %s", decodeType.c_str())); + else + { + ctcFlag = true; + } } - return decodeSeq; } + struct PrefixScore + { + // blank ending score + float pB; + // none blank ending score + float pNB; + + PrefixScore() : pB(kNegativeInfinity), pNB(kNegativeInfinity) + { + + } + PrefixScore(float pB, float pNB) : pB(pB), pNB(pNB) + { + + } + }; + + struct PrefixHash + { + size_t operator()(const std::vector& prefix) const + { + // BKDR hash + unsigned int seed = 131; + size_t hash = 0; + for (size_t i = 0; i < prefix.size(); i++) + { + hash = hash * seed + prefix[i]; + } + return hash; + } + }; + + static + std::vector> TopK( + const float* predictions, int length, int k) + { + std::vector> results; + // No prune. + if (k <= 0) + { + for (int i = 0; i < length; ++i) + { + results.emplace_back(predictions[i], i); + } + return results; + } + + for (int i = 0; i < k; ++i) + { + results.emplace_back(predictions[i], i); + } + std::make_heap(results.begin(), results.end(), std::greater>{}); + + for (int i = k; i < length; ++i) + { + if (predictions[i] > results.front().first) + { + std::pop_heap(results.begin(), results.end(), std::greater>{}); + results.pop_back(); + results.emplace_back(predictions[i], i); + std::push_heap(results.begin(), results.end(), std::greater>{}); + } + } + return results; + } + + static inline + bool PrefixScoreCompare( + const std::pair, PrefixScore>& a, + const std::pair, PrefixScore>& b) + { + float probA = LogAdd(a.second.pB, a.second.pNB); + float probB = LogAdd(b.second.pB, b.second.pNB); + return probA > probB; + } + + virtual + std::string ctcPrefixBeamSearchDecode(const Mat& prediction) { + // CTC prefix beam seach decode. + // For more detail, refer to: + // https://distill.pub/2017/ctc/#inference + // https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0i + using Beam = std::vector, PrefixScore>>; + using BeamInDict = std::unordered_map, PrefixScore, PrefixHash>; + + CV_CheckType(prediction.type(), CV_32FC1, ""); + CV_CheckEQ(prediction.dims, 3, ""); + CV_CheckEQ(prediction.size[1], 1, ""); + CV_CheckEQ(prediction.size[2], (int)vocabulary.size() + 1, ""); // Length add 1 for ctc blank + + std::string decodeSeq; + Beam beam = {std::make_pair(std::vector(), PrefixScore(0.0, kNegativeInfinity))}; + for (int i = 0; i < prediction.size[0]; i++) + { + // Loop over time + BeamInDict nextBeam; + const float* pred = prediction.ptr(i); + std::vector> topkPreds = + TopK(pred, vocabulary.size() + 1, vocPruneSize); + for (const auto& each : topkPreds) + { + // Loop over vocabulary + float prob = each.first; + int token = each.second; + for (const auto& it : beam) + { + const std::vector& prefix = it.first; + const PrefixScore& prefixScore = it.second; + if (token == 0) // 0 stands for ctc blank + { + PrefixScore& nextScore = nextBeam[prefix]; + nextScore.pB = LogAdd(nextScore.pB, + LogAdd(prefixScore.pB + prob, prefixScore.pNB + prob)); + continue; + } + + std::vector nPrefix(prefix); + nPrefix.push_back(token); + PrefixScore& nextScore = nextBeam[nPrefix]; + if (prefix.size() > 0 && token == prefix.back()) + { + nextScore.pNB = LogAdd(nextScore.pNB, prefixScore.pB + prob); + PrefixScore& mScore = nextBeam[prefix]; + mScore.pNB = LogAdd(mScore.pNB, prefixScore.pNB + prob); + } + else + { + nextScore.pNB = LogAdd(nextScore.pNB, + LogAdd(prefixScore.pB + prob, prefixScore.pNB + prob)); + } + } + } + // Beam prune + Beam newBeam(nextBeam.begin(), nextBeam.end()); + int newBeamSize = std::min(static_cast(newBeam.size()), beamSize); + std::nth_element(newBeam.begin(), newBeam.begin() + newBeamSize, + newBeam.end(), PrefixScoreCompare); + newBeam.resize(newBeamSize); + std::sort(newBeam.begin(), newBeam.end(), PrefixScoreCompare); + beam = std::move(newBeam); + } + + CV_Assert(!beam.empty()); + for (int token : beam[0].first) + { + CV_Check(token, token > 0 && token <= vocabulary.size(), ""); + decodeSeq += vocabulary.at(token - 1); + } + return decodeSeq; + } + virtual std::string recognize(InputArray frame) { @@ -698,6 +870,12 @@ const std::string& TextRecognitionModel::getDecodeType() const return TextRecognitionModel_Impl::from(impl).decodeType; } +TextRecognitionModel& TextRecognitionModel::setDecodeOptsCTCPrefixBeamSearch(int beamSize, int vocPruneSize) +{ + TextRecognitionModel_Impl::from(impl).setDecodeOptsCTCPrefixBeamSearch(beamSize, vocPruneSize); + return *this; +} + TextRecognitionModel& TextRecognitionModel::setVocabulary(const std::vector& inputVoc) { TextRecognitionModel_Impl::from(impl).setVocabulary(inputVoc); diff --git a/modules/dnn/test/test_model.cpp b/modules/dnn/test/test_model.cpp index f7befa9..6ac9702 100644 --- a/modules/dnn/test/test_model.cpp +++ b/modules/dnn/test/test_model.cpp @@ -615,6 +615,25 @@ TEST_P(Test_Model, TextRecognition) testTextRecognitionModel(weightPath, "", imgPath, seq, decodeType, vocabulary, size, mean, scale); } +TEST_P(Test_Model, TextRecognitionWithCTCPrefixBeamSearch) +{ + if (target == DNN_TARGET_OPENCL_FP16) + applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + + std::string imgPath = _tf("text_rec_test.png"); + std::string weightPath = _tf("onnx/models/crnn.onnx", false); + std::string seq = "welcome"; + + Size size{100, 32}; + double scale = 1.0 / 127.5; + Scalar mean = Scalar(127.5); + std::string decodeType = "CTC-prefix-beam-search"; + std::vector vocabulary = {"0","1","2","3","4","5","6","7","8","9", + "a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z"}; + + testTextRecognitionModel(weightPath, "", imgPath, seq, decodeType, vocabulary, size, mean, scale); +} + TEST_P(Test_Model, TextDetectionByDB) { if (target == DNN_TARGET_OPENCL_FP16) -- 2.7.4