[locomotiv] Session's graph input/output size (#3531)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Mon, 20 May 2019 07:38:23 +0000 (16:38 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 20 May 2019 07:38:23 +0000 (16:38 +0900)
* [locomotiv] Session's graph input/output size

This commit introduces Session::input_size() and Session::output_size()
which are getter for graph input and output size it possesses.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
* Make them const method

* Use const variable for magic number

contrib/locomotiv/README.md
contrib/locomotiv/include/locomotiv/Session.h
contrib/locomotiv/src/Session.test.cpp

index e0d452c..1c71c3b 100644 (file)
@@ -16,7 +16,7 @@ loco::Graph *graph;
 // Open interpreter session
 locomotiv::Session sess(graph);
 
-for (uint32_t i = 0; i < graph->inputs()->size(); ++i)
+for (uint32_t i = 0; i < s.input_size(); ++i)
 {
   Buffer<type> buffer;
   // ... building buffer ...
index 4fb913d..45e9e28 100644 (file)
@@ -41,6 +41,9 @@ public:
   /// @brief Free all node annotations of the graph assigned by this Session
   ~Session();
 
+  /// @brief Get number of graph inputs held by this Session
+  uint32_t input_size() const { return _graph->inputs()->size(); }
+
   /**
    * @brief Set graph input at specific index by NodeData.
    *
@@ -58,6 +61,9 @@ public:
    */
   void infer();
 
+  /// @brief Get number of graph outputs held by this Session
+  uint32_t output_size() const { return _graph->outputs()->size(); }
+
   /**
    * @brief Get output of graph as NodeData
    *
index 0f6cff8..a375490 100644 (file)
@@ -29,6 +29,34 @@ using nncc::core::ADT::tensor::Shape;
 using nncc::core::ADT::tensor::LexicalLayout;
 using nncc::core::ADT::tensor::make_buffer;
 
+TEST(Session, graph_IO_size)
+{
+  // Make graph
+  auto g = loco::make_graph();
+
+  // inputs
+  const uint32_t inputs = 2;
+  for (uint32_t i = 0; i < inputs; ++i)
+  {
+    auto pull = g->nodes()->create<loco::Pull>();
+    g->inputs()->create()->node(pull);
+  }
+
+  // outputs
+  const uint32_t outputs = 3;
+  for (uint32_t o = 0; o < outputs; ++o)
+  {
+    auto push = g->nodes()->create<loco::Push>();
+    g->outputs()->create()->node(push);
+  }
+
+  // Make session
+  locomotiv::Session s(g.get());
+
+  ASSERT_EQ(s.input_size(), inputs);
+  ASSERT_EQ(s.output_size(), outputs);
+}
+
 TEST(Session, set_input)
 {
   // Make graph