From 0092556ae1fd433024f0ac228be0d6e8d58c77f9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Wed, 5 Sep 2018 13:09:57 +0900 Subject: [PATCH] [enco] Decide pooling method from spec (#1339) This commit revises PoolingSpec to correctly decide pooling method from a model specification given as prototxt. Signed-off-by: Jonghyun Park --- contrib/enco/frontend/caffe/src/PoolingSpec.cpp | 21 ++++++++++ contrib/enco/frontend/caffe/src/PoolingSpec.h | 9 ++++ .../enco/frontend/caffe/src/PoolingSpec.test.cpp | 48 ++++++++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/contrib/enco/frontend/caffe/src/PoolingSpec.cpp b/contrib/enco/frontend/caffe/src/PoolingSpec.cpp index 7c5915e..232ba4d 100644 --- a/contrib/enco/frontend/caffe/src/PoolingSpec.cpp +++ b/contrib/enco/frontend/caffe/src/PoolingSpec.cpp @@ -1,5 +1,6 @@ #include "PoolingSpec.h" +#include #include PoolingSpec::PoolingSpec(const ::caffe::PoolingParameter ¶m) : _param(param) @@ -7,6 +8,26 @@ PoolingSpec::PoolingSpec(const ::caffe::PoolingParameter ¶m) : _param(param) // DO NOTHING } +PoolingMethod PoolingSpec::method(void) const +{ + if (!_param.has_pool()) + { + // Default pooling method is MAX + // Reference: http://caffe.berkeleyvision.org/tutorial/layers/pooling.html + return PoolingMethod::Max; + } + + std::map<::caffe::PoolingParameter_PoolMethod, PoolingMethod> methods; + + // NOTE STOCHASTIC Pooling is not supported, yet + // TODO Support STOCHASTIC Pooling + methods[::caffe::PoolingParameter_PoolMethod_MAX] = PoolingMethod::Max; + methods[::caffe::PoolingParameter_PoolMethod_AVE] = PoolingMethod::Avg; + + assert(_param.has_pool()); + return methods.at(_param.pool()); +} + uint32_t PoolingSpec::window_height(void) const { // NOTE Global pooling is not supported, yet diff --git a/contrib/enco/frontend/caffe/src/PoolingSpec.h b/contrib/enco/frontend/caffe/src/PoolingSpec.h index d2ffa3a..13fd67e 100644 --- a/contrib/enco/frontend/caffe/src/PoolingSpec.h +++ b/contrib/enco/frontend/caffe/src/PoolingSpec.h @@ -5,6 +5,12 @@ #include +enum class PoolingMethod +{ + Max, + Avg +}; + class PoolingSpec { public: @@ -15,6 +21,9 @@ public: void ifm_shape(const nncc::core::ADT::tensor::Shape &shape) { _ifm_shape = shape; } public: + PoolingMethod method(void) const; + +public: uint32_t window_height(void) const; uint32_t window_width(void) const; diff --git a/contrib/enco/frontend/caffe/src/PoolingSpec.test.cpp b/contrib/enco/frontend/caffe/src/PoolingSpec.test.cpp index 0372896..c1e3ad6 100644 --- a/contrib/enco/frontend/caffe/src/PoolingSpec.test.cpp +++ b/contrib/enco/frontend/caffe/src/PoolingSpec.test.cpp @@ -23,6 +23,18 @@ using nncc::foundation::make_unique; namespace { +bool from_txt(std::istream &is, ::caffe::PoolingParameter &pooling) +{ + ::google::protobuf::io::IstreamInputStream iis{&is}; + return google::protobuf::TextFormat::Parse(&iis, &pooling); +} + +template bool from_txt(const std::string &txt, T &out) +{ + std::stringstream ss{txt}; + return from_txt(ss, out); +} + class SequentialBuilder { public: @@ -192,3 +204,39 @@ TEST_F(PoolingSpecTest, stride_for_all) ASSERT_EQ(expected, obtained); } } + +TEST_F(PoolingSpecTest, method_none) +{ + const char *prototxt = ""; + + ::caffe::PoolingParameter param; + from_txt(prototxt, param); + + PoolingSpec spec{param}; + + ASSERT_EQ(spec.method(), PoolingMethod::Max); +} + +TEST_F(PoolingSpecTest, method_max) +{ + const char *prototxt = "pool: MAX"; + + ::caffe::PoolingParameter param; + from_txt(prototxt, param); + + PoolingSpec spec{param}; + + ASSERT_EQ(spec.method(), PoolingMethod::Max); +} + +TEST_F(PoolingSpecTest, method_avg) +{ + const char *prototxt = "pool: AVE"; + + ::caffe::PoolingParameter param; + from_txt(prototxt, param); + + PoolingSpec spec{param}; + + ASSERT_EQ(spec.method(), PoolingMethod::Avg); +} -- 2.7.4