From 6b9d1aaa7c37729f3651094077e8eec0833b4ba1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Wed, 5 Sep 2018 09:11:06 +0900 Subject: [PATCH] [enco.caffe] Correctly enumerate network outputs (#1327) With this commit, caffe frontend now correctly enumerates network outputs. Signed-off-by: Jonghyun Park --- contrib/enco/frontend/caffe/src/Frontend.cpp | 44 +++++++++++++++++++++++++--- contrib/enco/test/caffe/009/BUILD | 0 contrib/enco/test/caffe/009/INFERENCE | 0 contrib/enco/test/caffe/009/test.prototxt | 12 ++++++++ 4 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 contrib/enco/test/caffe/009/BUILD create mode 100644 contrib/enco/test/caffe/009/INFERENCE create mode 100644 contrib/enco/test/caffe/009/test.prototxt diff --git a/contrib/enco/frontend/caffe/src/Frontend.cpp b/contrib/enco/frontend/caffe/src/Frontend.cpp index 249ac26..b424e9d 100644 --- a/contrib/enco/frontend/caffe/src/Frontend.cpp +++ b/contrib/enco/frontend/caffe/src/Frontend.cpp @@ -89,15 +89,51 @@ enco::Bundle Frontend::load(void) const std::map shape_ctx; std::map bag_ctx; - std::set top; + std::set bags; + std::map def_count; + std::map use_count; + + auto def = [&bags, &def_count, &use_count](const std::string &name) { + if (bags.find(name) == bags.end()) + { + bags.insert(name); + def_count[name] = 0; + use_count[name] = 0; + } + + def_count.at(name) += 1; + }; + + auto use = [&use_count](const std::string &name) { use_count.at(name) += 1; }; + + auto outputs = [&bags, &def_count, &use_count](void) { + std::set res; + + for (const auto &bag : bags) + { + if (def_count.at(bag) > use_count.at(bag)) + { + res.insert(bag); + } + } + + return res; + }; for (const auto &layer : _prototxt->layer()) { assert(layer.has_name()); assert(layer.has_type()); - top.clear(); - top.insert(layer.top().begin(), layer.top().end()); + for (uint32_t n = 0; n < layer.top().size(); ++n) + { + def(layer.top(n)); + } + + for (uint32_t n = 0; n < layer.bottom().size(); ++n) + { + use(layer.bottom(n)); + } if (layer.type() == "Input") { @@ -337,7 +373,7 @@ enco::Bundle Frontend::load(void) const } // Finalize: Create output for each top blob - for (const auto &name : top) + for (const auto &name : outputs()) { const auto &shape = shape_ctx.at(name); auto bag = bag_ctx.at(name); diff --git a/contrib/enco/test/caffe/009/BUILD b/contrib/enco/test/caffe/009/BUILD new file mode 100644 index 0000000..e69de29 diff --git a/contrib/enco/test/caffe/009/INFERENCE b/contrib/enco/test/caffe/009/INFERENCE new file mode 100644 index 0000000..e69de29 diff --git a/contrib/enco/test/caffe/009/test.prototxt b/contrib/enco/test/caffe/009/test.prototxt new file mode 100644 index 0000000..eece8b7 --- /dev/null +++ b/contrib/enco/test/caffe/009/test.prototxt @@ -0,0 +1,12 @@ +layer { + name: "data1" + type: "Input" + top: "data1" + input_param { shape: { dim: 1 dim: 3 dim: 15 dim: 15 } } +} +layer { + name: "data2" + type: "Input" + top: "data2" + input_param { shape: { dim: 1 dim: 3 dim: 15 dim: 15 } } +} -- 2.7.4