[CPU] CTCLoss performance improvement.
authorNikolay Shchegolev <nikolay.shchegolev@intel.com>
Mon, 5 Oct 2020 08:58:54 +0000 (11:58 +0300)
committerAlexander Peskov <alexander.peskov@intel.com>
Mon, 19 Oct 2020 10:01:39 +0000 (13:01 +0300)
inference-engine/src/mkldnn_plugin/nodes/ctc_loss.cpp

index f29b7ce..1453a14 100644 (file)
@@ -60,6 +60,8 @@ public:
     StatusCode execute(std::vector<Blob::Ptr>& inputs,
                        std::vector<Blob::Ptr>& outputs,
                        ResponseDesc *resp) noexcept override {
+        StatusCode returnCode = OK;
+
         const float* logits = inputs[0]->cbuffer().as<const float*>() +
             inputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
         const int* logitsLength = inputs[1]->cbuffer().as<const int*>() +
@@ -72,257 +74,210 @@ public:
             outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
 
         const auto& logitsShape = inputs[0]->getTensorDesc().getDims();
-        const auto batchNum = logitsShape[0];
-        const auto maxTime = logitsShape[1];
-        const auto classesNum = logitsShape[2];
+        const size_t batchNum = logitsShape[0];
+        const size_t maxTime = logitsShape[1];
+        const size_t classesNum = logitsShape[2];
 
         int blankIndex = classesNum - 1;
         if (inputs.size() > 4) {
             blankIndex = inputs[4]->cbuffer().as<const int*>()[0];
         }
 
-        std::vector<int> targetD(maxTime);
-
-        const size_t TC = maxTime * classesNum;
-
-        for (size_t b = 0; b < batchNum; b++) {
-            const int actualLogitLen = logitsLength[b];
-            const int actualTargetLen = labelsLength[b];
-            if (actualLogitLen < 0 || actualTargetLen < 0 || actualLogitLen > maxTime || actualTargetLen > maxTime
-                    || actualTargetLen > actualLogitLen) {
-                std::string errorMsg = _logPrefix + ". Logit or label length cannot be greater than max sequence length. "
-                    + "Also a label length cannot be greater than a logit length"
-                    + " and both cannot be negative.\nMaxSeqLen: "
-                    + std::to_string(maxTime) + "; Logit len: " + std::to_string(actualLogitLen)
-                    + "; Label len: " + std::to_string(actualTargetLen);
-                errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
-                return GENERAL_ERROR;
-            }
-
-            const int* target = &labels[b * maxTime];
-            // Decoding target: merge repeated characters if preprocess_collapse_repeated == True,
-            // find unique elemnts if unique == True
-            size_t decodedTargetLen = 0lu;
-            if (_unique) {
-                std::unordered_set<int> uniqVals;
-                for (size_t t = 0lu; t < actualTargetLen; t++) {
-                    if (uniqVals.find(target[t]) != uniqVals.end()) {
-                        continue;
-                    }
-                    uniqVals.insert(target[t]);
-                    targetD[decodedTargetLen++] = target[t];
+        std::vector<int> decodedTargetLenB(batchNum, 0);
+        std::vector<std::vector<int>> targetDB(batchNum);
+        std::vector<std::vector<std::vector<float>>> logProbabilitiesB(batchNum);
+        size_t workAmount2 = 0lu;
+        std::vector<std::string> errorMsgB(parallel_get_max_threads());
+
+        auto threadBody_1 = [&](const int ithr, const int nthr) {
+            size_t start(0lu), end(0lu);
+            splitter(batchNum, nthr, ithr, start, end);
+            if (start >= end)
+                return;
+
+            for (size_t b = start; b < end; b++) {
+                if (logitsLength[b] < 0 || labelsLength[b] < 0 || logitsLength[b] > maxTime || labelsLength[b] > logitsLength[b]) {
+                    errorMsgB[ithr] = _logPrefix + ". Logit length cannot be greater than max sequence length. "
+                        + "Label length cannot be greater than a logit length"
+                        + " and both cannot be negative.\nMaxSeqLen: "
+                        + std::to_string(maxTime) + "; Logit len: " + std::to_string(logitsLength[b])
+                        + "; Label len: " + std::to_string(labelsLength[b]);
+                    returnCode = GENERAL_ERROR;
+                    return;
                 }
-            } else if (_preprocessCollapseRepeated) {
-                int prevValue = target[0];
-                targetD[decodedTargetLen++] = target[0];
-                for (size_t t = 1lu; t < actualTargetLen; t++) {
-                    if (target[t] == prevValue) {
-                        continue;
+                const size_t actualLogitLen = logitsLength[b];
+                const size_t actualTargetLen = labelsLength[b];
+                size_t decodedTargetLen = 0lu;
+
+                // Decoding target: merge repeated characters if preprocess_collapse_repeated == True,
+                // find unique elemnts if unique == True.
+                // Inserts blanks before each index and a blank at the end.
+                const int* target = &labels[b * maxTime];
+                targetDB[b].resize(actualTargetLen * 2 + 1);
+                auto& targetD = targetDB[b];
+                if (_unique) {
+                    std::unordered_set<int> uniqVals;
+                    for (size_t t = 0lu; t < actualTargetLen; t++) {
+                        if (uniqVals.find(target[t]) != uniqVals.end()) {
+                            continue;
+                        }
+                        uniqVals.insert(target[t]);
+                        targetD[decodedTargetLen++] = blankIndex;
+                        targetD[decodedTargetLen++] = target[t];
+                    }
+                    targetD[decodedTargetLen++] = blankIndex;
+                } else if (_preprocessCollapseRepeated) {
+                    auto prevValue = target[0];
+                    targetD[decodedTargetLen++] = blankIndex;
+                    targetD[decodedTargetLen++] = target[0];
+                    for (size_t t = 1lu; t < actualTargetLen; t++) {
+                        if (target[t] == prevValue) {
+                            continue;
+                        }
+                        targetD[decodedTargetLen++] = blankIndex;
+                        targetD[decodedTargetLen++] = prevValue = target[t];
                     }
-                    targetD[decodedTargetLen++] = target[t];
-                    prevValue = target[t];
+                    targetD[decodedTargetLen++] = blankIndex;
+                } else {
+                    for (size_t t = 0lu; t < actualTargetLen; t++) {
+                        targetD[decodedTargetLen++] = blankIndex;
+                        targetD[decodedTargetLen++] = target[t];
+                    }
+                    targetD[decodedTargetLen++] = blankIndex;
                 }
-            } else {
-                std::copy(target, target + actualTargetLen, targetD.data());
-                decodedTargetLen = actualTargetLen;
-            }
-
-            const size_t BTC = b * TC;
+                decodedTargetLenB[b] = decodedTargetLen;
 
-            std::vector<std::unordered_map<size_t, float>> logProbabilities(actualLogitLen);
-            float logProb = 0.f, kExp = 0.f;
-            for (size_t t = 0; t < actualLogitLen; t++) {
-                kExp = 0.f;
-                const size_t btcT = BTC + classesNum * t;
-                for (size_t c = 0; c < classesNum; c++) {
-                    kExp += std::exp(logits[btcT + c]);
+                auto& logProbabilities = logProbabilitiesB[b];
+                logProbabilities.resize(actualLogitLen);
+                for (size_t ll = 0; ll < actualLogitLen; ll++) {
+                    logProbabilities[ll].resize(decodedTargetLen);
                 }
-                for (size_t s = 0; s < decodedTargetLen; s++) {
-                    logProb = logits[btcT + targetD[s]] - std::log(kExp);
-                    logProbabilities[t].insert({targetD[s], logProb});
-                }
-                logProb = logits[btcT + blankIndex] - std::log(kExp);
-                logProbabilities[t].insert({blankIndex, logProb});
+                workAmount2 += actualLogitLen;
+            } // for batch
+        }; // threadBody_1
+
+        parallel_nt(0, threadBody_1);
+        if (returnCode != OK) {
+            std::string resErr("");
+            for (auto& err : errorMsgB) {
+                if (!err.empty())
+                    resErr += err + "\n";
+                resErr.copy(resp->msg, sizeof(resp->msg) - 1);
             }
+            return returnCode;
+        }
 
-            const auto float_inf = std::numeric_limits<float>::infinity();
-            size_t work_amount = actualLogitLen - decodedTargetLen + 1lu;
-            std::vector<float> sumPerThread(parallel_get_max_threads(), -float_inf);
+        const size_t TC = maxTime * classesNum;
 
-            // Looking for aligned paths
-            auto thread_body = [&](const int ithr, const int nthr) {
-                size_t start0(0lu), end0(0lu);
-                splitter(work_amount, nthr, ithr, start0, end0);
-                if (start0 >= end0)
-                    return;
-                if (ithr >= sumPerThread.size())
-                    sumPerThread.push_back(-float_inf);
-
-                std::function<void(size_t, size_t, size_t, float)> findPaths =
-                        [&](size_t targetIdx, size_t start, size_t end, float prevLogProb) {
-                    if (end > actualLogitLen) {
-                        if (sumPerThread[ithr] == -float_inf) {
-                            sumPerThread[ithr] = prevLogProb;
-                        } else if (prevLogProb != -float_inf) {
-                            if (sumPerThread[ithr] > prevLogProb)
-                                sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(prevLogProb - sumPerThread[ithr]));
-                            else
-                                sumPerThread[ithr] = prevLogProb + std::log1pf(std::exp(sumPerThread[ithr] - prevLogProb));
-                        }
-                        return;
+        auto threadBody_2 = [&](const int ithr, const int nthr) {
+            size_t start(0lu), end(0lu);
+            size_t sB(0lu), sT(0lu);
+            splitter(workAmount2, nthr, ithr, start, end);
+            if (start >= end)
+                return;
+            int64_t cw = 0, st = start;
+            for (; sB < batchNum; sB++) {
+                cw += logitsLength[sB];
+                if (cw >= st) {
+                    sT = logitsLength[sB] + st - cw;
+                    break;
+                }
+            }
+            size_t workCounter = start;
+
+            for (size_t b = sB; b < batchNum; b++) {
+                const size_t actualLogitLen = logitsLength[b];
+                const size_t decodedTargetLen = decodedTargetLenB[b];
+                auto& logProbabilities = logProbabilitiesB[b];
+                auto& targetD = targetDB[b];
+
+                double expSum = 0.0;
+                size_t btcT = b * TC + sT * classesNum;
+                // logProbabilities = logSoftmax = logits[b][t][c] - ln(sum_c(exp(logits[b][t])))
+                for (size_t t = sT; t < actualLogitLen; t++) {
+                    expSum = 0.0;
+                    for (size_t c = 0lu; c < classesNum; c++) {
+                        expSum += std::exp(logits[btcT + c]);
                     }
-
-                    size_t nextIdx = targetIdx + 1;
-                    int64_t st64 = start;
-                    float newLogProb = prevLogProb;
-                    if (!_ctcMergeRepeated) {
-                        for (size_t pos = start; pos < end; pos++) {
-                            newLogProb = prevLogProb;
-                            for (size_t bl = start; bl < pos; bl++) {
-                                auto lnProbIt = logProbabilities[bl].find(blankIndex);
-                                if (lnProbIt != logProbabilities[bl].end())
-                                    newLogProb += lnProbIt->second;
-                            }
-                            auto lnProbIt = logProbabilities[pos].find(targetD[targetIdx]);
-                            if (lnProbIt != logProbabilities[pos].end())
-                                newLogProb += lnProbIt->second;
-                            if (end == actualLogitLen) {
-                                for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
-                                    auto lnProbIt = logProbabilities[ble].find(blankIndex);
-                                    if (lnProbIt != logProbabilities[ble].end())
-                                        newLogProb += lnProbIt->second;
-                                }
-                            }
-                            findPaths(nextIdx, pos + 1, end + 1, newLogProb);
-                        }
-                    } else {
-                        for (size_t pos = start; pos < end; pos++) {
-                            newLogProb = prevLogProb;
-                            size_t next_start = pos + 1;
-                            for (size_t bl = start; bl < pos; bl++) {
-                                auto lnProbIt = logProbabilities[bl].find(blankIndex);
-                                if (lnProbIt != logProbabilities[bl].end())
-                                    newLogProb += lnProbIt->second;
-                            }
-                            if (end == actualLogitLen) {
-                                for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
-                                    auto lnProbIt = logProbabilities[ble].find(blankIndex);
-                                    if (lnProbIt != logProbabilities[ble].end())
-                                        newLogProb += lnProbIt->second;
-                                }
-                            }
-                            if (targetIdx < decodedTargetLen - 1
-                                   && targetD[targetIdx] == targetD[targetIdx + 1]) {
-                                auto lnProbIt = logProbabilities[next_start++].find(blankIndex);
-                                if (lnProbIt != logProbabilities[next_start].end())
-                                    newLogProb += lnProbIt->second;
-                            }
-                            for (int64_t bl = pos; bl >= st64; bl--) {
-                                newLogProb += logProbabilities[bl].find(targetD[targetIdx])->second;
-                                findPaths(nextIdx, next_start, end + 1, newLogProb);
-                                if (bl > 0) {
-                                    auto lnProbIt = logProbabilities[bl - 1].find(blankIndex);
-                                    if (lnProbIt != logProbabilities[bl - 1].end())
-                                        newLogProb -= lnProbIt->second;
-                                }
-                            }
-                        }
+                    for (size_t s = 0lu; s < decodedTargetLen; s++) {
+                        logProbabilities[t][s] = logits[btcT + targetD[s]] - std::log(expSum);
                     }
-                }; // findPaths
-
-                // First tartget symbol
-                int64_t st64 = start0;
-                float newLogProb = 0.f;
-                if (!_ctcMergeRepeated) {
-                    for (size_t pos = start0; pos < end0; pos++) {
-                        newLogProb = 0.f;
-                        for (size_t bl = 0; bl < pos; bl++) {
-                            auto lnProbIt = logProbabilities[bl].find(blankIndex);
-                            if (lnProbIt != logProbabilities[bl].end())
-                                newLogProb += lnProbIt->second;
-                        }
-                        auto lnProbIt = logProbabilities[pos].find(targetD[0]);
-                        if (lnProbIt != logProbabilities[pos].end())
-                            newLogProb += lnProbIt->second;
-                        if (work_amount == actualLogitLen) {
-                            for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
-                                auto lnProbIt = logProbabilities[ble].find(blankIndex);
-                                if (lnProbIt != logProbabilities[ble].end())
-                                    newLogProb += lnProbIt->second;
-                            }
-                        }
-                        if (decodedTargetLen > 1) {
-                            findPaths(1, pos + 1, work_amount + 1, newLogProb);
-                        } else {
-                            if (sumPerThread[ithr] == -float_inf)
-                                sumPerThread[ithr] = newLogProb;
-                            else if (newLogProb != -float_inf)
-                                sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(newLogProb - sumPerThread[ithr]));
-                        }
+                    btcT += classesNum;
+                    if (++workCounter >= end) {
+                        return;
                     }
-                } else {
-                    for (size_t pos = start0; pos < end0; pos++) {
-                        newLogProb = 0.f;
-                        size_t next_start = pos + 1;
-                        for (size_t bl = 0; bl < pos; bl++) {
-                            auto lnProbIt = logProbabilities[bl].find(blankIndex);
-                            if (lnProbIt != logProbabilities[bl].end())
-                                newLogProb += lnProbIt->second;
-                        }
-                        if (work_amount == actualLogitLen) {
-                            for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
-                                auto lnProbIt = logProbabilities[ble].find(blankIndex);
-                                if (lnProbIt != logProbabilities[ble].end())
-                                    newLogProb += lnProbIt->second;
-                            }
+                }
+                sT = 0lu;
+            }  // for batch
+        }; // threadBody_2
+
+        parallel_nt(0, threadBody_2);
+
+        const auto float_inf = std::numeric_limits<float>::infinity();
+
+        auto sumLogs = [&float_inf](float log1, float log2) {
+            if (log1 == -float_inf) {
+                return log2;
+            } else if (log2 == -float_inf) {
+                return log1;
+            } else {
+                if (log1 > log2)
+                    return log1 + std::log1pf(std::exp(log2 - log1));
+                else
+                    return log2 + std::log1pf(std::exp(log1 - log2));
+            }
+        };
+
+        auto threadBody_3 = [&](const int ithr, const int nthr) {
+            size_t start(0lu), end(0lu);
+            splitter(batchNum, nthr, ithr, start, end);
+            if (start >= end)
+                return;
+
+            // As per Connectionist Temporal Classification - Labeling Unsegmented Sequence Data with Recurrent Neural Networks:
+            // Graves et al., 2016, paragraph 4.1 (10)
+            for (size_t b = start; b < end; b++) {
+                auto& targetD = targetDB[b];
+                auto& logProbabilities = logProbabilitiesB[b];
+                const int actualLogitLen = logitsLength[b];
+                const int decodedTargetLen = decodedTargetLenB[b];
+                std::vector<std::vector<float>> logBwd(decodedTargetLen, std::vector<float>(actualLogitLen, -float_inf));
+                for (int s = decodedTargetLen - 2; s < decodedTargetLen; s++)
+                    logBwd[s][actualLogitLen - 1] = 0.f;
+
+                for (int t = actualLogitLen - 2; t >= 0; t--) {
+                    const int t_1 = t + 1;
+                    for (int s = std::max(0, decodedTargetLen - (2 * (actualLogitLen - t)));
+                            s < std::min(decodedTargetLen, 2 * (t_1)); s++) {
+                        if (_ctcMergeRepeated || targetD[s] == blankIndex) {
+                            logBwd[s][t] = sumLogs(logBwd[s][t],
+                                logBwd[s][t_1] + logProbabilities[t_1][s]);
                         }
-                        if (decodedTargetLen > 1
-                               && targetD[0] == targetD[1]) {
-                            auto lnProbIt = logProbabilities[next_start++].find(blankIndex);
-                            if (lnProbIt != logProbabilities[next_start].end())
-                                newLogProb += lnProbIt->second;
+
+                        if (s + 1 < decodedTargetLen) {
+                            logBwd[s][t] = sumLogs(logBwd[s][t],
+                                logBwd[s + 1][t_1] + logProbabilities[t_1][s + 1]);
                         }
-                        for (int64_t bl = pos; bl >= 0; bl--) {
-                            auto lnProbIt = logProbabilities[bl].find(targetD[0]);
-                            if (lnProbIt != logProbabilities[bl].end())
-                                newLogProb += lnProbIt->second;
-                            if (decodedTargetLen > 1) {
-                                findPaths(1, next_start, work_amount + 1, newLogProb);
-                            } else {
-                                if (sumPerThread[ithr] == -float_inf)
-                                    sumPerThread[ithr] = newLogProb;
-                                else if (newLogProb != -float_inf)
-                                    sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(newLogProb - sumPerThread[ithr]));
-                            }
-                            if (bl > 0) {
-                                auto lnProbIt = logProbabilities[bl - 1].find(blankIndex);
-                                if (lnProbIt != logProbabilities[bl - 1].end())
-                                    newLogProb -= lnProbIt->second;
+
+                        if (s + 2 < decodedTargetLen) {
+                            if (targetD[s] != blankIndex && (!_ctcMergeRepeated || (targetD[s] != targetD[s + 2]))) {
+                                logBwd[s][t] = sumLogs(logBwd[s][t],
+                                    logBwd[s + 2][t_1] + logProbabilities[t_1][s + 2]);
                             }
                         }
                     }
                 }
-            }; // thread_body
-
-            parallel_nt(0, thread_body);
 
-            float res = -float_inf;
+                logBwd[0][0] += logProbabilities[0][0];
+                logBwd[1][0] += logProbabilities[0][(decodedTargetLen > 1) ? 1 : 0];
 
-            for (auto sum : sumPerThread) {
-                if (res == -float_inf) {
-                    res = sum;
-                } else if (sum != -float_inf) {
-                    if (res > sum)
-                        res = res + std::log1pf(std::exp(sum - res));
-                    else
-                        res = sum + std::log1pf(std::exp(res - sum));
-                }
-            }
+                dstData[b] = -sumLogs(logBwd[0][0], logBwd[1][0]);
+            } // for batch
+        }; // threadBody_3
 
-            dstData[b] = -res;
-        } // for (size_t b = 0; b < batchNum; b++)
+        parallel_nt(0, threadBody_3);
 
-        return OK;
+        return returnCode;
     } // execute
 
 protected:
@@ -334,8 +289,6 @@ protected:
 };
 
 REG_FACTORY_FOR(CTCLossImpl, CTCLoss);
-
 }  // namespace Cpu
 }  // namespace Extensions
 }  // namespace InferenceEngine
-