Include preprocessing nodes to object detection TensorFlow networks (#12211)
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 31 Aug 2018 12:41:56 +0000 (15:41 +0300)
committerVadim Pisarevsky <vadim.pisarevsky@gmail.com>
Fri, 31 Aug 2018 12:41:56 +0000 (15:41 +0300)
* Include preprocessing nodes to object detection TensorFlow networks

* Enable more fusion

* faster_rcnn_resnet50_coco_2018_01_28 test

modules/dnn/CMakeLists.txt
modules/dnn/src/dnn.cpp
modules/dnn/src/layers/convolution_layer.cpp
modules/dnn/src/layers/elementwise_layers.cpp
modules/dnn/test/test_backends.cpp
modules/dnn/test/test_caffe_importer.cpp
modules/dnn/test/test_tf_importer.cpp
modules/dnn/test/test_torch_importer.cpp
samples/dnn/tf_text_graph_faster_rcnn.py
samples/dnn/tf_text_graph_ssd.py

index 64fefb3..40b573f 100644 (file)
@@ -95,7 +95,7 @@ ocv_glob_module_sources(${sources_options} SOURCES ${fw_srcs})
 ocv_create_module(${libs} ${INF_ENGINE_TARGET})
 ocv_add_samples()
 ocv_add_accuracy_tests(${INF_ENGINE_TARGET})
-ocv_add_perf_tests()
+ocv_add_perf_tests(${INF_ENGINE_TARGET})
 
 ocv_option(${the_module}_PERF_CAFFE "Add performance tests of Caffe framework" OFF)
 ocv_option(${the_module}_PERF_CLCAFFE "Add performance tests of clCaffe framework" OFF)
index bc18695..214ac99 100644 (file)
@@ -1676,14 +1676,6 @@ struct Net::Impl
             // with the current layer if they follow it. Normally, the are fused with the convolution layer,
             // but some of them (like activation) may be fused with fully-connected, elemwise (+) and
             // some other layers.
-
-            // TODO: OpenCL target support more fusion styles.
-            if ( preferableBackend == DNN_BACKEND_OPENCV && IS_DNN_OPENCL_TARGET(preferableTarget) &&
-                 (!cv::ocl::useOpenCL() || (ld.layerInstance->type != "Convolution" &&
-                 ld.layerInstance->type != "MVN" && ld.layerInstance->type != "Pooling" &&
-                 ld.layerInstance->type != "Concat")) )
-                continue;
-
             Ptr<Layer>& currLayer = ld.layerInstance;
             if( ld.consumers.size() == 1 && pinsToKeep.count(LayerPin(lid, 0)) == 0 )
             {
@@ -1717,6 +1709,13 @@ struct Net::Impl
                 if (preferableBackend != DNN_BACKEND_OPENCV)
                     continue;  // Go to the next layer.
 
+                // TODO: OpenCL target support more fusion styles.
+                if ( preferableBackend == DNN_BACKEND_OPENCV && IS_DNN_OPENCL_TARGET(preferableTarget) &&
+                     (!cv::ocl::useOpenCL() || (ld.layerInstance->type != "Convolution" &&
+                     ld.layerInstance->type != "MVN" && ld.layerInstance->type != "Pooling" &&
+                     ld.layerInstance->type != "Concat")) )
+                    continue;
+
                 while (nextData)
                 {
                     // For now, OpenCL target support fusion with activation of ReLU/ChannelsPReLU/Power/Tanh
index 169e280..54b3245 100644 (file)
@@ -350,12 +350,14 @@ public:
         return false;
     }
 
-    void fuseWeights(const Mat& w, const Mat& b)
+    void fuseWeights(const Mat& w_, const Mat& b_)
     {
         // Convolution weights have OIHW data layout. Parameters fusion in case of
         // (conv(I) + b1 ) * w + b2
         // means to replace convolution's weights to [w*conv(I)] and bias to [b1 * w + b2]
         const int outCn = weightsMat.size[0];
+        Mat w = w_.total() == 1 ? Mat(1, outCn, CV_32F, Scalar(w_.at<float>(0))) : w_;
+        Mat b = b_.total() == 1 ? Mat(1, outCn, CV_32F, Scalar(b_.at<float>(0))) : b_;
         CV_Assert_N(!weightsMat.empty(), biasvec.size() == outCn + 2,
                     w.empty() || outCn == w.total(), b.empty() || outCn == b.total());
 
index 0a5ed54..74c89e6 100644 (file)
@@ -161,6 +161,16 @@ public:
         return Ptr<BackendNode>();
     }
 
+    virtual bool tryFuse(Ptr<dnn::Layer>& top) CV_OVERRIDE
+    {
+        return func.tryFuse(top);
+    }
+
+    void getScaleShift(Mat& scale_, Mat& shift_) const CV_OVERRIDE
+    {
+        func.getScaleShift(scale_, shift_);
+    }
+
     bool getMemoryShapes(const std::vector<MatShape> &inputs,
                          const int requiredOutputs,
                          std::vector<MatShape> &outputs,
@@ -343,6 +353,10 @@ struct ReLUFunctor
     }
 #endif  // HAVE_INF_ENGINE
 
+    bool tryFuse(Ptr<dnn::Layer>&) { return false; }
+
+    void getScaleShift(Mat&, Mat&) const {}
+
     int64 getFLOPSPerElement() const { return 1; }
 };
 
@@ -448,6 +462,10 @@ struct ReLU6Functor
     }
 #endif  // HAVE_INF_ENGINE
 
+    bool tryFuse(Ptr<dnn::Layer>&) { return false; }
+
+    void getScaleShift(Mat&, Mat&) const {}
+
     int64 getFLOPSPerElement() const { return 2; }
 };
 
@@ -518,6 +536,10 @@ struct TanHFunctor
     }
 #endif  // HAVE_INF_ENGINE
 
+    bool tryFuse(Ptr<dnn::Layer>&) { return false; }
+
+    void getScaleShift(Mat&, Mat&) const {}
+
     int64 getFLOPSPerElement() const { return 1; }
 };
 
@@ -588,6 +610,10 @@ struct SigmoidFunctor
     }
 #endif  // HAVE_INF_ENGINE
 
+    bool tryFuse(Ptr<dnn::Layer>&) { return false; }
+
+    void getScaleShift(Mat&, Mat&) const {}
+
     int64 getFLOPSPerElement() const { return 3; }
 };
 
@@ -659,6 +685,10 @@ struct ELUFunctor
     }
 #endif  // HAVE_INF_ENGINE
 
+    bool tryFuse(Ptr<dnn::Layer>&) { return false; }
+
+    void getScaleShift(Mat&, Mat&) const {}
+
     int64 getFLOPSPerElement() const { return 2; }
 };
 
@@ -727,6 +757,10 @@ struct AbsValFunctor
     }
 #endif  // HAVE_INF_ENGINE
 
+    bool tryFuse(Ptr<dnn::Layer>&) { return false; }
+
+    void getScaleShift(Mat&, Mat&) const {}
+
     int64 getFLOPSPerElement() const { return 1; }
 };
 
@@ -775,6 +809,10 @@ struct BNLLFunctor
     }
 #endif  // HAVE_INF_ENGINE
 
+    bool tryFuse(Ptr<dnn::Layer>&) { return false; }
+
+    void getScaleShift(Mat&, Mat&) const {}
+
     int64 getFLOPSPerElement() const { return 5; }
 };
 
@@ -875,15 +913,51 @@ struct PowerFunctor
 #ifdef HAVE_INF_ENGINE
     InferenceEngine::CNNLayerPtr initInfEngine(InferenceEngine::LayerParams& lp)
     {
-        lp.type = "Power";
-        std::shared_ptr<InferenceEngine::PowerLayer> ieLayer(new InferenceEngine::PowerLayer(lp));
-        ieLayer->power = power;
-        ieLayer->scale = scale;
-        ieLayer->offset = shift;
-        return ieLayer;
+        if (power == 1.0f && scale == 1.0f && shift == 0.0f)
+        {
+            // It looks like there is a bug in Inference Engine for DNN_TARGET_OPENCL and DNN_TARGET_OPENCL_FP16
+            // if power layer do nothing so we replace it to Identity.
+            lp.type = "Split";
+            return std::shared_ptr<InferenceEngine::SplitLayer>(new InferenceEngine::SplitLayer(lp));
+        }
+        else
+        {
+            lp.type = "Power";
+            std::shared_ptr<InferenceEngine::PowerLayer> ieLayer(new InferenceEngine::PowerLayer(lp));
+            ieLayer->power = power;
+            ieLayer->scale = scale;
+            ieLayer->offset = shift;
+            return ieLayer;
+        }
     }
 #endif  // HAVE_INF_ENGINE
 
+    bool tryFuse(Ptr<dnn::Layer>& top)
+    {
+        if (power != 1.0f && shift != 0.0f)
+            return false;
+
+        Mat w, b;
+        top->getScaleShift(w, b);
+        if ((w.empty() && b.empty()) || w.total() > 1 || b.total() > 1)
+            return false;
+
+        float nextScale = w.empty() ? 1.0f : w.at<float>(0);
+        float nextShift = b.empty() ? 0.0f : b.at<float>(0);
+        scale = std::pow(scale, power) * nextScale;
+        shift = nextScale * shift + nextShift;
+        return true;
+    }
+
+    void getScaleShift(Mat& _scale, Mat& _shift) const
+    {
+        if (power == 1.0f)
+        {
+            _scale = Mat(1, 1, CV_32F, Scalar(scale));
+            _shift = Mat(1, 1, CV_32F, Scalar(shift));
+        }
+    }
+
     int64 getFLOPSPerElement() const { return power == 1 ? 2 : 10; }
 };
 
@@ -989,6 +1063,10 @@ struct ChannelsPReLUFunctor
     }
 #endif  // HAVE_INF_ENGINE
 
+    bool tryFuse(Ptr<dnn::Layer>&) { return false; }
+
+    void getScaleShift(Mat&, Mat&) const {}
+
     int64 getFLOPSPerElement() const { return 1; }
 };
 
index 309f001..5ab7992 100644 (file)
@@ -161,7 +161,7 @@ TEST_P(DNNTestNetwork, MobileNet_SSD_v1_TensorFlow)
     if (backend == DNN_BACKEND_HALIDE)
         throw SkipTestException("");
     Mat sample = imread(findDataFile("dnn/street.png", false));
-    Mat inp = blobFromImage(sample, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), false);
+    Mat inp = blobFromImage(sample, 1.0f, Size(300, 300), Scalar(), false);
     float l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.011 : 0.0;
     float lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.06 : 0.0;
     processNet("dnn/ssd_mobilenet_v1_coco_2017_11_17.pb", "dnn/ssd_mobilenet_v1_coco_2017_11_17.pbtxt",
@@ -173,7 +173,7 @@ TEST_P(DNNTestNetwork, MobileNet_SSD_v2_TensorFlow)
     if (backend == DNN_BACKEND_HALIDE)
         throw SkipTestException("");
     Mat sample = imread(findDataFile("dnn/street.png", false));
-    Mat inp = blobFromImage(sample, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), false);
+    Mat inp = blobFromImage(sample, 1.0f, Size(300, 300), Scalar(), false);
     float l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.011 : 0.0;
     float lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.062 : 0.0;
     processNet("dnn/ssd_mobilenet_v2_coco_2018_03_29.pb", "dnn/ssd_mobilenet_v2_coco_2018_03_29.pbtxt",
@@ -247,8 +247,8 @@ TEST_P(DNNTestNetwork, Inception_v2_SSD_TensorFlow)
     if (backend == DNN_BACKEND_HALIDE)
         throw SkipTestException("");
     Mat sample = imread(findDataFile("dnn/street.png", false));
-    Mat inp = blobFromImage(sample, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), false);
-    float l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.008 : 0.0;
+    Mat inp = blobFromImage(sample, 1.0f, Size(300, 300), Scalar(), false);
+    float l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.015 : 0.0;
     float lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.0731 : 0.0;
     processNet("dnn/ssd_inception_v2_coco_2017_11_17.pb", "dnn/ssd_inception_v2_coco_2017_11_17.pbtxt",
                inp, "detection_out", "", l1, lInf);
index 4491fde..ff0bbb7 100644 (file)
@@ -417,7 +417,7 @@ TEST_P(Test_Caffe_nets, DenseNet_121)
     float l1 = default_l1, lInf = default_lInf;
     if (target == DNN_TARGET_OPENCL_FP16)
     {
-        l1 = 0.017; lInf = 0.067;
+        l1 = 0.017; lInf = 0.0795;
     }
     else if (target == DNN_TARGET_MYRIAD)
     {
index d95f6f5..c1a55cd 100644 (file)
@@ -296,7 +296,7 @@ TEST_P(Test_TensorFlow_nets, Inception_v2_SSD)
 
     Net net = readNetFromTensorflow(model, proto);
     Mat img = imread(findDataFile("dnn/street.png", false));
-    Mat blob = blobFromImage(img, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), true, false);
+    Mat blob = blobFromImage(img, 1.0f, Size(300, 300), Scalar(), true, false);
 
     net.setPreferableBackend(backend);
     net.setPreferableTarget(target);
@@ -310,32 +310,38 @@ TEST_P(Test_TensorFlow_nets, Inception_v2_SSD)
                                     0, 3, 0.75838411, 0.44668293, 0.45907149, 0.49459291, 0.52197015,
                                     0, 10, 0.95932811, 0.38349164, 0.32528657, 0.40387636, 0.39165527,
                                     0, 10, 0.93973452, 0.66561931, 0.37841269, 0.68074018, 0.42907384);
-    double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 5e-3 : default_l1;
+    double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.0097 : default_l1;
     double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.09 : default_lInf;
     normAssertDetections(ref, out, "", 0.5, scoreDiff, iouDiff);
 }
 
-TEST_P(Test_TensorFlow_nets, Inception_v2_Faster_RCNN)
+TEST_P(Test_TensorFlow_nets, Faster_RCNN)
 {
+    static std::string names[] = {"faster_rcnn_inception_v2_coco_2018_01_28",
+                                  "faster_rcnn_resnet50_coco_2018_01_28"};
+
     checkBackend();
     if ((backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU) ||
         (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16))
         throw SkipTestException("");
 
-    std::string proto = findDataFile("dnn/faster_rcnn_inception_v2_coco_2018_01_28.pbtxt", false);
-    std::string model = findDataFile("dnn/faster_rcnn_inception_v2_coco_2018_01_28.pb", false);
+    for (int i = 1; i < 2; ++i)
+    {
+        std::string proto = findDataFile("dnn/" + names[i] + ".pbtxt", false);
+        std::string model = findDataFile("dnn/" + names[i] + ".pb", false);
 
-    Net net = readNetFromTensorflow(model, proto);
-    net.setPreferableBackend(backend);
-    net.setPreferableTarget(target);
-    Mat img = imread(findDataFile("dnn/dog416.png", false));
-    Mat blob = blobFromImage(img, 1.0f / 127.5, Size(800, 600), Scalar(127.5, 127.5, 127.5), true, false);
+        Net net = readNetFromTensorflow(model, proto);
+        net.setPreferableBackend(backend);
+        net.setPreferableTarget(target);
+        Mat img = imread(findDataFile("dnn/dog416.png", false));
+        Mat blob = blobFromImage(img, 1.0f, Size(800, 600), Scalar(), true, false);
 
-    net.setInput(blob);
-    Mat out = net.forward();
+        net.setInput(blob);
+        Mat out = net.forward();
 
-    Mat ref = blobFromNPY(findDataFile("dnn/tensorflow/faster_rcnn_inception_v2_coco_2018_01_28.detection_out.npy"));
-    normAssertDetections(ref, out, "", 0.3);
+        Mat ref = blobFromNPY(findDataFile("dnn/tensorflow/" + names[i] + ".detection_out.npy"));
+        normAssertDetections(ref, out, names[i].c_str(), 0.3);
+    }
 }
 
 TEST_P(Test_TensorFlow_nets, MobileNet_v1_SSD_PPN)
@@ -347,15 +353,16 @@ TEST_P(Test_TensorFlow_nets, MobileNet_v1_SSD_PPN)
     Net net = readNetFromTensorflow(model, proto);
     Mat img = imread(findDataFile("dnn/dog416.png", false));
     Mat ref = blobFromNPY(findDataFile("dnn/tensorflow/ssd_mobilenet_v1_ppn_coco.detection_out.npy", false));
-    Mat blob = blobFromImage(img, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), true, false);
+    Mat blob = blobFromImage(img, 1.0f, Size(300, 300), Scalar(), true, false);
 
     net.setPreferableBackend(backend);
     net.setPreferableTarget(target);
 
     net.setInput(blob);
     Mat out = net.forward();
-    double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.006 : default_l1;
-    normAssertDetections(ref, out, "", 0.4, scoreDiff, default_lInf);
+    double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.008 : default_l1;
+    double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.021 : default_lInf;
+    normAssertDetections(ref, out, "", 0.4, scoreDiff, iouDiff);
 }
 
 TEST_P(Test_TensorFlow_nets, opencv_face_detector_uint8)
index 13e3dde..b6583da 100644 (file)
@@ -301,14 +301,14 @@ TEST_P(Test_Torch_nets, ENet_accuracy)
     // Due to numerical instability in Pooling-Unpooling layers (indexes jittering)
     // thresholds for ENet must be changed. Accuracy of results was checked on
     // Cityscapes dataset and difference in mIOU with Torch is 10E-4%
-    normAssert(ref, out, "", 0.00044, 0.44);
+    normAssert(ref, out, "", 0.00044, target == DNN_TARGET_CPU ? 0.453 : 0.44);
 
     const int N = 3;
     for (int i = 0; i < N; i++)
     {
         net.setInput(inputBlob, "");
         Mat out = net.forward();
-        normAssert(ref, out, "", 0.00044, 0.44);
+        normAssert(ref, out, "", 0.00044, target == DNN_TARGET_CPU ? 0.453 : 0.44);
     }
 }
 
index d18d82b..b02b0c5 100644 (file)
@@ -29,6 +29,8 @@ scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
                 'MaxPool2D',
                 'SecondStageFeatureExtractor',
                 'SecondStageBoxPredictor',
+                'Preprocessor/sub',
+                'Preprocessor/mul',
                 'image_tensor')
 
 scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
index 0d4a41f..c07bc76 100644 (file)
@@ -39,10 +39,11 @@ args = parser.parse_args()
 
 # Nodes that should be kept.
 keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu6', 'Placeholder', 'FusedBatchNorm',
-           'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity']
+           'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
+           'Sub']
 
 # Node with which prefixes should be removed
-prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Postprocessor/', 'Preprocessor/')
+prefixesToRemove = ('MultipleGridAnchorGenerator/', 'Postprocessor/', 'Preprocessor/map')
 
 # Read the graph.
 with tf.gfile.FastGFile(args.input, 'rb') as f: