From 54604aeea33dec2a03827bb8e627a260eaeac2bf Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 26 Jun 2019 09:48:16 +0900 Subject: [PATCH] [moco_tf] add get_float_attr (#3973) This will add get_float_attr method for accessing attribute float value for tensorflow node Signed-off-by: SaeHie Park --- contrib/moco-tf/src/Convert.cpp | 8 ++++++++ contrib/moco-tf/src/Convert.h | 1 + 2 files changed, 9 insertions(+) diff --git a/contrib/moco-tf/src/Convert.cpp b/contrib/moco-tf/src/Convert.cpp index 03e5119..8232f2d 100644 --- a/contrib/moco-tf/src/Convert.cpp +++ b/contrib/moco-tf/src/Convert.cpp @@ -90,6 +90,14 @@ int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_na return attr.i(); } +float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name) +{ + assert(has_attr(node, attr_name)); + const auto &attr = node.attr().at(attr_name); + assert(attr.value_case() == tensorflow::AttrValue::kF); + return attr.f(); +} + loco::DataType as_loco_datatype(const tensorflow::DataType dtype) { switch (dtype) diff --git a/contrib/moco-tf/src/Convert.h b/contrib/moco-tf/src/Convert.h index 921fcff..83ba3c1 100644 --- a/contrib/moco-tf/src/Convert.h +++ b/contrib/moco-tf/src/Convert.h @@ -42,6 +42,7 @@ const tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef & const std::string &attr_name); const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name); int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name); +float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name); loco::DataType as_loco_datatype(const tensorflow::DataType dtype); -- 2.7.4