[Recurrent] Support connection for recurrents
authorJihoon Lee <jhoon.it.lee@samsung.com>
Mon, 6 Dec 2021 08:31:29 +0000 (17:31 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 6 Dec 2021 13:19:51 +0000 (22:19 +0900)
This patch support connection for recurrent input, outputs

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/compiler/recurrent_realizer.cpp
nntrainer/compiler/recurrent_realizer.h
test/unittest/compiler/unittest_realizer.cpp

index d4a5e9e..f2c1cef 100644 (file)
  * @bug No known bugs except for NYI items
  */
 #include <algorithm>
-#include <recurrent_realizer.h>
+#include <stdexcept>
 
 #include <common_properties.h>
+#include <connection.h>
 #include <grucell.h>
 #include <input_layer.h>
 #include <layer_node.h>
 #include <lstmcell_core.h>
 #include <nntrainer_error.h>
 #include <node_exporter.h>
+#include <recurrent_realizer.h>
 #include <remap_realizer.h>
 #include <rnncell.h>
-#include <stdexcept>
 #include <util_func.h>
 #include <zoneout_lstmcell.h>
 
-#include <iostream>
-
 namespace nntrainer {
 
 namespace props {
@@ -51,7 +50,7 @@ UnrollFor::UnrollFor(const unsigned &value) { set(value); }
  * @brief Property for recurrent inputs
  *
  */
-class RecurrentInput final : public Name {
+class RecurrentInput final : public Property<Connection> {
 public:
   /**
    * @brief Construct a new Recurrent Input object
@@ -64,19 +63,19 @@ public:
    *
    * @param name name
    */
-  RecurrentInput(const std::string &name);
+  RecurrentInput(const Connection &name);
   static constexpr const char *key = "recurrent_input";
-  using prop_tag = str_prop_tag;
+  using prop_tag = connection_prop_tag;
 };
 
 RecurrentInput::RecurrentInput() {}
-RecurrentInput::RecurrentInput(const std::string &name) { set(name); };
+RecurrentInput::RecurrentInput(const Connection &con) { set(con); };
 
 /**
  * @brief Property for recurrent outputs
  *
  */
-class RecurrentOutput final : public Name {
+class RecurrentOutput final : public Property<Connection> {
 public:
   /**
    * @brief Construct a new Recurrent Output object
@@ -89,13 +88,13 @@ public:
    *
    * @param name name
    */
-  RecurrentOutput(const std::string &name);
+  RecurrentOutput(const Connection &name);
   static constexpr const char *key = "recurrent_output";
-  using prop_tag = str_prop_tag;
+  using prop_tag = connection_prop_tag;
 };
 
 RecurrentOutput::RecurrentOutput() {}
-RecurrentOutput::RecurrentOutput(const std::string &name) { set(name); };
+RecurrentOutput::RecurrentOutput(const Connection &con) { set(con); };
 } // namespace props
 
 RecurrentRealizer::RecurrentRealizer(
@@ -110,8 +109,6 @@ RecurrentRealizer::RecurrentRealizer(
     std::vector<props::AsSequence>(), props::UnrollFor(1))) {
   auto left = loadProperties(properties, *recurrent_props);
 
-  /// @todo check input, output number matches
-  /// @todo check if as sequence is subset of recurrent output
   /// @note AsSequence must be identifier based (not connection based) for now
   /// consider A(layer) outputs a0, a1 connection and a0 needs return seq
   /// Then it is impossible to locate a0 and a1 with the same name unless we
@@ -133,7 +130,7 @@ RecurrentRealizer::RecurrentRealizer(
                                                 seq) != end_layers.end();
                              }),
                 std::invalid_argument)
-    << "as_sequence property must be subset of recurrent_outputs";
+    << "as_sequence property must be subset of end_layers";
 
   std::unordered_set<std::string> check_seqs;
   for (auto &name : as_sequence) {
@@ -189,16 +186,8 @@ RecurrentRealizer::~RecurrentRealizer() {}
 GraphRepresentation
 RecurrentRealizer::realize(const GraphRepresentation &reference) {
 
-  auto step0_verify_and_prepare = [this, &reference]() {
-    for (auto &node : reference) {
-      if (recurrent_info.count(node->getName())) {
-        /// @todo this does not have to be the restriction as we
-        /// are supporting connections (#1760)
-        NNTR_THROW_IF(node->getNumInputConnections() != 1,
-                      std::invalid_argument)
-          << "recurrent input must have single connection: " << node->getName();
-      }
-    }
+  auto step0_verify_and_prepare = []() {
+    /// empty intended
   };
 
   /**
@@ -249,18 +238,15 @@ RecurrentRealizer::realize(const GraphRepresentation &reference) {
         });
 
       /// 2. override first output name to $name/$idx - 1
-      auto &name = node->getName();
-      auto suffix_len = std::string("/0").length();
-      if (auto iter =
-            recurrent_info.find(name.substr(0, name.length() - suffix_len));
-          iter != recurrent_info.end()) {
-        std::string output_name =
-          iter->second + "/" + std::to_string(time_idx - 1);
-        new_node->remapConnections(
-          [&output_name](std::string &name, unsigned &idx) {
-            /// @todo alter only when idx matches
-            name = output_name;
-          });
+      for (auto &[recurrent_input, recurrent_output] : recurrent_info) {
+        if (node->getName() != recurrent_input.getName() + "/0") {
+          continue;
+        }
+        new_node->setInputConnectionIndex(recurrent_input.getIndex(),
+                                          recurrent_output.getIndex());
+        new_node->setInputConnectionName(recurrent_input.getIndex(),
+                                         recurrent_output.getName() + "/" +
+                                           std::to_string(time_idx - 1));
       }
 
       /// 3. set shared_from
index 8099538..5315f2e 100644 (file)
@@ -22,6 +22,7 @@
 #include <unordered_set>
 #include <vector>
 
+#include <connection.h>
 namespace nntrainer {
 
 namespace props {
@@ -94,7 +95,7 @@ private:
   std::unordered_set<std::string>
     sequenced_return_layers; /**< sequenced return layers, subset of end_layers
                               */
-  std::unordered_map<std::string, std::string>
+  std::unordered_map<Connection, Connection>
     recurrent_info;                           /**< final output layers id */
   std::unique_ptr<PropTypes> recurrent_props; /**< recurrent properties */
 };
index 292e573..bba9d76 100644 (file)
@@ -281,15 +281,119 @@ TEST(RecurrentRealizer, recurrent_multi_inout_return_seq_p) {
   realizeAndEqual(r, before, expected);
 }
 
-TEST(DISABLED_RecurrentRealizer,
-     recurrent_multi_inout_using_connection_return_seq_p) {
-  /// NYI, just for specification
-  EXPECT_TRUE(false);
+TEST(RecurrentRealizer, recurrent_multi_inout_using_connection_return_seq_p) {
+
+  RecurrentRealizer r(
+    {
+      "unroll_for=3",
+      "as_sequence=fc_out",
+      "recurrent_input=lstm,add(2)",
+      "recurrent_output=fc_out,split(1)",
+    },
+    {"source", "source2", "source3"}, {"fc_out"});
+
+  /// @note for below graph,
+  /// 1. fc_out feds back to lstm
+  /// 2. ouput_dummy feds back to source2_dummy
+  /// ========================================================
+  /// lstm        -------- addition - split ---- fc_out (to_lstm)
+  /// source2_dummy   --/                  \----- (to addition 3)
+  std::vector<LayerRepresentation> before = {
+    {"lstm", {"name=lstm", "input_layers=source"}},
+    {"addition", {"name=add", "input_layers=lstm,source2,source3"}},
+    {"split", {"name=split", "input_layers=add"}},
+    {"fully_connected", {"name=fc_out", "input_layers=split(0)"}},
+  };
+
+  std::vector<LayerRepresentation> expected = {
+    /// timestep 0
+    {"lstm",
+     {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0"}},
+    {"addition", {"name=add/0", "input_layers=lstm/0,source2,source3"}},
+    {"split", {"name=split/0", "input_layers=add/0"}},
+    {"fully_connected", {"name=fc_out/0", "input_layers=split/0(0)"}},
+
+    /// timestep 1
+    {"lstm",
+     {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0",
+      "max_timestep=3", "timestep=1"}},
+    {"addition",
+     {"name=add/1", "input_layers=lstm/1,source2,split/0(1)",
+      "shared_from=add/0"}},
+    {"split", {"name=split/1", "input_layers=add/1", "shared_from=split/0"}},
+    {"fully_connected",
+     {"name=fc_out/1", "input_layers=split/1(0)", "shared_from=fc_out/0"}},
+
+    /// timestep 2
+    {"lstm",
+     {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0",
+      "max_timestep=3", "timestep=2"}},
+    {"addition",
+     {"name=add/2", "input_layers=lstm/2,source2,split/1(1)",
+      "shared_from=add/0"}},
+    {"split", {"name=split/2", "input_layers=add/2", "shared_from=split/0"}},
+    {"fully_connected",
+     {"name=fc_out/2", "input_layers=split/2(0)", "shared_from=fc_out/0"}},
+    {"concat", {"name=fc_out", "input_layers=fc_out/0,fc_out/1,fc_out/2"}},
+  };
+
+  realizeAndEqual(r, before, expected);
 }
 
-TEST(DISABLED_RecurrentRealizer, recurrent_multi_inout_using_connection_p) {
-  /// NYI, just for specification
-  EXPECT_TRUE(false);
+TEST(RecurrentRealizer, recurrent_multi_inout_using_connection_p) {
+  RecurrentRealizer r(
+    {
+      "unroll_for=3",
+      "recurrent_input=lstm,add(2)",
+      "recurrent_output=fc_out,split(1)",
+    },
+    {"source", "source2", "source3"}, {"fc_out"});
+
+  /// @note for below graph,
+  /// 1. fc_out feds back to lstm
+  /// 2. ouput_dummy feds back to source2_dummy
+  /// ========================================================
+  /// lstm        -------- addition - split ---- fc_out (to_lstm)
+  /// source2_dummy   --/                  \----- (to addition 3)
+  std::vector<LayerRepresentation> before = {
+    {"lstm", {"name=lstm", "input_layers=source"}},
+    {"addition", {"name=add", "input_layers=lstm,source2,source3"}},
+    {"split", {"name=split", "input_layers=add"}},
+    {"fully_connected", {"name=fc_out", "input_layers=split(0)"}},
+  };
+
+  std::vector<LayerRepresentation> expected = {
+    /// timestep 0
+    {"lstm",
+     {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0"}},
+    {"addition", {"name=add/0", "input_layers=lstm/0,source2,source3"}},
+    {"split", {"name=split/0", "input_layers=add/0"}},
+    {"fully_connected", {"name=fc_out/0", "input_layers=split/0(0)"}},
+
+    /// timestep 1
+    {"lstm",
+     {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0",
+      "max_timestep=3", "timestep=1"}},
+    {"addition",
+     {"name=add/1", "input_layers=lstm/1,source2,split/0(1)",
+      "shared_from=add/0"}},
+    {"split", {"name=split/1", "input_layers=add/1", "shared_from=split/0"}},
+    {"fully_connected",
+     {"name=fc_out/1", "input_layers=split/1(0)", "shared_from=fc_out/0"}},
+
+    /// timestep 2
+    {"lstm",
+     {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0",
+      "max_timestep=3", "timestep=2"}},
+    {"addition",
+     {"name=add/2", "input_layers=lstm/2,source2,split/1(1)",
+      "shared_from=add/0"}},
+    {"split", {"name=split/2", "input_layers=add/2", "shared_from=split/0"}},
+    {"fully_connected",
+     {"name=fc_out", "input_layers=split/2(0)", "shared_from=fc_out/0"}},
+  };
+
+  realizeAndEqual(r, before, expected);
 }
 
 TEST(RemapRealizer, remap_p) {