From: Jihoon Lee Date: Mon, 6 Dec 2021 05:27:09 +0000 (+0900) Subject: [Recurrent] Support multiple sequence X-Git-Tag: accepted/tizen/unified/20220323.062643~110 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e76a186c8b4a7bea1d8d30ffd9f208cfd2ec78e8;p=platform%2Fcore%2Fml%2Fnntrainer.git [Recurrent] Support multiple sequence This patch updates recurrent realizer to suupporting multiple sequence with layer name not boolean with `as_sequence` property **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- diff --git a/nntrainer/compiler/recurrent_realizer.cpp b/nntrainer/compiler/recurrent_realizer.cpp index 7a68691..350cee1 100644 --- a/nntrainer/compiler/recurrent_realizer.cpp +++ b/nntrainer/compiler/recurrent_realizer.cpp @@ -26,6 +26,8 @@ #include #include +#include + namespace nntrainer { namespace props { @@ -102,7 +104,7 @@ RecurrentRealizer::RecurrentRealizer( end_layers(end_layers), recurrent_props( new PropTypes(props::RecurrentInput(), props::RecurrentOutput(), - props::ReturnSequences(false), props::UnrollFor(1))) { + std::vector(), props::UnrollFor(1))) { auto left = loadProperties(properties, *recurrent_props); auto throw_if_empty = [](auto &&prop) { @@ -111,10 +113,16 @@ RecurrentRealizer::RecurrentRealizer( << getPropKey(prop); }; - throw_if_empty(std::get<0>(*recurrent_props)); - throw_if_empty(std::get<1>(*recurrent_props)); - throw_if_empty(std::get<2>(*recurrent_props)); - throw_if_empty(std::get<3>(*recurrent_props)); + /// @todo check input, output number matches + /// @todo check if as sequence is subset of recurrent output + /// @note AsSequence must be identifier based (not connection based) for now + /// consider A(layer) outputs a0, a1 connection and a0 needs return seq + /// Then it is impossible to locate a0 and a1 with the same name unless we + /// have some kind of multi,multiout identity layer. Until this is supported, + /// AsSequenced stays as identifier based + throw_if_empty(std::get<0>(*recurrent_props)); // input + throw_if_empty(std::get<1>(*recurrent_props)); // ouput + throw_if_empty(std::get<3>(*recurrent_props)); // unroll for NNTR_THROW_IF(!left.empty(), std::invalid_argument) << "There is unparesed properties"; } @@ -264,15 +272,13 @@ RecurrentRealizer::realize(const GraphRepresentation &reference) { * @todo support connection using node->remapConnection */ auto naive_output = [](const GraphRepresentation &reference_, - unsigned unroll_for) { - /// last output's index is removed so that it can be directly an output - auto suffix = "/" + std::to_string(unroll_for - 1); - RemapRealizer r([suffix](std::string &name) { - if (endswith(name, suffix)) { - auto pos = name.find_last_of('/'); - if (pos != std::string::npos) { - name = name.substr(0, pos); - } + const std::string &con, unsigned unroll_for) { + auto target = con + "/" + std::to_string(unroll_for - 1); + RemapRealizer r([target, con](std::string &name) { + std::cout << name << " vs " << target << '\n'; + if (name == target) { + std::cout << "matched, setting to con: " << con << '\n'; + name = con; } }); @@ -285,7 +291,7 @@ RecurrentRealizer::realize(const GraphRepresentation &reference) { * */ auto concat_output = [this](const GraphRepresentation &reference_, - unsigned unroll_for) { + const std::string &con, unsigned unroll_for) { GraphRepresentation processed(reference_.begin(), reference_.end()); for (auto &end : end_layers) { @@ -309,11 +315,27 @@ RecurrentRealizer::realize(const GraphRepresentation &reference) { auto step3_connect_output = [this, naive_output, concat_output](const GraphRepresentation &reference_, unsigned unroll_for) { - bool return_sequence = - std::get(*recurrent_props).get(); - /// @todo return_sequence will become a sequenced_output_layers - return return_sequence ? concat_output(reference_, unroll_for) - : naive_output(reference_, unroll_for); + auto sequenced_layers = + std::get>(*recurrent_props); + + std::unordered_set check_seqs; + for (auto &name : sequenced_layers) { + check_seqs.emplace(name.get()); + }; + + /// @note below is inefficient way of processing nodes consider optimize + /// below as needed by calling remap realizer only once + std::vector output_conns = { + std::get(*recurrent_props)}; + + auto processed = reference_; + for (auto &name : output_conns) { + processed = check_seqs.count(name) + ? concat_output(processed, name, unroll_for) + : naive_output(processed, name, unroll_for); + } + + return processed; }; auto unroll_for = std::get(*recurrent_props).get(); diff --git a/nntrainer/compiler/recurrent_realizer.h b/nntrainer/compiler/recurrent_realizer.h index eae0fe3..4a6a678 100644 --- a/nntrainer/compiler/recurrent_realizer.h +++ b/nntrainer/compiler/recurrent_realizer.h @@ -25,7 +25,7 @@ namespace nntrainer { namespace props { class UnrollFor; -class ReturnSequences; +class AsSequence; class OutputLayer; class RecurrentInput; class RecurrentOutput; @@ -83,8 +83,9 @@ public: GraphRepresentation realize(const GraphRepresentation &reference) override; private: - using PropTypes = std::tuple; + using PropTypes = + std::tuple, props::UnrollFor>; std::unordered_set input_layers; /**< external input layers */ std::vector end_layers; /**< final output layers id */ diff --git a/nntrainer/layers/common_properties.h b/nntrainer/layers/common_properties.h index 64b2581..0e27544 100644 --- a/nntrainer/layers/common_properties.h +++ b/nntrainer/layers/common_properties.h @@ -543,6 +543,17 @@ public: }; /** + * @brief Identifiers to locate a connection which should be returned as whole + * used in recurrent realizer + * + */ +class AsSequence : public Name { +public: + static constexpr const char *key = "as_sequence"; + using prop_tag = str_prop_tag; +}; + +/** * @brief Number of class * @todo deprecate this */ diff --git a/test/unittest/compiler/unittest_realizer.cpp b/test/unittest/compiler/unittest_realizer.cpp index 32744ce..c731d1f 100644 --- a/test/unittest/compiler/unittest_realizer.cpp +++ b/test/unittest/compiler/unittest_realizer.cpp @@ -59,9 +59,9 @@ TEST(FlattenRealizer, flatten_p) { TEST(RecurrentRealizer, recurrent_no_return_sequence_p) { - RecurrentRealizer r({"unroll_for=3", "return_sequences=false", - "recurrent_input=fc_in", "recurrent_output=fc_out"}, - {"source"}, {"fc_out"}); + RecurrentRealizer r( + {"unroll_for=3", "recurrent_input=fc_in", "recurrent_output=fc_out"}, + {"source"}, {"fc_out"}); std::vector before = { {"fully_connected", {"name=fc_in", "input_layers=source"}}, @@ -75,17 +75,17 @@ TEST(RecurrentRealizer, recurrent_no_return_sequence_p) { {"fully_connected", {"name=fc_out/1", "input_layers=fc_in/1", "shared_from=fc_out/0"}}, {"fully_connected", - {"name=fc_in", "input_layers=fc_out/1", "shared_from=fc_in/0"}}, + {"name=fc_in/2", "input_layers=fc_out/1", "shared_from=fc_in/0"}}, {"fully_connected", - {"name=fc_out", "input_layers=fc_in", "shared_from=fc_out/0"}}, + {"name=fc_out", "input_layers=fc_in/2", "shared_from=fc_out/0"}}, }; realizeAndEqual(r, before, expected); } -TEST(DISABLED_RecurrentRealizer, recurrent_return_sequence_single_p) { +TEST(RecurrentRealizer, recurrent_return_sequence_single_p) { - RecurrentRealizer r({"unroll_for=3", "return_sequences=fc_out", + RecurrentRealizer r({"unroll_for=3", "as_sequence=fc_out", "recurrent_input=lstm", "recurrent_output=fc_out"}, {"source"}, {"fc_out"}); diff --git a/test/unittest/models/unittest_models_recurrent.cpp b/test/unittest/models/unittest_models_recurrent.cpp index b1aa7fd..7f40c2d 100644 --- a/test/unittest/models/unittest_models_recurrent.cpp +++ b/test/unittest/models/unittest_models_recurrent.cpp @@ -98,7 +98,6 @@ std::unique_ptr makeFC() { ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=2", - "return_sequences=false", "recurrent_input=a1", "recurrent_output=a2", }); @@ -130,7 +129,6 @@ std::unique_ptr makeFCClipped() { ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=2", - "return_sequences=false", "recurrent_input=a1", "recurrent_output=a2", }); @@ -160,7 +158,7 @@ static std::unique_ptr makeSingleLSTM() { ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=2", - "return_sequences=true", + "as_sequence=a1", "recurrent_input=a1", "recurrent_output=a1", }); @@ -191,7 +189,7 @@ static std::unique_ptr makeStackedLSTM() { ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=2", - "return_sequences=true", + "as_sequence=a2", "recurrent_input=a1", "recurrent_output=a2", }); @@ -221,7 +219,7 @@ static std::unique_ptr makeSingleLSTMCell() { ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=2", - "return_sequences=true", + "as_sequence=a1", "recurrent_input=a1", "recurrent_output=a1", }); @@ -252,7 +250,7 @@ static std::unique_ptr makeStackedLSTMCell() { ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=2", - "return_sequences=true", + "as_sequence=a2", "recurrent_input=a1", "recurrent_output=a2", }); @@ -351,7 +349,7 @@ static std::unique_ptr makeSingleRNNCell() { {"a1"}, ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=2", - "return_sequences=true", + "as_sequence=a1", "recurrent_input=a1", "recurrent_output=a1", }); @@ -382,7 +380,7 @@ static std::unique_ptr makeStackedRNNCell() { {"a2"}, ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=2", - "return_sequences=true", + "as_sequence=a2", "recurrent_input=a1", "recurrent_output=a2", }); @@ -412,7 +410,7 @@ static std::unique_ptr makeSingleGRUCell() { {"a1"}, ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=2", - "return_sequences=true", + "as_sequence=a1", "recurrent_input=a1", "recurrent_output=a1", }); @@ -443,7 +441,7 @@ static std::unique_ptr makeStackedGRUCell() { {"a2"}, ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=2", - "return_sequences=true", + "as_sequence=a2", "recurrent_input=a1", "recurrent_output=a2", }); diff --git a/test/unittest/unittest_nntrainer_models.cpp b/test/unittest/unittest_nntrainer_models.cpp index 12febe4..c4d07b9 100644 --- a/test/unittest/unittest_nntrainer_models.cpp +++ b/test/unittest/unittest_nntrainer_models.cpp @@ -990,7 +990,7 @@ TEST(nntrainerModels, loadFromLayersRecurrent_p) { {"fc2"}, ml::train::ReferenceLayersType::RECURRENT, { "unroll_for=3", - "return_sequences=true", + "as_sequence=fc2", "recurrent_input=fc1", "recurrent_output=fc2", });