let StridedSlice layer support const input
authorzoom <zhongwl2018@mail.sustech.edu.cn>
Wed, 12 Oct 2022 03:47:31 +0000 (11:47 +0800)
committerzoom <zhongwl2018@mail.sustech.edu.cn>
Wed, 12 Oct 2022 03:50:44 +0000 (11:50 +0800)
modules/dnn/src/tensorflow/tf_importer.cpp

index 96e0af9..44e70ba 100644 (file)
@@ -1706,6 +1706,19 @@ void TFImporter::parseStridedSlice(tensorflow::GraphDef& net, const tensorflow::
     layerParams.set("begin", DictValue::arrayInt((int*)begins.data, begins.total()));
     layerParams.set("end", DictValue::arrayInt((int*)ends.data, ends.total()));
 
+    Pin inp = parsePin(layer.input(0));
+    if (value_id.find(inp.name) != value_id.end())
+    {
+        // The input is constant.
+        LayerParams lp;
+        lp.name = inp.name;
+        lp.type = "Const";
+        lp.blobs.push_back(getTensorContent(getConstBlob(layer, value_id, 0)));
+
+        int constInpId = dstNet.addLayer(lp.name, lp.type, lp);
+        layer_id[lp.name] = constInpId;
+    }
+
     int id = dstNet.addLayer(name, "Slice", layerParams);
     layer_id[name] = id;