Fix TensorFlow split layer
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 2 Oct 2017 19:44:42 +0000 (22:44 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 2 Oct 2017 19:44:42 +0000 (22:44 +0300)
modules/dnn/src/layers/slice_layer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index b824a06..c7db0f4 100644 (file)
@@ -116,7 +116,7 @@ public:
         }
         else  // Divide input blob on equal parts by axis.
         {
-            CV_Assert(0 < axis && axis < inpShape.size());
+            CV_Assert(0 <= axis && axis < inpShape.size());
             CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0);
             inpShape[axis] /= requiredOutputs;
             outputs.resize(requiredOutputs, inpShape);
index 55d9fd2..3065a1f 100644 (file)
@@ -866,8 +866,6 @@ void TFImporter::populateNet(Net dstNet)
             CV_Assert(layer.input_size() == 2);
             // num_split
             // 1st blob is dims tensor
-            layerParams.set("slice_point", DictValue::arrayReal((double*)0, 0));
-
             int axis = getConstBlob(layer, value_id, 0).int_val().Get(0);
             layerParams.set("axis", toNCHW[axis]);
 
index f382507..3f89dd4 100644 (file)
@@ -170,4 +170,9 @@ TEST(Test_TensorFlow, lstm)
     runTensorFlowNet("lstm");
 }
 
+TEST(Test_TensorFlow, split)
+{
+    runTensorFlowNet("split_equals");
+}
+
 }