Fix TF Split layer
authorLiubov Batanina <piccione-mail@yandex.ru>
Wed, 17 Jul 2019 12:50:50 +0000 (15:50 +0300)
committerLiubov Batanina <piccione-mail@yandex.ru>
Wed, 17 Jul 2019 12:50:50 +0000 (15:50 +0300)
modules/dnn/include/opencv2/dnn/all_layers.hpp
modules/dnn/src/layers/slice_layer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index 868a8f0..75cba09 100644 (file)
@@ -366,6 +366,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
          */
         std::vector<std::vector<Range> > sliceRanges;
         int axis;
+        int num_split;
 
         static Ptr<SliceLayer> create(const LayerParams &params);
     };
index 73d6a30..7640d46 100644 (file)
@@ -61,6 +61,7 @@ public:
     {
         setParamsFrom(params);
         axis = params.get<int>("axis", 1);
+        num_split = params.get<int>("num_split", 0);
         if (params.has("slice_point"))
         {
             CV_Assert(!params.has("begin") && !params.has("size") && !params.has("end"));
@@ -141,9 +142,10 @@ public:
         else  // Divide input blob on equal parts by axis.
         {
             CV_Assert(0 <= axis && axis < inpShape.size());
-            CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0);
-            inpShape[axis] /= requiredOutputs;
-            outputs.resize(requiredOutputs, inpShape);
+            int splits = num_split ? num_split : requiredOutputs;
+            CV_Assert(splits > 0 && inpShape[axis] % splits == 0);
+            inpShape[axis] /= splits;
+            outputs.resize(splits, inpShape);
         }
         return false;
     }
index c38b250..e546d9e 100644 (file)
@@ -1410,6 +1410,9 @@ void TFImporter::populateNet(Net dstNet)
                 axis = toNCHW(axis);
             layerParams.set("axis", axis);
 
+            if (hasLayerAttr(layer, "num_split"))
+                layerParams.set("num_split", getLayerAttr(layer, "num_split").i());
+
             int id = dstNet.addLayer(name, "Slice", layerParams);
             layer_id[name] = id;
 
index 2dae678..0357b8e 100644 (file)
@@ -350,6 +350,11 @@ TEST_P(Test_TensorFlow_layers, l2_normalize_3d)
     runTensorFlowNet("l2_normalize_3d");
 }
 
+TEST_P(Test_TensorFlow_layers, Split)
+{
+    runTensorFlowNet("split");
+}
+
 class Test_TensorFlow_nets : public DNNTestLayer {};
 
 TEST_P(Test_TensorFlow_nets, MobileNet_SSD)