[moco/tf] Converting functions for MaxPool2D (#3549)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Tue, 21 May 2019 04:12:43 +0000 (13:12 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 21 May 2019 04:12:43 +0000 (13:12 +0900)
This commit adds converting funtion for MaxPool2D.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
contrib/moco/lib/frontend/tf/src/Convert.cpp
contrib/moco/lib/frontend/tf/src/Convert.h

index dd73fa8..4b28d1d 100644 (file)
@@ -57,6 +57,23 @@ const tensorflow::TensorProto &get_tensor_attr(const tensorflow::NodeDef &node,
   return attr.tensor();
 }
 
+const ::tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef &node,
+                                                       const std::string &attr_name)
+{
+  assert(has_attr(node, attr_name));
+  const auto &attr = node.attr().at(attr_name);
+  assert(attr.value_case() == tensorflow::AttrValue::kList);
+  return attr.list();
+}
+
+const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
+{
+  assert(has_attr(node, attr_name));
+  const auto &attr = node.attr().at(attr_name);
+  assert(attr.value_case() == tensorflow::AttrValue::kS);
+  return attr.s();
+}
+
 loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
 {
   switch (dtype)
@@ -79,5 +96,17 @@ loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
   throw std::runtime_error{"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(dtype)};
 }
 
+const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
+{
+  auto layout = get_string_attr(node, attr_name);
+
+  if (layout == "NHWC")
+    return moco::tf::DataLayout::NHWC;
+  else if (layout == "NCHW")
+    return moco::tf::DataLayout::NCHW;
+  else
+    throw std::runtime_error("unknown data layout");
+}
+
 } // namespace tf
 } // namespace moco
index 19990f6..8c21ee2 100644 (file)
@@ -37,9 +37,23 @@ const tensorflow::TensorShapeProto &get_shape_attr(const tensorflow::NodeDef &no
                                                    const std::string &attr_name);
 const tensorflow::TensorProto &get_tensor_attr(const tensorflow::NodeDef &node,
                                                const std::string &attr_name);
+const tensorflow::AttrValue_ListValue &get_list_attr(const tensorflow::NodeDef &node,
+                                                     const std::string &attr_name);
+const std::string &get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name);
 
 loco::DataType as_loco_datatype(const tensorflow::DataType dtype);
 
+/**
+ * @brief Class to represent Tensorflow "data_format" attr.
+ */
+enum class DataLayout
+{
+  NHWC,
+  NCHW,
+};
+
+const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name);
+
 } // namespace tf
 } // namespace moco