OpenCL GPU target for Inference Engine deep learning backend
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 12 Mar 2018 14:35:28 +0000 (17:35 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 9 Apr 2018 14:21:35 +0000 (17:21 +0300)
Enable FP16 GPU target for DL Inference Engine backend.

12 files changed:
modules/dnn/include/opencv2/dnn/dnn.hpp
modules/dnn/perf/perf_net.cpp
modules/dnn/src/dnn.cpp
modules/dnn/src/layers/batch_norm_layer.cpp
modules/dnn/src/layers/convolution_layer.cpp
modules/dnn/src/layers/fully_connected_layer.cpp
modules/dnn/src/layers/scale_layer.cpp
modules/dnn/src/layers/shift_layer.cpp
modules/dnn/src/op_inf_engine.cpp
modules/dnn/src/op_inf_engine.hpp
modules/dnn/test/test_backends.cpp
modules/dnn/test/test_precomp.hpp

index f1e220c..7f8c7e7 100644 (file)
@@ -80,7 +80,8 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
     enum Target
     {
         DNN_TARGET_CPU,
-        DNN_TARGET_OPENCL
+        DNN_TARGET_OPENCL,
+        DNN_TARGET_OPENCL_FP16
     };
 
     /** @brief This class provides all data needed to initialize layer.
index 92719a8..12a2081 100644 (file)
@@ -13,7 +13,7 @@
 namespace opencv_test {
 
 CV_ENUM(DNNBackend, DNN_BACKEND_DEFAULT, DNN_BACKEND_HALIDE, DNN_BACKEND_INFERENCE_ENGINE)
-CV_ENUM(DNNTarget, DNN_TARGET_CPU, DNN_TARGET_OPENCL)
+CV_ENUM(DNNTarget, DNN_TARGET_CPU, DNN_TARGET_OPENCL, DNN_TARGET_OPENCL_FP16)
 
 class DNNTestNetwork : public ::perf::TestBaseWithParam< tuple<DNNBackend, DNNTarget> >
 {
@@ -41,8 +41,6 @@ public:
                 throw cvtest::SkipTestException("OpenCL is not available/disabled in OpenCV");
             }
         }
-        if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL)
-            throw SkipTestException("Skip OpenCL target of Inference Engine backend");
 
         randu(input, 0.0f, 1.0f);
 
@@ -89,24 +87,32 @@ public:
 
 PERF_TEST_P_(DNNTestNetwork, AlexNet)
 {
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
+        throw SkipTestException("");
     processNet("dnn/bvlc_alexnet.caffemodel", "dnn/bvlc_alexnet.prototxt",
             "alexnet.yml", Mat(cv::Size(227, 227), CV_32FC3));
 }
 
 PERF_TEST_P_(DNNTestNetwork, GoogLeNet)
 {
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
+        throw SkipTestException("");
     processNet("dnn/bvlc_googlenet.caffemodel", "dnn/bvlc_googlenet.prototxt",
             "", Mat(cv::Size(224, 224), CV_32FC3));
 }
 
 PERF_TEST_P_(DNNTestNetwork, ResNet_50)
 {
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
+        throw SkipTestException("");
     processNet("dnn/ResNet-50-model.caffemodel", "dnn/ResNet-50-deploy.prototxt",
             "resnet_50.yml", Mat(cv::Size(224, 224), CV_32FC3));
 }
 
 PERF_TEST_P_(DNNTestNetwork, SqueezeNet_v1_1)
 {
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
+        throw SkipTestException("");
     processNet("dnn/squeezenet_v1.1.caffemodel", "dnn/squeezenet_v1.1.prototxt",
             "squeezenet_v1_1.yml", Mat(cv::Size(227, 227), CV_32FC3));
 }
@@ -135,14 +141,18 @@ PERF_TEST_P_(DNNTestNetwork, SSD)
 
 PERF_TEST_P_(DNNTestNetwork, OpenFace)
 {
-    if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    if (backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
+        throw SkipTestException("");
     processNet("dnn/openface_nn4.small2.v1.t7", "", "",
             Mat(cv::Size(96, 96), CV_32FC3));
 }
 
 PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_Caffe)
 {
-    if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    if (backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
+        throw SkipTestException("");
     processNet("dnn/MobileNetSSD_deploy.caffemodel", "dnn/MobileNetSSD_deploy.prototxt", "",
             Mat(cv::Size(300, 300), CV_32FC3));
 }
@@ -150,7 +160,8 @@ PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_Caffe)
 PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_TensorFlow)
 {
     if (backend == DNN_BACKEND_DEFAULT && target == DNN_TARGET_OPENCL ||
-        backend == DNN_BACKEND_HALIDE)
+        backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
         throw SkipTestException("");
     processNet("dnn/ssd_mobilenet_v1_coco.pb", "ssd_mobilenet_v1_coco.pbtxt", "",
             Mat(cv::Size(300, 300), CV_32FC3));
@@ -158,7 +169,9 @@ PERF_TEST_P_(DNNTestNetwork, MobileNet_SSD_TensorFlow)
 
 PERF_TEST_P_(DNNTestNetwork, DenseNet_121)
 {
-    if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    if (backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
+        throw SkipTestException("");
     processNet("dnn/DenseNet_121.caffemodel", "dnn/DenseNet_121.prototxt", "",
                Mat(cv::Size(224, 224), CV_32FC3));
 }
@@ -189,7 +202,7 @@ PERF_TEST_P_(DNNTestNetwork, OpenPose_pose_mpi_faster_4_stages)
 PERF_TEST_P_(DNNTestNetwork, opencv_face_detector)
 {
     if (backend == DNN_BACKEND_HALIDE ||
-        backend == DNN_BACKEND_DEFAULT && target == DNN_TARGET_OPENCL)
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
         throw SkipTestException("");
     processNet("dnn/opencv_face_detector.caffemodel", "dnn/opencv_face_detector.prototxt", "",
                Mat(cv::Size(300, 300), CV_32FC3));
@@ -197,7 +210,9 @@ PERF_TEST_P_(DNNTestNetwork, opencv_face_detector)
 
 PERF_TEST_P_(DNNTestNetwork, Inception_v2_SSD_TensorFlow)
 {
-    if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    if (backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
+        throw SkipTestException("");
     processNet("dnn/ssd_inception_v2_coco_2017_11_17.pb", "ssd_inception_v2_coco_2017_11_17.pbtxt", "",
             Mat(cv::Size(300, 300), CV_32FC3));
 }
@@ -209,6 +224,8 @@ const tuple<DNNBackend, DNNTarget> testCases[] = {
 #endif
 #ifdef HAVE_INF_ENGINE
     tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_CPU),
+    tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_OPENCL),
+    tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_OPENCL_FP16),
 #endif
     tuple<DNNBackend, DNNTarget>(DNN_BACKEND_DEFAULT, DNN_TARGET_CPU),
     tuple<DNNBackend, DNNTarget>(DNN_BACKEND_DEFAULT, DNN_TARGET_OPENCL)
index 611e35e..d19869a 100644 (file)
@@ -1154,7 +1154,7 @@ struct Net::Impl
                 ld.skip = true;
             }
             layers[lastLayerId].skip = false;
-            ieNode->net->init();
+            ieNode->net->init(preferableTarget);
             return;
         }
 
@@ -1167,17 +1167,17 @@ struct Net::Impl
         for (it = layers.begin(); it != layers.end(); ++it)
         {
             LayerData &ld = it->second;
-            ld.skip = true;  // Initially skip all Inference Engine supported layers.
-            Ptr<Layer> layer = ld.layerInstance;
+            bool fused = ld.skip && ld.id != 0;
 
+            Ptr<Layer> layer = ld.layerInstance;
             if (!layer->supportBackend(preferableBackend))
             {
                 addInfEngineNetOutputs(ld);
-                ld.skip = false;
                 net = Ptr<InfEngineBackendNet>();
                 netBlobsWrappers.clear();
                 continue;
             }
+            ld.skip = true;  // Initially skip all Inference Engine supported layers.
 
             // Create a new network if one of inputs from different Inference Engine graph.
             for (int i = 0; i < ld.inputBlobsId.size(); ++i)
@@ -1217,19 +1217,16 @@ struct Net::Impl
             }
             netBlobsWrappers[ld.id] = ld.outputBlobsWrappers[0];
 
-            bool fused = false;
             Ptr<BackendNode> node;
             if (!net.empty())
             {
-                // Try to fuse.
-                bool inPlace = ld.inputBlobsId.size() == 1 && ld.outputBlobs.size() == 1 &&
-                               ld.inputBlobs[0]->data == ld.outputBlobs[0].data;
-                if (inPlace)
+                if (fused)
                 {
-                    node = layer->tryAttach(layers[ld.inputBlobsId[0].lid].backendNodes[preferableBackend]);
-                    fused = !node.empty();
-                    if (fused)
-                        ld.inputBlobsWrappers = layers[ld.inputBlobsId[0].lid].inputBlobsWrappers;
+                    bool inPlace = ld.inputBlobsId.size() == 1 && ld.outputBlobs.size() == 1 &&
+                                   ld.inputBlobs[0]->data == ld.outputBlobs[0].data;
+                    CV_Assert(inPlace);
+                    node = layers[ld.inputBlobsId[0].lid].backendNodes[preferableBackend];
+                    ld.inputBlobsWrappers = layers[ld.inputBlobsId[0].lid].inputBlobsWrappers;
                 }
             }
             else
@@ -1247,6 +1244,19 @@ struct Net::Impl
             CV_Assert(!ieNode.empty());
             ieNode->net = net;
 
+            if (preferableTarget == DNN_TARGET_OPENCL_FP16 && !fused)
+            {
+                ieNode->layer->precision = InferenceEngine::Precision::FP16;
+                auto weightableLayer = std::dynamic_pointer_cast<InferenceEngine::WeightableLayer>(ieNode->layer);
+                if (weightableLayer)
+                {
+                    if (weightableLayer->_weights)
+                        weightableLayer->_weights = convertFp16(weightableLayer->_weights);
+                    if (weightableLayer->_biases)
+                        weightableLayer->_biases = convertFp16(weightableLayer->_biases);
+                }
+            }
+
             ieNode->connect(ld.inputBlobsWrappers, ld.outputBlobsWrappers);
             net->addBlobs(ld.inputBlobsWrappers);
             net->addBlobs(ld.outputBlobsWrappers);
@@ -1276,7 +1286,7 @@ struct Net::Impl
 
             if (!ieNode->net->isInitialized())
             {
-                ieNode->net->init();
+                ieNode->net->init(preferableTarget);
                 ld.skip = false;
             }
         }
@@ -1380,7 +1390,8 @@ struct Net::Impl
 
     void fuseLayers(const std::vector<LayerPin>& blobsToKeep_)
     {
-        if( !fusion || preferableBackend != DNN_BACKEND_DEFAULT)
+        if( !fusion || preferableBackend != DNN_BACKEND_DEFAULT &&
+                       preferableBackend != DNN_BACKEND_INFERENCE_ENGINE)
             return;
 
         CV_TRACE_FUNCTION();
@@ -1407,7 +1418,7 @@ struct Net::Impl
             // some other layers.
 
             // TODO: OpenCL target support more fusion styles.
-            if ( preferableTarget == DNN_TARGET_OPENCL &&
+            if ( preferableBackend == DNN_BACKEND_DEFAULT && preferableTarget == DNN_TARGET_OPENCL &&
                  (!cv::ocl::useOpenCL() || (ld.layerInstance->type != "Convolution" &&
                  ld.layerInstance->type != "MVN")) )
                 continue;
@@ -1442,6 +1453,9 @@ struct Net::Impl
                         break;
                 }
 
+                if (preferableBackend != DNN_BACKEND_DEFAULT)
+                    continue;  // Go to the next layer.
+
                 // For now, OpenCL target support fusion with activation of ReLU/ChannelsPReLU/Power/Tanh
                 if ( preferableTarget != DNN_TARGET_OPENCL ||
                         (preferableTarget == DNN_TARGET_OPENCL &&
@@ -1583,6 +1597,9 @@ struct Net::Impl
                 }
             }
 
+            if (preferableBackend != DNN_BACKEND_DEFAULT)
+                continue;  // Go to the next layer.
+
             // the optimization #2. if there is no layer that takes max pooling layer's computed
             // max indices (and only some semantical segmentation networks might need this;
             // many others only take the maximum values), then we switch the max pooling
index df4e553..c2906b6 100644 (file)
@@ -234,19 +234,6 @@ public:
 #endif  // HAVE_HALIDE
                 break;
             }
-            case DNN_BACKEND_INFERENCE_ENGINE:
-            {
-#ifdef HAVE_INF_ENGINE
-                auto base = node.dynamicCast<InfEngineBackendNode>();
-                auto conv = std::dynamic_pointer_cast<InferenceEngine::ConvolutionLayer>(base->layer);
-                if (conv)
-                {
-                    fuseConvWeights(conv, weights_, bias_);
-                    return base;
-                }
-#endif  // HAVE_INF_ENGINE
-                break;
-            }
         }
         return Ptr<BackendNode>();
     }
@@ -287,8 +274,9 @@ public:
         lp.precision = InferenceEngine::Precision::FP32;
         std::shared_ptr<InferenceEngine::ScaleShiftLayer> ieLayer(new InferenceEngine::ScaleShiftLayer(lp));
 
-        ieLayer->_weights = wrapToInfEngineBlob(weights_);
-        ieLayer->_biases = wrapToInfEngineBlob(bias_);
+        const int numChannels = weights_.total();
+        ieLayer->_weights = wrapToInfEngineBlob(weights_, {numChannels}, InferenceEngine::Layout::C);
+        ieLayer->_biases = wrapToInfEngineBlob(bias_, {numChannels}, InferenceEngine::Layout::C);
 
         return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
 #endif  // HAVE_INF_ENGINE
index 6da8438..8c52bc0 100644 (file)
@@ -173,21 +173,21 @@ public:
     std::vector<float> biasvec;
     std::vector<float> reluslope;
     Ptr<ActivationLayer> activ;
+    bool newWeightAndBias;
+    bool fusedBias;
 
 #ifdef HAVE_OPENCL
     Ptr<OCL4DNNConvSpatial<float> > convolutionOp;
     std::vector<UMat> umat_blobs;
-    bool fusedBias;
-    bool newWeightAndBias;
     bool newActiv;
     ocl4dnnFusedActiv_t activType;
     float power;
 #endif
     ConvolutionLayerImpl(const LayerParams &params) : BaseConvolutionLayerImpl(params)
     {
-#ifdef HAVE_OPENCL
-        fusedBias = false;
         newWeightAndBias = false;
+        fusedBias = false;
+#ifdef HAVE_OPENCL
         newActiv = false;
         activType = OCL4DNN_CONV_FUSED_ACTIV_NONE;
         power = 0.f;
@@ -350,10 +350,8 @@ public:
                 biasvec[i] += b.at<float>(i);
         }
 
-#ifdef HAVE_OPENCL
         newWeightAndBias = !w.empty() || !b.empty();
         fusedBias = hasBias() || !b.empty();
-#endif
         biasvec[outCn] = biasvec[outCn+1] = biasvec[outCn-1];
     }
 
@@ -433,9 +431,31 @@ public:
         ieLayer->_dilation_y = dilation.height;
         ieLayer->_group = group;
 
-        ieLayer->_weights = wrapToInfEngineBlob(blobs[0]);
-        if (hasBias())
-            ieLayer->_biases = wrapToInfEngineBlob(blobs[1]);
+        ieLayer->_weights = wrapToInfEngineBlob(blobs[0], InferenceEngine::Layout::OIHW);
+        if (newWeightAndBias)
+        {
+            if (weightsMat.isContinuous())
+            {
+                Mat fusedWeights = weightsMat.reshape(1, blobs[0].dims, blobs[0].size);
+                ieLayer->_weights = wrapToInfEngineBlob(fusedWeights, InferenceEngine::Layout::OIHW);
+            }
+            else
+            {
+                ieLayer->_weights = InferenceEngine::make_shared_blob<float>(
+                                    InferenceEngine::Precision::FP32, InferenceEngine::Layout::OIHW,
+                                    ieLayer->_weights->dims());
+                ieLayer->_weights->allocate();
+
+                Mat newWeights = infEngineBlobToMat(ieLayer->_weights).reshape(1, outCn);
+                Mat fusedWeights = weightsMat.colRange(0, newWeights.cols);
+                fusedWeights.copyTo(newWeights);
+            }
+        }
+        if (hasBias() || fusedBias)
+        {
+            Mat biasesMat({outCn}, CV_32F, &biasvec[0]);
+            ieLayer->_biases = wrapToInfEngineBlob(biasesMat, {outCn}, InferenceEngine::Layout::C);
+        }
         return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
 #endif  // HAVE_INF_ENGINE
         return Ptr<BackendNode>();
index 68ca1b4..9ee7e98 100644 (file)
@@ -412,9 +412,9 @@ public:
         std::shared_ptr<InferenceEngine::FullyConnectedLayer> ieLayer(new InferenceEngine::FullyConnectedLayer(lp));
 
         ieLayer->_out_num = blobs[0].size[0];
-        ieLayer->_weights = wrapToInfEngineBlob(blobs[0]);
+        ieLayer->_weights = wrapToInfEngineBlob(blobs[0], {blobs[0].size[0], blobs[0].size[1], 1, 1}, InferenceEngine::Layout::OIHW);
         if (blobs.size() > 1)
-            ieLayer->_biases = wrapToInfEngineBlob(blobs[1]);
+            ieLayer->_biases = wrapToInfEngineBlob(blobs[1], {ieLayer->_out_num}, InferenceEngine::Layout::C);
         return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
 #endif  // HAVE_INF_ENGINE
         return Ptr<BackendNode>();
index 464e385..833c993 100644 (file)
@@ -132,20 +132,6 @@ public:
 #endif  // HAVE_HALIDE
                 break;
             }
-            case DNN_BACKEND_INFERENCE_ENGINE:
-            {
-#ifdef HAVE_INF_ENGINE
-                auto base = node.dynamicCast<InfEngineBackendNode>();
-                auto conv = std::dynamic_pointer_cast<InferenceEngine::ConvolutionLayer>(base->layer);
-                if (conv)
-                {
-                    Mat bias = hasBias ? blobs[1] : Mat();
-                    fuseConvWeights(conv, blobs[0], bias);
-                    return base;
-                }
-#endif  // HAVE_INF_ENGINE
-                break;
-            }
         }
         return Ptr<BackendNode>();
     }
@@ -192,9 +178,10 @@ public:
         lp.precision = InferenceEngine::Precision::FP32;
         std::shared_ptr<InferenceEngine::ScaleShiftLayer> ieLayer(new InferenceEngine::ScaleShiftLayer(lp));
 
-        ieLayer->_weights = wrapToInfEngineBlob(blobs[0]);
+        const int numChannels = blobs[0].total();
+        ieLayer->_weights = wrapToInfEngineBlob(blobs[0], {numChannels}, InferenceEngine::Layout::C);
         if (hasBias)
-            ieLayer->_biases = wrapToInfEngineBlob(blobs[1]);
+            ieLayer->_biases = wrapToInfEngineBlob(blobs[1], {numChannels}, InferenceEngine::Layout::C);
 
         return Ptr<BackendNode>(new InfEngineBackendNode(ieLayer));
 #endif  // HAVE_INF_ENGINE
index fbbdcb1..7c3bb14 100644 (file)
@@ -90,27 +90,6 @@ public:
         }
     }
 
-    virtual Ptr<BackendNode> tryAttach(const Ptr<BackendNode>& node) CV_OVERRIDE
-    {
-        switch (node->backendId)
-        {
-            case DNN_BACKEND_INFERENCE_ENGINE:
-            {
-#ifdef HAVE_INF_ENGINE
-                auto base = node.dynamicCast<InfEngineBackendNode>();
-                auto conv = std::dynamic_pointer_cast<InferenceEngine::ConvolutionLayer>(base->layer);
-                if (conv)
-                {
-                    fuseConvWeights(conv, Mat(), blobs[0]);
-                    return base;
-                }
-#endif  // HAVE_INF_ENGINE
-                break;
-            }
-        }
-        return Ptr<BackendNode>();
-    }
-
     virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> >&) CV_OVERRIDE
     {
 #ifdef HAVE_INF_ENGINE
index cad27ce..1514573 100644 (file)
@@ -59,22 +59,22 @@ static InferenceEngine::DataPtr wrapToInfEngineDataNode(const Mat& m, const std:
     std::vector<size_t> reversedShape(&m.size[0], &m.size[0] + m.dims);
     std::reverse(reversedShape.begin(), reversedShape.end());
     return InferenceEngine::DataPtr(
-      new InferenceEngine::Data(name, reversedShape, InferenceEngine::Precision::FP32,
-                                InferenceEngine::Layout::ANY)
+      new InferenceEngine::Data(name, reversedShape, InferenceEngine::Precision::FP32)
     );
 }
 
-InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, const std::vector<size_t>& shape)
+InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, const std::vector<size_t>& shape,
+                                                       InferenceEngine::Layout layout)
 {
     return InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32,
-                                                    shape, (float*)m.data);
+                                                    layout, shape, (float*)m.data);
 }
 
-InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m)
+InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, InferenceEngine::Layout layout)
 {
     std::vector<size_t> reversedShape(&m.size[0], &m.size[0] + m.dims);
     std::reverse(reversedShape.begin(), reversedShape.end());
-    return wrapToInfEngineBlob(m, reversedShape);
+    return wrapToInfEngineBlob(m, reversedShape, layout);
 }
 
 InferenceEngine::DataPtr infEngineDataNode(const Ptr<BackendWrapper>& ptr)
@@ -109,10 +109,14 @@ void InfEngineBackendWrapper::setHostDirty()
 
 InfEngineBackendNet::InfEngineBackendNet()
 {
+    targetDevice = InferenceEngine::TargetDevice::eCPU;
+    precision = InferenceEngine::Precision::FP32;
 }
 
 InfEngineBackendNet::InfEngineBackendNet(InferenceEngine::CNNNetwork& net)
 {
+    targetDevice = InferenceEngine::TargetDevice::eCPU;
+    precision = InferenceEngine::Precision::FP32;
     inputs = net.getInputsInfo();
     outputs = net.getOutputsInfo();
     layers.resize(net.layerCount());  // A hack to execute InfEngineBackendNet::layerCount correctly.
@@ -126,9 +130,14 @@ void InfEngineBackendNet::Release() noexcept
     outputs.clear();
 }
 
+void InfEngineBackendNet::setPrecision(InferenceEngine::Precision p) noexcept
+{
+    precision = p;
+}
+
 InferenceEngine::Precision InfEngineBackendNet::getPrecision() noexcept
 {
-    return InferenceEngine::Precision::FP32;
+    return precision;
 }
 
 // Assume that outputs of network is unconnected blobs.
@@ -161,9 +170,8 @@ InferenceEngine::InputInfo::Ptr InfEngineBackendNet::getInput(const std::string
     return it->second;
 }
 
-void InfEngineBackendNet::getName(char *pName, size_t len) noexcept
+void InfEngineBackendNet::getName(char*, size_t) noexcept
 {
-    CV_Error(Error::StsNotImplemented, "");
 }
 
 size_t InfEngineBackendNet::layerCount() noexcept
@@ -213,13 +221,15 @@ InfEngineBackendNet::getLayerByName(const char *layerName, InferenceEngine::CNNL
 
 void InfEngineBackendNet::setTargetDevice(InferenceEngine::TargetDevice device) noexcept
 {
-    if (device != InferenceEngine::TargetDevice::eCPU)
+    if (device != InferenceEngine::TargetDevice::eCPU &&
+        device != InferenceEngine::TargetDevice::eGPU)
         CV_Error(Error::StsNotImplemented, "");
+    targetDevice = device;
 }
 
 InferenceEngine::TargetDevice InfEngineBackendNet::getTargetDevice() noexcept
 {
-    return InferenceEngine::TargetDevice::eCPU;
+    return targetDevice;
 }
 
 InferenceEngine::StatusCode InfEngineBackendNet::setBatchSize(const size_t size) noexcept
@@ -234,7 +244,7 @@ size_t InfEngineBackendNet::getBatchSize() const noexcept
     return 0;
 }
 
-void InfEngineBackendNet::init()
+void InfEngineBackendNet::init(int targetId)
 {
     if (inputs.empty())
     {
@@ -307,6 +317,15 @@ void InfEngineBackendNet::init()
         outBlobs[it.first] = allBlobs[it.first];
     }
 
+    switch (targetId)
+    {
+    case DNN_TARGET_CPU: setTargetDevice(InferenceEngine::TargetDevice::eCPU); break;
+    case DNN_TARGET_OPENCL_FP16: setPrecision(InferenceEngine::Precision::FP16);  // Fallback to the next.
+    case DNN_TARGET_OPENCL: setTargetDevice(InferenceEngine::TargetDevice::eGPU); break;
+    default:
+        CV_Error(Error::StsError, format("Unknown target identifier: %d", targetId));
+    }
+
     if (!isInitialized())
         initPlugin(*this);
 }
@@ -319,7 +338,7 @@ void InfEngineBackendNet::initPlugin(InferenceEngine::ICNNNetwork& net)
     InferenceEngine::ResponseDesc resp;
     const InferenceEngine::Version* v = InferenceEngine::GetInferenceEngineVersion();
 
-    plugin = InferenceEngine::PluginDispatcher({""}).getSuitablePlugin(InferenceEngine::TargetDevice::eCPU);
+    plugin = InferenceEngine::PluginDispatcher({""}).getSuitablePlugin(targetDevice);
     if (std::atoi(v->buildNumber) > 5855)
     {
 #ifdef _WIN32
@@ -360,7 +379,7 @@ void InfEngineBackendNet::forward()
         CV_Error(Error::StsAssert, resp.msg);
 }
 
-static inline Mat infEngineBlobToMat(const InferenceEngine::Blob::Ptr& blob)
+Mat infEngineBlobToMat(const InferenceEngine::Blob::Ptr& blob)
 {
     // NOTE: Inference Engine sizes are reversed.
     std::vector<size_t> dims = blob->dims();
@@ -369,56 +388,6 @@ static inline Mat infEngineBlobToMat(const InferenceEngine::Blob::Ptr& blob)
     return Mat(size, CV_32F, (void*)blob->buffer());
 }
 
-void fuseConvWeights(const std::shared_ptr<InferenceEngine::ConvolutionLayer>& conv,
-                     const Mat& w, const Mat& b)
-{
-    CV_Assert(!w.empty() || !b.empty());
-    if (!w.empty())
-    {
-        // Get convolution's weights. Clone the data because Inference Engine can host it
-        // and conv->_weights->allocate() below will deallocate it.
-        Mat originWeights = infEngineBlobToMat(conv->_weights).clone();
-
-        // Create new weights blob.
-        conv->_weights = InferenceEngine::make_shared_blob<float>(
-                            InferenceEngine::Precision::FP32, conv->_weights->dims());
-        conv->_weights->allocate();
-
-        // Convolution weights have OIHW data layout.
-        // (conv(I) + b1 ) * w + b2
-        // w*conv(I) + b1 * w + b2
-        Mat fusedWeights = infEngineBlobToMat(conv->_weights);
-
-        const int numChannels = fusedWeights.size[0];
-        // Mat weights = blobs[0].reshape(1, 1);
-        // Mat bias = hasBias ? blobs[1].reshape(1, 1) : Mat();
-        CV_Assert(numChannels == w.total());
-        CV_Assert(b.empty() || numChannels == b.total());
-        for (int i = 0; i < numChannels; ++i)
-        {
-            cv::multiply(slice(originWeights, i), w.at<float>(i), slice(fusedWeights, i));
-        }
-    }
-    if (conv->_biases)
-    {
-        // The same for biases.
-        Mat originBiases = infEngineBlobToMat(conv->_biases).clone();
-
-        conv->_biases = InferenceEngine::make_shared_blob<float>(
-                            InferenceEngine::Precision::FP32, conv->_biases->dims());
-        conv->_biases->allocate();
-        Mat fusedBiases = infEngineBlobToMat(conv->_biases);
-        originBiases.copyTo(fusedBiases);
-
-        if (!w.empty())
-            cv::multiply(w.reshape(1, fusedBiases.dims, &fusedBiases.size[0]), fusedBiases, fusedBiases);
-        if (!b.empty())
-            cv::add(fusedBiases, b.reshape(1, fusedBiases.dims, &fusedBiases.size[0]), fusedBiases);
-    }
-    else
-        conv->_biases = wrapToInfEngineBlob(b);
-}
-
 InfEngineBackendLayer::InfEngineBackendLayer(const InferenceEngine::DataPtr& output_)
 {
     output = output_;
@@ -454,6 +423,16 @@ void InfEngineBackendLayer::forward(InputArrayOfArrays inputs, OutputArrayOfArra
     CV_Error(Error::StsInternal, "Choose Inference Engine as a preferable backend.");
 }
 
+InferenceEngine::TBlob<int16_t>::Ptr convertFp16(const InferenceEngine::Blob::Ptr& blob)
+{
+    auto halfs = InferenceEngine::make_shared_blob<int16_t>(InferenceEngine::Precision::FP16, blob->layout(), blob->dims());
+    halfs->allocate();
+    Mat floatsData(1, blob->size(), CV_32F, blob->buffer());
+    Mat halfsData(1, blob->size(), CV_16SC1, halfs->buffer());
+    convertFp16(floatsData, halfsData);
+    return halfs;
+}
+
 #endif  // HAVE_INF_ENGINE
 
 bool haveInfEngine()
index 4384635..67dadd3 100644 (file)
@@ -32,6 +32,8 @@ public:
 
     virtual void Release() noexcept CV_OVERRIDE;
 
+    void setPrecision(InferenceEngine::Precision p) noexcept;
+
     virtual InferenceEngine::Precision getPrecision() noexcept CV_OVERRIDE;
 
     virtual void getOutputsInfo(InferenceEngine::OutputsDataMap &out) noexcept /*CV_OVERRIDE*/;
@@ -68,7 +70,7 @@ public:
 
     virtual size_t getBatchSize() const noexcept CV_OVERRIDE;
 
-    void init();
+    void init(int targetId);
 
     void addBlobs(const std::vector<Ptr<BackendWrapper> >& wrappers);
 
@@ -83,6 +85,8 @@ private:
     InferenceEngine::BlobMap inpBlobs;
     InferenceEngine::BlobMap outBlobs;
     InferenceEngine::BlobMap allBlobs;
+    InferenceEngine::TargetDevice targetDevice;
+    InferenceEngine::Precision precision;
     InferenceEngine::InferenceEnginePluginPtr plugin;
 
     void initPlugin(InferenceEngine::ICNNNetwork& net);
@@ -116,15 +120,17 @@ public:
     InferenceEngine::TBlob<float>::Ptr blob;
 };
 
-InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m);
+InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, InferenceEngine::Layout layout = InferenceEngine::Layout::ANY);
 
-InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, const std::vector<size_t>& shape);
+InferenceEngine::TBlob<float>::Ptr wrapToInfEngineBlob(const Mat& m, const std::vector<size_t>& shape, InferenceEngine::Layout layout);
 
 InferenceEngine::DataPtr infEngineDataNode(const Ptr<BackendWrapper>& ptr);
 
-// Fuses convolution weights and biases with channel-wise scales and shifts.
-void fuseConvWeights(const std::shared_ptr<InferenceEngine::ConvolutionLayer>& conv,
-                     const Mat& w, const Mat& b = Mat());
+Mat infEngineBlobToMat(const InferenceEngine::Blob::Ptr& blob);
+
+// Convert Inference Engine blob with FP32 precision to FP16 precision.
+// Allocates memory for a new blob.
+InferenceEngine::TBlob<int16_t>::Ptr convertFp16(const InferenceEngine::Blob::Ptr& blob);
 
 // This is a fake class to run networks from Model Optimizer. Objects of that
 // class simulate responses of layers are imported by OpenCV and supported by
@@ -151,7 +157,6 @@ private:
     InferenceEngine::DataPtr output;
 };
 
-
 #endif  // HAVE_INF_ENGINE
 
 bool haveInfEngine();
index db657ee..ea79119 100644 (file)
@@ -100,6 +100,8 @@ public:
 
 TEST_P(DNNTestNetwork, AlexNet)
 {
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
+        throw SkipTestException("");
     processNet("dnn/bvlc_alexnet.caffemodel", "dnn/bvlc_alexnet.prototxt",
                Size(227, 227), "prob",
                target == DNN_TARGET_OPENCL ? "dnn/halide_scheduler_opencl_alexnet.yml" :
@@ -108,6 +110,8 @@ TEST_P(DNNTestNetwork, AlexNet)
 
 TEST_P(DNNTestNetwork, ResNet_50)
 {
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
+        throw SkipTestException("");
     processNet("dnn/ResNet-50-model.caffemodel", "dnn/ResNet-50-deploy.prototxt",
                Size(224, 224), "prob",
                target == DNN_TARGET_OPENCL ? "dnn/halide_scheduler_opencl_resnet_50.yml" :
@@ -116,6 +120,8 @@ TEST_P(DNNTestNetwork, ResNet_50)
 
 TEST_P(DNNTestNetwork, SqueezeNet_v1_1)
 {
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
+        throw SkipTestException("");
     processNet("dnn/squeezenet_v1.1.caffemodel", "dnn/squeezenet_v1.1.prototxt",
                Size(227, 227), "prob",
                target == DNN_TARGET_OPENCL ? "dnn/halide_scheduler_opencl_squeezenet_v1_1.yml" :
@@ -124,6 +130,8 @@ TEST_P(DNNTestNetwork, SqueezeNet_v1_1)
 
 TEST_P(DNNTestNetwork, GoogLeNet)
 {
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
+        throw SkipTestException("");
     processNet("dnn/bvlc_googlenet.caffemodel", "dnn/bvlc_googlenet.prototxt",
                Size(224, 224), "prob");
 }
@@ -147,7 +155,9 @@ TEST_P(DNNTestNetwork, ENet)
 
 TEST_P(DNNTestNetwork, MobileNet_SSD_Caffe)
 {
-    if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    if (backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
+        throw SkipTestException("");
     Mat sample = imread(findDataFile("dnn/street.png", false));
     Mat inp = blobFromImage(sample, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), false);
 
@@ -157,7 +167,9 @@ TEST_P(DNNTestNetwork, MobileNet_SSD_Caffe)
 
 TEST_P(DNNTestNetwork, MobileNet_SSD_TensorFlow)
 {
-    if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    if (backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
+        throw SkipTestException("");
     Mat sample = imread(findDataFile("dnn/street.png", false));
     Mat inp = blobFromImage(sample, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), false);
     processNet("dnn/ssd_mobilenet_v1_coco.pb", "dnn/ssd_mobilenet_v1_coco.pbtxt",
@@ -177,35 +189,45 @@ TEST_P(DNNTestNetwork, SSD_VGG16)
 TEST_P(DNNTestNetwork, OpenPose_pose_coco)
 {
     if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    double l1 = target == DNN_TARGET_OPENCL_FP16 ? 3e-5 : 1e-5;
+    double lInf = target == DNN_TARGET_OPENCL_FP16 ? 3e-3 : 1e-4;
     processNet("dnn/openpose_pose_coco.caffemodel", "dnn/openpose_pose_coco.prototxt",
-               Size(368, 368), "");
+               Size(368, 368), "", "", l1, lInf);
 }
 
 TEST_P(DNNTestNetwork, OpenPose_pose_mpi)
 {
     if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    double l1 = target == DNN_TARGET_OPENCL_FP16 ? 4e-5 : 1e-5;
+    double lInf = target == DNN_TARGET_OPENCL_FP16 ? 7e-3 : 1e-4;
     processNet("dnn/openpose_pose_mpi.caffemodel", "dnn/openpose_pose_mpi.prototxt",
-               Size(368, 368), "");
+               Size(368, 368), "", "", l1, lInf);
 }
 
 TEST_P(DNNTestNetwork, OpenPose_pose_mpi_faster_4_stages)
 {
     if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    double l1 = target == DNN_TARGET_OPENCL_FP16 ? 5e-5 : 1e-5;
+    double lInf = target == DNN_TARGET_OPENCL_FP16 ? 5e-3 : 1e-4;
     // The same .caffemodel but modified .prototxt
     // See https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/pose/poseParameters.cpp
     processNet("dnn/openpose_pose_mpi.caffemodel", "dnn/openpose_pose_mpi_faster_4_stages.prototxt",
-               Size(368, 368), "");
+               Size(368, 368), "", "", l1, lInf);
 }
 
 TEST_P(DNNTestNetwork, OpenFace)
 {
-    if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    if (backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
+        throw SkipTestException("");
     processNet("dnn/openface_nn4.small2.v1.t7", "", Size(96, 96), "");
 }
 
 TEST_P(DNNTestNetwork, opencv_face_detector)
 {
-    if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    if (backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
+        throw SkipTestException("");
     Mat img = imread(findDataFile("gpu/lbpcascade/er.png", false));
     Mat inp = blobFromImage(img, 1.0, Size(), Scalar(104.0, 177.0, 123.0), false, false);
     processNet("dnn/opencv_face_detector.caffemodel", "dnn/opencv_face_detector.prototxt",
@@ -214,13 +236,23 @@ TEST_P(DNNTestNetwork, opencv_face_detector)
 
 TEST_P(DNNTestNetwork, Inception_v2_SSD_TensorFlow)
 {
-    if (backend == DNN_BACKEND_HALIDE) throw SkipTestException("");
+    if (backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target != DNN_TARGET_CPU)
+        throw SkipTestException("");
     Mat sample = imread(findDataFile("dnn/street.png", false));
     Mat inp = blobFromImage(sample, 1.0f / 127.5, Size(300, 300), Scalar(127.5, 127.5, 127.5), false);
     processNet("dnn/ssd_inception_v2_coco_2017_11_17.pb", "dnn/ssd_inception_v2_coco_2017_11_17.pbtxt",
                inp, "detection_out");
 }
 
+TEST_P(DNNTestNetwork, DenseNet_121)
+{
+    if (backend == DNN_BACKEND_HALIDE ||
+        backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_OPENCL_FP16)
+        throw SkipTestException("");
+    processNet("dnn/DenseNet_121.caffemodel", "dnn/DenseNet_121.prototxt", Size(224, 224), "", "caffe");
+}
+
 const tuple<DNNBackend, DNNTarget> testCases[] = {
 #ifdef HAVE_HALIDE
     tuple<DNNBackend, DNNTarget>(DNN_BACKEND_HALIDE, DNN_TARGET_CPU),
@@ -228,6 +260,8 @@ const tuple<DNNBackend, DNNTarget> testCases[] = {
 #endif
 #ifdef HAVE_INF_ENGINE
     tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_CPU),
+    tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_OPENCL),
+    tuple<DNNBackend, DNNTarget>(DNN_BACKEND_INFERENCE_ENGINE, DNN_TARGET_OPENCL_FP16),
 #endif
     tuple<DNNBackend, DNNTarget>(DNN_BACKEND_DEFAULT, DNN_TARGET_OPENCL)
 };
index b4bb97d..54c9ce6 100644 (file)
@@ -53,7 +53,7 @@ namespace opencv_test {
 using namespace cv::dnn;
 
 CV_ENUM(DNNBackend, DNN_BACKEND_DEFAULT, DNN_BACKEND_HALIDE, DNN_BACKEND_INFERENCE_ENGINE)
-CV_ENUM(DNNTarget, DNN_TARGET_CPU, DNN_TARGET_OPENCL)
+CV_ENUM(DNNTarget, DNN_TARGET_CPU, DNN_TARGET_OPENCL, DNN_TARGET_OPENCL_FP16)
 
 static testing::internal::ParamGenerator<DNNTarget> availableDnnTargets()
 {