[locomotiv] Session for subgraph (#4348)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Fri, 19 Jul 2019 01:50:35 +0000 (10:50 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 19 Jul 2019 01:50:35 +0000 (10:50 +0900)
* [locomotiv] Session for subgraph

This commit supports locomotiv Session to run inference on subgraph
only, by setting user defined custom outputs. Compatibility to existing
'full graph' Session is preserved.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
* Review fix: get output by node

* Review fix: efficient ctor call

* Add warning

compiler/locomotiv/include/locomotiv/Session.h
compiler/locomotiv/src/Session.cpp
compiler/locomotiv/src/Session.test.cpp

index f82cad9..018ef76 100644 (file)
@@ -22,6 +22,7 @@
 #include <loco.h>
 
 #include <memory>
+#include <vector>
 
 namespace locomotiv
 {
@@ -33,7 +34,24 @@ class Session final
 {
 public:
   Session() = delete;
-  Session(loco::Graph *g) : _graph(g)
+
+  /// @brief  Make Session for graph with graph outputs themselves
+  Session(loco::Graph *g) : Session(g, loco::output_nodes(g))
+  {
+    // DO NOTHING
+  }
+
+  /**
+   * @brief  Make Session for graph with selective custom outputs. Only
+   *         subgraph to calculate given outputs would be executed.
+   *
+   * @note  Set required inputs for given outputs, or inference may fail.
+   * @note  custom_outputs don't need to be graph output, but can be any nodes
+   *        in the middle of the graph.
+   * @warn  This approach may fail in case of graph with control flow
+   */
+  Session(loco::Graph *g, const std::vector<loco::Node *> &custom_outputs)
+      : _graph(g), _outputs(custom_outputs)
   {
     // DO NOTHING
   }
@@ -62,7 +80,7 @@ public:
   void infer();
 
   /// @brief Get number of graph outputs held by this Session
-  uint32_t output_size() const { return _graph->outputs()->size(); }
+  uint32_t output_size() const { return _outputs.size(); }
 
   /**
    * @brief Get output of graph as NodeData
@@ -71,8 +89,11 @@ public:
    */
   const NodeData *get_output(uint32_t index);
 
+  const loco::Node *get_output_node(uint32_t index) { return _outputs.at(index); }
+
 private:
   loco::Graph *_graph;
+  std::vector<loco::Node *> _outputs;
 };
 
 } // namespace locomotiv
index cde38dd..bff632d 100644 (file)
@@ -73,13 +73,7 @@ void Session::set_input(uint32_t index, std::unique_ptr<NodeData> &&data)
 
 void Session::infer()
 {
-  std::vector<loco::Node *> output_vec;
-  for (uint32_t i = 0; i < _graph->outputs()->size(); ++i)
-  {
-    output_vec.push_back(_graph->outputs()->at(i)->node());
-  }
-
-  auto schedules = loco::postorder_traversal(output_vec);
+  auto schedules = loco::postorder_traversal(_outputs);
 
   for (auto node : schedules)
   {
@@ -91,7 +85,7 @@ const NodeData *Session::get_output(uint32_t index)
 {
   assert(index < output_size());
 
-  auto output_node = _graph->outputs()->at(index)->node();
+  auto output_node = _outputs.at(index);
   return annot_data(output_node);
 }
 
index 1eeddc7..2619f03 100644 (file)
@@ -175,6 +175,121 @@ TEST(Session, inference_identity)
   }
 }
 
+TEST(Session, session_for_subgraph)
+{
+  /*
+   * Make following graph:
+   *   ConstGen_1 --
+   *                \
+   *   ConstGen_2 --- TensorConcat_1 --- TensorConcat_3 --- Push
+   *                                   /
+   *   ConstGen_3 --- TensorConcat_2 --
+   *                /
+   *   ConstGen_4 --
+   */
+  auto g = loco::make_graph();
+
+  auto c1 = g->nodes()->create<loco::ConstGen>();
+  auto c2 = g->nodes()->create<loco::ConstGen>();
+  auto c3 = g->nodes()->create<loco::ConstGen>();
+  auto c4 = g->nodes()->create<loco::ConstGen>();
+
+  c1->dtype(loco::DataType::FLOAT32);
+  c2->dtype(loco::DataType::FLOAT32);
+  c3->dtype(loco::DataType::FLOAT32);
+  c4->dtype(loco::DataType::FLOAT32);
+  c1->shape({1});
+  c2->shape({1});
+  c3->shape({1});
+  c4->shape({1});
+  c1->size<loco::DataType::FLOAT32>(1);
+  c2->size<loco::DataType::FLOAT32>(1);
+  c3->size<loco::DataType::FLOAT32>(1);
+  c4->size<loco::DataType::FLOAT32>(1);
+
+  c1->at<loco::DataType::FLOAT32>(0) = 0.1f;
+  c2->at<loco::DataType::FLOAT32>(0) = 0.2f;
+  c3->at<loco::DataType::FLOAT32>(0) = 0.3f;
+  c4->at<loco::DataType::FLOAT32>(0) = 0.4f;
+
+  auto t1 = g->nodes()->create<loco::TensorConcat>();
+  auto t2 = g->nodes()->create<loco::TensorConcat>();
+  auto t3 = g->nodes()->create<loco::TensorConcat>();
+
+  // Note: default concat axis is 0
+  t1->lhs(c1);
+  t1->rhs(c2);
+  t2->lhs(c3);
+  t2->rhs(c4);
+  t3->lhs(t1);
+  t3->rhs(t2);
+
+  auto push = g->nodes()->create<loco::Push>();
+  push->from(t3);
+
+  {
+    // Session to get t1 only
+    locomotiv::Session s(g.get(), {t1});
+    ASSERT_EQ(s.output_size(), 1);
+    ASSERT_EQ(s.get_output_node(0), dynamic_cast<loco::Node *>(t1));
+
+    s.infer();
+
+    auto t1_data = s.get_output(0);
+    ASSERT_NE(t1_data, nullptr);
+    ASSERT_EQ(*(t1_data->shape()), Shape{2});
+
+    auto t1_buf = t1_data->as_f32_bufptr();
+    ASSERT_EQ(t1_buf->at({0}), 0.1f);
+    ASSERT_EQ(t1_buf->at({1}), 0.2f);
+  }
+
+  {
+    // Session to get t2 only
+    locomotiv::Session s(g.get(), {t2});
+    ASSERT_EQ(s.output_size(), 1);
+    ASSERT_EQ(s.get_output_node(0), dynamic_cast<loco::Node *>(t2));
+
+    s.infer();
+
+    auto t2_data = s.get_output(0);
+    ASSERT_NE(t2_data, nullptr);
+    ASSERT_EQ(*(t2_data->shape()), Shape{2});
+
+    auto t2_buf = t2_data->as_f32_bufptr();
+    ASSERT_EQ(t2_buf->at({0}), 0.3f);
+    ASSERT_EQ(t2_buf->at({1}), 0.4f);
+  }
+
+  {
+    // Session to get t2 and push
+    locomotiv::Session s(g.get(), {t2, push});
+    ASSERT_EQ(s.output_size(), 2);
+    ASSERT_EQ(s.get_output_node(0), dynamic_cast<loco::Node *>(t2));
+    ASSERT_EQ(s.get_output_node(1), dynamic_cast<loco::Node *>(push));
+
+    s.infer();
+
+    auto t2_data = s.get_output(0);
+    ASSERT_NE(t2_data, nullptr);
+    ASSERT_EQ(*(t2_data->shape()), Shape{2});
+
+    auto t2_buf = t2_data->as_f32_bufptr();
+    ASSERT_EQ(t2_buf->at({0}), 0.3f);
+    ASSERT_EQ(t2_buf->at({1}), 0.4f);
+
+    auto push_data = s.get_output(1);
+    ASSERT_NE(push_data, nullptr);
+    ASSERT_EQ(*(push_data->shape()), Shape{4});
+
+    auto push_buf = push_data->as_f32_bufptr();
+    ASSERT_EQ(push_buf->at({0}), 0.1f);
+    ASSERT_EQ(push_buf->at({1}), 0.2f);
+    ASSERT_EQ(push_buf->at({2}), 0.3f);
+    ASSERT_EQ(push_buf->at({3}), 0.4f);
+  }
+}
+
 // Below here is internal test for locomotiv, i.e. not public usage of locomotiv
 #include "NodeDataImpl.h"
 #include "NodeDomain.h"