From b57727477f0ffb67a26cc86affb5107dc1c249f7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 20 May 2019 18:29:12 +0900 Subject: [PATCH] [moco] Add UpdateQueue member (#3537) This will add UpdateQueue attribute in GraphBuilderContext Signed-off-by: SaeHie Park --- contrib/moco/lib/frontend/tf/src/Frontend.cpp | 3 ++- contrib/moco/lib/frontend/tf/src/GraphBuilderContext.h | 8 ++++++-- contrib/moco/lib/frontend/tf/src/GraphBuilderContext.test.cpp | 4 +++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/contrib/moco/lib/frontend/tf/src/Frontend.cpp b/contrib/moco/lib/frontend/tf/src/Frontend.cpp index 700ad4a..2863303 100644 --- a/contrib/moco/lib/frontend/tf/src/Frontend.cpp +++ b/contrib/moco/lib/frontend/tf/src/Frontend.cpp @@ -103,8 +103,9 @@ void convert_graph(const moco::tf::ModelSignature &signature, tensorflow::GraphD { auto nodes = stdex::make_unique(); auto input_names = stdex::make_unique(); + auto updates = stdex::make_unique(); - moco::tf::GraphBuilderContext gb_context(graph, nodes.get(), input_names.get()); + moco::tf::GraphBuilderContext gb_context(graph, nodes.get(), input_names.get(), updates.get()); // Building a loco graph // 1. Convert all the nodes to loco::Node diff --git a/contrib/moco/lib/frontend/tf/src/GraphBuilderContext.h b/contrib/moco/lib/frontend/tf/src/GraphBuilderContext.h index b49c578..64aee5b 100644 --- a/contrib/moco/lib/frontend/tf/src/GraphBuilderContext.h +++ b/contrib/moco/lib/frontend/tf/src/GraphBuilderContext.h @@ -67,6 +67,7 @@ private: MapNameNode_t _namenode; MapNodeNames_t _nodenames; + MapNameNode_t _table; }; /** @@ -110,8 +111,9 @@ private: class GraphBuilderContext { public: - GraphBuilderContext(loco::Graph *g, SymbolTable *nodes, SymbolTable *input_names) - : _g(g), _nodes(nodes), _input_names(input_names) + GraphBuilderContext(loco::Graph *g, SymbolTable *nodes, SymbolTable *input_names, + UpdateQueue *updates) + : _g(g), _nodes(nodes), _input_names(input_names), _updates(updates) { // DO NOTHING } @@ -123,11 +125,13 @@ public: loco::Graph *graph() { return _g; } SymbolTable *nodes() { return _nodes; } SymbolTable *input_names() { return _input_names; } + UpdateQueue *updates() { return _updates; } private: loco::Graph *_g; SymbolTable *_nodes; SymbolTable *_input_names; + UpdateQueue *_updates; }; } // namespace tf diff --git a/contrib/moco/lib/frontend/tf/src/GraphBuilderContext.test.cpp b/contrib/moco/lib/frontend/tf/src/GraphBuilderContext.test.cpp index 51a5ebe..d5c502d 100644 --- a/contrib/moco/lib/frontend/tf/src/GraphBuilderContext.test.cpp +++ b/contrib/moco/lib/frontend/tf/src/GraphBuilderContext.test.cpp @@ -25,12 +25,14 @@ TEST(GraphBuilderContext, ctor) auto graph = loco::make_graph(); moco::tf::SymbolTable nodes; moco::tf::SymbolTable input_names; + moco::tf::UpdateQueue updates; - moco::tf::GraphBuilderContext context(graph.get(), &nodes, &input_names); + moco::tf::GraphBuilderContext context(graph.get(), &nodes, &input_names, &updates); ASSERT_EQ(context.graph(), graph.get()); ASSERT_EQ(context.nodes(), &nodes); ASSERT_EQ(context.input_names(), &input_names); + ASSERT_EQ(context.updates(), &updates); } TEST(SymbolTable, node_name) -- 2.7.4