[locomotiv] Implement Session::get_output() (#3423)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Thu, 9 May 2019 07:40:33 +0000 (16:40 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 9 May 2019 07:40:33 +0000 (16:40 +0900)
This commit implements how session get its output. Related test for this
feature added as well.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
contrib/locomotiv/include/locomotiv/Session.h
contrib/locomotiv/src/Session.cpp
contrib/locomotiv/src/Session.test.cpp

index a70f3f1..6bca285 100644 (file)
@@ -55,6 +55,11 @@ public:
    */
   void infer();
 
+  /**
+   * @brief Get output of graph as NodeData
+   *
+   * @note May return nullptr, for example, when graph output not yet calculated
+   */
   const NodeData *get_output(uint32_t index);
 
 private:
index cd4b52b..22ba44c 100644 (file)
@@ -72,4 +72,10 @@ void Session::infer()
   }
 }
 
+const NodeData *Session::get_output(uint32_t index)
+{
+  auto output_node = _graph->outputs()->at(index)->node();
+  return annot_data(output_node);
+}
+
 } // namespace locomotiv
index fbeec4d..0f6cff8 100644 (file)
@@ -139,6 +139,10 @@ TEST(Session, inference_identity)
     // Multiple run is possible
     ASSERT_NO_THROW(s.infer());
 
-    // TODO get and check output
+    auto output_data = s.get_output(0);
+    ASSERT_NE(output_data, nullptr);
+    ASSERT_EQ(output_data->dtype(), loco::DataType::FLOAT32);
+    ASSERT_EQ(*(output_data->shape()), Shape{1});
+    ASSERT_EQ(output_data->as_f32_bufptr()->at(Index{0}), 3.14f);
   }
 }