Merge pull request #9305 from dkurt:public_dnn_importer_is_deprecated
[platform/upstream/opencv.git] / modules / dnn / src / tensorflow / tf_importer.cpp
index 0797986..8e1f18e 100644 (file)
@@ -85,11 +85,38 @@ static Mat getTensorContent(const tensorflow::TensorProto &tensor)
     switch (tensor.dtype())
     {
         case tensorflow::DT_FLOAT:
-            return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
+        {
+            if (!content.empty())
+                return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
+            else
+            {
+                const RepeatedField<float>& field = tensor.float_val();
+                CV_Assert(!field.empty());
+                return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone();
+            }
+        }
         case tensorflow::DT_DOUBLE:
-            return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
+        {
+            if (!content.empty())
+                return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
+            else
+            {
+                const RepeatedField<double>& field = tensor.double_val();
+                CV_Assert(!field.empty());
+                return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone();
+            }
+        }
         case tensorflow::DT_INT32:
-            return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone();
+        {
+            if (!content.empty())
+                return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone();
+            else
+            {
+                const RepeatedField<int32_t>& field = tensor.int_val();
+                CV_Assert(!field.empty());
+                return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone();
+            }
+        }
         case tensorflow::DT_HALF:
         {
             Mat halfs;
@@ -573,7 +600,7 @@ void TFImporter::populateNet(Net dstNet)
         if(layers_to_ignore.find(li) != layers_to_ignore.end())
             continue;
 
-        if (type == "Conv2D" || type == "SpaceToBatchND")
+        if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
         {
             // The first node of dilated convolution subgraph.
             // Extract input node, dilation rate and paddings.
@@ -621,7 +648,28 @@ void TFImporter::populateNet(Net dstNet)
             }
 
             kernelFromTensor(getConstBlob(layer, value_id), layerParams.blobs[0]);
-            const int* kshape = layerParams.blobs[0].size.p;
+            int* kshape = layerParams.blobs[0].size.p;
+            if (type == "DepthwiseConv2dNative")
+            {
+                const int chMultiplier = kshape[0];
+                const int inCh = kshape[1];
+                const int height = kshape[2];
+                const int width = kshape[3];
+
+                Mat copy = layerParams.blobs[0].clone();
+                float* src = (float*)copy.data;
+                float* dst = (float*)layerParams.blobs[0].data;
+                for (int i = 0; i < chMultiplier; ++i)
+                    for (int j = 0; j < inCh; ++j)
+                        for (int s = 0; s < height * width; ++s)
+                            {
+                                int src_i = (i * inCh + j) * height * width + s;
+                                int dst_i = (j * chMultiplier + i) * height* width + s;
+                                dst[dst_i] = src[src_i];
+                            }
+                kshape[0] = inCh * chMultiplier;
+                kshape[1] = 1;
+            }
             layerParams.set("kernel_h", kshape[2]);
             layerParams.set("kernel_w", kshape[3]);
             layerParams.set("num_output", kshape[0]);
@@ -689,6 +737,10 @@ void TFImporter::populateNet(Net dstNet)
             layerParams.blobs.resize(1);
 
             StrIntVector next_layers = getNextLayers(net, name, "BiasAdd");
+            if (next_layers.empty())
+            {
+                next_layers = getNextLayers(net, name, "Add");
+            }
             if (next_layers.size() == 1) {
                 layerParams.set("bias_term", true);
                 layerParams.blobs.resize(2);
@@ -840,20 +892,20 @@ void TFImporter::populateNet(Net dstNet)
             {
                 // Multiplication by constant.
                 CV_Assert(layer.input_size() == 2);
+                Mat scaleMat = getTensorContent(getConstBlob(layer, value_id));
+                CV_Assert(scaleMat.type() == CV_32FC1);
 
-                float scale;
-                if (!getConstBlob(layer, value_id).float_val().empty())
-                    scale = getConstBlob(layer, value_id).float_val()[0];
-                else
+                int id;
+                if (scaleMat.total() == 1)  // is a scalar.
                 {
-                    Mat scaleMat;
-                    blobFromTensor(getConstBlob(layer, value_id), scaleMat);
-                    CV_Assert(scaleMat.total() == 1 && scaleMat.type() == CV_32FC1);
-                    scale = scaleMat.at<float>(0, 0);
+                    layerParams.set("scale", scaleMat.at<float>(0));
+                    id = dstNet.addLayer(name, "Power", layerParams);
+                }
+                else  // is a vector
+                {
+                    layerParams.blobs.resize(1, scaleMat);
+                    id = dstNet.addLayer(name, "Scale", layerParams);
                 }
-                layerParams.set("scale", scale);
-
-                int id = dstNet.addLayer(name, "Power", layerParams);
                 layer_id[name] = id;
 
                 Pin inp0 = parsePin(layer.input(0));
@@ -1006,12 +1058,13 @@ void TFImporter::populateNet(Net dstNet)
         }
         else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
                  type == "Relu" || type == "Elu" || type == "Softmax" ||
-                 type == "Identity")
+                 type == "Identity" || type == "Relu6")
         {
             std::string dnnType = type;
             if (type == "Abs") dnnType = "AbsVal";
             else if (type == "Tanh") dnnType = "TanH";
             else if (type == "Relu") dnnType = "ReLU";
+            else if (type == "Relu6") dnnType = "ReLU6";
             else if (type == "Elu") dnnType = "ELU";
 
             int id = dstNet.addLayer(name, dnnType, layerParams);