From 857e849bdfede75b46f56dddb6358cddbcffe45b Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 14 Nov 2019 18:04:44 +0900 Subject: [PATCH] [moco] Introduce find_node_byname (#8938) This will intrdocue moco GraphHelper with find_node_byname() that quries node by name. Signed-off-by: SaeHie Park --- compiler/moco/import/include/moco/GraphHelper.h | 59 +++++++++++++++++++++++++ compiler/moco/import/src/Importer.test.cpp | 5 +-- 2 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 compiler/moco/import/include/moco/GraphHelper.h diff --git a/compiler/moco/import/include/moco/GraphHelper.h b/compiler/moco/import/include/moco/GraphHelper.h new file mode 100644 index 0000000..fad62af --- /dev/null +++ b/compiler/moco/import/include/moco/GraphHelper.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MOCO_GRAPH_HELPER_H__ +#define __MOCO_GRAPH_HELPER_H__ + +#include + +#include + +namespace moco +{ + +/** + * @brief find_node_byname() will return a node with type T with given name + * in graph g + * + * @note this uses simple linear search, but can speed up with better + * algorithms when needed. + */ +template T *find_node_byname(loco::Graph *g, const char *name) +{ + T *first_node = nullptr; + loco::Graph::NodeContext *nodes = g->nodes(); + uint32_t count = nodes->size(); + + for (uint32_t i = 0; i < count; ++i) + { + auto tfnode = dynamic_cast(nodes->at(i)); + if (tfnode != nullptr) + { + if (tfnode->name() == name) + { + // if tfnode is NOT type of T then return will be nullptr + // this is OK cause the user wanted to get type T but it isn't + return dynamic_cast(tfnode); + } + } + } + + return nullptr; +} + +} // namespace moco + +#endif // __MOCO_GRAPH_HELPER_H__ diff --git a/compiler/moco/import/src/Importer.test.cpp b/compiler/moco/import/src/Importer.test.cpp index 6573e8d..2387339 100644 --- a/compiler/moco/import/src/Importer.test.cpp +++ b/compiler/moco/import/src/Importer.test.cpp @@ -15,6 +15,7 @@ */ #include "moco/Importer.h" +#include "moco/GraphHelper.h" #include @@ -210,9 +211,6 @@ TEST(TensorFlowImport, find_node_by_name) auto tfidentity = find_first_node_bytype(graph.get()); ASSERT_NE(tfidentity, nullptr); ASSERT_NE(tfidentity->input(), nullptr); - -// TODO make this test pass -#if 0 ASSERT_STREQ(tfidentity->name().c_str(), "output/identity"); auto query_node = moco::find_node_byname(graph.get(), "Foo/w_min"); @@ -222,5 +220,4 @@ TEST(TensorFlowImport, find_node_by_name) auto query_node2 = moco::find_node_byname(graph.get(), "Foo/w_max"); ASSERT_NE(query_node2, nullptr); ASSERT_STREQ(query_node2->name().c_str(), "Foo/w_max"); -#endif } -- 2.7.4