[Realizer] Change slice realizer to get connection
authorJihoon Lee <jhoon.it.lee@samsung.com>
Thu, 16 Dec 2021 13:07:19 +0000 (22:07 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 28 Dec 2021 12:42:16 +0000 (21:42 +0900)
This patch change slice realizer to get connection

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/compiler/slice_realizer.cpp
nntrainer/compiler/slice_realizer.h
nntrainer/models/neuralnet.cpp
test/unittest/compiler/unittest_realizer.cpp

index 70d17be6e5c5742677b6858aa9e246be5b0cb068..4d181f9da2451230bc7870aba8e294f08bd4f21c 100644 (file)
@@ -10,6 +10,8 @@
  * @bug No known bugs except for NYI items
  */
 
+#include <connection.h>
+#include <iterator>
 #include <layer_node.h>
 #include <slice_realizer.h>
 
 
 namespace nntrainer {
 
-SliceRealizer::SliceRealizer(const std::vector<std::string> &start_layers,
-                             const std::vector<std::string> &end_layers) :
-  start_layers(start_layers),
-  end_layers(end_layers.begin(), end_layers.end()) {}
+SliceRealizer::SliceRealizer(const std::vector<Connection> &start_layers,
+                             const std::vector<Connection> &end_layers) {
+  /// discard index information as it is not needed as it is not really needed
+  this->start_layers.reserve(start_layers.size());
+
+  std::transform(start_layers.begin(), start_layers.end(),
+                 std::back_inserter(this->start_layers),
+                 [](const Connection &c) { return c.getName(); });
+
+  std::transform(end_layers.begin(), end_layers.end(),
+                 std::inserter(this->end_layers, this->end_layers.begin()),
+                 [](const Connection &c) { return c.getName(); });
+}
 
 SliceRealizer::~SliceRealizer() {}
 
index b4de028c161cb0824b08d8795659451fd7129e73..420e5bd4456b942909cc5fa7b63aa36f96c908bf 100644 (file)
@@ -21,6 +21,8 @@
 
 namespace nntrainer {
 
+class Connection;
+
 /**
  * @brief Graph realizer class which slice graph representation
  *
@@ -30,11 +32,11 @@ public:
   /**
    * @brief Construct a new Slice Realizer object
    *
-   * @param start_layers start layers
-   * @param end_layers end layers
+   * @param start_connections start layers
+   * @param end_connections end layers
    */
-  SliceRealizer(const std::vector<std::string> &start_layers,
-                const std::vector<std::string> &end_layers);
+  SliceRealizer(const std::vector<Connection> &start_connections,
+                const std::vector<Connection> &end_connections);
 
   /**
    * @brief Destroy the Graph Realizer object
index e424c1ada23e4d9b3834b53c5d75761d54dbc38f..59fc88694f8a30d6535f48f835e6895c5a0108f2 100644 (file)
@@ -943,13 +943,16 @@ void NeuralNetwork::addWithReferenceLayers(
   auto start_layers_ = normalize(start_layers);
   auto end_layers_ = normalize(end_layers);
 
-  auto start_conns_ =
+  auto start_conns =
     std::vector<Connection>(start_layers.begin(), start_layers.end());
 
+  auto end_conns =
+    std::vector<Connection>(end_layers.begin(), end_layers.end());
+
   std::vector<std::unique_ptr<GraphRealizer>> realizers;
 
-  realizers.emplace_back(new PreviousInputRealizer(start_conns_));
-  realizers.emplace_back(new SliceRealizer(start_layers_, end_layers_));
+  realizers.emplace_back(new PreviousInputRealizer(start_conns));
+  realizers.emplace_back(new SliceRealizer(start_conns, end_conns));
 
   if (!input_layers_.empty()) {
     realizers.emplace_back(new InputRealizer(start_layers_, input_layers_));
index 06177e8a54be67820b65f19a5c253be5159372dc..5385e47cacbfbe6852f602bab76b66581105a8b2 100644 (file)
@@ -464,7 +464,18 @@ TEST(SliceRealizer, slice_01_p) {
     {"fully_connected", {"name=d2", "input_layers=c1"}},
   };
 
-  SliceRealizer r({"a1", "b1", "b2"}, {"a1", "d1", "d2"});
+  using C = Connection;
+  SliceRealizer r(
+    {
+      C("a1"),
+      C("b1"),
+      C("b2"),
+    },
+    {
+      C("a1"),
+      C("d1"),
+      C("d2"),
+    });
 
   realizeAndEqual(r, before, after);
 }
@@ -498,7 +509,7 @@ TEST(SliceRealizer, slice_02_p) {
     {"concat", {"name=c1", "input_layers=a1, b1"}},
   };
 
-  SliceRealizer r({"a1"}, {"c1"});
+  SliceRealizer r({Connection("a1")}, {Connection("c1")});
 
   realizeAndEqual(r, before, after);
 }