From 962caff80751b212bf0a9d64d182690656cb1ff8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 4 Jul 2019 08:52:39 +0900 Subject: [PATCH] [moco/tf] Add more methods in Convert (#4070) * [moco/tf] Add more methods in Convert This will add more methods in Convert for import Signed-off-by: SaeHie Park * use ASSERT_EQ --- contrib/moco-tf/src/Convert.cpp | 23 +++++++++++++++++++++++ contrib/moco-tf/src/Convert.h | 3 +++ contrib/moco-tf/src/Convert.test.cpp | 26 ++++++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/contrib/moco-tf/src/Convert.cpp b/contrib/moco-tf/src/Convert.cpp index 6f90bb4..1cdf749 100644 --- a/contrib/moco-tf/src/Convert.cpp +++ b/contrib/moco-tf/src/Convert.cpp @@ -106,6 +106,18 @@ bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name return attr.b(); } +std::vector as_int64_list(const tensorflow::AttrValue_ListValue &lv) +{ + std::vector vi; + int isize = lv.i_size(); + + vi.resize(isize); + for (int i = 0; i < isize; ++i) + vi[i] = lv.i(i); + + return vi; +} + loco::DataType as_loco_datatype(const tensorflow::DataType dtype) { switch (dtype) @@ -128,10 +140,21 @@ loco::DataType as_loco_datatype(const tensorflow::DataType dtype) throw std::runtime_error{"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(dtype)}; } +const DataLayout as_DataLayout(const std::string &layout) +{ + if (layout == "NHWC") + return moco::tf::DataLayout::NHWC; + else if (layout == "NCHW") + return moco::tf::DataLayout::NCHW; + else + throw std::runtime_error("unknown data layout"); +} + const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name) { auto layout = get_string_attr(node, attr_name); + // TODO use as_DataLayout() if (layout == "NHWC") return moco::tf::DataLayout::NHWC; else if (layout == "NCHW") diff --git a/contrib/moco-tf/src/Convert.h b/contrib/moco-tf/src/Convert.h index b980144..d4ce5aa 100644 --- a/contrib/moco-tf/src/Convert.h +++ b/contrib/moco-tf/src/Convert.h @@ -45,6 +45,7 @@ int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_na float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name); bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name); +std::vector as_int64_list(const tensorflow::AttrValue_ListValue &lv); loco::DataType as_loco_datatype(const tensorflow::DataType dtype); /** @@ -56,6 +57,8 @@ enum class DataLayout NCHW, }; +const DataLayout as_DataLayout(const std::string &data_layout); + const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name); } // namespace tf diff --git a/contrib/moco-tf/src/Convert.test.cpp b/contrib/moco-tf/src/Convert.test.cpp index 4f3b6f9..81768cd 100644 --- a/contrib/moco-tf/src/Convert.test.cpp +++ b/contrib/moco-tf/src/Convert.test.cpp @@ -39,6 +39,12 @@ void prepare_test_node(tensorflow::NodeDef &node) shape->add_dim()->set_size(2); shape->add_dim()->set_size(4); shape->add_dim()->set_size(8); + + auto *list = (*node.mutable_attr())["list_1"].mutable_list(); + list->add_i(1); + list->add_i(20); + list->add_i(1LL << 40); + list->add_i(-(1LL << 40)); } } // namespace @@ -85,3 +91,23 @@ TEST(moco_Convert, string_toupper) ASSERT_EQ(convert, "HELLO WORLD!!!"); } + +TEST(moco_Convert, attr_ilist) +{ + tensorflow::NodeDef node; + prepare_test_node(node); + + const auto &p_list = moco::tf::get_list_attr(node, "list_1"); + auto i_list = moco::tf::as_int64_list(p_list); + ASSERT_EQ(i_list.size(), 4); + ASSERT_EQ(i_list.at(0), 1); + ASSERT_EQ(i_list.at(1), 20); + ASSERT_EQ(i_list.at(2), 1LL << 40); + ASSERT_EQ(i_list.at(3), -(1LL << 40)); +} + +TEST(moco_Convert, to_DataLayout) +{ + ASSERT_EQ(moco::tf::as_DataLayout("NHWC"), moco::tf::DataLayout::NHWC); + ASSERT_EQ(moco::tf::as_DataLayout("NCHW"), moco::tf::DataLayout::NCHW); +} -- 2.7.4