static int toNCHW[] = {0, 2, 3, 1};
+// This values are used to indicate layer output's data layout where it's possible.
+enum DataLayout
+{
+ DATA_LAYOUT_NHWC,
+ DATA_LAYOUT_NCHW,
+ DATA_LAYOUT_UNKNOWN
+};
+
typedef std::vector<std::pair<String, int> > StrIntVector;
struct Pin
}
}
+// If all inputs of specific layer have the same data layout we can say that
+// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
+static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::map<String, int>& data_layouts)
+{
+ int layout = DATA_LAYOUT_UNKNOWN;
+ std::map<String, int>::const_iterator it;
+ for (int i = 0, n = layer.input_size(); i < n; ++i)
+ {
+ it = data_layouts.find(layer.input(i));
+ if (it != data_layouts.end())
+ {
+ if (it->second == DATA_LAYOUT_UNKNOWN)
+ return DATA_LAYOUT_UNKNOWN;
+ else if (it->second != layout)
+ {
+ if (layout == DATA_LAYOUT_UNKNOWN)
+ layout = it->second;
+ else
+ return DATA_LAYOUT_UNKNOWN;
+ }
+ }
+ }
+ return layout;
+}
+
void TFImporter::populateNet(Net dstNet)
{
RemoveIdentityOps(netBin);
int layersSize = net.node_size();
+ std::map<String, int> data_layouts;
+
// find all Const layers for params
std::map<String, int> value_id;
addConstNodes(netBin, value_id, layers_to_ignore);
if(layers_to_ignore.find(name) != layers_to_ignore.end())
continue;
+ data_layouts[name] = predictOutputDataLayout(layer, data_layouts);
+
if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
{
// The first node of dilated convolution subgraph.
// one input only
connect(layer_id, dstNet, parsePin(input), id, 0);
+
+ if (hasLayerAttr(layer, "data_format"))
+ {
+ std::string format = getLayerAttr(layer, "data_format").s();
+ if (format == "NHWC")
+ data_layouts[name] = DATA_LAYOUT_NHWC;
+ else if (format == "NCHW")
+ data_layouts[name] = DATA_LAYOUT_NCHW;
+ else
+ CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
+ }
+ else
+ data_layouts[name] = DATA_LAYOUT_NHWC;
}
else if (type == "BiasAdd" || type == "Add")
{
// one input only
int input_blob_index = kernel_blob_index == 0 ? 1 : 0;
connect(layer_id, dstNet, parsePin(layer.input(input_blob_index)), id, 0);
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
}
else if (type == "Reshape")
{
- layerParams.set("dim", parseDims(getConstBlob(layer, value_id, 1)));
+ Pin inpId = parsePin(layer.input(0));
+ DictValue newShape = parseDims(getConstBlob(layer, value_id, 1));
+
+ if (newShape.size() != 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+ {
+ LayerParams permLP;
+ int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC.
+ permLP.set("order", DictValue::arrayInt<int*>(order, 4));
+
+ std::string permName = name + "/nchw";
+ CV_Assert(layer_id.find(permName) == layer_id.end());
+ int permId = dstNet.addLayer(permName, "Permute", permLP);
+ layer_id[permName] = permId;
+ connect(layer_id, dstNet, inpId, permId, 0);
+ inpId = Pin(permName);
+ }
+ layerParams.set("dim", newShape);
int id = dstNet.addLayer(name, "Reshape", layerParams);
layer_id[name] = id;
// one input only
- connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+ connect(layer_id, dstNet, inpId, id, 0);
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
}
else if (type == "Flatten")
{
+ Pin inpId = parsePin(layer.input(0));
+ if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+ {
+ LayerParams permLP;
+ int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC.
+ permLP.set("order", DictValue::arrayInt<int*>(order, 4));
+
+ std::string permName = name + "/nchw";
+ CV_Assert(layer_id.find(permName) == layer_id.end());
+ int permId = dstNet.addLayer(permName, "Permute", permLP);
+ layer_id[permName] = permId;
+ connect(layer_id, dstNet, inpId, permId, 0);
+ inpId = Pin(permName);
+ }
int id = dstNet.addLayer(name, "Flatten", layerParams);
layer_id[name] = id;
- connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+ connect(layer_id, dstNet, inpId, id, 0);
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
}
else if (type == "Transpose")
{
int* permData = (int*)perm.data;
if (perm.total() == 4)
{
- for (int i = 0; i < 4; ++i)
- permData[i] = toNCHW[permData[i]];
+ // Only NHWC <-> NCHW permutations are allowed. OpenCV is always
+ // keep NCHW layout this way.
+ if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+ {
+ if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
+ {
+ // in TensorFlow: NHWC->NCHW
+ // in OpenCV: NCHW->NCHW
+ data_layouts[name] = DATA_LAYOUT_NCHW;
+ }
+ else if (permData[0] == 0 && permData[1] == 1 && permData[2] == 2 && permData[3] == 3)
+ {
+ // in TensorFlow: NHWC->NHWC
+ // in OpenCV: NCHW->NCHW
+ data_layouts[name] = DATA_LAYOUT_NHWC;
+ }
+ else
+ CV_Assert(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
+ }
+ else if (data_layouts[layer.input(0)] == DATA_LAYOUT_NCHW)
+ {
+ if (permData[0] == 0 && permData[1] == 2 && permData[2] == 3 && permData[3] == 1)
+ {
+ // in TensorFlow: NCHW->NHWC
+ // in OpenCV: NCHW->NCHW
+ data_layouts[name] = DATA_LAYOUT_NHWC;
+ }
+ else if (permData[0] == 0 && permData[1] == 1 && permData[2] == 2 && permData[3] == 3)
+ {
+ // in TensorFlow: NCHW->NCHW
+ // in OpenCV: NCHW->NCHW
+ data_layouts[name] = DATA_LAYOUT_NCHW;
+ }
+ else
+ CV_Assert(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
+ }
+ int id = dstNet.addLayer(name, "Identity", layerParams);
+ layer_id[name] = id;
+ connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
- layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
+ else
+ {
+ layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
- int id = dstNet.addLayer(name, "Permute", layerParams);
- layer_id[name] = id;
+ int id = dstNet.addLayer(name, "Permute", layerParams);
+ layer_id[name] = id;
- // one input only
- connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+ // one input only
+ connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
+ }
}
else if (type == "Const")
{
// one input only
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0);
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
}
else if (type == "ResizeNearestNeighbor")
{
layer_id[name] = id;
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1);
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
}
else if (type == "DetectionOutput")
{
layer_id[name] = id;
for (int i = 0; i < 3; ++i)
connect(layer_id, dstNet, parsePin(layer.input(i)), id, i);
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
}
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
type == "Relu" || type == "Elu" || type == "Softmax" ||