[moco/ONNX] Convert onnx graph to loco graph (#3368)
author남궁석/On-Device Lab(SR)/Engineer/삼성전자 <sk.namkoong@samsung.com>
Mon, 29 Apr 2019 02:06:41 +0000 (11:06 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 29 Apr 2019 02:06:41 +0000 (11:06 +0900)
* [moco/ONNX] Convert onnx graph to loco graph

This commit will enable converting ONNX graph to loco graph

Signed-off-by: Seok NamKoong <sk.namkoong@samsung.com>
* fix type name

contrib/moco/lib/frontend/onnx/src/Frontend.cpp

index 3251eee..e327b72 100644 (file)
@@ -71,6 +71,94 @@ void load_onnx(const std::string &path, moco::onnx::Frontend::FileType type,
   }
 }
 
+void convert_graph(::onnx::GraphProto &onnx_graph_proto, loco::Graph *graph)
+{
+  auto nodes = stdex::make_unique<moco::onnx::SymbolTable>();
+  auto input_names = stdex::make_unique<moco::onnx::SymbolTable>();
+
+  moco::onnx::GraphBuilderContext gb_context(graph, nodes.get(), input_names.get());
+
+  // Building a loco graph
+  // 1. Convert all the nodes to loco::Node
+  // 2. Connect inputs: set all node input(from a string) to actual node object
+  // 3. Set graph input
+  // 4. Create loco::Push node and set input and set graph output
+
+  // 1. Convert all the nodes to loco::Node
+  for (const auto &n : onnx_graph_proto.node())
+  {
+    if (const auto *graph_builder = moco::onnx::GraphBuilderRegistry::get().lookup(n.op_type()))
+    {
+      if (!graph_builder->validate(n))
+      {
+        throw std::runtime_error{"Invalid operator: " + n.op_type()};
+      }
+
+      graph_builder->build(n, &gb_context);
+    }
+    else
+    {
+      throw std::runtime_error{"Not supported: " + n.op_type()};
+    }
+  }
+
+  // 2. Connect inputs: set all node input(from a string) to actual node object
+  loco::Graph::NodeContext *graph_nodes = graph->nodes();
+  uint32_t nodes_count = graph_nodes->size();
+  for (uint32_t n = 0; n < nodes_count; ++n)
+  {
+    loco::Node *node_to_set = graph_nodes->at(n);
+
+    unsigned int names_size = input_names->size(node_to_set);
+    assert(names_size == node_to_set->arity());
+    for (unsigned int i = 0; i < names_size; ++i)
+    {
+      auto input_name = input_names->name(node_to_set, i);
+      auto node = nodes->node(input_name);
+
+      // TODO use enum instead of dynamic_cast
+      loco::Forward *forward_node = dynamic_cast<loco::Forward *>(node_to_set);
+      if (forward_node != nullptr)
+        forward_node->input(node);
+    }
+  }
+
+  // 3. Set graph input
+  for (int i = 0; i < onnx_graph_proto.input_size(); i++)
+  {
+    auto input = onnx_graph_proto.input(i).name();
+
+    auto node = nodes->node(input);
+    assert(node != nullptr);
+
+    auto graph_input = graph->inputs()->create();
+
+    loco::Pull *pull_node = dynamic_cast<loco::Pull *>(node);
+    assert(pull_node != nullptr);
+
+    graph_input->name(input);
+    graph_input->node(pull_node);
+  }
+
+  // 4. Create loco::Push node and set input and set graph output
+  for (int i = 0; i < onnx_graph_proto.output_size(); i++)
+  {
+    auto output = onnx_graph_proto.output(i).name();
+
+    auto output_node = nodes->node(output);
+    assert(output_node);
+
+    // create loco::Push for output of graph
+    auto push_node = graph->nodes()->create<loco::Push>();
+    push_node->from(output_node); // set input of Push to output node
+
+    // set the graph output name and node object
+    auto graph_output = graph->outputs()->create();
+    graph_output->name(output);
+    graph_output->node(push_node);
+  }
+}
+
 } // namespace
 
 namespace moco
@@ -86,12 +174,17 @@ Frontend::Frontend()
 std::unique_ptr<loco::Graph> Frontend::load(const char *modelfile, FileType type) const
 {
   ::onnx::ModelProto onnx_model_proto;
+  ::onnx::GraphProto onnx_graph_proto;
 
   load_onnx(modelfile, type, onnx_model_proto);
 
-  // TODO convert onnx graph to loco graph
+  onnx_graph_proto = onnx_model_proto.graph();
+
+  auto graph = loco::make_graph();
+
+  convert_graph(onnx_graph_proto, graph.get());
 
-  throw std::runtime_error{"NYI"};
+  return std::move(graph);
 }
 
 } // namespace onnx