From 8e1ff18c83c63687dc0f0b8bba6f79300469cfe5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 19 Jul 2019 10:50:35 +0900 Subject: [PATCH] [locomotiv] Session for subgraph (#4348) * [locomotiv] Session for subgraph This commit supports locomotiv Session to run inference on subgraph only, by setting user defined custom outputs. Compatibility to existing 'full graph' Session is preserved. Signed-off-by: Cheongyo Bahk * Review fix: get output by node * Review fix: efficient ctor call * Add warning --- compiler/locomotiv/include/locomotiv/Session.h | 25 +++++- compiler/locomotiv/src/Session.cpp | 10 +-- compiler/locomotiv/src/Session.test.cpp | 115 +++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 10 deletions(-) diff --git a/compiler/locomotiv/include/locomotiv/Session.h b/compiler/locomotiv/include/locomotiv/Session.h index f82cad9..018ef76 100644 --- a/compiler/locomotiv/include/locomotiv/Session.h +++ b/compiler/locomotiv/include/locomotiv/Session.h @@ -22,6 +22,7 @@ #include #include +#include namespace locomotiv { @@ -33,7 +34,24 @@ class Session final { public: Session() = delete; - Session(loco::Graph *g) : _graph(g) + + /// @brief Make Session for graph with graph outputs themselves + Session(loco::Graph *g) : Session(g, loco::output_nodes(g)) + { + // DO NOTHING + } + + /** + * @brief Make Session for graph with selective custom outputs. Only + * subgraph to calculate given outputs would be executed. + * + * @note Set required inputs for given outputs, or inference may fail. + * @note custom_outputs don't need to be graph output, but can be any nodes + * in the middle of the graph. + * @warn This approach may fail in case of graph with control flow + */ + Session(loco::Graph *g, const std::vector &custom_outputs) + : _graph(g), _outputs(custom_outputs) { // DO NOTHING } @@ -62,7 +80,7 @@ public: void infer(); /// @brief Get number of graph outputs held by this Session - uint32_t output_size() const { return _graph->outputs()->size(); } + uint32_t output_size() const { return _outputs.size(); } /** * @brief Get output of graph as NodeData @@ -71,8 +89,11 @@ public: */ const NodeData *get_output(uint32_t index); + const loco::Node *get_output_node(uint32_t index) { return _outputs.at(index); } + private: loco::Graph *_graph; + std::vector _outputs; }; } // namespace locomotiv diff --git a/compiler/locomotiv/src/Session.cpp b/compiler/locomotiv/src/Session.cpp index cde38dd..bff632d 100644 --- a/compiler/locomotiv/src/Session.cpp +++ b/compiler/locomotiv/src/Session.cpp @@ -73,13 +73,7 @@ void Session::set_input(uint32_t index, std::unique_ptr &&data) void Session::infer() { - std::vector output_vec; - for (uint32_t i = 0; i < _graph->outputs()->size(); ++i) - { - output_vec.push_back(_graph->outputs()->at(i)->node()); - } - - auto schedules = loco::postorder_traversal(output_vec); + auto schedules = loco::postorder_traversal(_outputs); for (auto node : schedules) { @@ -91,7 +85,7 @@ const NodeData *Session::get_output(uint32_t index) { assert(index < output_size()); - auto output_node = _graph->outputs()->at(index)->node(); + auto output_node = _outputs.at(index); return annot_data(output_node); } diff --git a/compiler/locomotiv/src/Session.test.cpp b/compiler/locomotiv/src/Session.test.cpp index 1eeddc7..2619f03 100644 --- a/compiler/locomotiv/src/Session.test.cpp +++ b/compiler/locomotiv/src/Session.test.cpp @@ -175,6 +175,121 @@ TEST(Session, inference_identity) } } +TEST(Session, session_for_subgraph) +{ + /* + * Make following graph: + * ConstGen_1 -- + * \ + * ConstGen_2 --- TensorConcat_1 --- TensorConcat_3 --- Push + * / + * ConstGen_3 --- TensorConcat_2 -- + * / + * ConstGen_4 -- + */ + auto g = loco::make_graph(); + + auto c1 = g->nodes()->create(); + auto c2 = g->nodes()->create(); + auto c3 = g->nodes()->create(); + auto c4 = g->nodes()->create(); + + c1->dtype(loco::DataType::FLOAT32); + c2->dtype(loco::DataType::FLOAT32); + c3->dtype(loco::DataType::FLOAT32); + c4->dtype(loco::DataType::FLOAT32); + c1->shape({1}); + c2->shape({1}); + c3->shape({1}); + c4->shape({1}); + c1->size(1); + c2->size(1); + c3->size(1); + c4->size(1); + + c1->at(0) = 0.1f; + c2->at(0) = 0.2f; + c3->at(0) = 0.3f; + c4->at(0) = 0.4f; + + auto t1 = g->nodes()->create(); + auto t2 = g->nodes()->create(); + auto t3 = g->nodes()->create(); + + // Note: default concat axis is 0 + t1->lhs(c1); + t1->rhs(c2); + t2->lhs(c3); + t2->rhs(c4); + t3->lhs(t1); + t3->rhs(t2); + + auto push = g->nodes()->create(); + push->from(t3); + + { + // Session to get t1 only + locomotiv::Session s(g.get(), {t1}); + ASSERT_EQ(s.output_size(), 1); + ASSERT_EQ(s.get_output_node(0), dynamic_cast(t1)); + + s.infer(); + + auto t1_data = s.get_output(0); + ASSERT_NE(t1_data, nullptr); + ASSERT_EQ(*(t1_data->shape()), Shape{2}); + + auto t1_buf = t1_data->as_f32_bufptr(); + ASSERT_EQ(t1_buf->at({0}), 0.1f); + ASSERT_EQ(t1_buf->at({1}), 0.2f); + } + + { + // Session to get t2 only + locomotiv::Session s(g.get(), {t2}); + ASSERT_EQ(s.output_size(), 1); + ASSERT_EQ(s.get_output_node(0), dynamic_cast(t2)); + + s.infer(); + + auto t2_data = s.get_output(0); + ASSERT_NE(t2_data, nullptr); + ASSERT_EQ(*(t2_data->shape()), Shape{2}); + + auto t2_buf = t2_data->as_f32_bufptr(); + ASSERT_EQ(t2_buf->at({0}), 0.3f); + ASSERT_EQ(t2_buf->at({1}), 0.4f); + } + + { + // Session to get t2 and push + locomotiv::Session s(g.get(), {t2, push}); + ASSERT_EQ(s.output_size(), 2); + ASSERT_EQ(s.get_output_node(0), dynamic_cast(t2)); + ASSERT_EQ(s.get_output_node(1), dynamic_cast(push)); + + s.infer(); + + auto t2_data = s.get_output(0); + ASSERT_NE(t2_data, nullptr); + ASSERT_EQ(*(t2_data->shape()), Shape{2}); + + auto t2_buf = t2_data->as_f32_bufptr(); + ASSERT_EQ(t2_buf->at({0}), 0.3f); + ASSERT_EQ(t2_buf->at({1}), 0.4f); + + auto push_data = s.get_output(1); + ASSERT_NE(push_data, nullptr); + ASSERT_EQ(*(push_data->shape()), Shape{4}); + + auto push_buf = push_data->as_f32_bufptr(); + ASSERT_EQ(push_buf->at({0}), 0.1f); + ASSERT_EQ(push_buf->at({1}), 0.2f); + ASSERT_EQ(push_buf->at({2}), 0.3f); + ASSERT_EQ(push_buf->at({3}), 0.4f); + } +} + // Below here is internal test for locomotiv, i.e. not public usage of locomotiv #include "NodeDataImpl.h" #include "NodeDomain.h" -- 2.7.4