Add text recognition example
authorAleksandr Pertovskiy <rng.mlcv@gmail.com>
Wed, 6 May 2020 12:23:55 +0000 (15:23 +0300)
committerAleksandr Pertovskiy <rng.mlcv@gmail.com>
Wed, 6 May 2020 12:26:17 +0000 (15:26 +0300)
samples/dnn/text_detection.cpp

index e7b0f23..706e2fe 100644 (file)
@@ -1,3 +1,20 @@
+/*
+    Text detection model: https://github.com/argman/EAST
+    Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
+
+    Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
+    How to convert from pb to onnx:
+    Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
+
+    import torch
+    import models.crnn as crnn
+
+    model = CRNN(32, 1, 37, 256)
+    model.load_state_dict(torch.load('crnn.pth'))
+    dummy_input = torch.randn(1, 1, 32, 100)
+    torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
+*/
+
 #include <opencv2/imgproc.hpp>
 #include <opencv2/highgui.hpp>
 #include <opencv2/dnn.hpp>
@@ -8,21 +25,26 @@ using namespace cv::dnn;
 const char* keys =
     "{ help  h     | | Print help message. }"
     "{ input i     | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
-    "{ model m     | | Path to a binary .pb file contains trained network.}"
+    "{ model m     | | Path to a binary .pb file contains trained detector network.}"
+    "{ ocr         | | Path to a binary .pb or .onnx file contains trained recognition network.}"
     "{ width       | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }"
     "{ height      | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }"
     "{ thr         | 0.5 | Confidence threshold. }"
     "{ nms         | 0.4 | Non-maximum suppression threshold. }";
 
-void decode(const Mat& scores, const Mat& geometry, float scoreThresh,
-            std::vector<RotatedRect>& detections, std::vector<float>& confidences);
+void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh,
+                         std::vector<RotatedRect>& detections, std::vector<float>& confidences);
+
+void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result);
+
+void decodeText(const Mat& scores, std::string& text);
 
 int main(int argc, char** argv)
 {
     // Parse command line arguments.
     CommandLineParser parser(argc, argv, keys);
     parser.about("Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
-                  "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)");
+                 "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)");
     if (argc == 1 || parser.has("help"))
     {
         parser.printMessage();
@@ -33,7 +55,8 @@ int main(int argc, char** argv)
     float nmsThreshold = parser.get<float>("nms");
     int inpWidth = parser.get<int>("width");
     int inpHeight = parser.get<int>("height");
-    String model = parser.get<String>("model");
+    String modelDecoder = parser.get<String>("model");
+    String modelRecognition = parser.get<String>("ocr");
 
     if (!parser.check())
     {
@@ -41,17 +64,19 @@ int main(int argc, char** argv)
         return 1;
     }
 
-    CV_Assert(!model.empty());
+    CV_Assert(!modelDecoder.empty());
+
+    // Load networks.
+    Net detector = readNet(modelDecoder);
+    Net recognizer;
 
-    // Load network.
-    Net net = readNet(model);
+    if (!modelRecognition.empty())
+        recognizer = readNet(modelRecognition);
 
     // Open a video file or an image file or a camera stream.
     VideoCapture cap;
-    if (parser.has("input"))
-        cap.open(parser.get<String>("input"));
-    else
-        cap.open(0);
+    bool openSuccess = parser.has("input") ? cap.open(parser.get<String>("input")) : cap.open(0);
+    CV_Assert(openSuccess);
 
     static const std::string kWinName = "EAST: An Efficient and Accurate Scene Text Detector";
     namedWindow(kWinName, WINDOW_NORMAL);
@@ -62,6 +87,7 @@ int main(int argc, char** argv)
     outNames[1] = "feature_fusion/concat_3";
 
     Mat frame, blob;
+    TickMeter tickMeter;
     while (waitKey(1) < 0)
     {
         cap >> frame;
@@ -72,8 +98,10 @@ int main(int argc, char** argv)
         }
 
         blobFromImage(frame, blob, 1.0, Size(inpWidth, inpHeight), Scalar(123.68, 116.78, 103.94), true, false);
-        net.setInput(blob);
-        net.forward(outs, outNames);
+        detector.setInput(blob);
+        tickMeter.start();
+        detector.forward(outs, outNames);
+        tickMeter.stop();
 
         Mat scores = outs[0];
         Mat geometry = outs[1];
@@ -81,43 +109,64 @@ int main(int argc, char** argv)
         // Decode predicted bounding boxes.
         std::vector<RotatedRect> boxes;
         std::vector<float> confidences;
-        decode(scores, geometry, confThreshold, boxes, confidences);
+        decodeBoundingBoxes(scores, geometry, confThreshold, boxes, confidences);
 
         // Apply non-maximum suppression procedure.
         std::vector<int> indices;
         NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);
 
-        // Render detections.
         Point2f ratio((float)frame.cols / inpWidth, (float)frame.rows / inpHeight);
+
+        // Render text.
         for (size_t i = 0; i < indices.size(); ++i)
         {
             RotatedRect& box = boxes[indices[i]];
 
             Point2f vertices[4];
             box.points(vertices);
+
             for (int j = 0; j < 4; ++j)
             {
                 vertices[j].x *= ratio.x;
                 vertices[j].y *= ratio.y;
             }
+
+            if (!modelRecognition.empty())
+            {
+                Mat cropped;
+                fourPointsTransform(frame, vertices, cropped);
+
+                cvtColor(cropped, cropped, cv::COLOR_BGR2GRAY);
+
+                Mat blobCrop = blobFromImage(cropped, 1.0/127.5, Size(), Scalar::all(127.5));
+                recognizer.setInput(blobCrop);
+
+                tickMeter.start();
+                Mat result = recognizer.forward();
+                tickMeter.stop();
+
+                std::string wordRecognized = "";
+                decodeText(result, wordRecognized);
+                putText(frame, wordRecognized, vertices[1], FONT_HERSHEY_SIMPLEX, 1.5, Scalar(0, 0, 255));
+            }
+
             for (int j = 0; j < 4; ++j)
                 line(frame, vertices[j], vertices[(j + 1) % 4], Scalar(0, 255, 0), 1);
         }
 
         // Put efficiency information.
-        std::vector<double> layersTimes;
-        double freq = getTickFrequency() / 1000;
-        double t = net.getPerfProfile(layersTimes) / freq;
-        std::string label = format("Inference time: %.2f ms", t);
+        std::string label = format("Inference time: %.2f ms", tickMeter.getTimeMilli());
         putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
 
         imshow(kWinName, frame);
+
+        tickMeter.reset();
     }
     return 0;
 }
 
-void decode(const Mat& scores, const Mat& geometry, float scoreThresh,
-            std::vector<RotatedRect>& detections, std::vector<float>& confidences)
+void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh,
+                         std::vector<RotatedRect>& detections, std::vector<float>& confidences)
 {
     detections.clear();
     CV_Assert(scores.dims == 4); CV_Assert(geometry.dims == 4); CV_Assert(scores.size[0] == 1);
@@ -159,3 +208,51 @@ void decode(const Mat& scores, const Mat& geometry, float scoreThresh,
         }
     }
 }
+
+void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result)
+{
+    const Size outputSize = Size(100, 32);
+
+    Point2f targetVertices[4] = {Point(0, outputSize.height - 1),
+                                  Point(0, 0), Point(outputSize.width - 1, 0),
+                                  Point(outputSize.width - 1, outputSize.height - 1),
+                                  };
+    Mat rotationMatrix = getPerspectiveTransform(vertices, targetVertices);
+
+    warpPerspective(frame, result, rotationMatrix, outputSize);
+}
+
+void decodeText(const Mat& scores, std::string& text)
+{
+    static const std::string alphabet = "0123456789abcdefghijklmnopqrstuvwxyz";
+    Mat scoresMat = scores.reshape(1, scores.size[0]);
+
+    std::vector<char> elements;
+    elements.reserve(scores.size[0]);
+
+    for (int rowIndex = 0; rowIndex < scoresMat.rows; ++rowIndex)
+    {
+        Point p;
+        minMaxLoc(scoresMat.row(rowIndex), 0, 0, 0, &p);
+        if (p.x > 0 && static_cast<size_t>(p.x) <= alphabet.size())
+        {
+            elements.push_back(alphabet[p.x - 1]);
+        }
+        else
+        {
+            elements.push_back('-');
+        }
+    }
+
+    if (elements.size() > 0 && elements[0] != '-')
+        text += elements[0];
+
+    for (size_t elementIndex = 1; elementIndex < elements.size(); ++elementIndex)
+    {
+        if (elementIndex > 0 && elements[elementIndex] != '-' &&
+            elements[elementIndex - 1] != elements[elementIndex])
+        {
+            text += elements[elementIndex];
+        }
+    }
+}
\ No newline at end of file