From 0a296e08d3e214c16cce0dc319d3d40485c6d489 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 20 May 2019 16:38:23 +0900 Subject: [PATCH] [locomotiv] Session's graph input/output size (#3531) * [locomotiv] Session's graph input/output size This commit introduces Session::input_size() and Session::output_size() which are getter for graph input and output size it possesses. Signed-off-by: Cheongyo Bahk * Make them const method * Use const variable for magic number --- contrib/locomotiv/README.md | 2 +- contrib/locomotiv/include/locomotiv/Session.h | 6 ++++++ contrib/locomotiv/src/Session.test.cpp | 28 +++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/contrib/locomotiv/README.md b/contrib/locomotiv/README.md index e0d452c..1c71c3b 100644 --- a/contrib/locomotiv/README.md +++ b/contrib/locomotiv/README.md @@ -16,7 +16,7 @@ loco::Graph *graph; // Open interpreter session locomotiv::Session sess(graph); -for (uint32_t i = 0; i < graph->inputs()->size(); ++i) +for (uint32_t i = 0; i < s.input_size(); ++i) { Buffer buffer; // ... building buffer ... diff --git a/contrib/locomotiv/include/locomotiv/Session.h b/contrib/locomotiv/include/locomotiv/Session.h index 4fb913d..45e9e28 100644 --- a/contrib/locomotiv/include/locomotiv/Session.h +++ b/contrib/locomotiv/include/locomotiv/Session.h @@ -41,6 +41,9 @@ public: /// @brief Free all node annotations of the graph assigned by this Session ~Session(); + /// @brief Get number of graph inputs held by this Session + uint32_t input_size() const { return _graph->inputs()->size(); } + /** * @brief Set graph input at specific index by NodeData. * @@ -58,6 +61,9 @@ public: */ void infer(); + /// @brief Get number of graph outputs held by this Session + uint32_t output_size() const { return _graph->outputs()->size(); } + /** * @brief Get output of graph as NodeData * diff --git a/contrib/locomotiv/src/Session.test.cpp b/contrib/locomotiv/src/Session.test.cpp index 0f6cff8..a375490 100644 --- a/contrib/locomotiv/src/Session.test.cpp +++ b/contrib/locomotiv/src/Session.test.cpp @@ -29,6 +29,34 @@ using nncc::core::ADT::tensor::Shape; using nncc::core::ADT::tensor::LexicalLayout; using nncc::core::ADT::tensor::make_buffer; +TEST(Session, graph_IO_size) +{ + // Make graph + auto g = loco::make_graph(); + + // inputs + const uint32_t inputs = 2; + for (uint32_t i = 0; i < inputs; ++i) + { + auto pull = g->nodes()->create(); + g->inputs()->create()->node(pull); + } + + // outputs + const uint32_t outputs = 3; + for (uint32_t o = 0; o < outputs; ++o) + { + auto push = g->nodes()->create(); + g->outputs()->create()->node(push); + } + + // Make session + locomotiv::Session s(g.get()); + + ASSERT_EQ(s.input_size(), inputs); + ASSERT_EQ(s.output_size(), outputs); +} + TEST(Session, set_input) { // Make graph -- 2.7.4