From: 윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Tue, 21 May 2019 04:12:43 +0000 (+0900) Subject: [moco/tf] Converting functions for MaxPool2D (#3549) X-Git-Tag: nncc_backup~548 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=18bb262cd336f1399edd424ba7cb4eb587200f42;p=platform%2Fcore%2Fml%2Fnnfw.git [moco/tf] Converting functions for MaxPool2D (#3549) This commit adds converting funtion for MaxPool2D. Signed-off-by: Hyun Sik Yoon --- diff --git a/contrib/moco/lib/frontend/tf/src/Convert.cpp b/contrib/moco/lib/frontend/tf/src/Convert.cpp index dd73fa8..4b28d1d 100644 --- a/contrib/moco/lib/frontend/tf/src/Convert.cpp +++ b/contrib/moco/lib/frontend/tf/src/Convert.cpp @@ -57,6 +57,23 @@ const tensorflow::TensorProto &get_tensor_attr(const tensorflow::NodeDef &node, return attr.tensor(); } +const ::tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef &node, + const std::string &attr_name) +{ + assert(has_attr(node, attr_name)); + const auto &attr = node.attr().at(attr_name); + assert(attr.value_case() == tensorflow::AttrValue::kList); + return attr.list(); +} + +const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name) +{ + assert(has_attr(node, attr_name)); + const auto &attr = node.attr().at(attr_name); + assert(attr.value_case() == tensorflow::AttrValue::kS); + return attr.s(); +} + loco::DataType as_loco_datatype(const tensorflow::DataType dtype) { switch (dtype) @@ -79,5 +96,17 @@ loco::DataType as_loco_datatype(const tensorflow::DataType dtype) throw std::runtime_error{"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(dtype)}; } +const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name) +{ + auto layout = get_string_attr(node, attr_name); + + if (layout == "NHWC") + return moco::tf::DataLayout::NHWC; + else if (layout == "NCHW") + return moco::tf::DataLayout::NCHW; + else + throw std::runtime_error("unknown data layout"); +} + } // namespace tf } // namespace moco diff --git a/contrib/moco/lib/frontend/tf/src/Convert.h b/contrib/moco/lib/frontend/tf/src/Convert.h index 19990f6..8c21ee2 100644 --- a/contrib/moco/lib/frontend/tf/src/Convert.h +++ b/contrib/moco/lib/frontend/tf/src/Convert.h @@ -37,9 +37,23 @@ const tensorflow::TensorShapeProto &get_shape_attr(const tensorflow::NodeDef &no const std::string &attr_name); const tensorflow::TensorProto &get_tensor_attr(const tensorflow::NodeDef &node, const std::string &attr_name); +const tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef &node, + const std::string &attr_name); +const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name); loco::DataType as_loco_datatype(const tensorflow::DataType dtype); +/** + * @brief Class to represent Tensorflow "data_format" attr. + */ +enum class DataLayout +{ + NHWC, + NCHW, +}; + +const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name); + } // namespace tf } // namespace moco