loco::Graph *graph)
{
auto nodes = stdex::make_unique<moco::tf::SymbolTable>();
- auto input_names = stdex::make_unique<moco::tf::SymbolTable>();
auto updates = stdex::make_unique<moco::tf::UpdateQueue>();
- moco::tf::GraphBuilderContext gb_context(graph, nodes.get(), input_names.get(), updates.get());
+ moco::tf::GraphBuilderContext gb_context(graph, nodes.get(), updates.get());
// Building a loco graph
// 1. Convert all the nodes to loco::Node
return iter->second;
}
-void SymbolTable::list(loco::Node *node, const std::string &name)
-{
- MapNodeNames_t::iterator iter = _nodenames.find(node);
-
- if (iter == _nodenames.end())
- {
- // add a new vector for the first name
- _nodenames[node] = {name};
- return;
- }
-
- _nodenames[node].push_back(name);
-}
-
-unsigned SymbolTable::size(loco::Node *node)
-{
- MapNodeNames_t::iterator iter = _nodenames.find(node);
-
- if (iter == _nodenames.end())
- {
- return 0;
- }
-
- return iter->second.size();
-}
-
-const std::string &SymbolTable::name(loco::Node *node, unsigned index)
-{
- MapNodeNames_t::iterator iter = _nodenames.find(node);
-
- if (iter == _nodenames.end())
- {
- throw std::runtime_error{"Error: Cannot find names given node"};
- }
-
- if (index >= iter->second.size())
- {
- throw std::runtime_error{"Error: Invalid name index for given node"};
- }
-
- return iter->second.at(index);
-}
-
void UpdateQueue::enroll(std::unique_ptr<GraphUpdate> &&update)
{
_queue.push_back(std::move(update));
*/
loco::Node *node(const std::string &node_name) const;
- /**
- * @brief Registers multiple (appends) names for a node
- * Table is independent with registering with enroll()
- */
- void list(loco::Node *node, const std::string &name);
- /**
- * @brief Returns number of listed(registered) names for a node
- */
- unsigned size(loco::Node *node);
- /**
- * @brief Queries listed(registered) with node and index(from 0 to size-1)
- * Will throw runtime_error if node is not found or index is out of bounds
- */
- const std::string &name(loco::Node *node, unsigned index);
-
private:
using MapNameNode_t = std::map<std::string, loco::Node *>;
- using MapNodeNames_t = std::map<loco::Node *, std::vector<std::string>>;
- MapNameNode_t _namenode;
- MapNodeNames_t _nodenames;
MapNameNode_t _table;
};
class GraphBuilderContext
{
public:
- GraphBuilderContext(loco::Graph *g, SymbolTable *nodes, SymbolTable *input_names,
- UpdateQueue *updates)
- : _g(g), _nodes(nodes), _input_names(input_names), _updates(updates)
+ GraphBuilderContext(loco::Graph *g, SymbolTable *nodes, UpdateQueue *updates)
+ : _g(g), _nodes(nodes), _updates(updates)
{
// DO NOTHING
}
public:
loco::Graph *graph() { return _g; }
SymbolTable *nodes() { return _nodes; }
- SymbolTable *input_names() { return _input_names; }
UpdateQueue *updates() { return _updates; }
private:
loco::Graph *_g;
SymbolTable *_nodes;
- SymbolTable *_input_names;
UpdateQueue *_updates;
};
{
auto graph = loco::make_graph();
moco::tf::SymbolTable nodes;
- moco::tf::SymbolTable input_names;
moco::tf::UpdateQueue updates;
- moco::tf::GraphBuilderContext context(graph.get(), &nodes, &input_names, &updates);
+ moco::tf::GraphBuilderContext context(graph.get(), &nodes, &updates);
ASSERT_EQ(context.graph(), graph.get());
ASSERT_EQ(context.nodes(), &nodes);
- ASSERT_EQ(context.input_names(), &input_names);
ASSERT_EQ(context.updates(), &updates);
}
// unregistered name should throw
EXPECT_THROW(table.node(invalid), std::runtime_error);
}
-
-TEST(SymbolTable, name_node)
-{
- moco::tf::SymbolTable table;
- loco::Push push_node;
- std::string in1("in1");
- std::string in2("in2");
-
- ASSERT_EQ(table.size(&push_node), 0);
-
- table.list(&push_node, in1);
- table.list(&push_node, in2);
- unsigned size = table.size(&push_node);
- ASSERT_EQ(size, 2);
- ASSERT_EQ(in1, table.name(&push_node, 0));
- ASSERT_EQ(in2, table.name(&push_node, 1));
- EXPECT_THROW(table.name(&push_node, 2), std::runtime_error);
-}
loco::Graph *graph = context->graph();
SymbolTable *nodes = context->nodes();
- SymbolTable *input_names = context->input_names();
UpdateQueue *updates = context->updates();
// Create a "Forward" node for Identity
loco::Graph *graph = context->graph();
SymbolTable *nodes = context->nodes();
- SymbolTable *input_names = context->input_names();
UpdateQueue *updates = context->updates();
// Create a "ReLU" node for Relu