return it != data_layouts.end() ? it->second : DATA_LAYOUT_UNKNOWN;
}
+static
+bool hasAllOnes(const Mat &inputs, int startPos, int endPos)
+{
+ CV_CheckLE(inputs.dims, 2, "");
+ CV_CheckGE(startPos, 0, "");
+ CV_CheckLE(startPos, endPos, "");
+ CV_CheckLT((size_t)endPos, inputs.total(), "");
+
+ for (int i = startPos; i < endPos; i++)
+ {
+ if (inputs.at<int>(i) != 1 || inputs.at<int>(i)!= -1)
+ return false;
+ }
+ return true;
+}
+
void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
{
if (hasLayerAttr(layer, "strides"))
std::map<String, Mat> sharedWeights;
std::map<String, int> layer_id;
+
+private:
+ void addPermuteLayer(const int* order, const std::string& permName, Pin& inpId);
};
TFImporter::TFImporter(Net& net, const char *model, const char *config)
CV_LOG_DEBUG(NULL, "DNN/TF: ===================== Import completed =====================");
}
+void TFImporter::addPermuteLayer(const int* order, const std::string& permName, Pin& inpId)
+{
+ LayerParams permLP;
+ permLP.set("order", DictValue::arrayInt<const int*>(order, 4));
+ CV_Assert(layer_id.find(permName) == layer_id.end());
+ int permId = dstNet.addLayer(permName, "Permute", permLP);
+ layer_id[permName] = permId;
+ connect(layer_id, dstNet, inpId, permId, 0);
+ inpId = Pin(permName);
+}
+
void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
{
tensorflow::NodeDef layer = layer_;
if (value_id.find(layer.input(1)) != value_id.end())
{
Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
- if (newShape.total() == 4)
+ int newShapeSize = newShape.total();
+ bool hasSwap = false;
+ if (newShapeSize == 4 && hasAllOnes(newShape, 0, 2))
{
// NHWC->NCHW
std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
+ hasSwap = true;
}
if (inpLayout == DATA_LAYOUT_NHWC)
{
- if (newShape.total() != 4 || newShape.at<int>(1) == 1)
+ if (newShapeSize >= 2 || newShape.at<int>(1) == 1)
{
- LayerParams permLP;
int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC.
- permLP.set("order", DictValue::arrayInt<int*>(order, 4));
-
- std::string permName = name + "/nchw";
- CV_Assert(layer_id.find(permName) == layer_id.end());
- int permId = dstNet.addLayer(permName, "Permute", permLP);
- layer_id[permName] = permId;
- connect(layer_id, dstNet, inpId, permId, 0);
- inpId = Pin(permName);
- inpLayout = DATA_LAYOUT_NCHW;
+ addPermuteLayer(order, name + "/nhwc", inpId);
+ if (newShapeSize < 4)
+ {
+ inpLayout = DATA_LAYOUT_NCHW;
+ }
+ else
+ {
+ inpLayout = DATA_LAYOUT_NHWC;
+ }
}
}
- layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShape.total()));
+ layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShapeSize));
int id = dstNet.addLayer(name, "Reshape", layerParams);
layer_id[name] = id;
// one input only
connect(layer_id, dstNet, inpId, id, 0);
- data_layouts[name] = newShape.total() == 2 ? DATA_LAYOUT_PLANAR : inpLayout;
+ inpId = Pin(name);
+
+ if ((inpLayout == DATA_LAYOUT_NHWC || inpLayout == DATA_LAYOUT_UNKNOWN || inpLayout == DATA_LAYOUT_PLANAR) &&
+ newShapeSize == 4 && !hasSwap)
+ {
+ int order[] = {0, 3, 1, 2}; // Transform back to OpenCV's NCHW.
+ addPermuteLayer(order, name + "/nchw", inpId);
+ inpLayout = DATA_LAYOUT_NCHW;
+ }
+
+ data_layouts[name] = newShapeSize == 2 ? DATA_LAYOUT_PLANAR : inpLayout;
}
else
{