Merge pull request #15082 from dvd42:segmentation-module
authorDiego <diegovd0296@gmail.com>
Tue, 13 Aug 2019 20:38:48 +0000 (22:38 +0200)
committerAlexander Alekhin <alexander.a.alekhin@gmail.com>
Tue, 13 Aug 2019 20:38:48 +0000 (23:38 +0300)
Segmentation module (#15082)

modules/dnn/include/opencv2/dnn/dnn.hpp
modules/dnn/src/model.cpp
modules/dnn/test/test_model.cpp

index 22eba0d..db6ef94 100644 (file)
@@ -1109,6 +1109,36 @@ CV__DNN_INLINE_NS_BEGIN
          CV_WRAP void classify(InputArray frame, CV_OUT int& classId, CV_OUT float& conf);
      };
 
+     /** @brief This class represents high-level API for segmentation  models
+      *
+      * SegmentationModel allows to set params for preprocessing input image.
+      * SegmentationModel creates net from file with trained weights and config,
+      * sets preprocessing input, runs forward pass and returns the class prediction for each pixel.
+      */
+     class CV_EXPORTS_W SegmentationModel: public Model
+     {
+     public:
+         /**
+          * @brief Create segmentation model from network represented in one of the supported formats.
+          * An order of @p model and @p config arguments does not matter.
+          * @param[in] model Binary file contains trained weights.
+          * @param[in] config Text file contains network configuration.
+          */
+          CV_WRAP SegmentationModel(const String& model, const String& config = "");
+
+         /**
+          * @brief Create model from deep learning network.
+          * @param[in] network Net object.
+          */
+         CV_WRAP SegmentationModel(const Net& network);
+
+         /** @brief Given the @p input frame, create input blob, run net
+          *  @param[in]  frame  The input image.
+          *  @param[out] mask Allocated class prediction for each pixel
+          */
+         CV_WRAP void segment(InputArray frame, OutputArray mask);
+     };
+
      /** @brief This class represents high-level API for object detection networks.
       *
       * DetectionModel allows to set params for preprocessing input image.
index 55a9cda..07965c4 100644 (file)
@@ -137,6 +137,47 @@ void ClassificationModel::classify(InputArray frame, int& classId, float& conf)
     std::tie(classId, conf) = classify(frame);
 }
 
+SegmentationModel::SegmentationModel(const String& model, const String& config)
+    : Model(model, config) {};
+
+SegmentationModel::SegmentationModel(const Net& network) : Model(network) {};
+
+void SegmentationModel::segment(InputArray frame, OutputArray mask)
+{
+
+    std::vector<Mat> outs;
+    impl->predict(*this, frame.getMat(), outs);
+    CV_Assert(outs.size() == 1);
+    Mat score = outs[0];
+
+    const int chns = score.size[1];
+    const int rows = score.size[2];
+    const int cols = score.size[3];
+
+    mask.create(rows, cols, CV_8U);
+    Mat classIds = mask.getMat();
+    classIds.setTo(0);
+    Mat maxVal(rows, cols, CV_32F, score.data);
+
+    for (int ch = 1; ch < chns; ch++)
+    {
+        for (int row = 0; row < rows; row++)
+        {
+            const float *ptrScore = score.ptr<float>(0, ch, row);
+            uint8_t *ptrMaxCl = classIds.ptr<uint8_t>(row);
+            float *ptrMaxVal = maxVal.ptr<float>(row);
+            for (int col = 0; col < cols; col++)
+            {
+                if (ptrScore[col] > ptrMaxVal[col])
+                {
+                    ptrMaxVal[col] = ptrScore[col];
+                    ptrMaxCl[col] = ch;
+                }
+            }
+        }
+    }
+}
+
 DetectionModel::DetectionModel(const String& model, const String& config)
     : Model(model, config) {};
 
index 943665a..8a333d9 100644 (file)
@@ -69,6 +69,25 @@ public:
         EXPECT_EQ(prediction.first, ref.first);
         ASSERT_NEAR(prediction.second, ref.second, norm);
     }
+
+    void testSegmentationModel(const std::string& weights_file, const std::string& config_file,
+                               const std::string& inImgPath, const std::string& outImgPath,
+                               float norm, const Size& size = {-1, -1}, Scalar mean = Scalar(),
+                               double scale = 1.0, bool swapRB = false, bool crop = false)
+    {
+        checkBackend();
+
+        Mat frame = imread(inImgPath);
+        Mat mask;
+        Mat exp = imread(outImgPath, 0);
+
+        SegmentationModel model(weights_file, config_file);
+        model.setInputSize(size).setInputMean(mean).setInputScale(scale)
+             .setInputSwapRB(swapRB).setInputCrop(crop);
+
+        model.segment(frame, mask);
+        normAssert(mask, exp, "", norm, norm);
+    }
 };
 
 TEST_P(Test_Model, Classify)
@@ -202,6 +221,22 @@ TEST_P(Test_Model, DetectionMobilenetSSD)
                     scoreDiff, iouDiff, confThreshold, nmsThreshold, size, mean, scale);
 }
 
+TEST_P(Test_Model, Segmentation)
+{
+    std::string inp = _tf("dog416.png");
+    std::string weights_file = _tf("fcn8s-heavy-pascal.prototxt");
+    std::string config_file = _tf("fcn8s-heavy-pascal.caffemodel");
+    std::string exp = _tf("segmentation_exp.png");
+
+    Size size{128, 128};
+    float norm = 0;
+    double scale = 1.0;
+    Scalar mean = Scalar();
+    bool swapRB = false;
+
+    testSegmentationModel(weights_file, config_file, inp, exp, norm, size, mean, scale, swapRB);
+}
+
 INSTANTIATE_TEST_CASE_P(/**/, Test_Model, dnnBackendsAndTargets());
 
 }} // namespace