[moco/tf] Add more methods in Convert (#4070)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 3 Jul 2019 23:52:39 +0000 (08:52 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 3 Jul 2019 23:52:39 +0000 (08:52 +0900)
* [moco/tf] Add more methods in Convert

This will add more methods in Convert for import

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* use ASSERT_EQ

contrib/moco-tf/src/Convert.cpp
contrib/moco-tf/src/Convert.h
contrib/moco-tf/src/Convert.test.cpp

index 6f90bb4..1cdf749 100644 (file)
@@ -106,6 +106,18 @@ bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name
   return attr.b();
 }
 
+std::vector<int64_t> as_int64_list(const tensorflow::AttrValue_ListValue &lv)
+{
+  std::vector<int64_t> vi;
+  int isize = lv.i_size();
+
+  vi.resize(isize);
+  for (int i = 0; i < isize; ++i)
+    vi[i] = lv.i(i);
+
+  return vi;
+}
+
 loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
 {
   switch (dtype)
@@ -128,10 +140,21 @@ loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
   throw std::runtime_error{"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(dtype)};
 }
 
+const DataLayout as_DataLayout(const std::string &layout)
+{
+  if (layout == "NHWC")
+    return moco::tf::DataLayout::NHWC;
+  else if (layout == "NCHW")
+    return moco::tf::DataLayout::NCHW;
+  else
+    throw std::runtime_error("unknown data layout");
+}
+
 const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
 {
   auto layout = get_string_attr(node, attr_name);
 
+  // TODO use as_DataLayout()
   if (layout == "NHWC")
     return moco::tf::DataLayout::NHWC;
   else if (layout == "NCHW")
index b980144..d4ce5aa 100644 (file)
@@ -45,6 +45,7 @@ int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_na
 float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name);
 bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name);
 
+std::vector<int64_t> as_int64_list(const tensorflow::AttrValue_ListValue &lv);
 loco::DataType as_loco_datatype(const tensorflow::DataType dtype);
 
 /**
@@ -56,6 +57,8 @@ enum class DataLayout
   NCHW,
 };
 
+const DataLayout as_DataLayout(const std::string &data_layout);
+
 const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name);
 
 } // namespace tf
index 4f3b6f9..81768cd 100644 (file)
@@ -39,6 +39,12 @@ void prepare_test_node(tensorflow::NodeDef &node)
   shape->add_dim()->set_size(2);
   shape->add_dim()->set_size(4);
   shape->add_dim()->set_size(8);
+
+  auto *list = (*node.mutable_attr())["list_1"].mutable_list();
+  list->add_i(1);
+  list->add_i(20);
+  list->add_i(1LL << 40);
+  list->add_i(-(1LL << 40));
 }
 
 } // namespace
@@ -85,3 +91,23 @@ TEST(moco_Convert, string_toupper)
 
   ASSERT_EQ(convert, "HELLO WORLD!!!");
 }
+
+TEST(moco_Convert, attr_ilist)
+{
+  tensorflow::NodeDef node;
+  prepare_test_node(node);
+
+  const auto &p_list = moco::tf::get_list_attr(node, "list_1");
+  auto i_list = moco::tf::as_int64_list(p_list);
+  ASSERT_EQ(i_list.size(), 4);
+  ASSERT_EQ(i_list.at(0), 1);
+  ASSERT_EQ(i_list.at(1), 20);
+  ASSERT_EQ(i_list.at(2), 1LL << 40);
+  ASSERT_EQ(i_list.at(3), -(1LL << 40));
+}
+
+TEST(moco_Convert, to_DataLayout)
+{
+  ASSERT_EQ(moco::tf::as_DataLayout("NHWC"), moco::tf::DataLayout::NHWC);
+  ASSERT_EQ(moco::tf::as_DataLayout("NCHW"), moco::tf::DataLayout::NCHW);
+}