Merge pull request #16955 from themechanicalcoder:text_recognition
authorGourav Roy <34737471+themechanicalcoder@users.noreply.github.com>
Wed, 10 Jun 2020 06:53:18 +0000 (12:23 +0530)
committerGitHub <noreply@github.com>
Wed, 10 Jun 2020 06:53:18 +0000 (06:53 +0000)
* add text recognition sample

* fix pylint warning

* made changes according to the c++ example

* fix errors

* add text recognition sample

* update text detection sample

samples/dnn/text_detection.py

index 9ea4c10..7014a80 100644 (file)
@@ -1,25 +1,81 @@
+'''
+    Text detection model: https://github.com/argman/EAST
+    Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
+    Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
+    How to convert from pb to onnx:
+    Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
+    import torch
+    import models.crnn as CRNN
+    model = CRNN(32, 1, 37, 256)
+    model.load_state_dict(torch.load('crnn.pth'))
+    dummy_input = torch.randn(1, 1, 32, 100)
+    torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
+'''
+
+
 # Import required modules
+import numpy as np
 import cv2 as cv
 import math
 import argparse
 
 ############ Add argument parser for command line arguments ############
-parser = argparse.ArgumentParser(description='Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)')
-parser.add_argument('--input', help='Path to input image or video file. Skip this argument to capture frames from a camera.')
-parser.add_argument('--model', required=True,
-                    help='Path to a binary .pb file of model contains trained weights.')
+parser = argparse.ArgumentParser(
+    description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
+                "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
+                "The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch")
+parser.add_argument('--input',
+                    help='Path to input image or video file. Skip this argument to capture frames from a camera.')
+parser.add_argument('--model', '-m', required=True,
+                    help='Path to a binary .pb file contains trained detector network.')
+parser.add_argument('--ocr', default="crnn.onnx",
+                    help="Path to a binary .pb or .onnx file contains trained recognition network", )
 parser.add_argument('--width', type=int, default=320,
                     help='Preprocess input image by resizing to a specific width. It should be multiple by 32.')
-parser.add_argument('--height',type=int, default=320,
+parser.add_argument('--height', type=int, default=320,
                     help='Preprocess input image by resizing to a specific height. It should be multiple by 32.')
-parser.add_argument('--thr',type=float, default=0.5,
+parser.add_argument('--thr', type=float, default=0.5,
                     help='Confidence threshold.')
-parser.add_argument('--nms',type=float, default=0.4,
+parser.add_argument('--nms', type=float, default=0.4,
                     help='Non-maximum suppression threshold.')
 args = parser.parse_args()
 
+
 ############ Utility functions ############
-def decode(scores, geometry, scoreThresh):
+
+def fourPointsTransform(frame, vertices):
+    vertices = np.asarray(vertices)
+    outputSize = (100, 32)
+    targetVertices = np.array([
+        [0, outputSize[1] - 1],
+        [0, 0],
+        [outputSize[0] - 1, 0],
+        [outputSize[0] - 1, outputSize[1] - 1]], dtype="float32")
+
+    rotationMatrix = cv.getPerspectiveTransform(vertices, targetVertices)
+    result = cv.warpPerspective(frame, rotationMatrix, outputSize)
+    return result
+
+
+def decodeText(scores):
+    text = ""
+    alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
+    for i in range(scores.shape[0]):
+        c = np.argmax(scores[i][0])
+        if c != 0:
+            text += alphabet[c - 1]
+        else:
+            text += '-'
+
+    # adjacent same letters as well as background text must be removed to get the final output
+    char_list = []
+    for i in range(len(text)):
+        if text[i] != '-' and (not (i > 0 and text[i] == text[i - 1])):
+            char_list.append(text[i])
+    return ''.join(char_list)
+
+
+def decodeBoundingBoxes(scores, geometry, scoreThresh):
     detections = []
     confidences = []
 
@@ -47,7 +103,7 @@ def decode(scores, geometry, scoreThresh):
             score = scoresData[x]
 
             # If score is lower than threshold score, move to next x
-            if(score < scoreThresh):
+            if (score < scoreThresh):
                 continue
 
             # Calculate offset
@@ -66,24 +122,27 @@ def decode(scores, geometry, scoreThresh):
 
             # Find points for rectangle
             p1 = (-sinA * h + offset[0], -cosA * h + offset[1])
-            p3 = (-cosA * w + offset[0],  sinA * w + offset[1])
-            center = (0.5*(p1[0]+p3[0]), 0.5*(p1[1]+p3[1]))
-            detections.append((center, (w,h), -1*angle * 180.0 / math.pi))
+            p3 = (-cosA * w + offset[0], sinA * w + offset[1])
+            center = (0.5 * (p1[0] + p3[0]), 0.5 * (p1[1] + p3[1]))
+            detections.append((center, (w, h), -1 * angle * 180.0 / math.pi))
             confidences.append(float(score))
 
     # Return detections and confidences
     return [detections, confidences]
 
+
 def main():
     # Read and store arguments
     confThreshold = args.thr
     nmsThreshold = args.nms
     inpWidth = args.width
     inpHeight = args.height
-    model = args.model
+    modelDetector = args.model
+    modelRecognition = args.ocr
 
     # Load network
-    net = cv.dnn.readNet(model)
+    detector = cv.dnn.readNet(modelDetector)
+    recognizer = cv.dnn.readNet(modelRecognition)
 
     # Create a new named window
     kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
@@ -95,6 +154,7 @@ def main():
     # Open a video file or an image file or a camera stream
     cap = cv.VideoCapture(args.input if args.input else 0)
 
+    tickmeter = cv.TickMeter()
     while cv.waitKey(1) < 0:
         # Read frame
         hasFrame, frame = cap.read()
@@ -111,19 +171,20 @@ def main():
         # Create a 4D blob from frame.
         blob = cv.dnn.blobFromImage(frame, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False)
 
-        # Run the model
-        net.setInput(blob)
-        outs = net.forward(outNames)
-        t, _ = net.getPerfProfile()
-        label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency())
+        # Run the detection model
+        detector.setInput(blob)
+
+        tickmeter.start()
+        outs = detector.forward(outNames)
+        tickmeter.stop()
 
         # Get scores and geometry
         scores = outs[0]
         geometry = outs[1]
-        [boxes, confidences] = decode(scores, geometry, confThreshold)
+        [boxes, confidences] = decodeBoundingBoxes(scores, geometry, confThreshold)
 
         # Apply NMS
-        indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold,nmsThreshold)
+        indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold, nmsThreshold)
         for i in indices:
             # get 4 corners of the rotated rect
             vertices = cv.boxPoints(boxes[i[0]])
@@ -131,16 +192,40 @@ def main():
             for j in range(4):
                 vertices[j][0] *= rW
                 vertices[j][1] *= rH
+
+
+            # get cropped image using perspective transform
+            if modelRecognition:
+                cropped = fourPointsTransform(frame, vertices)
+                cropped = cv.cvtColor(cropped, cv.COLOR_BGR2GRAY)
+
+                # Create a 4D blob from cropped image
+                blob = cv.dnn.blobFromImage(cropped, size=(100, 32), mean=127.5, scalefactor=1 / 127.5)
+                recognizer.setInput(blob)
+
+                # Run the recognition model
+                tickmeter.start()
+                result = recognizer.forward()
+                tickmeter.stop()
+
+                # decode the result into text
+                wordRecognized = decodeText(result)
+                cv.putText(frame, wordRecognized, (int(vertices[1][0]), int(vertices[1][1])), cv.FONT_HERSHEY_SIMPLEX,
+                           0.5, (255, 0, 0))
+
             for j in range(4):
                 p1 = (vertices[j][0], vertices[j][1])
                 p2 = (vertices[(j + 1) % 4][0], vertices[(j + 1) % 4][1])
                 cv.line(frame, p1, p2, (0, 255, 0), 1)
 
         # Put efficiency information
+        label = 'Inference time: %.2f ms' % (tickmeter.getTimeMilli())
         cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
 
         # Display the frame
-        cv.imshow(kWinName,frame)
+        cv.imshow(kWinName, frame)
+        tickmeter.reset()
+
 
 if __name__ == "__main__":
     main()