return attr.tensor();
}
+const ::tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef &node,
+ const std::string &attr_name)
+{
+ assert(has_attr(node, attr_name));
+ const auto &attr = node.attr().at(attr_name);
+ assert(attr.value_case() == tensorflow::AttrValue::kList);
+ return attr.list();
+}
+
+const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
+{
+ assert(has_attr(node, attr_name));
+ const auto &attr = node.attr().at(attr_name);
+ assert(attr.value_case() == tensorflow::AttrValue::kS);
+ return attr.s();
+}
+
loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
{
switch (dtype)
throw std::runtime_error{"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(dtype)};
}
+const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
+{
+ auto layout = get_string_attr(node, attr_name);
+
+ if (layout == "NHWC")
+ return moco::tf::DataLayout::NHWC;
+ else if (layout == "NCHW")
+ return moco::tf::DataLayout::NCHW;
+ else
+ throw std::runtime_error("unknown data layout");
+}
+
} // namespace tf
} // namespace moco
const std::string &attr_name);
const tensorflow::TensorProto &get_tensor_attr(const tensorflow::NodeDef &node,
const std::string &attr_name);
+const tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef &node,
+ const std::string &attr_name);
+const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name);
loco::DataType as_loco_datatype(const tensorflow::DataType dtype);
+/**
+ * @brief Class to represent Tensorflow "data_format" attr.
+ */
+enum class DataLayout
+{
+ NHWC,
+ NCHW,
+};
+
+const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name);
+
} // namespace tf
} // namespace moco