LSTM layer for TensorFlow importer
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 25 Aug 2017 11:45:03 +0000 (14:45 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 26 Sep 2017 09:59:36 +0000 (12:59 +0300)
modules/dnn/include/opencv2/dnn/all_layers.hpp
modules/dnn/src/init.cpp
modules/dnn/src/layers/recurrent_layers.cpp
modules/dnn/src/layers/slice_layer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_layers.cpp
modules/dnn/test/test_tf_importer.cpp

index cf47c70..dc070fd 100644 (file)
@@ -84,7 +84,9 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
         /** Creates instance of LSTM layer */
         static Ptr<LSTMLayer> create(const LayerParams& params);
 
-        /** Set trained weights for LSTM layer.
+        /** @deprecated Use LayerParams::blobs instead.
+        @brief Set trained weights for LSTM layer.
+
         LSTM behavior on each step is defined by current input, previous output, previous cell state and learned weights.
 
         Let @f$x_t@f$ be current input, @f$h_t@f$ be current output, @f$c_t@f$ be current state.
@@ -114,7 +116,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
         @param Wx is matrix defining how current input is transformed to internal gates (i.e. according to abovemtioned notation is @f$ W_x @f$)
         @param b  is bias vector (i.e. according to abovemtioned notation is @f$ b @f$)
         */
-        virtual void setWeights(const Mat &Wh, const Mat &Wx, const Mat &b) = 0;
+        CV_DEPRECATED virtual void setWeights(const Mat &Wh, const Mat &Wx, const Mat &b) = 0;
 
         /** @brief Specifies shape of output blob which will be [[`T`], `N`] + @p outTailShape.
           * @details If this parameter is empty or unset then @p outTailShape = [`Wh`.size(0)] will be used,
@@ -122,7 +124,8 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
           */
         virtual void setOutShape(const MatShape &outTailShape = MatShape()) = 0;
 
-        /** @brief Specifies either interpet first dimension of input blob as timestamp dimenion either as sample.
+        /** @deprecated Use flag `produce_cell_output` in LayerParams.
+          * @brief Specifies either interpet first dimension of input blob as timestamp dimenion either as sample.
           *
           * If flag is set to true then shape of input blob will be interpeted as [`T`, `N`, `[data dims]`] where `T` specifies number of timpestamps, `N` is number of independent streams.
           * In this case each forward() call will iterate through `T` timestamps and update layer's state `T` times.
@@ -130,12 +133,13 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
           * If flag is set to false then shape of input blob will be interpeted as [`N`, `[data dims]`].
           * In this case each forward() call will make one iteration and produce one timestamp with shape [`N`, `[out dims]`].
           */
-        virtual void setUseTimstampsDim(bool use = true) = 0;
+        CV_DEPRECATED virtual void setUseTimstampsDim(bool use = true) = 0;
 
-        /** @brief If this flag is set to true then layer will produce @f$ c_t @f$ as second output.
+        /** @deprecated Use flag `use_timestamp_dim` in LayerParams.
+         * @brief If this flag is set to true then layer will produce @f$ c_t @f$ as second output.
          * @details Shape of the second output is the same as first output.
          */
-        virtual void setProduceCellOutput(bool produce = false) = 0;
+        CV_DEPRECATED virtual void setProduceCellOutput(bool produce = false) = 0;
 
         /* In common case it use single input with @f$x_t@f$ values to compute output(s) @f$h_t@f$ (and @f$c_t@f$).
          * @param input should contain packed values @f$x_t@f$
@@ -322,11 +326,41 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
         static Ptr<SplitLayer> create(const LayerParams &params);
     };
 
+    /**
+     * Slice layer has several modes:
+     * 1. Caffe mode
+     * @param[in] axis Axis of split operation
+     * @param[in] slice_point Array of split points
+     *
+     * Number of output blobs equals to number of split points plus one. The
+     * first blob is a slice on input from 0 to @p slice_point[0] - 1 by @p axis,
+     * the second output blob is a slice of input from @p slice_point[0] to
+     * @p slice_point[1] - 1 by @p axis and the last output blob is a slice of
+     * input from @p slice_point[-1] up to the end of @p axis size.
+     *
+     * 2. TensorFlow mode
+     * @param begin Vector of start indices
+     * @param size Vector of sizes
+     *
+     * More convinient numpy-like slice. One and only output blob
+     * is a slice `input[begin[0]:begin[0]+size[0], begin[1]:begin[1]+size[1], ...]`
+     *
+     * 3. Torch mode
+     * @param axis Axis of split operation
+     *
+     * Split input blob on the equal parts by @p axis.
+     */
     class CV_EXPORTS SliceLayer : public Layer
     {
     public:
+        /**
+         * @brief Vector of slice ranges.
+         *
+         * The first dimension equals number of output blobs.
+         * Inner vector has slice ranges for the first number of input dimensions.
+         */
+        std::vector<std::vector<Range> > sliceRanges;
         int axis;
-        std::vector<int> sliceIndices;
 
         static Ptr<SliceLayer> create(const LayerParams &params);
     };
index fe1036c..64e1155 100644 (file)
@@ -117,6 +117,8 @@ void initializeLayerFactory()
     CV_DNN_REGISTER_LAYER_CLASS(Shift,          ShiftLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Padding,        PaddingLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Scale,          ScaleLayer);
+
+    CV_DNN_REGISTER_LAYER_CLASS(LSTM,           LSTMLayer);
 }
 
 CV__DNN_EXPERIMENTAL_NS_END
index 10a6f74..a40bcc6 100644 (file)
@@ -90,6 +90,8 @@ class LSTMLayerImpl : public LSTMLayer
 
     bool useTimestampDim;
     bool produceCellOutput;
+    float forgetBias, cellClip;
+    bool useCellClip, usePeephole;
 
 public:
 
@@ -97,9 +99,40 @@ public:
         : numTimeStamps(0), numSamples(0)
     {
         setParamsFrom(params);
-        type = "LSTM";
-        useTimestampDim = true;
-        produceCellOutput = false;
+
+        if (!blobs.empty())
+        {
+            CV_Assert(blobs.size() >= 3);
+
+            blobs[2] = blobs[2].reshape(1, 1);
+
+            const Mat& Wh = blobs[0];
+            const Mat& Wx = blobs[1];
+            const Mat& bias = blobs[2];
+            CV_Assert(Wh.dims == 2 && Wx.dims == 2);
+            CV_Assert(Wh.rows == Wx.rows);
+            CV_Assert(Wh.rows == 4*Wh.cols);
+            CV_Assert(Wh.rows == (int)bias.total());
+            CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
+
+            // Peephole weights.
+            if (blobs.size() > 3)
+            {
+                CV_Assert(blobs.size() == 6);
+                for (int i = 3; i < 6; ++i)
+                {
+                    CV_Assert(blobs[i].rows == Wh.cols && blobs[i].cols == Wh.cols);
+                    CV_Assert(blobs[i].type() == bias.type());
+                }
+            }
+        }
+        useTimestampDim = params.get<bool>("use_timestamp_dim", true);
+        produceCellOutput = params.get<bool>("produce_cell_output", false);
+        forgetBias = params.get<float>("forget_bias", 0.0f);
+        cellClip = params.get<float>("cell_clip", 0.0f);
+        useCellClip = params.get<bool>("use_cell_clip", false);
+        usePeephole = params.get<bool>("use_peephole", false);
+
         allocated = false;
         outTailShape.clear();
     }
@@ -141,7 +174,7 @@ public:
                          std::vector<MatShape> &outputs,
                          std::vector<MatShape> &internals) const
     {
-        CV_Assert(blobs.size() == 3);
+        CV_Assert(!usePeephole && blobs.size() == 3 || usePeephole && blobs.size() == 6);
         CV_Assert(inputs.size() == 1);
         const MatShape& inp0 = inputs[0];
 
@@ -186,7 +219,7 @@ public:
 
     void finalize(const std::vector<Mat*> &input, std::vector<Mat> &output)
     {
-        CV_Assert(blobs.size() == 3);
+        CV_Assert(!usePeephole && blobs.size() == 3 || usePeephole && blobs.size() == 6);
         CV_Assert(input.size() == 1);
         const Mat& inp0 = *input[0];
 
@@ -251,13 +284,27 @@ public:
             gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T);  //+Wh * h_{t-1}
             gemm(dummyOnes, bias, 1, gates, 1, gates);          //+b
 
-            Mat getesIFO = gates.colRange(0, 3*numOut);
             Mat gateI = gates.colRange(0*numOut, 1*numOut);
             Mat gateF = gates.colRange(1*numOut, 2*numOut);
             Mat gateO = gates.colRange(2*numOut, 3*numOut);
             Mat gateG = gates.colRange(3*numOut, 4*numOut);
 
-            sigmoid(getesIFO, getesIFO);
+            if (forgetBias)
+                add(gateF, forgetBias, gateF);
+
+            if (usePeephole)
+            {
+                Mat gatesIF = gates.colRange(0, 2*numOut);
+                gemm(cInternal, blobs[3], 1, gateI, 1, gateI);
+                gemm(cInternal, blobs[4], 1, gateF, 1, gateF);
+                sigmoid(gatesIF, gatesIF);
+            }
+            else
+            {
+                Mat gatesIFO = gates.colRange(0, 3*numOut);
+                sigmoid(gatesIFO, gatesIFO);
+            }
+
             tanh(gateG, gateG);
 
             //compute c_t
@@ -265,6 +312,17 @@ public:
             multiply(gateI, gateG, gateI);      // i_t (*) g_t
             add(gateF, gateI, cInternal);       // c_t = f_t (*) c_{t-1} + i_t (*) g_t
 
+            if (useCellClip)
+            {
+                min(cInternal, cellClip, cInternal);
+                max(cInternal, -cellClip, cInternal);
+            }
+            if (usePeephole)
+            {
+                gemm(cInternal, blobs[5], 1, gateO, 1, gateO);
+                sigmoid(gateO, gateO);
+            }
+
             //compute h_t
             tanh(cInternal, hInternal);
             multiply(gateO, hInternal, hInternal);
index 86313e3..b824a06 100644 (file)
@@ -56,14 +56,40 @@ public:
     {
         setParamsFrom(params);
         axis = params.get<int>("axis", 1);
-
         if (params.has("slice_point"))
         {
+            CV_Assert(!params.has("begin") && !params.has("size"));
             const DictValue &indicesValue = params.get("slice_point");
-            int i, n = indicesValue.size();
-            sliceIndices.resize(n);
-            for (i = 0; i < n; i++)
-                sliceIndices[i] = indicesValue.get<int>(i);
+            sliceRanges.resize(indicesValue.size() + 1,
+                               std::vector<Range>(axis + 1, Range::all()));
+            int prevSlice = 0;
+            for (int i = 0; i < indicesValue.size(); ++i)
+            {
+                sliceRanges[i][axis].start = prevSlice;
+                sliceRanges[i][axis].end = indicesValue.get<int>(i);
+                prevSlice = sliceRanges[i][axis].end;
+            }
+            sliceRanges.back()[axis].start = prevSlice;
+        }
+        else if (params.has("begin") && params.has("size"))
+        {
+            const DictValue &begins = params.get("begin");
+            const DictValue &sizes = params.get("size");
+            CV_Assert(begins.size() == sizes.size());
+
+            sliceRanges.resize(1);
+            sliceRanges[0].resize(begins.size(), Range::all());
+            for (int i = 0; i < begins.size(); ++i)
+            {
+                int start = begins.get<int>(i);
+                int size = sizes.get<int>(i);
+                CV_Assert(start >= 0);
+                CV_Assert(size == -1 || size > 0);  // -1 value means range [start, axis_size).
+
+                sliceRanges[0][i].start = start;
+                if (size > 0)
+                    sliceRanges[0][i].end = start + size;
+            }
         }
     }
 
@@ -73,47 +99,68 @@ public:
                             std::vector<MatShape> &internals) const
     {
         CV_Assert(inputs.size() == 1);
-
-        outputs.clear();
-
         MatShape inpShape = inputs[0];
-        int cAxis = clamp(axis, inpShape.size());
-        int axisSize = inpShape[cAxis];
 
-        if (sliceIndices.size()) //divide blob with respect to passed parameters
+        if (!sliceRanges.empty())
         {
-           std::vector<int> outAxisSize;
-           int prevSlice = 0;
-
-           for (size_t i = 0; i < sliceIndices.size(); i++)
-           {
-               if (!(prevSlice < sliceIndices[i] && sliceIndices[i] < axisSize))
-                   CV_Error(Error::StsBadArg, "Slice indices should be positive, increased and don't exceed size of sliced dimension");
-
-               outAxisSize.push_back(sliceIndices[i] - prevSlice);
-               prevSlice = sliceIndices[i];
-            }
-            outAxisSize.push_back(axisSize - prevSlice);
-
-            for (size_t i = 0; i < outAxisSize.size(); i++)
+            outputs.resize(sliceRanges.size(), inpShape);
+            for (int i = 0; i < outputs.size(); ++i)
             {
-               inpShape[cAxis] = outAxisSize[i];
-              outputs.push_back(inpShape);
+                CV_Assert(sliceRanges[i].size() <= inpShape.size());
+                for (int j = 0; j < sliceRanges[i].size(); ++j)
+                {
+                    outputs[i][j] = std::min(sliceRanges[i][j].end, inpShape[j]) -
+                                    std::max(sliceRanges[i][j].start, 0);
+                }
             }
         }
-        else //divide blob with respect to count of output blobs
+        else  // Divide input blob on equal parts by axis.
         {
-           CV_Assert(requiredOutputs > 0 && axisSize % requiredOutputs == 0);
-           int outAxisSize = axisSize / (int)requiredOutputs;
+            CV_Assert(0 < axis && axis < inpShape.size());
+            CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0);
+            inpShape[axis] /= requiredOutputs;
+            outputs.resize(requiredOutputs, inpShape);
+        }
+        return false;
+    }
 
-           for (size_t i = 0; i < requiredOutputs; i++)
+    void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
+    {
+        CV_Assert(inputs.size() == 1);
+        const MatSize& inpShape = inputs[0]->size;
+
+        if (sliceRanges.empty())
+        {
+            // Divide input blob on equal parts by axis.
+            int outAxisSize = inpShape[axis] / outputs.size();
+            sliceRanges.resize(outputs.size(),
+                               std::vector<Range>(axis + 1, Range::all()));
+            int prevSlice = 0;
+            for (int i = 0; i < outputs.size(); ++i)
             {
-               inpShape[cAxis] = outAxisSize;
-               outputs.push_back(inpShape);
+                sliceRanges[i][axis].start = prevSlice;
+                sliceRanges[i][axis].end = sliceRanges[i][axis].start + outAxisSize;
+                prevSlice = sliceRanges[i][axis].end;
             }
         }
+        else
+            CV_Assert(outputs.size() == sliceRanges.size());
 
-        return false;
+        for (int i = 0; i < outputs.size(); ++i)
+        {
+            CV_Assert(sliceRanges[i].size() <= inpShape[-1]);
+            // Clamp.
+            for (int j = 0; j < sliceRanges[i].size(); ++j)
+            {
+                sliceRanges[i][j].start = std::max(0, sliceRanges[i][j].start);
+                sliceRanges[i][j].end = std::min(sliceRanges[i][j].end, inpShape[j]);
+            }
+            // Fill the rest of ranges.
+            for (int j = sliceRanges[i].size(); j < inpShape[-1]; ++j)
+            {
+                sliceRanges[i].push_back(Range::all());
+            }
+        }
     }
 
     void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
@@ -122,15 +169,10 @@ public:
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
         const Mat& inpMat = *inputs[0];
-        std::vector<Range> ranges(inpMat.dims, Range::all());
-        int cAxis = clamp(axis, inpMat.dims);
-
-        ranges[cAxis].start = 0;
+        CV_Assert(outputs.size() == sliceRanges.size());
         for (size_t i = 0; i < outputs.size(); i++)
         {
-            ranges[cAxis].end = ranges[cAxis].start + outputs[i].size[cAxis];
-            inpMat(&ranges[0]).copyTo(outputs[i]);
-            ranges[cAxis].start = ranges[cAxis].end;
+            inpMat(sliceRanges[i]).copyTo(outputs[i]);
         }
     }
 };
index 67565cc..7e78ed5 100644 (file)
@@ -877,6 +877,34 @@ void TFImporter::populateNet(Net dstNet)
             // one input only
             connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0);
         }
+        else if (type == "Slice")
+        {
+            // op: "Slice"
+            // input: "input_node"
+            // input: "Slice/begin"
+            // input: "Slice/size"
+            CV_Assert(layer.input_size() == 3);
+
+            const tensorflow::TensorProto begins = getConstBlob(layer, value_id, 1);
+            const tensorflow::TensorProto sizes = getConstBlob(layer, value_id, 2);
+            std::string beginsData = begins.tensor_content();
+            std::string sizesData = sizes.tensor_content();
+            CV_Assert(begins.dtype() == tensorflow::DT_INT32);
+            CV_Assert(sizes.dtype() == tensorflow::DT_INT32);
+            CV_Assert(!beginsData.empty());
+            CV_Assert(!sizesData.empty());
+            CV_Assert(beginsData.size() == sizesData.size());
+
+            layerParams.set("begin", DictValue::arrayInt((int*)beginsData.c_str(),
+                                                         beginsData.size() / 4));
+            layerParams.set("size", DictValue::arrayInt((int*)sizesData.c_str(),
+                                                        sizesData.size() / 4));
+
+            int id = dstNet.addLayer(name, "Slice", layerParams);
+            layer_id[name] = id;
+
+            connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+        }
         else if (type == "Mul")
         {
             bool haveConst = false;
@@ -1055,6 +1083,82 @@ void TFImporter::populateNet(Net dstNet)
             // one input only
             connect(layer_id, dstNet, parsePin(layer.input(2)), id, 0);
         }
+        else if (type == "BlockLSTM")
+        {
+            // op: "BlockLSTM"
+            // input: "lstm_block_wrapper/ToInt64/x"  (ignore, number of time stamps)
+            // input: "input"
+            // input: "lstm_block_wrapper/zeros"      (ignore)
+            // input: "lstm_block_wrapper/zeros"      (ignore)
+            // input: "lstm_block_wrapper/kernel"
+            // input: "lstm_block_wrapper/w_i_diag"
+            // input: "lstm_block_wrapper/w_f_diag"
+            // input: "lstm_block_wrapper/w_o_diag"
+            // input: "lstm_block_wrapper/bias"
+            if (layer.input_size() != 9)
+                CV_Error(Error::StsNotImplemented, "Unexpected number of input nodes");
+
+            if (hasLayerAttr(layer, "forget_bias"))
+                layerParams.set("forget_bias", getLayerAttr(layer, "forget_bias").f());
+
+            if (hasLayerAttr(layer, "forget_bias"))
+            {
+                float cellClip = getLayerAttr(layer, "cell_clip").f();
+                // Cell clip disabled if it's negative.
+                if (cellClip >= 0)
+                {
+                    layerParams.set("use_cell_clip", true);
+                    layerParams.set("cell_clip", cellClip);
+                }
+            }
+
+            Mat W, Wh, Wx, b;
+            blobFromTensor(getConstBlob(layer, value_id, 4), W);
+            blobFromTensor(getConstBlob(layer, value_id, 8), b);
+            const int outSize = W.cols / 4;
+
+            // IGFO->IFOG
+            float* weightData = (float*)W.data;
+            for (int i = 0; i < W.rows; ++i)
+                for (int j = 0; j < outSize; ++j)
+                {
+                    std::swap(weightData[i * W.cols + 1 * outSize + j],
+                              weightData[i * W.cols + 2 * outSize + j]);
+                    std::swap(weightData[i * W.cols + 2 * outSize + j],
+                              weightData[i * W.cols + 3 * outSize + j]);
+                }
+            Wx = W.rowRange(0, W.rows - outSize).t();
+            Wh = W.rowRange(W.rows - outSize, W.rows).t();
+
+            layerParams.blobs.resize(3);
+            layerParams.blobs[0] = Wh;
+            layerParams.blobs[1] = Wx;
+            layerParams.blobs[2] = b;
+
+            if (hasLayerAttr(layer, "use_peephole"))
+            {
+                bool usePeephole = getLayerAttr(layer, "use_peephole").b();
+                if (usePeephole)
+                {
+                    layerParams.set("use_peephole", true);
+                    layerParams.blobs.resize(6);
+                    for (int i = 0; i < 3; ++i)
+                    {
+                        Mat w;
+                        blobFromTensor(getConstBlob(layer, value_id, 5 + i), w);
+                        w = w.reshape(1, w.total());  // Single column.
+                        w = Mat::diag(w);  // Make a diagonal matrix.
+                        layerParams.blobs[3 + i] = w;
+                    }
+                }
+            }
+
+            int id = dstNet.addLayer(name, "LSTM", layerParams);
+            layer_id[name] = id;
+
+            // one input only
+            connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0);
+        }
         else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
                  type == "Relu" || type == "Elu" || type == "Softmax" ||
                  type == "Identity" || type == "Relu6")
index e6807b3..e3aeb7e 100644 (file)
@@ -289,7 +289,8 @@ public:
 
     Layer_LSTM_Test() {}
 
-    void init(const MatShape &inpShape_, const MatShape &outShape_)
+    void init(const MatShape &inpShape_, const MatShape &outShape_,
+              bool produceCellOutput, bool useTimestampDim)
     {
         numInp = total(inpShape_);
         numOut = total(outShape_);
@@ -298,8 +299,15 @@ public:
         Wx = Mat::ones(4 * numOut, numInp, CV_32F);
         b  = Mat::ones(4 * numOut, 1, CV_32F);
 
-        layer = LSTMLayer::create(LayerParams());
-        layer->setWeights(Wh, Wx, b);
+        LayerParams lp;
+        lp.blobs.resize(3);
+        lp.blobs[0] = Wh;
+        lp.blobs[1] = Wx;
+        lp.blobs[2] = b;
+        lp.set<bool>("produce_cell_output", produceCellOutput);
+        lp.set<bool>("use_timestamp_dim", useTimestampDim);
+
+        layer = LSTMLayer::create(lp);
         layer->setOutShape(outShape_);
     }
 };
@@ -312,9 +320,7 @@ TEST_F(Layer_LSTM_Test, get_set_test)
     MatShape inpResShape = concat(shape(TN), inpShape);
     MatShape outResShape = concat(shape(TN), outShape);
 
-    init(inpShape, outShape);
-    layer->setProduceCellOutput(true);
-    layer->setUseTimstampsDim(false);
+    init(inpShape, outShape, true, false);
     layer->setOutShape(outShape);
 
     Mat C((int)outResShape.size(), &outResShape[0], CV_32F);
@@ -344,12 +350,12 @@ TEST_F(Layer_LSTM_Test, get_set_test)
 
 TEST(Layer_LSTM_Test_Accuracy_with_, CaffeRecurrent)
 {
-    Ptr<LSTMLayer> layer = LSTMLayer::create(LayerParams());
-
-    Mat Wx = blobFromNPY(_tf("lstm.prototxt.w_0.npy"));
-    Mat Wh = blobFromNPY(_tf("lstm.prototxt.w_2.npy"));
-    Mat b  = blobFromNPY(_tf("lstm.prototxt.w_1.npy"));
-    layer->setWeights(Wh, Wx, b);
+    LayerParams lp;
+    lp.blobs.resize(3);
+    lp.blobs[0] = blobFromNPY(_tf("lstm.prototxt.w_2.npy"));  // Wh
+    lp.blobs[1] = blobFromNPY(_tf("lstm.prototxt.w_0.npy"));  // Wx
+    lp.blobs[2] = blobFromNPY(_tf("lstm.prototxt.w_1.npy"));  // bias
+    Ptr<LSTMLayer> layer = LSTMLayer::create(lp);
 
     Mat inp = blobFromNPY(_tf("recurrent.input.npy"));
     std::vector<Mat> inputs(1, inp), outputs;
index 57227ff..09a990f 100644 (file)
@@ -2,7 +2,7 @@
 // It is subject to the license terms in the LICENSE file found in the top-level directory
 // of this distribution and at http://opencv.org/license.html.
 
-// Copyright (C) 2016, Intel Corporation, all rights reserved.
+// Copyright (C) 2017, Intel Corporation, all rights reserved.
 // Third party copyrights are property of their respective owners.
 
 /*
@@ -146,6 +146,7 @@ TEST(Test_TensorFlow, defun)
 TEST(Test_TensorFlow, reshape)
 {
     runTensorFlowNet("shift_reshape_no_reorder");
+    runTensorFlowNet("reshape_reduce");
 }
 
 TEST(Test_TensorFlow, fp16)
@@ -163,4 +164,9 @@ TEST(Test_TensorFlow, fp16)
     runTensorFlowNet("fp16_padding_same", l1, lInf);
 }
 
+TEST(Test_TensorFlow, lstm)
+{
+    runTensorFlowNet("lstm");
+}
+
 }