From 93ba90266c8c7c24b103464f957dab80ca916c11 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 3 Jun 2019 07:22:55 +0900 Subject: [PATCH] [moco/tf] Add get_int_attr (#3643) This will add get_int_attr() to read int type attribute value from TensorFlow NodeDef Signed-off-by: SaeHie Park --- contrib/moco/lib/frontend/tf/src/Convert.cpp | 8 ++++++++ contrib/moco/lib/frontend/tf/src/Convert.h | 1 + 2 files changed, 9 insertions(+) diff --git a/contrib/moco/lib/frontend/tf/src/Convert.cpp b/contrib/moco/lib/frontend/tf/src/Convert.cpp index 4b28d1d..cc3c89d 100644 --- a/contrib/moco/lib/frontend/tf/src/Convert.cpp +++ b/contrib/moco/lib/frontend/tf/src/Convert.cpp @@ -74,6 +74,14 @@ const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::s return attr.s(); } +int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name) +{ + assert(has_attr(node, attr_name)); + const auto &attr = node.attr().at(attr_name); + assert(attr.value_case() == tensorflow::AttrValue::kI); + return attr.i(); +} + loco::DataType as_loco_datatype(const tensorflow::DataType dtype) { switch (dtype) diff --git a/contrib/moco/lib/frontend/tf/src/Convert.h b/contrib/moco/lib/frontend/tf/src/Convert.h index 8c21ee2..a9d9606 100644 --- a/contrib/moco/lib/frontend/tf/src/Convert.h +++ b/contrib/moco/lib/frontend/tf/src/Convert.h @@ -40,6 +40,7 @@ const tensorflow::TensorProto &get_tensor_attr(const tensorflow::NodeDef &node, const tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name); const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name); +int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name); loco::DataType as_loco_datatype(const tensorflow::DataType dtype); -- 2.7.4