Merge pull request #19484 from UnaNancyOwen:fix_highlevelapi
authorTsukasa Sugiura <t.sugiura0204@gmail.com>
Wed, 10 Feb 2021 19:42:00 +0000 (04:42 +0900)
committerGitHub <noreply@github.com>
Wed, 10 Feb 2021 19:42:00 +0000 (19:42 +0000)
* [dnn] fix high level api for python

* [dnn] add test_textdetection_model_db

* [dnn] fix textdetection test only check type and shape

modules/dnn/include/opencv2/dnn/dnn.hpp
modules/dnn/misc/python/test/test_dnn.py

index 3ece129..7722494 100644 (file)
@@ -1216,7 +1216,7 @@ CV__DNN_INLINE_NS_BEGIN
       * KeypointsModel creates net from file with trained weights and config,
       * sets preprocessing input, runs forward pass and returns the x and y coordinates of each detected keypoint
       */
-     class CV_EXPORTS_W KeypointsModel: public Model
+     class CV_EXPORTS_W_SIMPLE KeypointsModel: public Model
      {
      public:
          /**
@@ -1248,7 +1248,7 @@ CV__DNN_INLINE_NS_BEGIN
       * SegmentationModel creates net from file with trained weights and config,
       * sets preprocessing input, runs forward pass and returns the class prediction for each pixel.
       */
-     class CV_EXPORTS_W SegmentationModel: public Model
+     class CV_EXPORTS_W_SIMPLE SegmentationModel: public Model
      {
      public:
          /**
@@ -1406,7 +1406,7 @@ public:
 
 /** @brief Base class for text detection networks
  */
-class CV_EXPORTS_W TextDetectionModel : public Model
+class CV_EXPORTS_W_SIMPLE TextDetectionModel : public Model
 {
 protected:
     CV_DEPRECATED_EXTERNAL  // avoid using in C++ code, will be moved to "protected" (need to fix bindings first)
index 746dabf..d0687ca 100644 (file)
@@ -197,6 +197,25 @@ class dnn_test(NewOpenCVTests):
         normAssert(self, out, ref)
 
 
+    def test_textdetection_model(self):
+        img_path = self.find_dnn_file("dnn/text_det_test1.png")
+        weights = self.find_dnn_file("dnn/onnx/models/DB_TD500_resnet50.onnx", required=False)
+        if weights is None:
+            raise unittest.SkipTest("Missing DNN test files (onnx/models/DB_TD500_resnet50.onnx). Verify OPENCV_DNN_TEST_DATA_PATH configuration parameter.")
+
+        frame = cv.imread(img_path)
+        scale = 1.0 / 255.0
+        size = (736, 736)
+        mean = (122.67891434, 116.66876762, 104.00698793)
+
+        model = cv.dnn_TextDetectionModel_DB(weights)
+        model.setInputParams(scale, size, mean)
+        out, _ = model.detect(frame)
+
+        self.assertTrue(type(out) == list)
+        self.assertTrue(np.array(out).shape == (2, 4, 2))
+
+
     def test_face_detection(self):
         proto = self.find_dnn_file('dnn/opencv_face_detector.prototxt')
         model = self.find_dnn_file('dnn/opencv_face_detector.caffemodel', required=False)