Fix slice layer from TensorFlow
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 31 Jan 2018 16:12:37 +0000 (19:12 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 31 Jan 2018 16:12:37 +0000 (19:12 +0300)
modules/dnn/src/layers/slice_layer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index 18758b9..171e4f7 100644 (file)
@@ -91,7 +91,7 @@ public:
                 {
                     int size = sizeOrEnd;
                     CV_Assert(size == -1 || size > 0);  // -1 value means range [start, axis_size).
-                    sliceRanges[0][i].end = start > 0 ? start + size : -1;  // We'll finalize a negative value later.
+                    sliceRanges[0][i].end = size > 0 ? (start + size) : -1;  // We'll finalize a negative value later.
                 }
                 else
                 {
index 7e03b59..ccb028b 100644 (file)
@@ -1119,21 +1119,21 @@ void TFImporter::populateNet(Net dstNet)
             // input: "Slice/begin"
             // input: "Slice/size"
             CV_Assert(layer.input_size() == 3);
+            Mat begins = getTensorContent(getConstBlob(layer, value_id, 1));
+            Mat sizes = getTensorContent(getConstBlob(layer, value_id, 2));
+            CV_Assert(!begins.empty(), !sizes.empty(), begins.type() == CV_32SC1,
+                      sizes.type() == CV_32SC1);
 
-            const tensorflow::TensorProto begins = getConstBlob(layer, value_id, 1);
-            const tensorflow::TensorProto sizes = getConstBlob(layer, value_id, 2);
-            std::string beginsData = begins.tensor_content();
-            std::string sizesData = sizes.tensor_content();
-            CV_Assert(begins.dtype() == tensorflow::DT_INT32);
-            CV_Assert(sizes.dtype() == tensorflow::DT_INT32);
-            CV_Assert(!beginsData.empty());
-            CV_Assert(!sizesData.empty());
-            CV_Assert(beginsData.size() == sizesData.size());
-
-            layerParams.set("begin", DictValue::arrayInt((int*)beginsData.c_str(),
-                                                         beginsData.size() / 4));
-            layerParams.set("size", DictValue::arrayInt((int*)sizesData.c_str(),
-                                                        sizesData.size() / 4));
+            if (begins.total() == 4)
+            {
+                // Perhabs, we have an NHWC order. Swap it to NCHW.
+                std::swap(*begins.ptr<int32_t>(0, 2), *begins.ptr<int32_t>(0, 3));
+                std::swap(*begins.ptr<int32_t>(0, 1), *begins.ptr<int32_t>(0, 2));
+                std::swap(*sizes.ptr<int32_t>(0, 2), *sizes.ptr<int32_t>(0, 3));
+                std::swap(*sizes.ptr<int32_t>(0, 1), *sizes.ptr<int32_t>(0, 2));
+            }
+            layerParams.set("begin", DictValue::arrayInt((int*)begins.data, begins.total()));
+            layerParams.set("size", DictValue::arrayInt((int*)sizes.data, sizes.total()));
 
             int id = dstNet.addLayer(name, "Slice", layerParams);
             layer_id[name] = id;
index 0b4dc64..bfc7443 100644 (file)
@@ -301,6 +301,11 @@ TEST(Test_TensorFlow, resize_nearest_neighbor)
     runTensorFlowNet("resize_nearest_neighbor");
 }
 
+TEST(Test_TensorFlow, slice)
+{
+    runTensorFlowNet("slice_4d");
+}
+
 TEST(Test_TensorFlow, memory_read)
 {
     double l1 = 1e-5;