[moco] Produce TFPush for graph output (#8647)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 31 Oct 2019 07:47:27 +0000 (16:47 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 31 Oct 2019 07:47:27 +0000 (16:47 +0900)
* [moco] Produce TFPush for graph output

This will update to produce TFPush for graph output node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* update comment

compiler/moco/import/src/Importer.cpp

index 01ff8c8..d73472c 100644 (file)
@@ -21,6 +21,7 @@
 #include "moco/Import/GraphBuilderRegistry.h"
 
 #include <moco/IR/Nodes/TFPlaceholder.h>
+#include <moco/IR/TFNode.h>
 
 #include <moco/Log.h>
 
@@ -48,7 +49,7 @@ void convert_graph(const moco::GraphBuilderSource &source, const moco::ModelSign
   // 1. Convert all the nodes to loco::Node
   // 2. Connect inputs: set all node input(from a string) to actual node object
   // 3. Set graph input
-  // 4. Create loco::Push node and set input and set graph output
+  // 4. Create moco::TFPush node and set graph output
 
   /**
    * @brief Prepare tensorflow::NodeDef search table from name
@@ -150,23 +151,24 @@ void convert_graph(const moco::GraphBuilderSource &source, const moco::ModelSign
   }
 
   /**
-   * @brief 4. Create loco::Push node and set graph input and output
+   * @brief 4. Create moco::TFPush node and set graph output
    */
   for (auto output : signature.outputs())
   {
     auto output_node = tensor_names->node(output);
     assert(output_node);
 
-    // create loco::Push for output of graph
-    auto push_node = graph->nodes()->create<loco::Push>();
-    push_node->from(output_node); // set input of Push to output node
+    // create moco::TFPush for output of graph
+    auto push_node = graph->nodes()->create<moco::TFPush>();
+    push_node->from(output_node); // set input of TFPush to output node
 
     // set the graph output name and node object
     auto graph_output = graph->outputs()->create();
     graph_output->name(output.nodeName());
+    push_node->index(graph_output->index());
+
     // TODO Support other types
     graph_output->dtype(loco::DataType::FLOAT32);
-    loco::link(graph_output, push_node);
   }
 
   // validate graph