[enco.caffe] Correctly enumerate network outputs (#1327)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 5 Sep 2018 00:11:06 +0000 (09:11 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 5 Sep 2018 00:11:06 +0000 (09:11 +0900)
With this commit, caffe frontend now correctly enumerates network
outputs.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/enco/frontend/caffe/src/Frontend.cpp
contrib/enco/test/caffe/009/BUILD [new file with mode: 0644]
contrib/enco/test/caffe/009/INFERENCE [new file with mode: 0644]
contrib/enco/test/caffe/009/test.prototxt [new file with mode: 0644]

index 249ac26..b424e9d 100644 (file)
@@ -89,15 +89,51 @@ enco::Bundle Frontend::load(void) const
   std::map<std::string, tensor::Shape> shape_ctx;
   std::map<std::string, coco::Bag *> bag_ctx;
 
-  std::set<std::string> top;
+  std::set<std::string> bags;
+  std::map<std::string, uint32_t> def_count;
+  std::map<std::string, uint32_t> use_count;
+
+  auto def = [&bags, &def_count, &use_count](const std::string &name) {
+    if (bags.find(name) == bags.end())
+    {
+      bags.insert(name);
+      def_count[name] = 0;
+      use_count[name] = 0;
+    }
+
+    def_count.at(name) += 1;
+  };
+
+  auto use = [&use_count](const std::string &name) { use_count.at(name) += 1; };
+
+  auto outputs = [&bags, &def_count, &use_count](void) {
+    std::set<std::string> res;
+
+    for (const auto &bag : bags)
+    {
+      if (def_count.at(bag) > use_count.at(bag))
+      {
+        res.insert(bag);
+      }
+    }
+
+    return res;
+  };
 
   for (const auto &layer : _prototxt->layer())
   {
     assert(layer.has_name());
     assert(layer.has_type());
 
-    top.clear();
-    top.insert(layer.top().begin(), layer.top().end());
+    for (uint32_t n = 0; n < layer.top().size(); ++n)
+    {
+      def(layer.top(n));
+    }
+
+    for (uint32_t n = 0; n < layer.bottom().size(); ++n)
+    {
+      use(layer.bottom(n));
+    }
 
     if (layer.type() == "Input")
     {
@@ -337,7 +373,7 @@ enco::Bundle Frontend::load(void) const
   }
 
   // Finalize: Create output for each top blob
-  for (const auto &name : top)
+  for (const auto &name : outputs())
   {
     const auto &shape = shape_ctx.at(name);
     auto bag = bag_ctx.at(name);
diff --git a/contrib/enco/test/caffe/009/BUILD b/contrib/enco/test/caffe/009/BUILD
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/contrib/enco/test/caffe/009/INFERENCE b/contrib/enco/test/caffe/009/INFERENCE
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/contrib/enco/test/caffe/009/test.prototxt b/contrib/enco/test/caffe/009/test.prototxt
new file mode 100644 (file)
index 0000000..eece8b7
--- /dev/null
@@ -0,0 +1,12 @@
+layer {
+  name: "data1"
+  type: "Input"
+  top: "data1"
+  input_param { shape: { dim: 1 dim: 3 dim: 15 dim: 15 } }
+}
+layer {
+  name: "data2"
+  type: "Input"
+  top: "data2"
+  input_param { shape: { dim: 1 dim: 3 dim: 15 dim: 15 } }
+}