[Fix] Flatten realizer to maintain original name
authorJihoon Lee <jhoon.it.lee@samsung.com>
Fri, 17 Dec 2021 06:04:27 +0000 (15:04 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 25 Jan 2022 05:40:29 +0000 (14:40 +0900)
This patch changes flatten realizer to maintain original node name.
Please see the included test patch to get the concept.

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/compiler/flatten_realizer.cpp
test/unittest/compiler/unittest_realizer.cpp

index ce2ebb4..d84e658 100644 (file)
@@ -10,6 +10,8 @@
  * @bug No known bugs except for NYI items
  */
 #include <flatten_realizer.h>
+#include <remap_realizer.h>
+#include <unordered_map>
 
 #include <flatten_layer.h>
 #include <layer_node.h>
@@ -23,20 +25,50 @@ FlattenRealizer::realize(const GraphRepresentation &reference) {
   GraphRepresentation processed;
   processed.reserve(reference.size());
 
+  std::unordered_map<std::string /**< layer_name */,
+                     std::string /**< flatten_layer_name */>
+    remap_table;
+  std::vector<LayerNode *> flatten_nodes;
+  std::unordered_map<std::string /**< temp_layer_name */,
+                     std::string /**< layer_name */>
+    recovery_table;
+
   for (auto &node : reference) {
     /// @note: [node] type=flatten; flatten=true; is awkward but allowed.
     /// There is no reason to prohibit this.
     processed.push_back(node);
     if (node->getFlatten() && !node->getDistribute()) {
+      node->setProperty({"flatten=false"});
+
       auto layer_name = node->getName();
+
+      auto flatten_name = layer_name + "/flatten_realized";
+      auto temp_name = flatten_name + "/temp";
+
+      remap_table.insert({layer_name, flatten_name});
+      recovery_table.insert({temp_name, layer_name});
+
       auto flatten_node =
-        createLayerNode(FlattenLayer::type, {"name=" + layer_name});
-      node->setProperty({"flatten=false"});
-      node->setProperty({"name=" + layer_name + "/flatten_realized"});
-      flatten_node->setProperty({"input_layers=" + node->getName()});
+        createLayerNode(FlattenLayer::type, {"name=" + flatten_name});
+      flatten_node->setProperty({"input_layers=" + temp_name});
       processed.push_back(std::move(flatten_node));
     }
   }
+  RemapRealizer remap_others([&remap_table](std::string &name, unsigned &idx) {
+    if (auto iter = remap_table.find(name); iter != remap_table.end()) {
+      name = iter->second;
+    }
+  });
+
+  RemapRealizer recover_temp(
+    [&recovery_table](std::string &name, unsigned &idx) {
+      if (auto iter = recovery_table.find(name); iter != recovery_table.end()) {
+        name = iter->second;
+      }
+    });
+
+  processed = remap_others.realize(processed);
+  processed = recover_temp.realize(processed);
 
   return processed;
 }
index e0a8e18..967883d 100644 (file)
@@ -47,12 +47,15 @@ static void realizeAndEqual(GraphRealizer &realizer,
 TEST(FlattenRealizer, flatten_p) {
   FlattenRealizer fr;
 
-  LayerRepresentation input1 = {"fully_connected",
-                                {"name=layer1", "flatten=true"}};
-  LayerRepresentation expected1 = {"fully_connected",
-                                   {"name=layer1/flatten_realized"}};
+  LayerRepresentation input1 = {
+    "fully_connected",
+    {"name=layer1", "flatten=true"},
+  };
+  LayerRepresentation expected1 = {"fully_connected", {"name=layer1"}};
   LayerRepresentation expected2 = {
-    "flatten", {"name=layer1", "input_layers=layer1/flatten_realized"}};
+    "flatten",
+    {"name=layer1/flatten_realized", "input_layers=layer1"},
+  };
 
   realizeAndEqual(fr, {input1}, {expected1, expected2});
 }