From ff7fc01c7614b5b50442410b9dbd4f66d22fbd47 Mon Sep 17 00:00:00 2001 From: Nikolay Shchegolev Date: Mon, 5 Oct 2020 11:58:54 +0300 Subject: [PATCH] [CPU] CTCLoss performance improvement. --- .../src/mkldnn_plugin/nodes/ctc_loss.cpp | 405 +++++++++------------ 1 file changed, 179 insertions(+), 226 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/nodes/ctc_loss.cpp b/inference-engine/src/mkldnn_plugin/nodes/ctc_loss.cpp index f29b7ce..1453a14 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/ctc_loss.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/ctc_loss.cpp @@ -60,6 +60,8 @@ public: StatusCode execute(std::vector& inputs, std::vector& outputs, ResponseDesc *resp) noexcept override { + StatusCode returnCode = OK; + const float* logits = inputs[0]->cbuffer().as() + inputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding(); const int* logitsLength = inputs[1]->cbuffer().as() + @@ -72,257 +74,210 @@ public: outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding(); const auto& logitsShape = inputs[0]->getTensorDesc().getDims(); - const auto batchNum = logitsShape[0]; - const auto maxTime = logitsShape[1]; - const auto classesNum = logitsShape[2]; + const size_t batchNum = logitsShape[0]; + const size_t maxTime = logitsShape[1]; + const size_t classesNum = logitsShape[2]; int blankIndex = classesNum - 1; if (inputs.size() > 4) { blankIndex = inputs[4]->cbuffer().as()[0]; } - std::vector targetD(maxTime); - - const size_t TC = maxTime * classesNum; - - for (size_t b = 0; b < batchNum; b++) { - const int actualLogitLen = logitsLength[b]; - const int actualTargetLen = labelsLength[b]; - if (actualLogitLen < 0 || actualTargetLen < 0 || actualLogitLen > maxTime || actualTargetLen > maxTime - || actualTargetLen > actualLogitLen) { - std::string errorMsg = _logPrefix + ". Logit or label length cannot be greater than max sequence length. " - + "Also a label length cannot be greater than a logit length" - + " and both cannot be negative.\nMaxSeqLen: " - + std::to_string(maxTime) + "; Logit len: " + std::to_string(actualLogitLen) - + "; Label len: " + std::to_string(actualTargetLen); - errorMsg.copy(resp->msg, sizeof(resp->msg) - 1); - return GENERAL_ERROR; - } - - const int* target = &labels[b * maxTime]; - // Decoding target: merge repeated characters if preprocess_collapse_repeated == True, - // find unique elemnts if unique == True - size_t decodedTargetLen = 0lu; - if (_unique) { - std::unordered_set uniqVals; - for (size_t t = 0lu; t < actualTargetLen; t++) { - if (uniqVals.find(target[t]) != uniqVals.end()) { - continue; - } - uniqVals.insert(target[t]); - targetD[decodedTargetLen++] = target[t]; + std::vector decodedTargetLenB(batchNum, 0); + std::vector> targetDB(batchNum); + std::vector>> logProbabilitiesB(batchNum); + size_t workAmount2 = 0lu; + std::vector errorMsgB(parallel_get_max_threads()); + + auto threadBody_1 = [&](const int ithr, const int nthr) { + size_t start(0lu), end(0lu); + splitter(batchNum, nthr, ithr, start, end); + if (start >= end) + return; + + for (size_t b = start; b < end; b++) { + if (logitsLength[b] < 0 || labelsLength[b] < 0 || logitsLength[b] > maxTime || labelsLength[b] > logitsLength[b]) { + errorMsgB[ithr] = _logPrefix + ". Logit length cannot be greater than max sequence length. " + + "Label length cannot be greater than a logit length" + + " and both cannot be negative.\nMaxSeqLen: " + + std::to_string(maxTime) + "; Logit len: " + std::to_string(logitsLength[b]) + + "; Label len: " + std::to_string(labelsLength[b]); + returnCode = GENERAL_ERROR; + return; } - } else if (_preprocessCollapseRepeated) { - int prevValue = target[0]; - targetD[decodedTargetLen++] = target[0]; - for (size_t t = 1lu; t < actualTargetLen; t++) { - if (target[t] == prevValue) { - continue; + const size_t actualLogitLen = logitsLength[b]; + const size_t actualTargetLen = labelsLength[b]; + size_t decodedTargetLen = 0lu; + + // Decoding target: merge repeated characters if preprocess_collapse_repeated == True, + // find unique elemnts if unique == True. + // Inserts blanks before each index and a blank at the end. + const int* target = &labels[b * maxTime]; + targetDB[b].resize(actualTargetLen * 2 + 1); + auto& targetD = targetDB[b]; + if (_unique) { + std::unordered_set uniqVals; + for (size_t t = 0lu; t < actualTargetLen; t++) { + if (uniqVals.find(target[t]) != uniqVals.end()) { + continue; + } + uniqVals.insert(target[t]); + targetD[decodedTargetLen++] = blankIndex; + targetD[decodedTargetLen++] = target[t]; + } + targetD[decodedTargetLen++] = blankIndex; + } else if (_preprocessCollapseRepeated) { + auto prevValue = target[0]; + targetD[decodedTargetLen++] = blankIndex; + targetD[decodedTargetLen++] = target[0]; + for (size_t t = 1lu; t < actualTargetLen; t++) { + if (target[t] == prevValue) { + continue; + } + targetD[decodedTargetLen++] = blankIndex; + targetD[decodedTargetLen++] = prevValue = target[t]; } - targetD[decodedTargetLen++] = target[t]; - prevValue = target[t]; + targetD[decodedTargetLen++] = blankIndex; + } else { + for (size_t t = 0lu; t < actualTargetLen; t++) { + targetD[decodedTargetLen++] = blankIndex; + targetD[decodedTargetLen++] = target[t]; + } + targetD[decodedTargetLen++] = blankIndex; } - } else { - std::copy(target, target + actualTargetLen, targetD.data()); - decodedTargetLen = actualTargetLen; - } - - const size_t BTC = b * TC; + decodedTargetLenB[b] = decodedTargetLen; - std::vector> logProbabilities(actualLogitLen); - float logProb = 0.f, kExp = 0.f; - for (size_t t = 0; t < actualLogitLen; t++) { - kExp = 0.f; - const size_t btcT = BTC + classesNum * t; - for (size_t c = 0; c < classesNum; c++) { - kExp += std::exp(logits[btcT + c]); + auto& logProbabilities = logProbabilitiesB[b]; + logProbabilities.resize(actualLogitLen); + for (size_t ll = 0; ll < actualLogitLen; ll++) { + logProbabilities[ll].resize(decodedTargetLen); } - for (size_t s = 0; s < decodedTargetLen; s++) { - logProb = logits[btcT + targetD[s]] - std::log(kExp); - logProbabilities[t].insert({targetD[s], logProb}); - } - logProb = logits[btcT + blankIndex] - std::log(kExp); - logProbabilities[t].insert({blankIndex, logProb}); + workAmount2 += actualLogitLen; + } // for batch + }; // threadBody_1 + + parallel_nt(0, threadBody_1); + if (returnCode != OK) { + std::string resErr(""); + for (auto& err : errorMsgB) { + if (!err.empty()) + resErr += err + "\n"; + resErr.copy(resp->msg, sizeof(resp->msg) - 1); } + return returnCode; + } - const auto float_inf = std::numeric_limits::infinity(); - size_t work_amount = actualLogitLen - decodedTargetLen + 1lu; - std::vector sumPerThread(parallel_get_max_threads(), -float_inf); + const size_t TC = maxTime * classesNum; - // Looking for aligned paths - auto thread_body = [&](const int ithr, const int nthr) { - size_t start0(0lu), end0(0lu); - splitter(work_amount, nthr, ithr, start0, end0); - if (start0 >= end0) - return; - if (ithr >= sumPerThread.size()) - sumPerThread.push_back(-float_inf); - - std::function findPaths = - [&](size_t targetIdx, size_t start, size_t end, float prevLogProb) { - if (end > actualLogitLen) { - if (sumPerThread[ithr] == -float_inf) { - sumPerThread[ithr] = prevLogProb; - } else if (prevLogProb != -float_inf) { - if (sumPerThread[ithr] > prevLogProb) - sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(prevLogProb - sumPerThread[ithr])); - else - sumPerThread[ithr] = prevLogProb + std::log1pf(std::exp(sumPerThread[ithr] - prevLogProb)); - } - return; + auto threadBody_2 = [&](const int ithr, const int nthr) { + size_t start(0lu), end(0lu); + size_t sB(0lu), sT(0lu); + splitter(workAmount2, nthr, ithr, start, end); + if (start >= end) + return; + int64_t cw = 0, st = start; + for (; sB < batchNum; sB++) { + cw += logitsLength[sB]; + if (cw >= st) { + sT = logitsLength[sB] + st - cw; + break; + } + } + size_t workCounter = start; + + for (size_t b = sB; b < batchNum; b++) { + const size_t actualLogitLen = logitsLength[b]; + const size_t decodedTargetLen = decodedTargetLenB[b]; + auto& logProbabilities = logProbabilitiesB[b]; + auto& targetD = targetDB[b]; + + double expSum = 0.0; + size_t btcT = b * TC + sT * classesNum; + // logProbabilities = logSoftmax = logits[b][t][c] - ln(sum_c(exp(logits[b][t]))) + for (size_t t = sT; t < actualLogitLen; t++) { + expSum = 0.0; + for (size_t c = 0lu; c < classesNum; c++) { + expSum += std::exp(logits[btcT + c]); } - - size_t nextIdx = targetIdx + 1; - int64_t st64 = start; - float newLogProb = prevLogProb; - if (!_ctcMergeRepeated) { - for (size_t pos = start; pos < end; pos++) { - newLogProb = prevLogProb; - for (size_t bl = start; bl < pos; bl++) { - auto lnProbIt = logProbabilities[bl].find(blankIndex); - if (lnProbIt != logProbabilities[bl].end()) - newLogProb += lnProbIt->second; - } - auto lnProbIt = logProbabilities[pos].find(targetD[targetIdx]); - if (lnProbIt != logProbabilities[pos].end()) - newLogProb += lnProbIt->second; - if (end == actualLogitLen) { - for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) { - auto lnProbIt = logProbabilities[ble].find(blankIndex); - if (lnProbIt != logProbabilities[ble].end()) - newLogProb += lnProbIt->second; - } - } - findPaths(nextIdx, pos + 1, end + 1, newLogProb); - } - } else { - for (size_t pos = start; pos < end; pos++) { - newLogProb = prevLogProb; - size_t next_start = pos + 1; - for (size_t bl = start; bl < pos; bl++) { - auto lnProbIt = logProbabilities[bl].find(blankIndex); - if (lnProbIt != logProbabilities[bl].end()) - newLogProb += lnProbIt->second; - } - if (end == actualLogitLen) { - for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) { - auto lnProbIt = logProbabilities[ble].find(blankIndex); - if (lnProbIt != logProbabilities[ble].end()) - newLogProb += lnProbIt->second; - } - } - if (targetIdx < decodedTargetLen - 1 - && targetD[targetIdx] == targetD[targetIdx + 1]) { - auto lnProbIt = logProbabilities[next_start++].find(blankIndex); - if (lnProbIt != logProbabilities[next_start].end()) - newLogProb += lnProbIt->second; - } - for (int64_t bl = pos; bl >= st64; bl--) { - newLogProb += logProbabilities[bl].find(targetD[targetIdx])->second; - findPaths(nextIdx, next_start, end + 1, newLogProb); - if (bl > 0) { - auto lnProbIt = logProbabilities[bl - 1].find(blankIndex); - if (lnProbIt != logProbabilities[bl - 1].end()) - newLogProb -= lnProbIt->second; - } - } - } + for (size_t s = 0lu; s < decodedTargetLen; s++) { + logProbabilities[t][s] = logits[btcT + targetD[s]] - std::log(expSum); } - }; // findPaths - - // First tartget symbol - int64_t st64 = start0; - float newLogProb = 0.f; - if (!_ctcMergeRepeated) { - for (size_t pos = start0; pos < end0; pos++) { - newLogProb = 0.f; - for (size_t bl = 0; bl < pos; bl++) { - auto lnProbIt = logProbabilities[bl].find(blankIndex); - if (lnProbIt != logProbabilities[bl].end()) - newLogProb += lnProbIt->second; - } - auto lnProbIt = logProbabilities[pos].find(targetD[0]); - if (lnProbIt != logProbabilities[pos].end()) - newLogProb += lnProbIt->second; - if (work_amount == actualLogitLen) { - for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) { - auto lnProbIt = logProbabilities[ble].find(blankIndex); - if (lnProbIt != logProbabilities[ble].end()) - newLogProb += lnProbIt->second; - } - } - if (decodedTargetLen > 1) { - findPaths(1, pos + 1, work_amount + 1, newLogProb); - } else { - if (sumPerThread[ithr] == -float_inf) - sumPerThread[ithr] = newLogProb; - else if (newLogProb != -float_inf) - sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(newLogProb - sumPerThread[ithr])); - } + btcT += classesNum; + if (++workCounter >= end) { + return; } - } else { - for (size_t pos = start0; pos < end0; pos++) { - newLogProb = 0.f; - size_t next_start = pos + 1; - for (size_t bl = 0; bl < pos; bl++) { - auto lnProbIt = logProbabilities[bl].find(blankIndex); - if (lnProbIt != logProbabilities[bl].end()) - newLogProb += lnProbIt->second; - } - if (work_amount == actualLogitLen) { - for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) { - auto lnProbIt = logProbabilities[ble].find(blankIndex); - if (lnProbIt != logProbabilities[ble].end()) - newLogProb += lnProbIt->second; - } + } + sT = 0lu; + } // for batch + }; // threadBody_2 + + parallel_nt(0, threadBody_2); + + const auto float_inf = std::numeric_limits::infinity(); + + auto sumLogs = [&float_inf](float log1, float log2) { + if (log1 == -float_inf) { + return log2; + } else if (log2 == -float_inf) { + return log1; + } else { + if (log1 > log2) + return log1 + std::log1pf(std::exp(log2 - log1)); + else + return log2 + std::log1pf(std::exp(log1 - log2)); + } + }; + + auto threadBody_3 = [&](const int ithr, const int nthr) { + size_t start(0lu), end(0lu); + splitter(batchNum, nthr, ithr, start, end); + if (start >= end) + return; + + // As per Connectionist Temporal Classification - Labeling Unsegmented Sequence Data with Recurrent Neural Networks: + // Graves et al., 2016, paragraph 4.1 (10) + for (size_t b = start; b < end; b++) { + auto& targetD = targetDB[b]; + auto& logProbabilities = logProbabilitiesB[b]; + const int actualLogitLen = logitsLength[b]; + const int decodedTargetLen = decodedTargetLenB[b]; + std::vector> logBwd(decodedTargetLen, std::vector(actualLogitLen, -float_inf)); + for (int s = decodedTargetLen - 2; s < decodedTargetLen; s++) + logBwd[s][actualLogitLen - 1] = 0.f; + + for (int t = actualLogitLen - 2; t >= 0; t--) { + const int t_1 = t + 1; + for (int s = std::max(0, decodedTargetLen - (2 * (actualLogitLen - t))); + s < std::min(decodedTargetLen, 2 * (t_1)); s++) { + if (_ctcMergeRepeated || targetD[s] == blankIndex) { + logBwd[s][t] = sumLogs(logBwd[s][t], + logBwd[s][t_1] + logProbabilities[t_1][s]); } - if (decodedTargetLen > 1 - && targetD[0] == targetD[1]) { - auto lnProbIt = logProbabilities[next_start++].find(blankIndex); - if (lnProbIt != logProbabilities[next_start].end()) - newLogProb += lnProbIt->second; + + if (s + 1 < decodedTargetLen) { + logBwd[s][t] = sumLogs(logBwd[s][t], + logBwd[s + 1][t_1] + logProbabilities[t_1][s + 1]); } - for (int64_t bl = pos; bl >= 0; bl--) { - auto lnProbIt = logProbabilities[bl].find(targetD[0]); - if (lnProbIt != logProbabilities[bl].end()) - newLogProb += lnProbIt->second; - if (decodedTargetLen > 1) { - findPaths(1, next_start, work_amount + 1, newLogProb); - } else { - if (sumPerThread[ithr] == -float_inf) - sumPerThread[ithr] = newLogProb; - else if (newLogProb != -float_inf) - sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(newLogProb - sumPerThread[ithr])); - } - if (bl > 0) { - auto lnProbIt = logProbabilities[bl - 1].find(blankIndex); - if (lnProbIt != logProbabilities[bl - 1].end()) - newLogProb -= lnProbIt->second; + + if (s + 2 < decodedTargetLen) { + if (targetD[s] != blankIndex && (!_ctcMergeRepeated || (targetD[s] != targetD[s + 2]))) { + logBwd[s][t] = sumLogs(logBwd[s][t], + logBwd[s + 2][t_1] + logProbabilities[t_1][s + 2]); } } } } - }; // thread_body - - parallel_nt(0, thread_body); - float res = -float_inf; + logBwd[0][0] += logProbabilities[0][0]; + logBwd[1][0] += logProbabilities[0][(decodedTargetLen > 1) ? 1 : 0]; - for (auto sum : sumPerThread) { - if (res == -float_inf) { - res = sum; - } else if (sum != -float_inf) { - if (res > sum) - res = res + std::log1pf(std::exp(sum - res)); - else - res = sum + std::log1pf(std::exp(res - sum)); - } - } + dstData[b] = -sumLogs(logBwd[0][0], logBwd[1][0]); + } // for batch + }; // threadBody_3 - dstData[b] = -res; - } // for (size_t b = 0; b < batchNum; b++) + parallel_nt(0, threadBody_3); - return OK; + return returnCode; } // execute protected: @@ -334,8 +289,6 @@ protected: }; REG_FACTORY_FOR(CTCLossImpl, CTCLoss); - } // namespace Cpu } // namespace Extensions } // namespace InferenceEngine - -- 2.7.4