From 8433620295891c184ce4edd86bbd5ad6440eda45 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Sun, 22 Mar 2020 00:20:36 +0300 Subject: [PATCH] Bidirectional LSTM --- modules/dnn/src/layers/recurrent_layers.cpp | 162 +++++++++++++++------------- modules/dnn/src/onnx/onnx_importer.cpp | 43 ++++---- modules/dnn/test/test_onnx_importer.cpp | 5 + 3 files changed, 116 insertions(+), 94 deletions(-) diff --git a/modules/dnn/src/layers/recurrent_layers.cpp b/modules/dnn/src/layers/recurrent_layers.cpp index 26d2ea9..69606a6 100644 --- a/modules/dnn/src/layers/recurrent_layers.cpp +++ b/modules/dnn/src/layers/recurrent_layers.cpp @@ -93,6 +93,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer float forgetBias, cellClip; bool useCellClip, usePeephole; bool reverse; // If true, go in negative direction along the time axis + bool bidirectional; // If true, produces both forward and reversed directions along time axis public: @@ -101,6 +102,7 @@ public: { setParamsFrom(params); + bidirectional = params.get("bidirectional", false); if (!blobs.empty()) { CV_Assert(blobs.size() >= 3); @@ -113,7 +115,7 @@ public: CV_CheckEQ(Wh.dims, 2, ""); CV_CheckEQ(Wx.dims, 2, ""); CV_CheckEQ(Wh.rows, Wx.rows, ""); - CV_CheckEQ(Wh.rows, 4*Wh.cols, ""); + CV_CheckEQ(Wh.rows, (1 + static_cast(bidirectional))*4*Wh.cols, ""); CV_CheckEQ(Wh.rows, (int)bias.total(), ""); CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type()); @@ -136,6 +138,7 @@ public: useCellClip = params.get("use_cell_clip", false); usePeephole = params.get("use_peephole", false); reverse = params.get("reverse", false); + CV_Assert(!reverse || !bidirectional); allocated = false; outTailShape.clear(); @@ -207,6 +210,7 @@ public: outResShape.push_back(_numSamples); outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end()); + outResShape.back() *= (1 + static_cast(bidirectional)); size_t noutputs = produceCellOutput ? 2 : 1; outputs.assign(noutputs, outResShape); @@ -253,6 +257,7 @@ public: outTsShape.clear(); outTsShape.push_back(numSamples); outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end()); + outTsShape.back() *= (1 + static_cast(bidirectional)); allocated = true; } @@ -273,91 +278,96 @@ public: outputs_arr.getMatVector(output); internals_arr.getMatVector(internals); - const Mat &Wh = blobs[0]; - const Mat &Wx = blobs[1]; - const Mat &bias = blobs[2]; - - int numOut = Wh.size[1]; - - Mat hInternal = internals[0], cInternal = internals[1], - dummyOnes = internals[2], gates = internals[3]; - hInternal.setTo(0.); - cInternal.setTo(0.); - dummyOnes.setTo(1.); - - int numSamplesTotal = numTimeStamps*numSamples; - Mat xTs = input[0].reshape(1, numSamplesTotal); - - Mat hOutTs = output[0].reshape(1, numSamplesTotal); - Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat(); - - int tsStart, tsEnd, tsInc; - if (reverse) { - tsStart = numTimeStamps - 1; - tsEnd = -1; - tsInc = -1; - } - else { - tsStart = 0; - tsEnd = numTimeStamps; - tsInc = 1; - } - for (int ts = tsStart; ts != tsEnd; ts += tsInc) + const int numDirs = 1 + static_cast(bidirectional); + for (int i = 0; i < numDirs; ++i) { - Range curRowRange(ts*numSamples, (ts + 1)*numSamples); - Mat xCurr = xTs.rowRange(curRowRange); + const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs); + const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs); + const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs); + + int numOut = Wh.size[1]; + + Mat hInternal = internals[0], cInternal = internals[1], + dummyOnes = internals[2], gates = internals[3]; + hInternal.setTo(0.); + cInternal.setTo(0.); + dummyOnes.setTo(1.); + + int numSamplesTotal = numTimeStamps*numSamples; + Mat xTs = input[0].reshape(1, numSamplesTotal); + + Mat hOutTs = output[0].reshape(1, numSamplesTotal); + hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs); + Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat(); + + int tsStart, tsEnd, tsInc; + if (reverse || i == 1) { + tsStart = numTimeStamps - 1; + tsEnd = -1; + tsInc = -1; + } + else { + tsStart = 0; + tsEnd = numTimeStamps; + tsInc = 1; + } + for (int ts = tsStart; ts != tsEnd; ts += tsInc) + { + Range curRowRange(ts*numSamples, (ts + 1)*numSamples); + Mat xCurr = xTs.rowRange(curRowRange); - gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T); // Wx * x_t - gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1} - gemm(dummyOnes, bias, 1, gates, 1, gates); //+b + gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T); // Wx * x_t + gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1} + gemm(dummyOnes, bias, 1, gates, 1, gates); //+b - Mat gateI = gates.colRange(0*numOut, 1*numOut); - Mat gateF = gates.colRange(1*numOut, 2*numOut); - Mat gateO = gates.colRange(2*numOut, 3*numOut); - Mat gateG = gates.colRange(3*numOut, 4*numOut); + Mat gateI = gates.colRange(0*numOut, 1*numOut); + Mat gateF = gates.colRange(1*numOut, 2*numOut); + Mat gateO = gates.colRange(2*numOut, 3*numOut); + Mat gateG = gates.colRange(3*numOut, 4*numOut); - if (forgetBias) - add(gateF, forgetBias, gateF); + if (forgetBias) + add(gateF, forgetBias, gateF); - if (usePeephole) - { - Mat gatesIF = gates.colRange(0, 2*numOut); - gemm(cInternal, blobs[3], 1, gateI, 1, gateI); - gemm(cInternal, blobs[4], 1, gateF, 1, gateF); - sigmoid(gatesIF, gatesIF); - } - else - { - Mat gatesIFO = gates.colRange(0, 3*numOut); - sigmoid(gatesIFO, gatesIFO); - } + if (usePeephole) + { + Mat gatesIF = gates.colRange(0, 2*numOut); + gemm(cInternal, blobs[3], 1, gateI, 1, gateI); + gemm(cInternal, blobs[4], 1, gateF, 1, gateF); + sigmoid(gatesIF, gatesIF); + } + else + { + Mat gatesIFO = gates.colRange(0, 3*numOut); + sigmoid(gatesIFO, gatesIFO); + } - tanh(gateG, gateG); + tanh(gateG, gateG); - //compute c_t - multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1} - multiply(gateI, gateG, gateI); // i_t (*) g_t - add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t + //compute c_t + multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1} + multiply(gateI, gateG, gateI); // i_t (*) g_t + add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t - if (useCellClip) - { - min(cInternal, cellClip, cInternal); - max(cInternal, -cellClip, cInternal); - } - if (usePeephole) - { - gemm(cInternal, blobs[5], 1, gateO, 1, gateO); - sigmoid(gateO, gateO); - } + if (useCellClip) + { + min(cInternal, cellClip, cInternal); + max(cInternal, -cellClip, cInternal); + } + if (usePeephole) + { + gemm(cInternal, blobs[5], 1, gateO, 1, gateO); + sigmoid(gateO, gateO); + } - //compute h_t - tanh(cInternal, hInternal); - multiply(gateO, hInternal, hInternal); + //compute h_t + tanh(cInternal, hInternal); + multiply(gateO, hInternal, hInternal); - //save results in output blobs - hInternal.copyTo(hOutTs.rowRange(curRowRange)); - if (produceCellOutput) - cInternal.copyTo(cOutTs.rowRange(curRowRange)); + //save results in output blobs + hInternal.copyTo(hOutTs.rowRange(curRowRange)); + if (produceCellOutput) + cInternal.copyTo(cOutTs.rowRange(curRowRange)); + } } } }; diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index b243a98..79386e6 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -630,37 +630,44 @@ void ONNXImporter::populateNet(Net dstNet) Mat Wx = getBlob(node_proto, constBlobs, 1); Mat Wh = getBlob(node_proto, constBlobs, 2); Mat b = getBlob(node_proto, constBlobs, 3); + b = b.reshape(1, b.size[0]); const int numHidden = lstmParams.get("hidden_size"); - - Wx = Wx.reshape(1, Wx.size[1]); - Wh = Wh.reshape(1, Wh.size[1]); - b = b.reshape(1, 2); - reduce(b, b, 0, REDUCE_SUM); + const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM. + const int numFeatures = Wx.size[2]; + Mat bx = b.colRange(0, b.cols / 2); + Mat bh = b.colRange(b.cols / 2, b.cols); + b = bx + bh; // IFGO->IGFO - float* WxData = (float*)Wx.data; - float* WhData = (float*)Wh.data; - float* biasData = (float*)b.data; - for (int j = 0; j < numHidden; ++j) + for (int k = 0; k < numDirs; ++k) { - for (int i = 0; i < Wx.cols; ++i) - { - std::swap(WxData[(numHidden + j) * Wx.cols + i], - WxData[(numHidden * 2 + j) * Wx.cols + i]); - } - for (int i = 0; i < Wh.cols; ++i) + float* WxData = Wx.ptr(k); + float* WhData = Wh.ptr(k); + float* biasData = b.ptr(k); + for (int j = 0; j < numHidden; ++j) { - std::swap(WhData[(numHidden + j) * Wh.cols + i], - WhData[(numHidden * 2 + j) * Wh.cols + i]); + for (int i = 0; i < numFeatures; ++i) + { + std::swap(WxData[(numHidden + j) * numFeatures + i], + WxData[(numHidden * 2 + j) * numFeatures + i]); + } + for (int i = 0; i < numHidden; ++i) + { + std::swap(WhData[(numHidden + j) * numHidden + i], + WhData[(numHidden * 2 + j) * numHidden + i]); + } + std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]); } - std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]); } + Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]); + Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]); lstmParams.blobs.resize(3); lstmParams.blobs[0] = Wh; lstmParams.blobs[1] = Wx; lstmParams.blobs[2] = b; + lstmParams.set("bidirectional", lstmParams.get("direction", "") == "bidirectional"); node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name addLayer(dstNet, lstmParams, node_proto, layer_id, outShapes); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index a2cd2c3..f741319 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -456,6 +456,11 @@ TEST_P(Test_ONNX_layers, LSTM) testONNXModels("lstm"); } +TEST_P(Test_ONNX_layers, LSTM_bidirectional) +{ + testONNXModels("lstm_bidirectional"); +} + INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); class Test_ONNX_nets : public Test_ONNX_layers -- 2.7.4