From 363efc45e7b84ee3f2692a01d5d1eaebcdf5e380 Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Thu, 18 Nov 2021 16:02:30 +0900 Subject: [PATCH] [Realizer] Implement previous input realizer This patch implement previous input realizer. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- nntrainer/compiler/previous_input_realizer.cpp | 46 ++++++++++++++++++++++++-- nntrainer/compiler/previous_input_realizer.h | 7 ++-- test/unittest/compiler/unittest_realizer.cpp | 4 +-- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/nntrainer/compiler/previous_input_realizer.cpp b/nntrainer/compiler/previous_input_realizer.cpp index 03ba1b0..14cc205 100644 --- a/nntrainer/compiler/previous_input_realizer.cpp +++ b/nntrainer/compiler/previous_input_realizer.cpp @@ -9,22 +9,64 @@ * @author Jihoon Lee * @bug No known bugs except for NYI items */ +#include #include #include +#include #include +#include +#include #include namespace nntrainer { PreviousInputRealizer::PreviousInputRealizer( - const std::vector &identified_input) {} + const std::vector &identified_inputs_) : + identified_inputs(identified_inputs_) {} PreviousInputRealizer::~PreviousInputRealizer() {} GraphRepresentation PreviousInputRealizer::realize(const GraphRepresentation &reference) { - return GraphRepresentation(); + GraphRepresentation processed(reference.begin(), reference.end()); + + /** + * @brief for node has input connection, below function determines if the node + * should be input node or add input_layers from previous layer + * + */ + auto is_actually_an_input_node = [this](const LayerNode &node) { + return node.hasInputShapeProperty() or + std::any_of(identified_inputs.begin(), identified_inputs.end(), + [&node](auto &name) { return node.getName() == name; }); + }; + + for (auto iter = processed.begin(); iter != processed.end(); ++iter) { + auto &node = *iter; + if (node->getNumInputConnections() != 0) { + continue; + } + + if (is_actually_an_input_node(*node)) { + continue; + } + + NNTR_THROW_IF(iter == processed.begin(), std::invalid_argument) + << "First node must be identified as an input if it is qualified to be " + "input, name: " + << node->getName(); + + auto &prev_node = *(iter - 1); + ml_logi( + "%s is identified as a non-input node and default input layer(%s) is " + "being set ", + node->getName().c_str(), prev_node->getName().c_str()); + + node->setInputLayers({prev_node->getName()}); + } + + return processed; } } // namespace nntrainer diff --git a/nntrainer/compiler/previous_input_realizer.h b/nntrainer/compiler/previous_input_realizer.h index 1fdd528..15abc11 100644 --- a/nntrainer/compiler/previous_input_realizer.h +++ b/nntrainer/compiler/previous_input_realizer.h @@ -30,10 +30,10 @@ public: /** * @brief Construct a new Previous Input Realizer object * - * @param identified_input node that is identified as an input, this must not + * @param identified_inputs node that is identified as an input, this must not * connect to other nodes automatically */ - PreviousInputRealizer(const std::vector &identified_input); + PreviousInputRealizer(const std::vector &identified_inputs); /** * @brief Destroy the Graph Realizer object @@ -46,6 +46,9 @@ public: * */ GraphRepresentation realize(const GraphRepresentation &reference) override; + +private: + std::vector identified_inputs; /**< inputs are identified */ }; } // namespace nntrainer diff --git a/test/unittest/compiler/unittest_realizer.cpp b/test/unittest/compiler/unittest_realizer.cpp index a1093c1..96bca86 100644 --- a/test/unittest/compiler/unittest_realizer.cpp +++ b/test/unittest/compiler/unittest_realizer.cpp @@ -231,7 +231,7 @@ TEST(InputRealizer, remap_p) { realizeAndEqual(r, before, after); } -TEST(PreviousInputRealizer, DISABLED_previous_p) { +TEST(PreviousInputRealizer, previous_p) { { /// realization without identifying custom input std::vector before = { {"fully_connected", {"name=fc1", "input_shape=1"}}, // model input @@ -267,7 +267,7 @@ TEST(PreviousInputRealizer, DISABLED_previous_p) { } } -TEST(PreviousInputRealizer, DISABLED_user_not_identifying_first_input_n) { +TEST(PreviousInputRealizer, user_not_identifying_first_input_n) { /// realization without identifying custom input std::vector before = { {"fully_connected", {"name=fc1"}}, // this should be model input but -- 2.7.4