return attr.b();
}
+std::vector<int64_t> as_int64_list(const tensorflow::AttrValue_ListValue &lv)
+{
+ std::vector<int64_t> vi;
+ int isize = lv.i_size();
+
+ vi.resize(isize);
+ for (int i = 0; i < isize; ++i)
+ vi[i] = lv.i(i);
+
+ return vi;
+}
+
loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
{
switch (dtype)
throw std::runtime_error{"Unsupported tensorflow dtype: " + tensorflow::DataType_Name(dtype)};
}
+const DataLayout as_DataLayout(const std::string &layout)
+{
+ if (layout == "NHWC")
+ return moco::tf::DataLayout::NHWC;
+ else if (layout == "NCHW")
+ return moco::tf::DataLayout::NCHW;
+ else
+ throw std::runtime_error("unknown data layout");
+}
+
const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name)
{
auto layout = get_string_attr(node, attr_name);
+ // TODO use as_DataLayout()
if (layout == "NHWC")
return moco::tf::DataLayout::NHWC;
else if (layout == "NCHW")
float get_float_attr(const tensorflow::NodeDef &node, const std::string &attr_name);
bool get_bool_attr(const tensorflow::NodeDef &node, const std::string &attr_name);
+std::vector<int64_t> as_int64_list(const tensorflow::AttrValue_ListValue &lv);
loco::DataType as_loco_datatype(const tensorflow::DataType dtype);
/**
NCHW,
};
+const DataLayout as_DataLayout(const std::string &data_layout);
+
const DataLayout get_data_layout(const tensorflow::NodeDef &node, const std::string &attr_name);
} // namespace tf
shape->add_dim()->set_size(2);
shape->add_dim()->set_size(4);
shape->add_dim()->set_size(8);
+
+ auto *list = (*node.mutable_attr())["list_1"].mutable_list();
+ list->add_i(1);
+ list->add_i(20);
+ list->add_i(1LL << 40);
+ list->add_i(-(1LL << 40));
}
} // namespace
ASSERT_EQ(convert, "HELLO WORLD!!!");
}
+
+TEST(moco_Convert, attr_ilist)
+{
+ tensorflow::NodeDef node;
+ prepare_test_node(node);
+
+ const auto &p_list = moco::tf::get_list_attr(node, "list_1");
+ auto i_list = moco::tf::as_int64_list(p_list);
+ ASSERT_EQ(i_list.size(), 4);
+ ASSERT_EQ(i_list.at(0), 1);
+ ASSERT_EQ(i_list.at(1), 20);
+ ASSERT_EQ(i_list.at(2), 1LL << 40);
+ ASSERT_EQ(i_list.at(3), -(1LL << 40));
+}
+
+TEST(moco_Convert, to_DataLayout)
+{
+ ASSERT_EQ(moco::tf::as_DataLayout("NHWC"), moco::tf::DataLayout::NHWC);
+ ASSERT_EQ(moco::tf::as_DataLayout("NCHW"), moco::tf::DataLayout::NCHW);
+}