return layer.attr().at(name);
}
+static int getDataLayout(const tensorflow::NodeDef& layer)
+{
+ if (hasLayerAttr(layer, "data_format"))
+ {
+ std::string format = getLayerAttr(layer, "data_format").s();
+ if (format == "NHWC" || format == "channels_last")
+ return DATA_LAYOUT_NHWC;
+ else if (format == "NCHW" || format == "channels_first")
+ return DATA_LAYOUT_NCHW;
+ else
+ CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
+ }
+ return DATA_LAYOUT_UNKNOWN;
+}
+
void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
{
if (hasLayerAttr(layer, "strides"))
{
const tensorflow::AttrValue& val = getLayerAttr(layer, "strides");
+ int dimX, dimY, dimC;
+ int layout = getDataLayout(layer);
+ if (layout == DATA_LAYOUT_NCHW)
+ {
+ dimC = 1; dimY = 2; dimX = 3;
+ }
+ else
+ {
+ dimY = 1; dimX = 2; dimC = 3;
+ }
if (val.list().i_size() != 4 ||
- val.list().i(0) != 1 || val.list().i(3) != 1)
+ val.list().i(0) != 1 || val.list().i(dimC) != 1)
CV_Error(Error::StsError, "Unsupported strides");
- layerParams.set("stride_h", static_cast<int>(val.list().i(1)));
- layerParams.set("stride_w", static_cast<int>(val.list().i(2)));
+ layerParams.set("stride_h", static_cast<int>(val.list().i(dimY)));
+ layerParams.set("stride_w", static_cast<int>(val.list().i(dimX)));
}
}
if (hasLayerAttr(layer, "ksize"))
{
const tensorflow::AttrValue& val = getLayerAttr(layer, "ksize");
+ int dimX, dimY, dimC;
+ int layout = getDataLayout(layer);
+ if (layout == DATA_LAYOUT_NCHW)
+ {
+ dimC = 1; dimY = 2; dimX = 3;
+ }
+ else
+ {
+ dimY = 1; dimX = 2; dimC = 3;
+ }
if (val.list().i_size() != 4 ||
- val.list().i(0) != 1 || val.list().i(3) != 1)
+ val.list().i(0) != 1 || val.list().i(dimC) != 1)
CV_Error(Error::StsError, "Unsupported ksize");
- layerParams.set("kernel_h", static_cast<int>(val.list().i(1)));
- layerParams.set("kernel_w", static_cast<int>(val.list().i(2)));
+ layerParams.set("kernel_h", static_cast<int>(val.list().i(dimY)));
+ layerParams.set("kernel_w", static_cast<int>(val.list().i(dimX)));
}
else
{
}
}
-static int getDataLayout(const tensorflow::NodeDef& layer)
-{
- if (hasLayerAttr(layer, "data_format"))
- {
- std::string format = getLayerAttr(layer, "data_format").s();
- if (format == "NHWC" || format == "channels_last")
- return DATA_LAYOUT_NHWC;
- else if (format == "NCHW" || format == "channels_first")
- return DATA_LAYOUT_NCHW;
- else
- CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
- }
- return DATA_LAYOUT_UNKNOWN;
-}
-
static inline std::string getNodeName(const std::string& tensorName)
{
return tensorName.substr(0, tensorName.rfind(':'));