[moco_tf] add get_float_attr (#3973)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 26 Jun 2019 00:48:16 +0000 (09:48 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 26 Jun 2019 00:48:16 +0000 (09:48 +0900)
This will add get_float_attr method for accessing attribute float value for tensorflow node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco-tf/src/Convert.cpp
contrib/moco-tf/src/Convert.h

index 03e5119..8232f2d 100644 (file)
@@ -90,6 +90,14 @@ int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_na
   return attr.i();
 }
 
+float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
+{
+  assert(has_attr(node, attr_name));
+  const auto &attr = node.attr().at(attr_name);
+  assert(attr.value_case() == tensorflow::AttrValue::kF);
+  return attr.f();
+}
+
 loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
 {
   switch (dtype)
index 921fcff..83ba3c1 100644 (file)
@@ -42,6 +42,7 @@ const tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef &
                                                      const std::string &attr_name);
 const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name);
 int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name);
+float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name);
 
 loco::DataType as_loco_datatype(const tensorflow::DataType dtype);