}
}
+void convert_graph(::onnx::GraphProto &onnx_graph_proto, loco::Graph *graph)
+{
+ auto nodes = stdex::make_unique<moco::onnx::SymbolTable>();
+ auto input_names = stdex::make_unique<moco::onnx::SymbolTable>();
+
+ moco::onnx::GraphBuilderContext gb_context(graph, nodes.get(), input_names.get());
+
+ // Building a loco graph
+ // 1. Convert all the nodes to loco::Node
+ // 2. Connect inputs: set all node input(from a string) to actual node object
+ // 3. Set graph input
+ // 4. Create loco::Push node and set input and set graph output
+
+ // 1. Convert all the nodes to loco::Node
+ for (const auto &n : onnx_graph_proto.node())
+ {
+ if (const auto *graph_builder = moco::onnx::GraphBuilderRegistry::get().lookup(n.op_type()))
+ {
+ if (!graph_builder->validate(n))
+ {
+ throw std::runtime_error{"Invalid operator: " + n.op_type()};
+ }
+
+ graph_builder->build(n, &gb_context);
+ }
+ else
+ {
+ throw std::runtime_error{"Not supported: " + n.op_type()};
+ }
+ }
+
+ // 2. Connect inputs: set all node input(from a string) to actual node object
+ loco::Graph::NodeContext *graph_nodes = graph->nodes();
+ uint32_t nodes_count = graph_nodes->size();
+ for (uint32_t n = 0; n < nodes_count; ++n)
+ {
+ loco::Node *node_to_set = graph_nodes->at(n);
+
+ unsigned int names_size = input_names->size(node_to_set);
+ assert(names_size == node_to_set->arity());
+ for (unsigned int i = 0; i < names_size; ++i)
+ {
+ auto input_name = input_names->name(node_to_set, i);
+ auto node = nodes->node(input_name);
+
+ // TODO use enum instead of dynamic_cast
+ loco::Forward *forward_node = dynamic_cast<loco::Forward *>(node_to_set);
+ if (forward_node != nullptr)
+ forward_node->input(node);
+ }
+ }
+
+ // 3. Set graph input
+ for (int i = 0; i < onnx_graph_proto.input_size(); i++)
+ {
+ auto input = onnx_graph_proto.input(i).name();
+
+ auto node = nodes->node(input);
+ assert(node != nullptr);
+
+ auto graph_input = graph->inputs()->create();
+
+ loco::Pull *pull_node = dynamic_cast<loco::Pull *>(node);
+ assert(pull_node != nullptr);
+
+ graph_input->name(input);
+ graph_input->node(pull_node);
+ }
+
+ // 4. Create loco::Push node and set input and set graph output
+ for (int i = 0; i < onnx_graph_proto.output_size(); i++)
+ {
+ auto output = onnx_graph_proto.output(i).name();
+
+ auto output_node = nodes->node(output);
+ assert(output_node);
+
+ // create loco::Push for output of graph
+ auto push_node = graph->nodes()->create<loco::Push>();
+ push_node->from(output_node); // set input of Push to output node
+
+ // set the graph output name and node object
+ auto graph_output = graph->outputs()->create();
+ graph_output->name(output);
+ graph_output->node(push_node);
+ }
+}
+
} // namespace
namespace moco
std::unique_ptr<loco::Graph> Frontend::load(const char *modelfile, FileType type) const
{
::onnx::ModelProto onnx_model_proto;
+ ::onnx::GraphProto onnx_graph_proto;
load_onnx(modelfile, type, onnx_model_proto);
- // TODO convert onnx graph to loco graph
+ onnx_graph_proto = onnx_model_proto.graph();
+
+ auto graph = loco::make_graph();
+
+ convert_graph(onnx_graph_proto, graph.get());
- throw std::runtime_error{"NYI"};
+ return std::move(graph);
}
} // namespace onnx