#include <fstream>
#include <algorithm>
#include <string>
+#include <queue>
#include "tf_graph_simplifier.hpp"
#endif
}
}
-// If all inputs of specific layer have the same data layout we can say that
-// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
-static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::map<String, int>& data_layouts)
+static int getDataLayout(const tensorflow::NodeDef& layer)
{
if (hasLayerAttr(layer, "data_format"))
{
else
CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
}
+ return DATA_LAYOUT_UNKNOWN;
+}
+
+static inline std::string getNodeName(const std::string& tensorName)
+{
+ return tensorName.substr(0, tensorName.rfind(':'));
+}
+
+// If all inputs of specific layer have the same data layout we can say that
+// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
+static int predictOutputDataLayout(const tensorflow::GraphDef& net,
+ const tensorflow::NodeDef& layer,
+ const std::map<String, int>& data_layouts)
+{
+ int layout = getDataLayout(layer);
+ if (layout != DATA_LAYOUT_UNKNOWN)
+ return layout;
// Determine layout by layer's inputs
- int layout = DATA_LAYOUT_UNKNOWN;
std::map<String, int>::const_iterator it;
for (int i = 0, n = layer.input_size(); i < n; ++i)
{
- it = data_layouts.find(layer.input(i).substr(0, layer.input(i).rfind(':')));
+ it = data_layouts.find(getNodeName(layer.input(i)));
if (it != data_layouts.end())
{
- if (it->second == DATA_LAYOUT_UNKNOWN)
- return DATA_LAYOUT_UNKNOWN;
- else if (it->second != layout)
+ if (layout != DATA_LAYOUT_UNKNOWN)
{
- if (layout == DATA_LAYOUT_UNKNOWN)
- layout = it->second;
- else
+ if (it->second != layout && it->second != DATA_LAYOUT_UNKNOWN)
return DATA_LAYOUT_UNKNOWN;
}
+ else
+ layout = it->second;
}
}
- return layout;
+
+ if (layout != DATA_LAYOUT_UNKNOWN)
+ return layout;
+
+ // Determine layout by layer's consumers recursively.
+ it = data_layouts.find(layer.name());
+ CV_Assert(it != data_layouts.end());
+ return it->second;
}
void TFImporter::populateNet(Net dstNet)
int layersSize = net.node_size();
std::map<String, int> data_layouts;
+ // Pre-fill data layouts where they are set explicitly.
+ // Assuming that nodes are in topological order
+ for (int i = net.node_size() - 1; i >= 0; --i)
+ {
+ const tensorflow::NodeDef& layer = net.node(i);
+ std::string name = layer.name();
+
+ int layout = getDataLayout(layer);
+ std::map<String, int>::iterator it = data_layouts.find(name);
+ if (it != data_layouts.end())
+ {
+ if (layout != DATA_LAYOUT_UNKNOWN)
+ {
+ if (it->second == DATA_LAYOUT_UNKNOWN)
+ it->second = layout;
+ else if (it->second != layout)
+ {
+ it->second = DATA_LAYOUT_UNKNOWN;
+ layout = DATA_LAYOUT_UNKNOWN;
+ }
+ }
+ else
+ layout = it->second;
+ }
+ else
+ data_layouts[name] = layout;
+
+ // Specify input layers to have the same data layout.
+ for (int j = 0; j < layer.input_size(); ++j)
+ {
+ name = getNodeName(layer.input(j));
+ it = data_layouts.find(name);
+ if (it != data_layouts.end())
+ {
+ if (layout != DATA_LAYOUT_UNKNOWN)
+ {
+ if (it->second == DATA_LAYOUT_UNKNOWN)
+ it->second = layout;
+ else if (it->second != layout)
+ it->second = DATA_LAYOUT_UNKNOWN;
+ }
+ }
+ else
+ data_layouts[name] = layout;
+ }
+ }
// find all Const layers for params
std::map<String, int> value_id;
if(layers_to_ignore.find(name) != layers_to_ignore.end())
continue;
- data_layouts[name] = predictOutputDataLayout(layer, data_layouts);
+ int predictedLayout = predictOutputDataLayout(net, layer, data_layouts);
+ data_layouts[name] = predictedLayout;
if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
{
// one input only
connect(layer_id, dstNet, inpId, id, 0);
+ data_layouts[name] = DATA_LAYOUT_UNKNOWN;
}
else if (type == "Flatten" || type == "Squeeze")
{
{
int axisId = (type == "Concat" ? 0 : layer.input_size() - 1);
int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0);
- layerParams.set("axis", 0 <= axis && axis < 4 ? toNCHW(axis) : axis);
+
+ if (data_layouts[name] == DATA_LAYOUT_NHWC)
+ axis = toNCHW(axis);
+ layerParams.set("axis", axis);
int id = dstNet.addLayer(name, "Concat", layerParams);
layer_id[name] = id;